A Decision Tool for Reducing Blood Tests in Hospitalised Patients

The following is a Random Forest model that aims to predict the next day's white blood cell counts (WCC) for a patient in hospital using the patients most up-to-date WCC results (regression) as well as determining whether those results will be within normal range and whether another blood test is recommended (classification).

See here for a (8-minute) presentation on the context of the problem, an overview of the solution, and the intended future development and deployment.

This model was created as part of the 2022 IntelliHQ x ANZICS Healthcare Datathon in Australia and got the 1st place at the regional level (Victoria) and the 2nd place at the national level. This is a collaborative work between:

Important: The following is a work in progress and does not constitute the finalised tool. The code is presented as it was developed from inception (October 15th, 2022) and until the day of the national finals (October 20th, 2022). If you're interested in this problem and prototype, feel free to reach out!

Importing libraries and utilities

In [61]:
import sys
import os
sys.path.append("../")
os.environ['AWS_STS_REGIONAL_ENDPOINTS']='regional'

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

from utils.glue import Catalogue
from datetime import timedelta 
from datetime import datetime
from utils.athena_query import Athena_Query
In [111]:
catalogue = Catalogue()
catalogue.database_names
Out[111]:
['mimic_iv']

Connect with the relevant database and import the data

In [63]:
database_name = catalogue.database_names[0]
database = catalogue.database(database_name)
In [64]:
table_name = database.table_names[10]
table = database.table(table_name)
In [65]:
athena = Athena_Query()
In [ ]:
table_name = 'labevents'

athena1 = Athena_Query(database='mimic_iv')
athena1.database = 'mimic_iv'

#itemid 51301 is the main test for WCC
labevents = athena1.query_as_pandas(f"SELECT * FROM {table_name} WHERE itemid IN (51301);")
In [67]:
#Extract age
age = athena1.query_as_pandas(f"SELECT subject_id, anchor_age FROM {'patients'}")
#Extract albumin infusion (to be converted to binary any administration in last 24 hours)
alb_products = athena1.query_as_pandas(f"SELECT stay_id, starttime FROM {'inputevents'} WHERE itemid IN (220862, 220861, 220863, 220864) and amount is not null;")
.........
In [68]:
#Import patient identifiers
icustays = athena.query_as_pandas(f"SELECT * FROM {database_name}.icustays;")
icustays.head()
...
Out[68]:
subject_id hadm_id stay_id first_careunit last_careunit intime outtime los
0 17233909 24700140 37084004 Trauma SICU (TSICU) Trauma SICU (TSICU) 2157-10-31 01:12:51 2157-11-02 15:48:08 2.607836
1 19185346 22127136 37084815 Cardiac Vascular Intensive Care Unit (CVICU) Cardiac Vascular Intensive Care Unit (CVICU) 2110-04-07 10:56:47 2110-04-08 13:39:36 1.113067
2 13498867 24046990 37084966 Coronary Care Unit (CCU) Trauma SICU (TSICU) 2151-07-16 13:40:11 2151-07-20 22:53:29 4.384236
3 17030600 25070620 37085562 Surgical Intensive Care Unit (SICU) Surgical Intensive Care Unit (SICU) 2110-11-19 06:24:00 2110-11-20 17:16:26 1.453079
4 12408912 27098428 37086015 Medical Intensive Care Unit (MICU) Medical Intensive Care Unit (MICU) 2146-02-22 22:19:51 2146-02-23 14:56:08 0.691863

Data preprocessing

In [69]:
alb_products['starttime'] = pd.to_datetime(alb_products['starttime'])
In [70]:
labevents[['charttime','storetime']] = labevents[['charttime','storetime']].apply(pd.to_datetime)
In [71]:
icustays[['intime','outtime']] = icustays[['intime','outtime']].apply(pd.to_datetime)
In [72]:
merged = labevents.merge(icustays, on=['subject_id','hadm_id'], how='right')
In [73]:
icubloods = merged[(merged['intime'] <= merged['charttime']) & (merged['charttime'] <= merged['outtime'])]
In [ ]:
icubloods['time_since_icu'] = icubloods['charttime'] - icubloods['intime']
In [75]:
icubloods.reset_index(inplace=True, drop=True)
In [76]:
icubloods.head(4)
Out[76]:
labevent_id subject_id hadm_id specimen_id itemid charttime storetime value valuenum valueuom ... flag priority comments stay_id first_careunit last_careunit intime outtime los time_since_icu
0 89730556.0 17233909 24700140.0 53888401.0 51301.0 2157-10-31 03:12:00 2157-10-31 03:35:00 4.3 4.3 K/uL ... NaN STAT NaN 37084004 Trauma SICU (TSICU) Trauma SICU (TSICU) 2157-10-31 01:12:51 2157-11-02 15:48:08 2.607836 0 days 01:59:09
1 89730575.0 17233909 24700140.0 2987617.0 51301.0 2157-11-01 03:55:00 2157-11-01 04:28:00 4.5 4.5 K/uL ... NaN ROUTINE NaN 37084004 Trauma SICU (TSICU) Trauma SICU (TSICU) 2157-10-31 01:12:51 2157-11-02 15:48:08 2.607836 1 days 02:42:09
2 89730628.0 17233909 24700140.0 73099342.0 51301.0 2157-11-02 03:51:00 2157-11-02 05:01:00 10.8 10.8 K/uL ... abnormal STAT NaN 37084004 Trauma SICU (TSICU) Trauma SICU (TSICU) 2157-10-31 01:12:51 2157-11-02 15:48:08 2.607836 2 days 02:38:09
3 113963710.0 19185346 22127136.0 35549681.0 51301.0 2110-04-07 11:35:00 2110-04-07 11:46:00 18.7 18.7 K/uL ... abnormal STAT NaN 37084815 Cardiac Vascular Intensive Care Unit (CVICU) Cardiac Vascular Intensive Care Unit (CVICU) 2110-04-07 10:56:47 2110-04-08 13:39:36 1.113067 0 days 00:38:13

4 rows × 22 columns

In [ ]:
#Determining time from first test
ftt = icubloods.merge(icubloods.groupby(['hadm_id','stay_id','itemid'])['charttime'].min().rename('first_test_time'), on=['hadm_id','stay_id','itemid'], how='outer')['first_test_time']
icubloods['time_since_first_test'] = icubloods['charttime'] - ftt
In [112]:
icubloods.head()
Out[112]:
labevent_id subject_id hadm_id specimen_id itemid charttime storetime value valuenum valueuom ... priority comments stay_id first_careunit last_careunit intime outtime los time_since_icu time_since_first_test
0 89730556.0 17233909 24700140.0 53888401.0 51301.0 2157-10-31 03:12:00 2157-10-31 03:35:00 4.3 4.3 K/uL ... STAT NaN 37084004 Trauma SICU (TSICU) Trauma SICU (TSICU) 2157-10-31 01:12:51 2157-11-02 15:48:08 2.607836 0 days 01:59:09 0 days 00:00:00
1 89730575.0 17233909 24700140.0 2987617.0 51301.0 2157-11-01 03:55:00 2157-11-01 04:28:00 4.5 4.5 K/uL ... ROUTINE NaN 37084004 Trauma SICU (TSICU) Trauma SICU (TSICU) 2157-10-31 01:12:51 2157-11-02 15:48:08 2.607836 1 days 02:42:09 1 days 00:43:00
2 89730628.0 17233909 24700140.0 73099342.0 51301.0 2157-11-02 03:51:00 2157-11-02 05:01:00 10.8 10.8 K/uL ... STAT NaN 37084004 Trauma SICU (TSICU) Trauma SICU (TSICU) 2157-10-31 01:12:51 2157-11-02 15:48:08 2.607836 2 days 02:38:09 2 days 00:39:00
3 113963710.0 19185346 22127136.0 35549681.0 51301.0 2110-04-07 11:35:00 2110-04-07 11:46:00 18.7 18.7 K/uL ... STAT NaN 37084815 Cardiac Vascular Intensive Care Unit (CVICU) Cardiac Vascular Intensive Care Unit (CVICU) 2110-04-07 10:56:47 2110-04-08 13:39:36 1.113067 0 days 00:38:13 0 days 00:00:00
4 113963737.0 19185346 22127136.0 21501963.0 51301.0 2110-04-07 12:17:00 2110-04-07 13:29:00 23.3 23.3 K/uL ... STAT NaN 37084815 Cardiac Vascular Intensive Care Unit (CVICU) Cardiac Vascular Intensive Care Unit (CVICU) 2110-04-07 10:56:47 2110-04-08 13:39:36 1.113067 0 days 01:20:13 0 days 00:42:00

5 rows × 23 columns

In [78]:
# Perform resampling to create time series of WCC from patient data and forward-fill missing values
resamplingHours = 2
icuresampled = icubloods.head(10000).groupby(['hadm_id','stay_id','itemid']).resample(timedelta(hours=resamplingHours), on = 'time_since_first_test').last().ffill()
icuresampled.drop(columns=['time_since_first_test'], inplace=True)
In [79]:
icuresampled.drop(columns = ['hadm_id','stay_id', 'itemid'], inplace = True, errors='ignore')
In [80]:
icuresampled2 = icuresampled.reset_index(level=[3])
In [81]:
icuresampled2['hours_since_first_test'] = icuresampled2.apply(lambda x: x['time_since_first_test']/pd.Timedelta(hours=1), axis=1)
In [82]:
#Subset the data to patients with ICU stay longer than 1 day
fish = pd.DataFrame(icuresampled2.groupby(['hadm_id','stay_id', 'itemid'])['time_since_icu'].max() > timedelta(days=1))
fish.rename(columns = {'time_since_icu':'long_stay'},inplace = True)
In [83]:
fish = fish.merge(icuresampled2, how ='left', on = ['hadm_id','stay_id', 'itemid'])
In [84]:
longfish = fish[fish.long_stay]
In [85]:
#Visualise the first 10,000 entries for the time series, each line is a patient
fig, (ax1) = plt.subplots(1, 1, figsize=(18,16))
sns.lineplot(data = longfish.head(10000), x = 'hours_since_first_test', y='value', hue='stay_id', palette="tab10", ax=ax1)
ax1.legend([],[], frameon=False)
ax1.set(
        xlabel ='ICU stay length (hours)', 
        ylabel = 'White Blood Cell (K/ul)',
        title = 'Visualisation of Time Series of WCC, first 10,000 entries'
        );
In [ ]:
#Add lagged values to create serial dependence
for i in range(6,42): # 84 - 12h prior
    longfish[f'value_lag{i+1}'] = longfish.groupby(['itemid', 'stay_id', 'hadm_id'], observed=True)['valuenum'].shift(i+1)
In [87]:
longfish.reset_index(level=[1], inplace=True)

Prepare the data to train the model

In [88]:
pre_X = longfish.drop(columns =['long_stay',
 'time_since_first_test','labevent_id','specimen_id','storetime','valuenum','valueuom','ref_range_lower','ref_range_upper','flag',
 'priority','comments','first_careunit','last_careunit','outtime','los','time_since_icu','hours_since_first_test'])
In [89]:
pre_X = pre_X.merge(age, how = 'left', on ='subject_id')
In [90]:
pre_X = pre_X.drop(columns =['subject_id', 'charttime'])
In [91]:
pre_X = pre_X.merge(alb_products, on = 'stay_id', how= 'left')
In [92]:
pre_X['alb'] = pre_X['starttime'] - pre_X['intime'] < timedelta(days=1)
In [93]:
pre_X = pre_X.drop(columns =['intime', 'starttime'], errors='ignore')
In [94]:
pre_X.fillna(value=-1, inplace = True)
pre_X.head()
Out[94]:
stay_id value value_lag7 value_lag8 value_lag9 value_lag10 value_lag11 value_lag12 value_lag13 value_lag14 ... value_lag35 value_lag36 value_lag37 value_lag38 value_lag39 value_lag40 value_lag41 value_lag42 anchor_age alb
0 38291712 9.5 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 ... -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 67 False
1 38291712 9.5 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 ... -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 67 False
2 38291712 9.5 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 ... -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 67 False
3 38291712 7.7 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 ... -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 67 False
4 38291712 7.7 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 ... -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 67 False

5 rows × 40 columns

Train the model

In [95]:
#Remove unrequired columns from training cols
X = pre_X.drop(columns = ['value', 'stay_id'])
y = pre_X['value']
In [96]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error, plot_roc_curve
from sklearn.model_selection import train_test_split
import time
In [97]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=123)
In [98]:
clf = RandomForestRegressor(n_estimators=200, max_depth=5, random_state=0)
start = time.time()
clf.fit(X_train, y_train)
print(time.time()-start)
138.63172268867493
In [99]:
y_pred = clf.predict(X_test)
r2 = r2_score(y_test, y_pred)
mse = mean_squared_error(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)
In [100]:
#print out metrics
print("r2: ", r2, '\n'
      "mse: ", mse, '\n'
      "mae: ", mae)
r2:  0.8640473234871016 
mse:  14.30193593398133 
mae:  2.0901265023908504
In [101]:
# Fit a classifer
y_train_clas = (y_train <= 11.0) & (y_train >= 2.0)
y_train_clas = [int(x) for x in y_train_clas]

y_test_clas = (y_test <= 11.0) & (y_test >= 2.0)
y_test_clas = [int(x) for x in y_test_clas]

clf_2 = RandomForestClassifier(n_estimators = 200, max_depth = 5, random_state=0)
start = time.time()
clf_2.fit(X_train, y_train_clas)
print(time.time()-start)
27.226195335388184
In [102]:
y_pred_clas = clf_2.predict(X_test)
y_pred_clas[13:25]
Out[102]:
array([1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0])

Evaluate the model

In [103]:
from sklearn.metrics import confusion_matrix

matrix = confusion_matrix(y_test_clas, y_pred_clas)
print(matrix)
[[28038  4597]
 [ 2839 31553]]
In [104]:
round((matrix[1][0]/(len(y_test_clas)))*100,2)
Out[104]:
4.24
In [105]:
# Greate a ROC curve (required trained classifier)
fig, ax = plt.subplots(1,1, figsize=(8.5,8))
roc_curve = plot_roc_curve(clf_2, X_test, y_test_clas, ax = ax)
ax.set_title('ROC Curve for White Blood Cell Count', fontsize = 24)
ax.set_xlabel('False Positive Rate', fontsize = 20)
ax.set_ylabel('True Positive Rate', fontsize = 20)
ax.legend(fontsize = 12)
plt.show()
In [106]:
# Combine regressor and classifier
pred_level = clf.predict(X_test)
pred_norm  = clf_2.predict(X_test)
norm_proba = clf_2.predict_proba(X_test)
In [107]:
# Global feature importance for ALL training data
feature_scores = pd.Series(clf_2.feature_importances_, index=X_train.columns).sort_values(ascending=False)
f, ax = plt.subplots(figsize=(15, 10))
ax = sns.barplot(x=feature_scores, y=feature_scores.index)
ax.set_yticklabels(feature_scores.index)
ax.set_title('Feature Scores for White Blood Cell Count Model', fontsize = 24)
ax.set_xlabel('Feature Importance', fontsize = 20)
ax.set_ylabel('Feature', fontsize = 20)
plt.show()
In [108]:
# Example output
WBC_preds = pd.DataFrame({'pred_level':pred_level, 'pred_norm':pred_norm, 'confid_percent':np.amax(norm_proba, axis=1)*100})
WBC_preds = WBC_preds.replace(0, 'BLOODS REQUIRED').replace(1, 'BLOODS NOT REQUIRED')
WBC_preds.iloc[13:25]
Out[108]:
pred_level pred_norm confid_percent
13 8.134128 BLOODS NOT REQUIRED 95.693072
14 17.009877 BLOODS REQUIRED 98.627366
15 11.962857 BLOODS REQUIRED 72.852535
16 14.409101 BLOODS REQUIRED 81.519891
17 11.724305 BLOODS REQUIRED 51.125276
18 18.943167 BLOODS REQUIRED 98.242839
19 8.134128 BLOODS NOT REQUIRED 95.693072
20 13.057116 BLOODS REQUIRED 59.561591
21 15.293528 BLOODS REQUIRED 89.281453
22 22.503490 BLOODS REQUIRED 98.373760
23 15.374689 BLOODS REQUIRED 95.133055
24 17.009877 BLOODS REQUIRED 81.558589
In [109]:
 
In [ ]: