The following is a Random Forest model that aims to predict the next day's white blood cell counts (WCC) for a patient in hospital using the patients most up-to-date WCC results (regression) as well as determining whether those results will be within normal range and whether another blood test is recommended (classification).
See here for a (8-minute) presentation on the context of the problem, an overview of the solution, and the intended future development and deployment.
This model was created as part of the 2022 IntelliHQ x ANZICS Healthcare Datathon in Australia and got the 1st place at the regional level (Victoria) and the 2nd place at the national level. This is a collaborative work between:
Important: The following is a work in progress and does not constitute the finalised tool. The code is presented as it was developed from inception (October 15th, 2022) and until the day of the national finals (October 20th, 2022). If you're interested in this problem and prototype, feel free to reach out!
import sys
import os
sys.path.append("../")
os.environ['AWS_STS_REGIONAL_ENDPOINTS']='regional'
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from utils.glue import Catalogue
from datetime import timedelta
from datetime import datetime
from utils.athena_query import Athena_Query
catalogue = Catalogue()
catalogue.database_names
database_name = catalogue.database_names[0]
database = catalogue.database(database_name)
table_name = database.table_names[10]
table = database.table(table_name)
athena = Athena_Query()
table_name = 'labevents'
athena1 = Athena_Query(database='mimic_iv')
athena1.database = 'mimic_iv'
#itemid 51301 is the main test for WCC
labevents = athena1.query_as_pandas(f"SELECT * FROM {table_name} WHERE itemid IN (51301);")
#Extract age
age = athena1.query_as_pandas(f"SELECT subject_id, anchor_age FROM {'patients'}")
#Extract albumin infusion (to be converted to binary any administration in last 24 hours)
alb_products = athena1.query_as_pandas(f"SELECT stay_id, starttime FROM {'inputevents'} WHERE itemid IN (220862, 220861, 220863, 220864) and amount is not null;")
#Import patient identifiers
icustays = athena.query_as_pandas(f"SELECT * FROM {database_name}.icustays;")
icustays.head()
alb_products['starttime'] = pd.to_datetime(alb_products['starttime'])
labevents[['charttime','storetime']] = labevents[['charttime','storetime']].apply(pd.to_datetime)
icustays[['intime','outtime']] = icustays[['intime','outtime']].apply(pd.to_datetime)
merged = labevents.merge(icustays, on=['subject_id','hadm_id'], how='right')
icubloods = merged[(merged['intime'] <= merged['charttime']) & (merged['charttime'] <= merged['outtime'])]
icubloods['time_since_icu'] = icubloods['charttime'] - icubloods['intime']
icubloods.reset_index(inplace=True, drop=True)
icubloods.head(4)
#Determining time from first test
ftt = icubloods.merge(icubloods.groupby(['hadm_id','stay_id','itemid'])['charttime'].min().rename('first_test_time'), on=['hadm_id','stay_id','itemid'], how='outer')['first_test_time']
icubloods['time_since_first_test'] = icubloods['charttime'] - ftt
icubloods.head()
# Perform resampling to create time series of WCC from patient data and forward-fill missing values
resamplingHours = 2
icuresampled = icubloods.head(10000).groupby(['hadm_id','stay_id','itemid']).resample(timedelta(hours=resamplingHours), on = 'time_since_first_test').last().ffill()
icuresampled.drop(columns=['time_since_first_test'], inplace=True)
icuresampled.drop(columns = ['hadm_id','stay_id', 'itemid'], inplace = True, errors='ignore')
icuresampled2 = icuresampled.reset_index(level=[3])
icuresampled2['hours_since_first_test'] = icuresampled2.apply(lambda x: x['time_since_first_test']/pd.Timedelta(hours=1), axis=1)
#Subset the data to patients with ICU stay longer than 1 day
fish = pd.DataFrame(icuresampled2.groupby(['hadm_id','stay_id', 'itemid'])['time_since_icu'].max() > timedelta(days=1))
fish.rename(columns = {'time_since_icu':'long_stay'},inplace = True)
fish = fish.merge(icuresampled2, how ='left', on = ['hadm_id','stay_id', 'itemid'])
longfish = fish[fish.long_stay]
#Visualise the first 10,000 entries for the time series, each line is a patient
fig, (ax1) = plt.subplots(1, 1, figsize=(18,16))
sns.lineplot(data = longfish.head(10000), x = 'hours_since_first_test', y='value', hue='stay_id', palette="tab10", ax=ax1)
ax1.legend([],[], frameon=False)
ax1.set(
xlabel ='ICU stay length (hours)',
ylabel = 'White Blood Cell (K/ul)',
title = 'Visualisation of Time Series of WCC, first 10,000 entries'
);
#Add lagged values to create serial dependence
for i in range(6,42): # 84 - 12h prior
longfish[f'value_lag{i+1}'] = longfish.groupby(['itemid', 'stay_id', 'hadm_id'], observed=True)['valuenum'].shift(i+1)
longfish.reset_index(level=[1], inplace=True)
pre_X = longfish.drop(columns =['long_stay',
'time_since_first_test','labevent_id','specimen_id','storetime','valuenum','valueuom','ref_range_lower','ref_range_upper','flag',
'priority','comments','first_careunit','last_careunit','outtime','los','time_since_icu','hours_since_first_test'])
pre_X = pre_X.merge(age, how = 'left', on ='subject_id')
pre_X = pre_X.drop(columns =['subject_id', 'charttime'])
pre_X = pre_X.merge(alb_products, on = 'stay_id', how= 'left')
pre_X['alb'] = pre_X['starttime'] - pre_X['intime'] < timedelta(days=1)
pre_X = pre_X.drop(columns =['intime', 'starttime'], errors='ignore')
pre_X.fillna(value=-1, inplace = True)
pre_X.head()
#Remove unrequired columns from training cols
X = pre_X.drop(columns = ['value', 'stay_id'])
y = pre_X['value']
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error, plot_roc_curve
from sklearn.model_selection import train_test_split
import time
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=123)
clf = RandomForestRegressor(n_estimators=200, max_depth=5, random_state=0)
start = time.time()
clf.fit(X_train, y_train)
print(time.time()-start)
y_pred = clf.predict(X_test)
r2 = r2_score(y_test, y_pred)
mse = mean_squared_error(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)
#print out metrics
print("r2: ", r2, '\n'
"mse: ", mse, '\n'
"mae: ", mae)
# Fit a classifer
y_train_clas = (y_train <= 11.0) & (y_train >= 2.0)
y_train_clas = [int(x) for x in y_train_clas]
y_test_clas = (y_test <= 11.0) & (y_test >= 2.0)
y_test_clas = [int(x) for x in y_test_clas]
clf_2 = RandomForestClassifier(n_estimators = 200, max_depth = 5, random_state=0)
start = time.time()
clf_2.fit(X_train, y_train_clas)
print(time.time()-start)
y_pred_clas = clf_2.predict(X_test)
y_pred_clas[13:25]
from sklearn.metrics import confusion_matrix
matrix = confusion_matrix(y_test_clas, y_pred_clas)
print(matrix)
round((matrix[1][0]/(len(y_test_clas)))*100,2)
# Greate a ROC curve (required trained classifier)
fig, ax = plt.subplots(1,1, figsize=(8.5,8))
roc_curve = plot_roc_curve(clf_2, X_test, y_test_clas, ax = ax)
ax.set_title('ROC Curve for White Blood Cell Count', fontsize = 24)
ax.set_xlabel('False Positive Rate', fontsize = 20)
ax.set_ylabel('True Positive Rate', fontsize = 20)
ax.legend(fontsize = 12)
plt.show()
# Combine regressor and classifier
pred_level = clf.predict(X_test)
pred_norm = clf_2.predict(X_test)
norm_proba = clf_2.predict_proba(X_test)
# Global feature importance for ALL training data
feature_scores = pd.Series(clf_2.feature_importances_, index=X_train.columns).sort_values(ascending=False)
f, ax = plt.subplots(figsize=(15, 10))
ax = sns.barplot(x=feature_scores, y=feature_scores.index)
ax.set_yticklabels(feature_scores.index)
ax.set_title('Feature Scores for White Blood Cell Count Model', fontsize = 24)
ax.set_xlabel('Feature Importance', fontsize = 20)
ax.set_ylabel('Feature', fontsize = 20)
plt.show()
# Example output
WBC_preds = pd.DataFrame({'pred_level':pred_level, 'pred_norm':pred_norm, 'confid_percent':np.amax(norm_proba, axis=1)*100})
WBC_preds = WBC_preds.replace(0, 'BLOODS REQUIRED').replace(1, 'BLOODS NOT REQUIRED')
WBC_preds.iloc[13:25]