This page was generated from examples/cd_spot_the_diff_mnist_wine.ipynb.
Interpretable drift detection with the spot-the-diff detector on MNIST and Wine-Quality datasets
Under the hood drift detectors leverage a function of the data that is expected to be large when drift has occured and small when it hasn’t. In the Learned drift detectors on CIFAR-10 example notebook we note that we can learn a function satisfying this property by training a classifer to distinguish reference and test samples. However we now additionally note that if the classifier is specified in a certain way then when drift is detected we can inspect the weights of the classifier to shine light on exactly which features of the data were used to distinguish reference from test samples and therefore caused drift to be detected.
The SpotTheDiffDrift
detector is designed to make this process straightforward. Like the ClassifierDrift
detector, it uses a portion of the available data to train a classifier to discriminate between reference and test instances. Letting \(\hat{p}_T(x)\) represent the probability assigned by the classifier that the instance \(x\) is from the test set rather than reference set, the difference here is that we use a classifier of the form
where \(k(\cdot,\cdot)\) is a kernel specifying a notion of similarity between instances, \(w_i\) are learnable test locations and \(b_i\) are learnable regression coefficients.
The idea here is that if the detector flags drift and \(b_i >0\) then we know that it reached its decision by considering how similar each instance is to the instance \(w_i\), with those being more similar being more likely to be test instances than reference instances. Alternatively if \(b_i < 0\) then instances more similar to \(w_i\) were deemed more likely to be reference instances.
In order to provide less noisy and therefore more interpretable results, we define each test location as
where \(\bar{x}\) is the mean reference instance. We may then interpret \(d_i\) as the additive transformation deemed to make the average reference more (\(b_i>0\)) or less (\(b_i<0\)) similar to a test instance. Defining the test locations in this way allows us to instead learn the difference \(d_i\) and apply regularisation such that non-zero values must be justified by improved classification performance. This allows us to more clearly identify which features any detected drift should be attributed to.
This approach to interpretable drift detection is inspired by the work of Jitkrittum et al. (2016), however several major adaptations have been made.
Backend
The method works with both the PyTorch and TensorFlow frameworks. Alibi Detect does however not install PyTorch for you. Check the PyTorch docs how to do this.
Dataset
We start with an image example in order to provide a visual illustration of how the detector works. For this prupose we use the MNIST dataset of 28 by 28 grayscale handwritten digits. To represent the common problem of new classes emerging during the deployment phase we consider a reference set of ~9,000 instances containing only the digits 1-9 and a test set of 10,000 instances containing all of the digits 0-9. We would like drift to be detected in this scenario because a model trained of the reference instances will not know how to process instances from the new class.
This notebook requires the torchvision
package which can be installed via pip
:
[ ]:
!pip install torchvision
[1]:
import torch
import tensorflow as tf
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from alibi_detect.cd import SpotTheDiffDrift
np.random.seed(0)
torch.manual_seed(0)
tf.random.set_seed(0)
%matplotlib inline
[2]:
MNIST_PATH = 'my_path'
DOWNLOAD = True
MISSING_NUMBER = 0
N = 10000
# Load and shuffle data
mnist_train_ds = torchvision.datasets.MNIST(MNIST_PATH, train=True, download=DOWNLOAD)
all_x, all_y = mnist_train_ds.data, mnist_train_ds.targets
perm = np.random.permutation(len(all_x))
all_x, all_y = all_x[perm], all_y[perm]
all_x = all_x[:, None, : , :].numpy().astype(np.float32)/255.
# Create a reference and test set
x_ref = all_x[:N]
x = all_x[N:2*N]
# Remove a class from reference set
x_ref = x_ref[all_y[:10000] != MISSING_NUMBER]
When instantiating the detector we should specify the number of “diffs” we would like it to use to discriminate reference from test instances. Here there is a trade off. Using n_diffs=1
is the simplest to interpret and seems to work well in practice. Using more diffs may result in stronger detection power but the diffs may be harder to interpret due to intereactions and conditional dependencies.
The strength of the regularisation (l1_reg
) to apply to the diffs should also be specified. Stronger regularisation results in sparser diffs as the classifier is encouraged to discriminate using fewer features. This may make the diff more interpretable but may again come at the cost of detection power.
We should also specify how the classifier should be trained with standard arguments such as learning_rate
, epochs
and batch_size
. By default a Gaussian RBF is used for the kernel but alternatives can be specified via the kernel
kwarg. Additionally the classifier can be initialised with any desired diffs by passing them with the initial_diffs
kwarg – by default they are initialised with Gaussian noise with
standard deviation equal to that observed in the reference data.
[3]:
cd = SpotTheDiffDrift(
x_ref,
n_diffs=1,
l1_reg=1e-4,
backend='tensorflow',
verbose=1,
learning_rate=1e-2,
epochs=5,
batch_size=64,
)
When we then call the detector to detect drift on the deployment/test set it trains the classifier (thereby learning the diffs) and the usual is_drift
and p_val
properties can be inspected in the usual way:
[4]:
preds = cd.predict(x)
print(f"Drift? {'Yes' if preds['data']['is_drift'] else 'No'}")
print(f"p-value: {preds['data']['p_val']}")
222/222 [=] - 2s 7ms/step - loss_ma: 0.6926
222/222 [=] - 2s 7ms/step - loss_ma: 0.6905
222/222 [=] - 2s 7ms/step - loss_ma: 0.6869
222/222 [=] - 2s 7ms/step - loss_ma: 0.6831
222/222 [=] - 2s 7ms/step - loss_ma: 0.6818
Drift? Yes
p-value: 6.911625415497573e-09
As expected, the drift was detected. However we may now additionally look at the learned diffs and corresponding coefficients to determine how the detector reached this decision.
[5]:
print(f"Diff coeff: {preds['data']['diff_coeffs']}")
diff = preds['data']['diffs'][0,0]
plt.imshow(diff, cmap='RdBu', vmin=-np.max(np.abs(diff)), vmax=np.max(np.abs(diff)))
plt.colorbar()
Diff coeff: [2.669408]
[5]:
<matplotlib.colorbar.Colorbar at 0x7f31104da730>

The detector has identified the zero that was missing from the reference data – it realised that test instances were on average more (coefficient > 0) simmilar to an instance with below average middle pixel values and above average zero-region pixel values than reference instances were. It used this information to determine that drift had occured.
Interpretable Drift Detection on the Wine Quality Dataset
To provide an example on tabular data we consider the Wine Quality Data Set consisting of 4898 and 1599 samples of white and red wine respectively. Each sample has an associated quality (as determined by experts) and 11 numeric features indicating its acidity, density, pH etc. To represent the problem of a model being trained on one distribution and deployed on a subtly different one, we take as a reference set the samples of white wine and consider the red wine samples to form a ‘corrupted’ deployment set.
[6]:
import pandas as pd
red_df = pd.read_csv(
"https://storage.googleapis.com/seldon-datasets/wine_quality/winequality-red.csv", sep=';'
)
white_df = pd.read_csv(
"https://storage.googleapis.com/seldon-datasets/wine_quality/winequality-white.csv", sep=';'
)
white_df.describe()
[6]:
fixed acidity | volatile acidity | citric acid | residual sugar | chlorides | free sulfur dioxide | total sulfur dioxide | density | pH | sulphates | alcohol | quality | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 4898.000000 | 4898.000000 | 4898.000000 | 4898.000000 | 4898.000000 | 4898.000000 | 4898.000000 | 4898.000000 | 4898.000000 | 4898.000000 | 4898.000000 | 4898.000000 |
mean | 6.854788 | 0.278241 | 0.334192 | 6.391415 | 0.045772 | 35.308085 | 138.360657 | 0.994027 | 3.188267 | 0.489847 | 10.514267 | 5.877909 |
std | 0.843868 | 0.100795 | 0.121020 | 5.072058 | 0.021848 | 17.007137 | 42.498065 | 0.002991 | 0.151001 | 0.114126 | 1.230621 | 0.885639 |
min | 3.800000 | 0.080000 | 0.000000 | 0.600000 | 0.009000 | 2.000000 | 9.000000 | 0.987110 | 2.720000 | 0.220000 | 8.000000 | 3.000000 |
25% | 6.300000 | 0.210000 | 0.270000 | 1.700000 | 0.036000 | 23.000000 | 108.000000 | 0.991723 | 3.090000 | 0.410000 | 9.500000 | 5.000000 |
50% | 6.800000 | 0.260000 | 0.320000 | 5.200000 | 0.043000 | 34.000000 | 134.000000 | 0.993740 | 3.180000 | 0.470000 | 10.400000 | 6.000000 |
75% | 7.300000 | 0.320000 | 0.390000 | 9.900000 | 0.050000 | 46.000000 | 167.000000 | 0.996100 | 3.280000 | 0.550000 | 11.400000 | 6.000000 |
max | 14.200000 | 1.100000 | 1.660000 | 65.800000 | 0.346000 | 289.000000 | 440.000000 | 1.038980 | 3.820000 | 1.080000 | 14.200000 | 9.000000 |
We can see that the data for both red and white wine samples take the same format.
[7]:
red_df.describe()
[7]:
fixed acidity | volatile acidity | citric acid | residual sugar | chlorides | free sulfur dioxide | total sulfur dioxide | density | pH | sulphates | alcohol | quality | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 1599.000000 | 1599.000000 | 1599.000000 | 1599.000000 | 1599.000000 | 1599.000000 | 1599.000000 | 1599.000000 | 1599.000000 | 1599.000000 | 1599.000000 | 1599.000000 |
mean | 8.319637 | 0.527821 | 0.270976 | 2.538806 | 0.087467 | 15.874922 | 46.467792 | 0.996747 | 3.311113 | 0.658149 | 10.422983 | 5.636023 |
std | 1.741096 | 0.179060 | 0.194801 | 1.409928 | 0.047065 | 10.460157 | 32.895324 | 0.001887 | 0.154386 | 0.169507 | 1.065668 | 0.807569 |
min | 4.600000 | 0.120000 | 0.000000 | 0.900000 | 0.012000 | 1.000000 | 6.000000 | 0.990070 | 2.740000 | 0.330000 | 8.400000 | 3.000000 |
25% | 7.100000 | 0.390000 | 0.090000 | 1.900000 | 0.070000 | 7.000000 | 22.000000 | 0.995600 | 3.210000 | 0.550000 | 9.500000 | 5.000000 |
50% | 7.900000 | 0.520000 | 0.260000 | 2.200000 | 0.079000 | 14.000000 | 38.000000 | 0.996750 | 3.310000 | 0.620000 | 10.200000 | 6.000000 |
75% | 9.200000 | 0.640000 | 0.420000 | 2.600000 | 0.090000 | 21.000000 | 62.000000 | 0.997835 | 3.400000 | 0.730000 | 11.100000 | 6.000000 |
max | 15.900000 | 1.580000 | 1.000000 | 15.500000 | 0.611000 | 72.000000 | 289.000000 | 1.003690 | 4.010000 | 2.000000 | 14.900000 | 8.000000 |
We extract the features and shuffle and normalise them such that they take values in [0,1].
[8]:
white, red = np.asarray(white_df, np.float32)[:, :-1], np.asarray(red_df, np.float32)[:, :-1]
n_white, n_red = white.shape[0], red.shape[0]
col_maxes = white.max(axis=0)
white, red = white / col_maxes, red / col_maxes
white, red = white[np.random.permutation(n_white)], red[np.random.permutation(n_red)]
x, x_corr = white, red
We then split off half of the reference set to act as an unseen sample from the same underlying distribution for which drift should not be detected.
[9]:
x_ref = x[:len(x)//2]
x_h0 = x[len(x)//2:]
We instantiate our detector in the same way as we do above, but this time using the Pytorch backend for the sake of variety. We then get the predictions of the detector on both the undrifted and corrupted test sets.
[10]:
cd = SpotTheDiffDrift(
x_ref,
n_diffs=1,
l1_reg=1e-4,
backend='pytorch',
verbose=1,
learning_rate=1e-2,
epochs=5,
batch_size=64,
)
preds_h0 = cd.predict(x_h0)
preds_corr = cd.predict(x_corr)
No GPU detected, fall back on CPU.
Epoch 1/5: 100%|██████████| 58/58 [00:00<00:00, 200.84it/s, loss_ma=0.694]
Epoch 2/5: 100%|██████████| 58/58 [00:00<00:00, 201.83it/s, loss_ma=0.694]
Epoch 3/5: 100%|██████████| 58/58 [00:00<00:00, 196.56it/s, loss_ma=0.694]
Epoch 4/5: 100%|██████████| 58/58 [00:00<00:00, 201.99it/s, loss_ma=0.693]
Epoch 5/5: 100%|██████████| 58/58 [00:00<00:00, 194.22it/s, loss_ma=0.694]
Epoch 1/5: 100%|██████████| 48/48 [00:00<00:00, 267.09it/s, loss_ma=0.624]
Epoch 2/5: 100%|██████████| 48/48 [00:00<00:00, 195.68it/s, loss_ma=0.533]
Epoch 3/5: 100%|██████████| 48/48 [00:00<00:00, 193.16it/s, loss_ma=0.464]
Epoch 4/5: 100%|██████████| 48/48 [00:00<00:00, 190.22it/s, loss_ma=0.412]
Epoch 5/5: 100%|██████████| 48/48 [00:00<00:00, 231.42it/s, loss_ma=0.37]
[11]:
print(f"Drift on h0? {'Yes' if preds_h0['data']['is_drift'] else 'No'}")
print(f"p-value on h0: {preds_h0['data']['p_val']}")
print(f"Drift on corrupted? {'Yes' if preds_corr['data']['is_drift'] else 'No'}")
print(f"p-value on corrupted:: {preds_corr['data']['p_val']}")
Drift on h0? No
p-value on h0: 0.7351070494981216
Drift on corrupted? Yes
p-value on corrupted:: 3.825295341753544e-201
As expected drift is detected on the red wine samples but not the held out white wine samples from the same distribution. Now we can inspect the returned diff to determine how the detector reached its decision
[12]:
diff = preds_corr['data']['diffs'][0]
print(f"Diff coeff: {preds_corr['data']['diff_coeffs']}")
plt.barh(white_df.columns[:-1], diff)
plt.xlim((-1.1*np.max(np.abs(diff)), 1.1*np.max(np.abs(diff))))
plt.axvline(0, linestyle='--', color='black')
plt.show()
Diff coeff: [-1.905726]

We see that the detector was able to discriminate the corrupted (red) wine samples from the reference (white) samples by noting that on average reference samples (coeff < 0) typically contain more sulfur dioxide and residual sugars but have less sulphates and chlorides and have lower pH and volatile and fixed acidity.