Comparison with Partial Dependence Plot¶
In the "Interpretable Machine Learning" book, we can read:
The partial dependence plot (short PDP or PD plot) shows the marginal effect one or two features have on the predicted outcome of a machine learning model (Friedman, Jerome H. “Greedy function approximation: A gradient boosting machine.” Annals of statistics (2001): 1189-1232.). A partial dependence plot can show whether the relationship between the target and a feature is linear, monotonic or more complex. For example, when applied to a linear regression model, partial dependence plots always show a linear relationship.
Put differently further in the book:
Partial Dependence Plots: “Let me show you what the model predicts on average when each data instance has the value v for that feature. I ignore whether the value v makes sense for all data instances.”
Computing a PDP is really straightforward:
- Select a feature (e.g. "age")
- Define a grid on the feature's domain (e.g. 20, 21, 22, ..., 59, 60)
- For each value
v
of the grid:- Replace the feature with
v
for all data samples - Compute the predictions
- Take the average
- Replace the feature with
- Draw the curve
average_prediction = f(v)
PDPs are used in Google's What-If Tool. In this notebook, we compare this method with ours, Entropic Variable Boosting (EVB), on the "Adult" dataset (see the dedicated notebook for additional information).
import ethik
import lightgbm as lgb
import numpy as np
import pandas as pd
import plotly.graph_objs as go
from sklearn import model_selection
import sklearn.inspection
url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data'
names = [
'age', 'workclass', 'fnlwgt', 'education',
'education-num', 'marital-status', 'occupation',
'relationship', 'race', 'gender', 'capital-gain',
'capital-loss', 'hours-per-week', 'native-country',
'salary'
]
dtypes = {
'workclass': 'category',
'education': 'category',
'marital-status': 'category',
'occupation': 'category',
'relationship': 'category',
'race': 'category',
'gender': 'category',
'native-country': 'category'
}
X = pd.read_csv(url, names=names, header=None, dtype=dtypes)
y = X.pop('salary').map({' <=50K': False, ' >50K': True})
# plot_partial_dependence() doesn't handle strings
cat_columns = X.select_dtypes(['category']).columns
X[cat_columns] = X[cat_columns].apply(lambda x: x.cat.codes)
X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, shuffle=True, random_state=42)
model = lgb.LGBMClassifier(random_state=42).fit(X_train, y_train)
y_pred = pd.Series(model.predict_proba(X_test)[:, 1], name='>$50k')
Let's define helpers to compare PDP and EVB:
def create_fig():
fig = go.Figure()
fig.update_layout(
margin=dict(t=50, r=50),
xaxis=dict(title=feature, zeroline=False),
yaxis=dict(title="Average prediction", range=[0, 1], showline=True, tickformat="%"),
plot_bgcolor="white",
)
return fig
def plot_partial_dependence(feature, n_samples=1, fig=None):
rs = model_selection.ShuffleSplit(n_splits=n_samples, train_size=0.8)
data = []
for index, _ in rs.split(X_test):
averaged_predictions, values = sklearn.inspection.partial_dependence(
estimator=model,
X=X_test.iloc[index],
features=[X_test.columns.get_loc(feature)],
grid_resolution=41,
)
x = values[0]
y = averaged_predictions[0]
data.append((x, y))
x = data[0][0]
ys = [y for x, y in data]
if fig is None:
fig = create_fig()
if n_samples > 1:
low = np.quantile(ys, q=0.05, axis=0)
high = np.quantile(ys, q=0.95, axis=0)
fig.add_trace(
go.Scatter(
x=np.concatenate((x, x[::-1])),
y=np.concatenate((low, high[::-1])),
name="PDP: 5% - 95%",
fill="toself",
fillcolor="#eee",
line_color="rgba(0, 0, 0, 0)",
legendgroup="PDP"
)
)
fig.add_trace(go.Scatter(
x=x,
y=np.mean(ys, axis=0),
name="PDP",
legendgroup="PDP"
))
return fig
def plot_evb(feature, n_samples=1):
return ethik.ClassificationExplainer(n_samples=n_samples).plot_influence(
X_test=X_test[feature],
y_pred=y_pred
)
def plot_all(feature, n_samples=1):
fig = plot_evb(feature, n_samples=n_samples)
return plot_partial_dependence(feature, n_samples=n_samples, fig=fig)
plot_all("age")
plot_all("age", n_samples=30)
plot_all("education-num")
plot_all("education-num", n_samples=30)