# Maximum Mean Discrepancy drift detector on CIFAR-10¶

## Method¶

The Maximum Mean Discrepancy (MMD) detector is a kernel-based method for multivariate 2 sample testing. The MMD is a distance-based measure between 2 distributions p and q based on the mean embeddings $$\mu_{p}$$ and $$\mu_{q}$$ in a reproducing kernel Hilbert space $$F$$:

\begin{align} MMD(F, p, q) & = || \mu_{p} - \mu_{q} ||^2_{F} \\ \end{align}

We can compute unbiased estimates of $$MMD^2$$ from the samples of the 2 distributions after applying the kernel trick. We use by default a radial basis function kernel, but users are free to pass their own kernel of preference to the detector. We obtain a $$p$$-value via a permutation test on the values of $$MMD^2$$. This method is also described in Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift.

## Dataset¶

CIFAR10 consists of 60,000 32 by 32 RGB images equally distributed over 10 classes. We evaluate the drift detector on the CIFAR-10-C dataset (Hendrycks & Dietterich, 2019). The instances in CIFAR-10-C have been corrupted and perturbed by various types of noise, blur, brightness etc. at different levels of severity, leading to a gradual decline in the classification model performance. We also check for drift against the original test set with class imbalances.

[1]:

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, Flatten, InputLayer, Reshape

from alibi_detect.cd import MMDDrift
from alibi_detect.cd.preprocess import uae, hidden_output
from alibi_detect.models.resnet import scale_by_instance
from alibi_detect.utils.fetching import fetch_tf_model
from alibi_detect.utils.kernels import gaussian_kernel
from alibi_detect.utils.prediction import predict_batch
from alibi_detect.datasets import fetch_cifar10c, corruption_types_cifar10c

ERROR:fbprophet:Importing plotly failed. Interactive plots will not work.


Original CIFAR-10 data:

[2]:

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
y_train = y_train.astype('int64').reshape(-1,)
y_test = y_test.astype('int64').reshape(-1,)


For CIFAR-10-C, we can select from the following corruption types at 5 severity levels:

[3]:

corruptions = corruption_types_cifar10c()
print(corruptions)

['brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate', 'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise', 'zoom_blur']


Let’s pick a subset of the corruptions at corruption level 5. Each corruption type consists of perturbations on all of the original test set images.

[4]:

corruption = ['gaussian_noise', 'motion_blur', 'brightness', 'pixelate']
X_corr, y_corr = fetch_cifar10c(corruption=corruption, severity=5, return_X_y=True)
X_corr = X_corr.astype('float32') / 255


We split the original test set in a reference dataset and a dataset which should not be rejected under the H0 of the MMD test. We also split the corrupted data by corruption type:

[5]:

np.random.seed(0)
n_test = X_test.shape[0]
idx = np.random.choice(n_test, size=n_test // 2, replace=False)
idx_h0 = np.delete(np.arange(n_test), idx, axis=0)
X_ref,y_ref = X_test[idx], y_test[idx]
X_h0, y_h0 = X_test[idx_h0], y_test[idx_h0]
print(X_ref.shape, X_h0.shape)

(5000, 32, 32, 3) (5000, 32, 32, 3)

[6]:

# check that the classes are more or less balanced
classes, counts_ref = np.unique(y_ref, return_counts=True)
counts_h0 = np.unique(y_h0, return_counts=True)[1]
print('Class Ref H0')
for cl, cref, ch0 in zip(classes, counts_ref, counts_h0):
assert cref + ch0 == n_test // 10
print('{}     {} {}'.format(cl, cref, ch0))

Class Ref H0
0     472 528
1     510 490
2     498 502
3     492 508
4     501 499
5     495 505
6     493 507
7     501 499
8     516 484
9     522 478

[7]:

X_c = []
n_corr = len(corruption)
for i in range(n_corr):
X_c.append(X_corr[i * n_test:(i + 1) * n_test])


We can visualise the same instance for each corruption type:

[8]:

i = 4

n_test = X_test.shape[0]
plt.title('Original')
plt.axis('off')
plt.imshow(X_test[i])
plt.show()
for _ in range(len(corruption)):
plt.title(corruption[_])
plt.axis('off')
plt.imshow(X_corr[n_test * _+ i])
plt.show()


We can also verify that the performance of a classification model on CIFAR-10 drops significantly on this perturbed dataset:

[9]:

dataset = 'cifar10'
model = 'resnet32'
clf = fetch_tf_model(dataset, model)
acc = clf.evaluate(scale_by_instance(X_test), y_test, batch_size=128, verbose=0)[1]
print('Test set accuracy:')
print('Original {:.4f}'.format(acc))
clf_accuracy = {'original': acc}
for _ in range(len(corruption)):
acc = clf.evaluate(scale_by_instance(X_c[_]), y_test, batch_size=128, verbose=0)[1]
clf_accuracy[corruption[_]] = acc
print('{} {:.4f}'.format(corruption[_], acc))

Test set accuracy:
Original 0.9278
gaussian_noise 0.2208
motion_blur 0.6339
brightness 0.8913
pixelate 0.3666


Given the drop in performance, it is important that we detect the harmful data drift!

## Detect drift¶

We are trying to detect data drift on high-dimensional (32x32x3) data using a multivariate MMD permutation test. It therefore makes sense to apply dimensionality reduction first. Some dimensionality reduction methods also used in Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift are readily available: UAE (Untrained AutoEncoder), BBSDs (black-box shift detection using the classifier’s softmax outputs) and PCA.

### Untrained AutoEncoder¶

First we try UAE:

[10]:

tf.random.set_seed(0)

# define encoder
encoding_dim = 32
encoder_net = tf.keras.Sequential(
[
InputLayer(input_shape=(32, 32, 3)),
Flatten(),
Dense(encoding_dim,)
]
)

# initialise drift detector
cd = MMDDrift(
p_val=.05,          # p-value for permutation test
X_ref=X_ref,        # reference data to test against
preprocess_fn=uae,  # UAE for dimensionality reduction
preprocess_kwargs={'encoder_net': encoder_net, 'batch_size': 128},
kernel=gaussian_kernel,  # use the default Gaussian kernel in MMD
kernel_kwargs={'sigma': np.array([1.])},
chunk_size=1000,
n_permutations=5    # nb of permutations in the test, set to 5 for runtime
)                       # purposes; should be much higher for a real test

# we can also save/load an initialised detector
filepath = 'my_path'  # change to directory where detector is saved
save_detector(cd, filepath)


The optional chunk_size variable will be used to compute the maximum mean discrepancy distance between the 2 samples in chunks using dask to avoid potential out-of-memory errors. In terms of speed, the optimal chunk_size is application and hardware dependent, so it is often worth to test a few different values, including None. None means that the computation is done in-memory in NumPy.

Let’s check whether the detector thinks drift occurred within the original test set:

[11]:

preds_h0 = cd.predict(X_h0, return_p_val=True)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))

Drift? No!


As expected, no drift occurred. The p-value of the permutation test is above the $$0.05$$ threshold:

[12]:

print(preds_h0['data']['p_val'])

0.8


Let’s now check the predictions on the perturbed data:

[13]:

for x, c in zip(X_c, corruption):
preds = cd.predict(x, return_p_val=True)
print(f'Corruption type: {c}')
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('Feature-wise p-values:')
print(preds['data']['p_val'])
print('')

Corruption type: gaussian_noise
Drift? Yes!
Feature-wise p-values:
0.0

Corruption type: motion_blur
Drift? Yes!
Feature-wise p-values:
0.0

Corruption type: brightness
Drift? Yes!
Feature-wise p-values:
0.0

Corruption type: pixelate
Drift? Yes!
Feature-wise p-values:
0.0



### BBSDs¶

For BBSDs, we use the classifier’s softmax outputs for black-box shift detection. This method is based on Detecting and Correcting for Label Shift with Black Box Predictors. The ResNet classifier is trained on data standardised by instance so we need to rescale the data.

[14]:

X_train = scale_by_instance(X_train)
X_test = scale_by_instance(X_test)
for i in range(n_corr):
X_c[i] = scale_by_instance(X_c[i])
X_ref = scale_by_instance(X_ref)
X_h0 = scale_by_instance(X_h0)


Initialisation of the drift detector. Here we use the output of the softmax layer to detect the drift, but other hidden layers can be extracted as well by setting ‘layer’ to the index of the desired hidden layer in the model:

[15]:

cd = MMDDrift(
p_val=.05,
X_ref=X_ref,
preprocess_fn=hidden_output,
preprocess_kwargs={'model': clf, 'layer': -1, 'batch_size': 128},  # use output softmax layer
kernel_kwargs={'sigma': np.array([1.])},
chunk_size=1000,
n_permutations=5
)


There is no drift on the original held out test set:

[16]:

preds_h0 = cd.predict(X_h0)
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print(preds_h0['data']['p_val'])

Drift? No!
0.4


We compare this with the perturbed data:

[17]:

for x, c in zip(X_c, corruption):
preds = cd.predict(x)
print(f'Corruption type: {c}')
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('Feature-wise p-values:')
print(preds['data']['p_val'])
print('')

Corruption type: gaussian_noise
Drift? Yes!
Feature-wise p-values:
0.0

Corruption type: motion_blur
Drift? Yes!
Feature-wise p-values:
0.0

Corruption type: brightness
Drift? Yes!
Feature-wise p-values:
0.0

Corruption type: pixelate
Drift? Yes!
Feature-wise p-values:
0.0



## Kernel bandwidth¶

So far we have defined a specific bandwidth sigma for the Gaussian kernel. We can however also sum over a number of different kernel bandwidths or infer sigma from X_ref and X using the following heuristic: compute the pairwise distances between each of the instances in X_ref and X, and set sigma to the median distance.

Let’s first try a range of bandwidths:

[18]:

cd = MMDDrift(
p_val=.05,
X_ref=X_ref,
preprocess_fn=hidden_output,
preprocess_kwargs={'model': clf, 'layer': -1, 'batch_size': 128},
kernel_kwargs={'sigma': np.array([.5, 1., 5.])},
chunk_size=1000,
n_permutations=5
)

[19]:

preds_h0 = cd.predict(X_h0)
print('Original test set sample')
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print(preds_h0['data']['p_val'])
print('')

for x, c in zip(X_c, corruption):
preds = cd.predict(x)
print(f'Corruption type: {c}')
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('Feature-wise p-values:')
print(preds['data']['p_val'])
print('')

Original test set sample
Drift? No!
0.2

Corruption type: gaussian_noise
Drift? Yes!
Feature-wise p-values:
0.0

Corruption type: motion_blur
Drift? Yes!
Feature-wise p-values:
0.0

Corruption type: brightness
Drift? Yes!
Feature-wise p-values:
0.0

Corruption type: pixelate
Drift? Yes!
Feature-wise p-values:
0.0



A bandwidth can also be inferred from X_ref and X using the heuristic:

[20]:

cd = MMDDrift(
p_val=.05,
X_ref=X_ref,
preprocess_fn=hidden_output,
preprocess_kwargs={'model': clf, 'layer': -1, 'batch_size': 128},
chunk_size=1000,
n_permutations=5
)

[21]:

preds_h0 = cd.predict(X_h0)
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print(preds_h0['data']['p_val'])

Drift? No!
0.2

[22]:

print('Inferred bandwidth: {:.4f}'.format(cd.permutation_test.keywords['sigma'].item()))

Inferred bandwidth: 1.4132


## Label drift¶

We can also check what happens when we introduce class imbalances between the reference data X_ref and the tested data X_imb. The reference data will use $$75$$% of the instances of the first 5 classes and only $$25$$% of the last 5. The data used for drift testing then uses respectively $$25$$% and $$75$$% of the test instances for the first and last 5 classes.

[23]:

np.random.seed(0)
# get index for each class in the test set
num_classes = len(np.unique(y_test))
idx_by_class = [np.where(y_test == c)[0] for c in range(num_classes)]
# sample imbalanced data for different classes for X_ref and X_imb
perc_ref = .75
perc_ref_by_class = [perc_ref if c < 5 else 1 - perc_ref for c in range(num_classes)]
n_by_class = n_test // num_classes
X_ref = []
X_imb, y_imb = [], []
for _ in range(num_classes):
idx_class_ref = np.random.choice(n_by_class, size=int(perc_ref_by_class[_] * n_by_class), replace=False)
idx_ref = idx_by_class[_][idx_class_ref]
idx_class_imb = np.delete(np.arange(n_by_class), idx_class_ref, axis=0)
idx_imb = idx_by_class[_][idx_class_imb]
assert idx_ref != idx_imb
X_ref.append(X_test[idx_ref])
X_imb.append(X_test[idx_imb])
y_imb.append(y_test[idx_imb])
X_ref = np.concatenate(X_ref)
X_imb = np.concatenate(X_imb)
y_imb = np.concatenate(y_imb)
print(X_ref.shape, X_imb.shape, y_imb.shape)

(5000, 32, 32, 3) (5000, 32, 32, 3) (5000,)


Update reference dataset for the detector and make predictions:

[24]:

cd.X_ref = X_ref

[25]:

preds_imb = cd.predict(X_imb)
print('Drift? {}'.format(labels[preds_imb['data']['is_drift']]))
print(preds_imb['data']['p_val'])

Drift? Yes!
0.0