Dealing with Imbalanced Data in HealthTech ML/AI – 1. Stroke Prediction

  • A stroke is caused when blood flow to a part of the brain is stopped abruptly. Early Warning Systems (EWS) can significantly carry valuable information for the prediction of stroke and promoting a healthy life. 
  • In this post, we will discuss machine learning (ML) to design a robust stroke risk management system.
  • The key challenge is to deal with the highly imbalanced stroke data using  imblearn library.
  • This library contains data resampling techniques that are divided in the following 4 categories:
  • Under-sampling the majority class(es).
  • Over-sampling the minority class.
  • Combining over- and under-sampling.
  • Create ensemble balanced sets.

Specifically, we will compare the (1) SMOTE-balanced Torch NN (viz. the Cross-Entropy Adam Optimizer) against the (2) Sinnott’s Python algorithm from scikit-learn to be validated by various scikit-learn metrics, such as AUC, precision, recall, F-measure and accuracy. 

Table of Contents

  1. Data Analysis
  2. Torch ANN Training
  3. SciKit-Learn Training
  4. In-Depth QC Analysis
  5. In-Depth Ada QC Insights
  6. Summary
  7. Explore More

Our Jupyter notebook and the entire Python project will be stored in the working directory STROKE

import os
os.chdir(‘STROKE’)
os. getcwd()

Data Analysis

Importing Libraries
import pandas as pd
import numpy as np
import seaborn as sns
import tensorflow as tf
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings(“ignore”)

Importing the Kaggle Stroke Dataset
dataset = pd.read_csv(‘healthcare-dataset-stroke-data.csv’)

dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5110 entries, 0 to 5109
Data columns (total 12 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   id                 5110 non-null   int64  
 1   gender             5110 non-null   object 
 2   age                5110 non-null   float64
 3   hypertension       5110 non-null   int64  
 4   heart_disease      5110 non-null   int64  
 5   ever_married       5110 non-null   object 
 6   work_type          5110 non-null   object 
 7   Residence_type     5110 non-null   object 
 8   avg_glucose_level  5110 non-null   float64
 9   bmi                4909 non-null   float64
 10  smoking_status     5110 non-null   object 
 11  stroke             5110 non-null   int64  
dtypes: float64(3), int64(4), object(5)
memory usage: 479.2+ KB

dataset.shape

(5110, 12)

dataset.isnull().sum()

id                     0
gender                 0
age                    0
hypertension           0
heart_disease          0
ever_married           0
work_type              0
Residence_type         0
avg_glucose_level      0
bmi                  201
smoking_status         0
stroke                 0
dtype: int64

Let’s replace missing values with median values

dataset.bmi.replace(to_replace=np.nan, value=dataset.bmi.median(), inplace=True)

dataset.describe().T

Input data descriptive statistics

Let’s look at correlations

dataset.corr()

Data correlations table

The data correlation heatmap is

corr = dataset.corr()

mask = np.triu(np.ones_like(corr, dtype=bool))

f, ax = plt.subplots(figsize=(11, 9))

cmap = sns.diverging_palette(230, 20, as_cmap=True)

sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3, center=0, square=True, annot=True,linewidths=.5, cbar_kws={“shrink”: .5})

Data correlation matrix

Let’s plot the gender column

print(dataset.gender.value_counts())
sns.set_theme(style=”darkgrid”)
ax = sns.countplot(data=dataset, x=”gender”)
plt.show()

Female    2994
Male      2115
Other        1
Name: gender, dtype: int64
Gender column

The hypertension column plot is

print(dataset.hypertension.value_counts())
sns.set_theme(style=”darkgrid”)
ax = sns.countplot(data=dataset, x=”hypertension”)
plt.show()

0    4612
1     498
Name: hypertension, dtype: int64
The hypertension column plot

The ever_married column plot is

print(dataset.ever_married.value_counts())
sns.set_theme(style=”darkgrid”)
ax = sns.countplot(data=dataset, x=”ever_married”)
plt.show()

Yes    3353
No     1757
Name: ever_married, dtype: int64
The ever_married column plot

The work_type column plot is

print(dataset.work_type.value_counts())
sns.set_theme(style=”darkgrid”)
ax = sns.countplot(data=dataset, x=”work_type”)
plt.show()

Private          2925
Self-employed     819
children          687
Govt_job          657
Never_worked       22
Name: work_type, dtype: int64
The work_type column plot

The Residence_type column plot is

print(dataset.Residence_type.value_counts())
sns.set_theme(style=”darkgrid”)
ax = sns.countplot(data=dataset, x=”Residence_type”)
plt.show()

Urban    2596
Rural    2514
Name: Residence_type, dtype: int64
The Residence_type column plot

The smoking_status column plot is

print(dataset.smoking_status.value_counts())
sns.set_theme(style=”darkgrid”)
ax = sns.countplot(data=dataset, x=”smoking_status”)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=10)
plt.tight_layout()
plt.show()

never smoked       1892
Unknown            1544
formerly smoked     885
smokes              789
Name: smoking_status, dtype: int64
The smoking_status column plot

The target variable plot is

print(dataset.stroke.value_counts())
sns.set_theme(style=”darkgrid”)
ax = sns.countplot(data=dataset, x=”stroke”)
plt.show()

0    4861
1     249
Name: stroke, dtype: int64
The target variable plot

Let’s look at avg_glucose_level

fig = plt.figure(figsize=(7,7))
sns.distplot(dataset.avg_glucose_level, color=”green”, label=”avg_glucose_level”, kde= True)
plt.legend()

Let’s plot the bmi histogram

fig = plt.figure(figsize=(7,7))
sns.distplot(dataset.bmi, color=”orange”, label=”bmi”, kde= True)
plt.legend()

BMI histogram

Let’s compare No Stroke vs Stroke by BMI

plt.figure(figsize=(12,10))

sns.distplot(dataset[dataset[‘stroke’] == 0][“bmi”], color=’green’) # No Stroke – green
sns.distplot(dataset[dataset[‘stroke’] == 1][“bmi”], color=’red’) # Stroke – Red

plt.title(‘No Stroke vs Stroke by BMI’, fontsize=15)
plt.xlim([10,100])
plt.legend([‘No Stroke’,’Stroke’])
plt.show()

Histograms No Stroke vs Stroke by BMI

Let’s compare No Stroke vs Stroke by Avg. Glucose Level

plt.figure(figsize=(12,10))

sns.distplot(dataset[dataset[‘stroke’] == 0][“avg_glucose_level”], color=’green’) # No Stroke – green
sns.distplot(dataset[dataset[‘stroke’] == 1][“avg_glucose_level”], color=’red’) # Stroke – Red

plt.title(‘No Stroke vs Stroke by Avg. Glucose Level’, fontsize=15)
plt.legend([‘No Stroke’,’Stroke’])
plt.xlim([30,330])
plt.show()

Histograms No Stroke vs Stroke by Avg. Glucose Level

Let’s compare No Stroke vs Stroke by Age

plt.figure(figsize=(12,10))

sns.distplot(dataset[dataset[‘stroke’] == 0][“age”], color=’green’) # No Stroke – green
sns.distplot(dataset[dataset[‘stroke’] == 1][“age”], color=’red’) # Stroke – Red

plt.title(‘No Stroke vs Stroke by Age’, fontsize=15)
plt.legend([‘No Stroke’,’Stroke’])
plt.xlim([18,100])
plt.show()

Compare No Stroke vs Stroke by Age

Let’s plot violin plots of our model features vs target variable

plt.figure(figsize=(13,13))
sns.set_theme(style=”darkgrid”)
plt.subplot(2,3,1)
sns.violinplot(x = ‘gender’, y = ‘stroke’, data = dataset)
plt.subplot(2,3,2)
sns.violinplot(x = ‘hypertension’, y = ‘stroke’, data = dataset)
plt.subplot(2,3,3)
sns.violinplot(x = ‘heart_disease’, y = ‘stroke’, data = dataset)
plt.subplot(2,3,4)
sns.violinplot(x = ‘ever_married’, y = ‘stroke’, data = dataset)
plt.subplot(2,3,5)
sns.violinplot(x = ‘work_type’, y = ‘stroke’, data = dataset)
plt.xticks(fontsize=9, rotation=45)
plt.subplot(2,3,6)
sns.violinplot(x = ‘Residence_type’, y = ‘stroke’, data = dataset)
plt.show()

Model features vs target variable  violin plots
Model features vs target variable  violin plots

Torch ANN Training

Importing libraries

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
plt.style.use(‘ggplot’)
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
import plotly.figure_factory as ff
from plotly.subplots import make_subplots

import warnings
warnings.filterwarnings(‘ignore’)
%matplotlib inline

from sklearn.preprocessing import (StandardScaler,
LabelEncoder,
OneHotEncoder)
from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.metrics import (classification_report, accuracy_score,
auc,
precision_score,
recall_score,
f1_score,
roc_auc_score,
confusion_matrix)
from sklearn.model_selection import (GridSearchCV,
StratifiedKFold,
cross_val_score)

from sklearn.decomposition import PCA

import pylab as pl

from imblearn.datasets import make_imbalance
from imblearn.under_sampling import (RandomUnderSampler,
ClusterCentroids,
TomekLinks,
NeighbourhoodCleaningRule,
EditedNearestNeighbours,
NearMiss)

from imblearn.over_sampling import (SMOTE,
ADASYN)

from sklearn.ensemble import (RandomForestClassifier,
AdaBoostClassifier,
GradientBoostingClassifier)
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier

from sklearn.tree import ExtraTreeClassifier
from sklearn.svm import OneClassSVM

Reading the csv file in variable

df = pd.read_csv(‘healthcare-dataset-stroke-data.csv’)

df.dtypes

id                     int64
gender                object
age                  float64
hypertension           int64
heart_disease          int64
ever_married          object
work_type             object
Residence_type        object
avg_glucose_level    float64
bmi                  float64
smoking_status        object
stroke                 int64
dtype: object

df[‘bmi’].nunique()

418

Handling missing values
df[‘bmi’].fillna(df[‘bmi’].mean(), inplace=True)
df.isnull().sum()

id                   0
gender               0
age                  0
hypertension         0
heart_disease        0
ever_married         0
work_type            0
Residence_type       0
avg_glucose_level    0
bmi                  0
smoking_status       0
stroke               0
dtype: int64

Let’s prepare data for visualizations of columns with unique values

list_col=[‘smoking_status’,’work_type’,’Residence_type’,’gender’]

for col in list_col:
print(‘{} :{} ‘ . format(col.upper(),df[col].unique()))

SMOKING_STATUS :['formerly smoked' 'never smoked' 'smokes' 'Unknown'] 
WORK_TYPE :['Private' 'Self-employed' 'Govt_job' 'children' 'Never_worked'] 
RESIDENCE_TYPE :['Urban' 'Rural'] 
GENDER :['Male' 'Female' 'Other'] 

Let’s look at the avg_glucose_level boxplot

fig = px.box(data_frame = df,
x = “avg_glucose_level”,
width = 800,
height = 300)
fig.update_layout({“template”:”plotly_dark”})
fig.show()

The avg_glucose_level boxplot

Binning of numerical variables

df[‘bmi_cat’] = pd.cut(df[‘bmi’], bins = [0, 19, 25,30,10000], labels = [‘Underweight’, ‘Ideal’, ‘Overweight’, ‘Obesity’])
df[‘age_cat’] = pd.cut(df[‘age’], bins = [0,13,18, 45,60,200], labels = [‘Children’, ‘Teens’, ‘Adults’,’Mid Adults’,’Elderly’])
df[‘glucose_cat’] = pd.cut(df[‘avg_glucose_level’], bins = [0,90,160,230,500], labels = [‘Low’, ‘Normal’, ‘High’, ‘Very High’])

Let’s create lists of categorical and continuous variables

cat_cols = [“gender”,”hypertension”,”heart_disease”,”ever_married”,”work_type”,”Residence_type”,”smoking_status”,”stroke”]
cont_cols = [“age”,”avg_glucose_level”,”bmi”]

Let’s plot the correlation matrix

import matplotlib
matplotlib.rc(‘xtick’, labelsize=20)
matplotlib.rc(‘ytick’, labelsize=20)
cr = df[cont_cols].corr()
plt.figure(figsize = (10,10))
sns.heatmap(cr,cmap=”viridis”, annot = True)
plt.show()

Correlation matrix heatmap

Let’s plot the BMI histogram

bmi = list(df[‘bmi’].values)
hist_data = [bmi]
group_labels = [“bmi”]
colors = [‘Red’]
fig = ff.create_distplot(hist_data,group_labels,show_hist = True,colors=colors)
fig.show()

the BMI histogram

Let’s check the gender value count

df[‘gender’].value_counts()

Female    2994
Male      2115
Other        1
Name: gender, dtype: int64

df.drop(df[df[‘gender’] == ‘Other’].index, inplace = True)
df[‘gender’].unique()

array(['Male', 'Female'], dtype=object)

print(“The shape before removing the BMI outliers : “,df.shape)
df.drop(df[df[‘bmi’] > 47].index, inplace = True)
print(“The shape after removing the BMI outliers : “,df.shape)

The shape before removing the BMI outliers :  (5109, 15)
The shape after removing the BMI outliers :  (4992, 15)

Let’s plot the BMI histogram after removing outliers

bmi = list(df[‘bmi’].values)
hist_data = [bmi]
group_labels = [“bmi”]
colors = [‘Red’]
fig = ff.create_distplot(hist_data,group_labels,show_hist = True,colors=colors)
fig.show()

The BMI histogram after removing outliers

Let’s perform Label Encoding of the categorical variables

from sklearn.preprocessing import LabelEncoder
object_cols = [“gender”,”ever_married”,”work_type”,”Residence_type”,”smoking_status”]
label_encoder = LabelEncoder()
for col in object_cols:
label_encoder.fit(df[col])
df[col] = label_encoder.transform(df[col])

df.drop([‘bmi_cat’, ‘age_cat’, ‘glucose_cat’], axis=1, inplace=True)

Using SMOTE
from imblearn.over_sampling import SMOTE
sampler = SMOTE(random_state = 42)
X = df.drop([‘stroke’],axis=1)
y = df[[‘stroke’]]
X,y= sampler.fit_resample(X,y[‘stroke’].values.ravel())
y = pd.DataFrame({‘stroke’:y})
sns.countplot(data = y, x = ‘stroke’, y= None)
plt.show()

Target variable after SMOTE

Joining back dataset
df = pd.concat([X,y],axis = 1)
df.head()

Input data after editing

df = df.sample(frac = 1)

Importing torch

import torch
import torch.nn as nn

and creating columns

cat_cols = [“gender”,”hypertension”,”heart_disease”,”ever_married”,”work_type”,”Residence_type”,”smoking_status”]
cont_cols = [“age”,”avg_glucose_level”,”bmi”]
y_col = [“stroke”]

for cat in cat_cols:
df[cat] = df[cat].astype(‘category’)

Stacking the categorical columns
cats = np.stack([df[col].cat.codes.values for col in cat_cols], 1)

Converting the stack into tensor
cats = torch.tensor(cats, dtype = torch.int64)

Stacking the continuous columns & converting to tensor
conts = np.stack([df[col].values for col in cont_cols], 1)
conts = torch.tensor(conts, dtype=torch.float)

Converting target variable to tensor and flattening since CrossEntropyLoss expects a 1-d tensor:
y = torch.tensor(df[y_col].values).flatten()

print(cats.shape)
print(conts.shape)
print(y.shape)

torch.Size([9492, 7])
torch.Size([9492, 3])
torch.Size([9492])

cat_szs = [len(df[col].cat.categories) for col in cat_cols]
emb_szs = [(size, min(50, (size+1)//2)) for size in cat_szs]

Let’s introduce the class TabularModel

class TabularModel(nn.Module):

def __init__(self, emb_szs, n_cont, out_sz, layers, p=0.5):
    super().__init__()
    self.embeds = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs])
    self.emb_drop = nn.Dropout(p)
    self.bn_cont = nn.BatchNorm1d(n_cont)

    layerlist = []
    n_emb = sum((nf for ni,nf in emb_szs))
    n_in = n_emb + n_cont

    for i in layers:
        layerlist.append(nn.Linear(n_in,i)) 
        layerlist.append(nn.ReLU(inplace=True))
        layerlist.append(nn.BatchNorm1d(i))
        layerlist.append(nn.Dropout(p))
        n_in = i
    layerlist.append(nn.Linear(layers[-1],out_sz))
    self.layers = nn.Sequential(*layerlist)

def forward(self, x_cat, x_cont):
    embeddings = []
    for i,e in enumerate(self.embeds):
        embeddings.append(e(x_cat[:,i]))
    x = torch.cat(embeddings, 1)
    x = self.emb_drop(x)

    x_cont = self.bn_cont(x_cont)
    x = torch.cat([x, x_cont], 1)
    x = self.layers(x)
    return x

Let’s create the torch model

torch.manual_seed(42)
model = TabularModel(emb_szs, conts.shape[1], 2, [400,200,100], p=0.2)
model

TabularModel(
  (embeds): ModuleList(
    (0): Embedding(2, 1)
    (1): Embedding(2, 1)
    (2): Embedding(2, 1)
    (3): Embedding(2, 1)
    (4): Embedding(5, 3)
    (5): Embedding(2, 1)
    (6): Embedding(4, 2)
  )
  (emb_drop): Dropout(p=0.2, inplace=False)
  (bn_cont): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layers): Sequential(
    (0): Linear(in_features=13, out_features=400, bias=True)
    (1): ReLU(inplace=True)
    (2): BatchNorm1d(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=400, out_features=200, bias=True)
    (5): ReLU(inplace=True)
    (6): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Dropout(p=0.2, inplace=False)
    (8): Linear(in_features=200, out_features=100, bias=True)
    (9): ReLU(inplace=True)
    (10): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Dropout(p=0.2, inplace=False)
    (12): Linear(in_features=100, out_features=2, bias=True)
  )
)

Let’s define the following input

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

batch_size = 9000
test_size = 492

cat_train = cats[:batch_size-test_size]
cat_test = cats[batch_size-test_size:batch_size]
con_train = conts[:batch_size-test_size]
con_test = conts[batch_size-test_size:batch_size]
y_train = y[:batch_size-test_size]
y_test = y[batch_size-test_size:batch_size]

Let’s train the model

import time
start_time = time.time()

epochs = 320
losses = []

for i in range(epochs):
i+=1
y_pred = model(cat_train, con_train)
loss = criterion(y_pred, y_train)
losses.append(loss)

if i%25 == 1:
    print(f'epoch: {i:3}  loss: {loss.item():10.8f}')

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f’epoch: {i:3} loss: {loss.item():10.8f}’)
print(f’\nDuration: {time.time() – start_time:.0f} seconds’)

epoch:   1  loss: 0.78527218
epoch:  26  loss: 0.36030686
epoch:  51  loss: 0.34402612
epoch:  76  loss: 0.33592215
epoch: 101  loss: 0.32512504
epoch: 126  loss: 0.31637788
epoch: 151  loss: 0.30911028
epoch: 176  loss: 0.30483150
epoch: 201  loss: 0.29670492
epoch: 226  loss: 0.29235491
epoch: 251  loss: 0.28425759
epoch: 276  loss: 0.27836165
epoch: 301  loss: 0.26823270
epoch: 320  loss: 0.26557177

Duration: 32 seconds

The corresponding loss vs epoch curve with error bars is as follows

Loss vs Epoch with error bars

The CE loss is given by

with torch.no_grad():
y_val = model(cat_test, con_test)
loss = criterion(y_val, y_test)
print(f’CE Loss: {loss:.8f}’)

CE Loss: 0.29475877

Let’s define the utility function

def Convert(a):
it = iter(a)
res_dct = dict(zip(it, it))
return res_dct

Let’s create our predictions

rows = 200
correct = 0
groundTruth = []
predictedValues = []
print(f'{“MODEL OUTPUT”:26} ARGMAX Y_TEST’)
for i in range(rows):
print(f'{str(y_val[i]):26} {y_val[i].argmax():^7}{y_test[i]:^7}’)
predictedValues.append(y_val[i].argmax().item())
groundTruth.append(y_test[i])
if y_val[i].argmax().item() == y_test[i]:
correct += 1

Let’s plot the classification report

from sklearn.metrics import f1_score
print(“The F1-score is :”, f1_score(groundTruth, predictedValues))
print(“\n”)
print(confusion_matrix(groundTruth, predictedValues))
print(“\n”)
print(classification_report(groundTruth, predictedValues))

The F1-score is : 0.8938053097345132


[[ 75  15]
 [  9 101]]


              precision    recall  f1-score   support

           0       0.89      0.83      0.86        90
           1       0.87      0.92      0.89       110

    accuracy                           0.88       200
   macro avg       0.88      0.88      0.88       200
weighted avg       0.88      0.88      0.88       200

Let’s plot the normalized confusion matrix

from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(groundTruth, predictedValues)
target_names=[‘No Stroke’,’Stroke’]

cmn = cm.astype(‘float’) / cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(cmn, annot=True, fmt=’.2f’, xticklabels=target_names, yticklabels=target_names)
plt.ylabel(‘Actual’)
plt.xlabel(‘Predicted’)
plt.show(block=False)

SciKit-Learn Training

Let’s prepare the input data for ML

df = pd.read_csv(‘healthcare-dataset-stroke-data.csv’)
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5110 entries, 0 to 5109
Data columns (total 12 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   id                 5110 non-null   int64  
 1   gender             5110 non-null   object 
 2   age                5110 non-null   float64
 3   hypertension       5110 non-null   int64  
 4   heart_disease      5110 non-null   int64  
 5   ever_married       5110 non-null   object 
 6   work_type          5110 non-null   object 
 7   Residence_type     5110 non-null   object 
 8   avg_glucose_level  5110 non-null   float64
 9   bmi                4909 non-null   float64
 10  smoking_status     5110 non-null   object 
 11  stroke             5110 non-null   int64  
dtypes: float64(3), int64(4), object(5)
memory usage: 479.2+ KB

df.dropna(subset = [‘bmi’], inplace = True)

text_cols = [‘gender’, ‘ever_married’, ‘work_type’,
‘Residence_type’, ‘smoking_status’]
[print(df[i].unique()) for i in df[text_cols]]

['Male' 'Female' 'Other']
['Yes' 'No']
['Private' 'Self-employed' 'Govt_job' 'children' 'Never_worked']
['Urban' 'Rural']
['formerly smoked' 'never smoked' 'smokes' 'Unknown']
[None, None, None, None, None]

df[‘gender’] = np.where(df[‘gender’] == ‘Male’, 1, 0)
df[‘ever_married’] = np.where(df[‘ever_married’] == ‘Yes’, 1, 0)
df[‘Residence_type’] = np.where(df[‘Residence_type’] == ‘Urban’, 1, 0)
df[‘smoking_status’] = np.where(((df[‘smoking_status’] == ‘smokes’) | (df[‘smoking_status’] == ‘formerly smoked’)), 1, 0)

Applying LabelEncoder

from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
df[‘work_type’] = le.fit_transform(df[‘work_type’])

and plotting the target variable

sns.displot(df.stroke)

Target variable as sns.displot

The corresponding value count is

df.stroke.value_counts()

0    4700
1     209
Name: stroke, dtype: int64

Let’s invoke sklearn.ensemble.IsolationForest with the contamination parameter that simply controls the threshold for the decision function when a scored data point should be considered an outlier. It has no impact on the model itself.

from sklearn.ensemble import IsolationForest
iso = IsolationForest(n_estimators = 1000, contamination = 0.03)
# the contamination value determines the outlier cut-off value
# we can adjust this value to ensure we do not further
# imbalance our target
outs = pd.Series(iso.fit_predict(df[[‘bmi’, ‘avg_glucose_level’]]),
name = ‘outliers’)
outs.value_counts()

 1    4761
-1     148
Name: outliers, dtype: int64

Let’s incorporate these outliers into the input Data Frame

df = pd.concat([outs.reset_index(), df.reset_index()], axis = 1,
ignore_index = False).drop(columns = ‘index’)
df = df[df[‘outliers’] == 1]
df[‘stroke’].value_counts()

0    4566
1     195
Name: stroke, dtype: int64

Let’s apply StandardScaler and split train/test data as 80:20

from sklearn.preprocessing import StandardScaler
ss = StandardScaler()
num_cols = [‘age’, ‘avg_glucose_level’, ‘bmi’]
df[num_cols] = ss.fit_transform(df[num_cols])
X = df.drop(columns = [‘stroke’, ‘id’, ‘outliers’])
y = df.stroke
X_train, X_test, y_train, y_test = train_test_split(X, y,
stratify =y)

Let’s apply RandomOverSampler and run RandomForestClassifier()

from imblearn.over_sampling import RandomOverSampler
rs = RandomOverSampler()
X, y = rs.fit_resample(X, y)
X_train, X_test, y_train, y_test = train_test_split(X, y,
stratify = y)
rf = RandomForestClassifier()
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
classification_eval(y_test, y_pred)

accuracy  = 0.997
precision = 0.994
recall    = 1.0
f1-score  = 0.997
roc auc   = 0.997
null accuracy = 0.5

This is a good result since accuracy > null accuracy.

Let’s apply SMOTE and run RandomForestClassifier()

from imblearn.over_sampling import SMOTE
smote = SMOTE()
X, y = smote.fit_resample(X, y)
X_train, X_test, y_train, y_test = train_test_split(X, y,
stratify = y)
rf = RandomForestClassifier()
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
classification_eval(y_test, y_pred)

accuracy  = 0.993
precision = 0.986
recall    = 1.0
f1-score  = 0.993
roc auc   = 0.993
null accuracy = 0.5

Let’s apply ADASYN and run RandomForestClassifier()

from imblearn.over_sampling import ADASYN
ada = ADASYN()
X, y = ada.fit_resample(X, y)
X_train, X_test, y_train, y_test = train_test_split(X, y,
stratify = y)
rf = RandomForestClassifier()
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
classification_eval(y_test, y_pred)

accuracy  = 0.994
precision = 0.988
recall    = 1.0
f1-score  = 0.994
roc auc   = 0.994
null accuracy = 0.5

Let’s apply SMOTENC and run RandomForestClassifier()

from imblearn.over_sampling import SMOTENC
sm = SMOTENC(categorical_features = [0, 2])
#here we have to define our categorical columns
X, y = sm.fit_resample(X, y)
X_train, X_test, y_train, y_test = train_test_split(X, y,
stratify = y)
rf = RandomForestClassifier()
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
classification_eval(y_test, y_pred)

accuracy  = 0.992
precision = 0.984
recall    = 1.0
f1-score  = 0.992
roc auc   = 0.992
null accuracy = 0.5

Let’s plot the normalized confusion matrix

from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(y_test, y_pred)
target_names=[‘No Stroke’,’Stroke’]

cmn = cm.astype(‘float’) / cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(cmn, annot=True, fmt=’.2f’, xticklabels=target_names, yticklabels=target_names)
plt.ylabel(‘Actual’)
plt.xlabel(‘Predicted’)
plt.show(block=False)
plt.savefig(‘stroke_confusionmatrix_best.png’)

Normalized confusion matrix:  SMOTENC and RandomForestClassifier()

Let’s plot the histogram of predicted probabilities

y_pred_prob = rf.predict_proba(X_test)
sns.displot(y_pred_prob[:, 1]);
plt.title(‘Histogram of predicted probabilities’);

Histogram of predicted probabilities.

Let’s plot top ranking features by importance

feat_names = [i for i in X_train]
classed = [i for i in y_train]
feat_import_df = pd.DataFrame({‘importances’: rf.feature_importances_,
‘name’: feat_names}).sort_values(‘importances’)
x = feat_import_df[‘importances’].tail(9)
y = feat_import_df[‘name’].tail(9)
sns.barplot(x = x, y = y).set_title(‘Top ranking features by importance’);

Top ranking features by importance

The RFC ROC curve is

Y_test_probs = rf.predict_proba(X_test)

skplt.metrics.plot_roc_curve(y_test, Y_test_probs,
title=”RF ROC Curve”, figsize=(12,6));

RFC ROC curve

The RFC precision-recall curve is

skplt.metrics.plot_precision_recall_curve(y_test, Y_test_probs,
title=”RF Precision-Recall Curve”, figsize=(12,6));

RFC precision-recall curve

The RFC classification report heatmap is

from yellowbrick.classifier import ClassificationReport
from sklearn.tree import DecisionTreeClassifier

target_names=[‘No Stroke’,’Stroke’]

viz = ClassificationReport(rf,
classes=target_names,
support=True,
fig=plt.figure(figsize=(8,6)))

viz.fit(X_train, y_train)

viz.score(X_test, y_test)

viz.show();

RFC classification report

In-Depth QC Analysis

Model 1: Torch ANN Training

from sklearn.metrics import cohen_kappa_score
cohen_kappa_score(groundTruth, predictedValues)

0.7560975609756098

from sklearn.metrics import hamming_loss
hamming_loss(groundTruth, predictedValues)

0.12

from sklearn.metrics import jaccard_score
jaccard_score(groundTruth, predictedValues)

0.808

from sklearn.metrics import matthews_corrcoef
matthews_corrcoef(groundTruth, predictedValues)

0.7575070874982563

Let’s plot the DTC ROC curves

ROC curves for Decision Tree Classifier

Similarly, the ROC curve for XGBClassifier() is

fig = plt.figure(figsize=(10,7))

roc_auc(XGBClassifier(),
cat_train,y_train,
cat_test,y_test,
classes=target_names
);

ROC curve for XGBClassifier()

Let’s plot learning curves by importing the key SciKit Plot libraries

import scikitplot as skplt

import sklearn

from sklearn.model_selection import train_test_split

from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier, ExtraTreesClassifier
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt

import sys
import warnings
warnings.filterwarnings(“ignore”)

print(“Scikit Plot Version : “, skplt.version)
print(“Scikit Learn Version : “, sklearn.version)
print(“Python Version : “, sys.version)

%matplotlib inline

Scikit Plot Version :  0.3.7
Scikit Learn Version :  1.2.2
Python Version :  3.9.16 (main, Jan 11 2023, 16:16:36) [MSC v.1916 64 bit (AMD64)]

The DTC learning curve is

skplt.estimators.plot_learning_curve(DecisionTreeClassifier(random_state=2), cat_train,y_train,
cv=7, shuffle=True, scoring=”accuracy”,
n_jobs=-1, figsize=(6,4), title_fontsize=”large”, text_fontsize=”large”,
title=”DT Classification Learning Curve”);

DTC learning curve

Similarly, the XGB learning curve is

skplt.estimators.plot_learning_curve(XGBClassifier(), cat_train,y_train,
cv=7, shuffle=True, scoring=”accuracy”,
n_jobs=-1, figsize=(6,4), title_fontsize=”large”, text_fontsize=”large”,
title=”XGB Classification Learning Curve”);

XGB learning curve

Let’s plot the DTC classification report

DTC classification report

Model 2: SciKit-Learn Training

Let’s look at the following scikit-learn performance metrics

from sklearn.metrics import cohen_kappa_score
cohen_kappa_score(y_test, y_pred)

0.9877300613496932

from sklearn.metrics import hamming_loss
hamming_loss(y_test, y_pred)

0.006134969325153374

from sklearn.metrics import matthews_corrcoef
matthews_corrcoef(y_test, y_pred)

0.9878044218151566

from sklearn.metrics import log_loss
log_loss(y_test, y_pred)

0.22112670790869438

from sklearn.metrics import jaccard_score
jaccard_score(y_test, y_pred)

0.9878787878787879

Let’s plot the DTC learning curve for our second (supervised ML) training model

skplt.estimators.plot_learning_curve(DecisionTreeClassifier(), X_train, y_train,
cv=7, shuffle=True, scoring=”accuracy”,
n_jobs=-1, figsize=(6,4), title_fontsize=”large”, text_fontsize=”large”,
title=”Decision Tree Classification Learning Curve”);

Supervised ML model: DTC learning curve

Similarly, the RFC learning curve is

skplt.estimators.plot_learning_curve(rf, X_train, y_train,
cv=7, shuffle=True, scoring=”accuracy”,
n_jobs=-1, figsize=(6,4), title_fontsize=”large”, text_fontsize=”large”,
title=”Random Forest Classification Learning Curve”);

RFC learning curve

The Logistic Regression (LR) Classification Learning Curve is

skplt.estimators.plot_learning_curve(LogisticRegression(), X_train, y_train,
cv=7, shuffle=True, scoring=”accuracy”,
n_jobs=-1, figsize=(6,4), title_fontsize=”large”, text_fontsize=”large”,
title=”Logistic Regression Classification Learning Curve”);

The Logistic Regression (LR) Classification Learning Curve

The SVC Classification Learning Curve is

from sklearn.svm import SVC
skplt.estimators.plot_learning_curve(SVC(), X_train, y_train,
cv=7, shuffle=True, scoring=”accuracy”,
n_jobs=-1, figsize=(6,4), title_fontsize=”large”, text_fontsize=”large”,
title=”SVC Classification Learning Curve”);

SVC Classification Learning Curve

from sklearn.ensemble import AdaBoostClassifier
skplt.estimators.plot_learning_curve(AdaBoostClassifier(), X_train, y_train,
cv=7, shuffle=True, scoring=”accuracy”,
n_jobs=-1, figsize=(6,4), title_fontsize=”large”, text_fontsize=”large”,
title=”Ada Classification Learning Curve”);

Ada Classification Learning Curve

The GNB Classification Learning Curve is

from sklearn.naive_bayes import GaussianNB
skplt.estimators.plot_learning_curve(GaussianNB(), X_train, y_train,
cv=7, shuffle=True, scoring=”accuracy”,
n_jobs=-1, figsize=(6,4), title_fontsize=”large”, text_fontsize=”large”,
title=”GNB Classification Learning Curve”);

The GNB Classification Learning Curve

The QDA Classification Learning Curve is

from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
skplt.estimators.plot_learning_curve(QuadraticDiscriminantAnalysis(), X_train, y_train,
cv=7, shuffle=True, scoring=”accuracy”,
n_jobs=-1, figsize=(6,4), title_fontsize=”large”, text_fontsize=”large”,
title=”QDA Classification Learning Curve”);

The QDA Classification Learning Curve

The KNN Classification Learning Curve is

from sklearn.neighbors import KNeighborsClassifier
skplt.estimators.plot_learning_curve(KNeighborsClassifier(), X_train, y_train,
cv=7, shuffle=True, scoring=”accuracy”,
n_jobs=-1, figsize=(6,4), title_fontsize=”large”, text_fontsize=”large”,
title=”KNN Classification Learning Curve”);

The KNN Classification Learning Curve

In-Depth Ada QC Insights

Based upon the comparison of learning curves discussed above, let’s evaluate further AdaBoostClassifier()

ada = AdaBoostClassifier()
ada.fit(X_train, y_train)
y_pred = ada.predict(X_test)
classification_eval(y_test, y_pred)

accuracy  = 0.799
precision = 0.755
recall    = 0.885
f1-score  = 0.815
roc auc   = 0.799
null accuracy = 0.5

from sklearn.metrics import cohen_kappa_score
cohen_kappa_score(y_test, y_pred)

0.5977212971078002

from sklearn.metrics import hamming_loss
hamming_loss(y_test, y_pred)

0.2011393514460999

from sklearn.metrics import matthews_corrcoef
matthews_corrcoef(y_test, y_pred)

0.6068345800502636

from sklearn.metrics import jaccard_score
jaccard_score(y_test, y_pred)

0.6875425459496256

from yellowbrick.classifier import ClassificationReport
from sklearn.tree import DecisionTreeClassifier

target_names=[‘No Stroke’,’Stroke’]

viz = ClassificationReport(ada,
classes=target_names,
support=True,
fig=plt.figure(figsize=(8,6)))

viz.fit(X_train, y_train)

viz.score(X_test, y_test)

viz.show();

Ada Classification Report

Let’s compare the Ada feature importance weights

feat_names = [i for i in X_train]
classed = [i for i in y_train]
feat_import_df = pd.DataFrame({‘importances’: ada.feature_importances_,
‘name’: feat_names}).sort_values(‘importances’)
x = feat_import_df[‘importances’].tail(9)
y = feat_import_df[‘name’].tail(9)
sns.barplot(x = x, y = y).set_title(‘Ada top ranking features by importance’);

Ada feature importance weights

Let’s plot the Ada histogram of predicted probabilities

y_pred_prob = ada.predict_proba(X_test)
sns.displot(y_pred_prob[:, 1]);
plt.title(‘Ada histogram of predicted probabilities’);

Ada histogram of predicted probabilities

The Ada confusion matrix is

from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(y_test, y_pred)
target_names=[‘No Stroke’,’Stroke’]

cmn = cm.astype(‘float’) / cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(cmn, annot=True, fmt=’.2f’, xticklabels=target_names, yticklabels=target_names)
plt.ylabel(‘Actual’)
plt.xlabel(‘Predicted’)
plt.show(block=False)

Ada confusion matrix

Let’s plot the Ada ROC curve

Y_test_probs = ada.predict_proba(X_test)

skplt.metrics.plot_roc_curve(y_test, Y_test_probs,
title=”Ada ROC Curve”, figsize=(12,6));

Ada ROC curve

The Ada Precision-Recall Curve is

skplt.metrics.plot_precision_recall_curve(y_test, Y_test_probs,
title=”Ada Precision-Recall Curve”, figsize=(12,6));

Ada Precision-Recall Curve

Summary

  • In this post, we have explored the use of several supervised ML and DL/NN models to predict the likelihood of medical patients suffering from stroke disease, a severe illness that affects the correct functioning of the human brain.
  • We have implemented and tested a hybrid ML/DL pipeline which can be used for imbalanced datasets. Previous studies have mainly focused on stroke prediction with balanced data.
  • We have evaluated our approach on a highly imbalanced data set with only 4.8% stroke cases.
  • After using data balancing techniques, the sensitivity and AUC considerably improved.
  • Feature importance scores have shown that Age, BMI, and Avg_Glucose_Level are the most important model features.
  • Our study suggests that ML/AI methods with data balancing techniques are effective tools for stroke prediction with imbalanced data.
  • The proposed approach can effectively increase sensitivity and specificity while maintaining accurate prediction using interpretable training models, indicating its potential to be clinically used in a fast and low-cost stroke diagnosis to minimize the disease’s sequels. 

Explore More

AI-Powered Stroke Prediction

The Power of AIHealth: Comparison of 12 ML Breast Cancer Classification Models

HealthTech ML/AI Use-Cases


One-Time
Monthly
Yearly

Make a one-time donation

Make a monthly donation

Make a yearly donation

Choose an amount

$5.00
$15.00
$100.00
$5.00
$15.00
$100.00
$5.00
$15.00
$100.00

Or enter a custom amount

$

Your contribution is appreciated.

Your contribution is appreciated.

Your contribution is appreciated.

DonateDonate monthlyDonate yearly
Advertisement

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: