Ethik AI

MNIST digit classification

We'll train a simple CNN on the MNIST dataset by copy/pasting this example from the Keras documentation.

First, we load the data:

In [1]:
import warnings

import ethik

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

import numpy as np

batch_size = 128
num_classes = 10
epochs = 12

# input image dimensions
img_rows, img_cols = 28, 28

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    
    # the data, split between train and test sets
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    if K.image_data_format() == 'channels_first':
        x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
        x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
        input_shape = (1, img_rows, img_cols)
    else:
        x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
        x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
        input_shape = (img_rows, img_cols, 1)

    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255

    # convert class vectors to binary class matrices
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)
Using TensorFlow backend.
Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

 Model creation

Now let's create the CNN:

In [2]:
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.Adadelta(),
                  metrics=['accuracy'])
In [3]:
def train(model):
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1,
              validation_data=(x_test, y_test))
    
    return model.evaluate(x_test, y_test, verbose=0)
    
def predict(model):
    return model.predict_proba(x_test)

Let's train the model and predict the class probabilities for the test set. Training and prediction steps take a lot of time so we use caching:

In [4]:
# loss, accuracy = train(model)
# y_pred = predict(model)
# np.save("cache/mnist.npy", y_pred)

y_pred = np.load("cache/mnist.npy")

y_pred is a (n_samples, n_features) (i.e. (n_samples, n_digits)) numpy array of probabilities:

In [5]:
y_pred.shape
Out[5]:
(10000, 10)
In [6]:
np.isclose(np.sum(y_pred[0]), 1)
Out[6]:
True

Bias explanation

Now, we have the data and can use ethik to explain it. There is one feature per pixel:

In [7]:
explainer = ethik.ImageClassificationExplainer()
explainer.plot_influence(x_test, y_pred)
100%|██████████| 18420/18420 [00:47<00:00, 384.45it/s]

You can adjust the size of the plot as well:

In [8]:
explainer.plot_influence(x_test, y_pred, cell_width=100)

The previous plot highlights the regions of importance for identifying each digit. More precisely, the intensity of each pixel corresponds to the probability increase of saturating or not the pixel. A value of 0.28 means that saturating the pixel increases the probability predicted by the model by 0.28. Note that we do not saturate and desaturate the pixels independently. Instead, our method understands which pixels are linked together and saturates them in a realistic manner. The previous images show that the CNN seems to be using the same visual cues as a human. However, we can see that is uses very specific regions on images to identify particular digits. For instance, the top-right region of an image seems to trigger the "5" digit, whereas the bottom parts of the images seem to be linked with the "7" digit.

 Performance explanation

y_test is an array of n_samples one-hot-encoded vectors:

In [9]:
y_test
Out[9]:
array([[0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

To get the label, we use .argmax():

In [10]:
y_test.argmax(axis=1)
Out[10]:
array([7, 2, 1, ..., 4, 5, 6])

Then we can explain the model performance with ethik. Because of sklearn.metrics.accuracy_score() API, we need to convert the prediction vectors into their corresponding label:

In [11]:
from sklearn import metrics

explainer.plot_performance(
    X_test=x_test,
    y_test=y_test.argmax(axis=1),
    y_pred=y_pred.argmax(axis=1),
    metric=metrics.accuracy_score
)
100%|██████████| 1842/1842 [00:07<00:00, 238.74it/s]

TODO: analysis

The log loss metric deals with vectors of probabilities so we don't need to get the label:

In [12]:
explainer.plot_performance(
    X_test=x_test,
    y_test=y_test,
    y_pred=y_pred,
    metric=metrics.log_loss
)
100%|██████████| 1842/1842 [00:54<00:00, 33.90it/s]

TODO: analysis and comparison with accuracy_score()

In [ ]: