Heart Disease UCI¶
In this notebook, we illustrate black-box model explanation with the medical Heart Disease UCI dataset. There are forteen features:
age
sex
cp
: chest paintype (4 values)trestbps
: resting blood pressurechol
: serum cholestoral in mg/dlfbs
: fasting blood sugar > 120 mg/dlrestecg
: resting electrocardiographic results (values 0,1,2)thalach
: maximum heart rate achievedexang
: exercise induced anginaoldpeak
: oldpeak = ST depression induced by exercise relative to restslope
: the slope of the peak exercise ST segmentca
: number of major vessels (0-3) colored by flourosopythal
: 3 = normal; 6 = fixed defect; 7 = reversable defect
The output is the presence (1
) or absence (0
) of heart disease.
In [1]:
import ethik
X, y = ethik.datasets.load_heart_disease()
X.head()
Out[1]:
In [2]:
y.head()
Out[2]:
In [3]:
from sklearn import model_selection
X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, shuffle=True, random_state=42)
In this notebook, we aim to illustrate explanability and will arbitrarily train a gradient-boosting tree using LightGBM.
In [4]:
import lightgbm as lgb
import pandas as pd
model = lgb.LGBMClassifier(random_state=42).fit(X_train, y_train)
y_pred = model.predict_proba(X_test)[:, 1]
# We use a named pandas series to make plot labels more explicit
y_pred = pd.Series(y_pred, name='has_heart_disease')
y_pred.head()
Out[4]:
In [5]:
from sklearn import metrics
# As `y_test` is binary (0 or 1), we need to make `y_pred` binary as well
# for `metrics.accuracy_score` to work.
print(f'Accuracy score: {metrics.accuracy_score(y_test, y_pred > 0.5):.4f}')
Let's plot the four most impactful features on the predictions:
In [6]:
explainer = ethik.ClassificationExplainer()
explainer.plot_influence_ranking(
X_test=X_test,
y_pred=y_pred,
n_features=10,
)
The maximum heart rate achieved is the most impactful feature on the probability of having diabetes. Let's have a look at the details:
In [7]:
explainer.plot_influence(
X_test=X_test["thalach"],
y_pred=y_pred,
)
In [8]:
explainer.plot_influence(
X_test=X_test["oldpeak"],
y_pred=y_pred,
)
In [9]:
explainer.plot_influence(
X_test=X_test["thal"],
y_pred=y_pred,
)
In [10]:
explainer.plot_influence(
X_test=X_test["cp"],
y_pred=y_pred,
)
In [11]:
explainer.plot_influence(
X_test=X_test[["thalach", "oldpeak", "thal"]],
y_pred=y_pred,
)
In [ ]: