MNIST digit classification¶
We'll train a simple CNN on the MNIST dataset by copy/pasting this example from the Keras documentation.
First, we load the data:
import warnings
import ethik
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
import numpy as np
batch_size = 128
num_classes = 10
epochs = 12
# input image dimensions
img_rows, img_cols = 28, 28
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
Model creation¶
Now let's create the CNN:
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
def train(model):
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test))
return model.evaluate(x_test, y_test, verbose=0)
def predict(model):
return model.predict_proba(x_test)
Let's train the model and predict the class probabilities for the test set. Training and prediction steps take a lot of time so we use caching:
# loss, accuracy = train(model)
# y_pred = predict(model)
# np.save("cache/mnist.npy", y_pred)
y_pred = np.load("cache/mnist.npy")
y_pred
is a (n_samples, n_features)
(i.e. (n_samples, n_digits)
) numpy array of probabilities:
y_pred.shape
np.isclose(np.sum(y_pred[0]), 1)
Bias explanation¶
Now, we have the data and can use ethik
to explain it. There is one feature per pixel:
explainer = ethik.ImageClassificationExplainer()
explainer.plot_influence(x_test, y_pred)
You can adjust the size of the plot as well:
explainer.plot_influence(x_test, y_pred, cell_width=100)
The previous plot highlights the regions of importance for identifying each digit. More precisely, the intensity of each pixel corresponds to the probability increase of saturating or not the pixel. A value of 0.28 means that saturating the pixel increases the probability predicted by the model by 0.28. Note that we do not saturate and desaturate the pixels independently. Instead, our method understands which pixels are linked together and saturates them in a realistic manner. The previous images show that the CNN seems to be using the same visual cues as a human. However, we can see that is uses very specific regions on images to identify particular digits. For instance, the top-right region of an image seems to trigger the "5" digit, whereas the bottom parts of the images seem to be linked with the "7" digit.
Performance explanation¶
y_test
is an array of n_samples
one-hot-encoded vectors:
y_test
To get the label, we use .argmax()
:
y_test.argmax(axis=1)
Then we can explain the model performance with ethik
. Because of sklearn.metrics.accuracy_score()
API, we need to convert the prediction vectors into their corresponding label:
from sklearn import metrics
explainer.plot_performance(
X_test=x_test,
y_test=y_test.argmax(axis=1),
y_pred=y_pred.argmax(axis=1),
metric=metrics.accuracy_score
)
TODO: analysis
The log loss metric deals with vectors of probabilities so we don't need to get the label:
explainer.plot_performance(
X_test=x_test,
y_test=y_test,
y_pred=y_pred,
metric=metrics.log_loss
)
TODO: analysis and comparison with accuracy_score()