In [160]:
import os 
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import sklearn
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.ensemble import  RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.metrics import classification_report, f1_score, accuracy_score, confusion_matrix
from sklearn.model_selection import cross_val_score
In [161]:
data = pd.read_csv('brain_stroke.csv')
data.head(10)
Out[161]:
gender age hypertension heart_disease ever_married work_type Residence_type avg_glucose_level bmi smoking_status stroke
0 Male 67.0 0 1 Yes Private Urban 228.69 36.6 formerly smoked 1
1 Male 80.0 0 1 Yes Private Rural 105.92 32.5 never smoked 1
2 Female 49.0 0 0 Yes Private Urban 171.23 34.4 smokes 1
3 Female 79.0 1 0 Yes Self-employed Rural 174.12 24.0 never smoked 1
4 Male 81.0 0 0 Yes Private Urban 186.21 29.0 formerly smoked 1
5 Male 74.0 1 1 Yes Private Rural 70.09 27.4 never smoked 1
6 Female 69.0 0 0 No Private Urban 94.39 22.8 never smoked 1
7 Female 78.0 0 0 Yes Private Urban 58.57 24.2 Unknown 1
8 Female 81.0 1 0 Yes Private Rural 80.43 29.7 never smoked 1
9 Female 61.0 0 1 Yes Govt_job Rural 120.46 36.8 smokes 1
In [162]:
data.isna().sum()
# 数据中不存在空值
Out[162]:
gender               0
age                  0
hypertension         0
heart_disease        0
ever_married         0
work_type            0
Residence_type       0
avg_glucose_level    0
bmi                  0
smoking_status       0
stroke               0
dtype: int64
In [163]:
data.duplicated().sum()
# 数据中不存在重复值
Out[163]:
0
In [164]:
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4981 entries, 0 to 4980
Data columns (total 11 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   gender             4981 non-null   object 
 1   age                4981 non-null   float64
 2   hypertension       4981 non-null   int64  
 3   heart_disease      4981 non-null   int64  
 4   ever_married       4981 non-null   object 
 5   work_type          4981 non-null   object 
 6   Residence_type     4981 non-null   object 
 7   avg_glucose_level  4981 non-null   float64
 8   bmi                4981 non-null   float64
 9   smoking_status     4981 non-null   object 
 10  stroke             4981 non-null   int64  
dtypes: float64(3), int64(3), object(5)
memory usage: 428.2+ KB
In [165]:
data.describe()
Out[165]:
age hypertension heart_disease avg_glucose_level bmi stroke
count 4981.000000 4981.000000 4981.000000 4981.000000 4981.000000 4981.000000
mean 43.419859 0.096165 0.055210 105.943562 28.498173 0.049789
std 22.662755 0.294848 0.228412 45.075373 6.790464 0.217531
min 0.080000 0.000000 0.000000 55.120000 14.000000 0.000000
25% 25.000000 0.000000 0.000000 77.230000 23.700000 0.000000
50% 45.000000 0.000000 0.000000 91.850000 28.100000 0.000000
75% 61.000000 0.000000 0.000000 113.860000 32.600000 0.000000
max 82.000000 1.000000 1.000000 271.740000 48.900000 1.000000

1. 探索性数据分析¶

  • 这个数据集几乎没有经过预处理,我丢弃了异常值和非常罕见的分类值。 我还删除了“id”列。 我建议对于这个数据集,删除小于 38 岁的“年龄”特征。

  • avg_glucose_level 值的前 25% 非常高

In [166]:
def create_comparison_graph(feature: str, bins=2, ticks=True):
    fig, ax = plt.subplots(1, 2, figsize=(7,4), sharey=True, constrained_layout=True)

    fig.suptitle('Stroke patient based on {}'.format(feature), fontsize=16)

    sns.histplot(data[data['stroke'] == 0][feature], bins=bins, ax=ax[0])
    ax[0].set_ylabel('Count')
    ax[0].set_xlabel('No Stroke')
    if bins == 2:
        ax[0].set_xticks([0,1])
        if ticks: ax[0].set_xticklabels(['No', 'Yes'])

    sns.histplot(data[data['stroke'] == 1][feature], bins=bins, ax=ax[1])
    ax[1].set_xlabel('Stroke')
    if bins == 2:
        ax[1].set_xticks([0,1])
        if ticks: ax[1].set_xticklabels(['No', 'Yes'])

    # fig.show() # 这行可以移除,在jupyter中会自动显示
In [167]:
sns.histplot(data=data,x="avg_glucose_level",kde=True)
Out[167]:
<Axes: xlabel='avg_glucose_level', ylabel='Count'>
In [168]:
sns.histplot(data=data,x="age")
#数据很可能是从大量人群中抽取的
Out[168]:
<Axes: xlabel='age', ylabel='Count'>
In [169]:
columns=data[["age","gender","stroke"]]
sns.pairplot(columns, hue="gender")
plt.show()
# 我们可以删除年龄特征
In [170]:
data = data.drop(["age"],axis=1)
data.head(1)
Out[170]:
gender hypertension heart_disease ever_married work_type Residence_type avg_glucose_level bmi smoking_status stroke
0 Male 0 1 Yes Private Urban 228.69 36.6 formerly smoked 1
In [171]:
# 绘制相关性矩阵
plt.figure(figsize=(15,10))
sns.heatmap(data.corr(numeric_only=True), annot=True, cmap="Blues")
Out[171]:
<Axes: >
In [172]:
#Relationship between stroke and avg_gluose_level - lower glucose level => lower chance of stroke
create_comparison_graph('avg_glucose_level',40)
In [173]:
#Relationship between stroke and gender - N/A
create_comparison_graph('gender',ticks=False)
In [174]:
#Relationship between stroke and heart_disease - no heart disease => lower chance of getting stroke
#The conclusion is not confounding  ref: https://www.cdc.gov/stroke/risk_factors.htm#:~:text=Heart%20disease,rich%20blood%20to%20the%20brain.
create_comparison_graph('heart_disease',ticks=False)
In [175]:
#Relationship between stroke and work_type - confounding?
create_comparison_graph('work_type',ticks=False)
In [176]:
columns = data[["heart_disease","avg_glucose_level","work_type"]]
sns.pairplot(columns,hue="work_type")
plt.show()
#work_type is confounding
In [177]:
#Relationship between stroke and residence - N/A
create_comparison_graph('Residence_type',ticks=False)
In [178]:
#Relationship between stroke and hypertension - lower hypertension => lower chance of stroke
create_comparison_graph('hypertension',ticks=False)
In [179]:
#Relationship between stroke and married status - yes
create_comparison_graph('ever_married',ticks=False)

2. 训练模型¶

In [180]:
#drop gender, residence, work_type columns
data = data.drop(["gender","Residence_type","work_type"],axis=1)
data.head(1)
Out[180]:
hypertension heart_disease ever_married avg_glucose_level bmi smoking_status stroke
0 0 1 Yes 228.69 36.6 formerly smoked 1
In [181]:
#convert non-object types to categorical values
encoder = LabelEncoder()
data['ever_married'] = encoder.fit_transform(data['ever_married'])
ever_married = {index : label for index, label in enumerate(encoder.classes_)}
data['smoking_status'] = encoder.fit_transform(data['smoking_status'])
smoking_status = {index : label for index, label in enumerate(encoder.classes_)}
In [182]:
x = data.drop('stroke',axis=1)
y = data['stroke']
In [183]:
scaler = MinMaxScaler(copy=True, feature_range=(0, 1))
X = scaler.fit_transform(x)
In [184]:
#train, test split
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.2,random_state=0)

Decison Tree Classifier¶

In [185]:
dt = DecisionTreeClassifier(max_depth=6)
dt.fit(x_train,y_train)
Out[185]:
DecisionTreeClassifier(max_depth=6)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeClassifier(max_depth=6)
In [186]:
y_predict = dt.predict(x_test)
In [187]:
print("Decision Tree Accuracy:")
print(accuracy_score(y_test,y_predict))
decision_tree_accuracy = accuracy_score(y_test,y_predict)
Decision Tree Accuracy:
0.9458375125376128
In [188]:
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# 计算决策树的准确度
print("Decision Tree Accuracy:")
decision_tree_accuracy = accuracy_score(y_test, y_predict)
print(decision_tree_accuracy)

# 计算 ROC AUC 分数
decision_tree_roc_auc = roc_auc_score(y_test, y_predict)
print("Decision Tree ROC AUC Score:")
print(decision_tree_roc_auc)

# 计算并绘制混淆矩阵
cm = confusion_matrix(y_test, y_predict)
sns.heatmap(cm, annot=True, fmt='d')
plt.title('Decision Tree - Confusion Matrix')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
Decision Tree Accuracy:
0.9458375125376128
Decision Tree ROC AUC Score:
0.4978880675818374

Random Forest Classifier¶

In [189]:
rf = RandomForestClassifier()
rf.fit(x_train,y_train)
Out[189]:
RandomForestClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestClassifier()
In [190]:
y_predict = rf.predict(x_test)
In [191]:
# 计算随机森林的准确度
print("Random Forest Accuracy:")
decision_tree_accuracy = accuracy_score(y_test, y_predict)
print(decision_tree_accuracy)

# 计算 ROC AUC 分数
decision_tree_roc_auc = roc_auc_score(y_test, y_predict)
print("Random Forest ROC AUC Score:")
print(decision_tree_roc_auc)

# 计算并绘制混淆矩阵
cm = confusion_matrix(y_test, y_predict)
sns.heatmap(cm, annot=True, fmt='d')
plt.title('Random Forest - Confusion Matrix')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
Random Forest Accuracy:
0.9438314944834504
Random Forest ROC AUC Score:
0.4968321013727561

SVM based classifier¶

In [192]:
svc = SVC(kernel='rbf', gamma=1, C=2)
svc.fit(x_train, y_train)
Out[192]:
SVC(C=2, gamma=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
SVC(C=2, gamma=1)
In [193]:
y_predict = svc.predict(x_test)
In [194]:
# 计算随机森林的准确度
print("SVM Accuracy:")
decision_tree_accuracy = accuracy_score(y_test, y_predict)
print(decision_tree_accuracy)

# 计算 ROC AUC 分数
decision_tree_roc_auc = roc_auc_score(y_test, y_predict)
print("SVM ROC AUC Score:")
print(decision_tree_roc_auc)

# 计算并绘制混淆矩阵
cm = confusion_matrix(y_test, y_predict)
sns.heatmap(cm, annot=True, fmt='d')
plt.title('SVM - Confusion Matrix')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
SVM Accuracy:
0.9428284854563691
SVM ROC AUC Score:
0.4963041182682154

XGBoost Classifier¶

In [195]:
# 计算准确度
xgboost_accuracy = accuracy_score(y_test, y_predict)
print('XGBoost Accuracy:', xgboost_accuracy)

# 计算 ROC AUC 分数
xgboost_roc_auc = roc_auc_score(y_test, y_predict)
print("XGBoost ROC AUC Score:", xgboost_roc_auc)

# 计算并绘制混淆矩阵
cm = confusion_matrix(y_test, y_predict)
sns.heatmap(cm, annot=True, fmt='d')
plt.title('XGBoost - Confusion Matrix')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
XGBoost Accuracy: 0.9428284854563691
XGBoost ROC AUC Score: 0.4963041182682154
In [196]:
import pandas as pd
import matplotlib.pyplot as plt

# 假设的准确度数据
accuracy_decision_tree = decision_tree_accuracy
accuracy_random_forest = random_forest_accuracy
accuracy_svc = svm_accuracy
accuracy_xgboost = xgboost_accuracy

# 准确度数据和模型名称
accuracies = [accuracy_decision_tree, accuracy_random_forest, accuracy_svc, accuracy_xgboost]
model_names = ['Decision Tree', 'Random Forest', 'SVM', 'XGBoost']

# 创建 DataFrame
accuracy_df = pd.DataFrame({'Model': model_names, 'Accuracy': accuracies})

# 显示表格
accuracy_df
Out[196]:
Model Accuracy
0 Decision Tree 0.942828
1 Random Forest 0.945838
2 SVM 0.942828
3 XGBoost 0.942828