Heart Disease UCI¶
In this notebook, we illustrate black-box model explanation with the medical Heart Disease UCI dataset. There are forteen features:
- age
- sex
- cp: chest paintype (4 values)
- trestbps: resting blood pressure
- chol: serum cholestoral in mg/dl
- fbs: fasting blood sugar > 120 mg/dl
- restecg: resting electrocardiographic results (values 0,1,2)
- thalach: maximum heart rate achieved
- exang: exercise induced angina
- oldpeak: oldpeak = ST depression induced by exercise relative to rest
- slope: the slope of the peak exercise ST segment
- ca: number of major vessels (0-3) colored by flourosopy
- thal: 3 = normal; 6 = fixed defect; 7 = reversable defect
The output is the presence (1) or absence (0) of heart disease.
In [1]:
import ethik
X, y = ethik.datasets.load_heart_disease()
X.head()
Out[1]:
In [2]:
y.head()
Out[2]:
In [3]:
from sklearn import model_selection
X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, shuffle=True, random_state=42)
In this notebook, we aim to illustrate explanability and will arbitrarily train a gradient-boosting tree using LightGBM.
In [4]:
import lightgbm as lgb
import pandas as pd
model = lgb.LGBMClassifier(random_state=42).fit(X_train, y_train)
y_pred = model.predict_proba(X_test)[:, 1]
# We use a named pandas series to make plot labels more explicit
y_pred = pd.Series(y_pred, name='has_heart_disease')
y_pred.head()
Out[4]:
In [5]:
from sklearn import metrics
# As `y_test` is binary (0 or 1), we need to make `y_pred` binary as well
# for `metrics.accuracy_score` to work.
print(f'Accuracy score: {metrics.accuracy_score(y_test, y_pred > 0.5):.4f}')
Let's plot the four most impactful features on the predictions:
In [6]:
explainer = ethik.ClassificationExplainer()
explainer.plot_influence_ranking(
    X_test=X_test,
    y_pred=y_pred,
    n_features=10,
)
The maximum heart rate achieved is the most impactful feature on the probability of having diabetes. Let's have a look at the details:
In [7]:
explainer.plot_influence(
    X_test=X_test["thalach"],
    y_pred=y_pred,
)
In [8]:
explainer.plot_influence(
    X_test=X_test["oldpeak"],
    y_pred=y_pred,
)
In [9]:
explainer.plot_influence(
    X_test=X_test["thal"],
    y_pred=y_pred,
)
In [10]:
explainer.plot_influence(
    X_test=X_test["cp"],
    y_pred=y_pred,
)
In [11]:
explainer.plot_influence(
    X_test=X_test[["thalach", "oldpeak", "thal"]],
    y_pred=y_pred,
)
In [ ]: