This page was generated from examples/cd_ks_cifar10.ipynb.

Kolmogorov-Smirnov data drift detector on CIFAR-10

Method

The drift detector applies feature-wise two-sample Kolmogorov-Smirnov (K-S) tests. For multivariate data, the obtained p-values for each feature are aggregated either via the Bonferroni or the False Discovery Rate (FDR) correction. The Bonferroni correction is more conservative and controls for the probability of at least one false positive. The FDR correction on the other hand allows for an expected fraction of false positives to occur.

For high-dimensional data, we typically want to reduce the dimensionality before computing the feature-wise univariate K-S tests and aggregating those via the chosen correction method. Following suggestions in Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift, we incorporate Untrained AutoEncoders (UAE), black-box shift detection using the classifier’s softmax outputs (BBSDs) and PCA as out-of-the box preprocessing methods. Preprocessing methods which do not rely on the classifier will usually pick up drift in the input data, while BBSDs focuses on label shift. The adversarial detector which is part of the library can also be transformed into a drift detector picking up drift that reduces the performance of the classification model. We can therefore combine different preprocessing techniques to figure out if there is drift which hurts the model performance, and whether this drift can be classified as input drift or label shift.

Dataset

CIFAR10 consists of 60,000 32 by 32 RGB images equally distributed over 10 classes. We evaluate the drift detector on the CIFAR-10-C dataset (Hendrycks & Dietterich, 2019). The instances in CIFAR-10-C have been corrupted and perturbed by various types of noise, blur, brightness etc. at different levels of severity, leading to a gradual decline in the classification model performance. We also check for drift against the original test set with class imbalances.

[1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, Flatten, InputLayer, Reshape

from alibi_detect.cd import KSDrift
from alibi_detect.cd.preprocess import uae, hidden_output
from alibi_detect.models.resnet import scale_by_instance
from alibi_detect.utils.fetching import fetch_tf_model, fetch_detector
from alibi_detect.utils.prediction import predict_batch
from alibi_detect.utils.saving import save_detector, load_detector
from alibi_detect.datasets import fetch_cifar10c, corruption_types_cifar10c
ERROR:fbprophet:Importing plotly failed. Interactive plots will not work.

Load data

Original CIFAR-10 data:

[2]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
y_train = y_train.astype('int64').reshape(-1,)
y_test = y_test.astype('int64').reshape(-1,)

For CIFAR-10-C, we can select from the following corruption types at 5 severity levels:

[3]:
corruptions = corruption_types_cifar10c()
print(corruptions)
['brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog', 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate', 'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise', 'zoom_blur']

Let’s pick a subset of the corruptions at corruption level 5. Each corruption type consists of perturbations on all of the original test set images.

[4]:
corruption = ['gaussian_noise', 'motion_blur', 'brightness', 'pixelate']
X_corr, y_corr = fetch_cifar10c(corruption=corruption, severity=5, return_X_y=True)
X_corr = X_corr.astype('float32') / 255

We split the original test set in a reference dataset and a dataset which should not be rejected under the H0 of the K-S test. We also split the corrupted data by corruption type:

[5]:
np.random.seed(0)
n_test = X_test.shape[0]
idx = np.random.choice(n_test, size=n_test // 2, replace=False)
idx_h0 = np.delete(np.arange(n_test), idx, axis=0)
X_ref,y_ref = X_test[idx], y_test[idx]
X_h0, y_h0 = X_test[idx_h0], y_test[idx_h0]
print(X_ref.shape, X_h0.shape)
(5000, 32, 32, 3) (5000, 32, 32, 3)
[6]:
# check that the classes are more or less balanced
classes, counts_ref = np.unique(y_ref, return_counts=True)
counts_h0 = np.unique(y_h0, return_counts=True)[1]
print('Class Ref H0')
for cl, cref, ch0 in zip(classes, counts_ref, counts_h0):
    assert cref + ch0 == n_test // 10
    print('{}     {} {}'.format(cl, cref, ch0))
Class Ref H0
0     472 528
1     510 490
2     498 502
3     492 508
4     501 499
5     495 505
6     493 507
7     501 499
8     516 484
9     522 478
[7]:
X_c = []
n_corr = len(corruption)
for i in range(n_corr):
    X_c.append(X_corr[i * n_test:(i + 1) * n_test])

We can visualise the same instance for each corruption type:

[8]:
i = 1

n_test = X_test.shape[0]
plt.title('Original')
plt.axis('off')
plt.imshow(X_test[i])
plt.show()
for _ in range(len(corruption)):
    plt.title(corruption[_])
    plt.axis('off')
    plt.imshow(X_corr[n_test * _+ i])
    plt.show()
../_images/examples_cd_ks_cifar10_13_0.png
../_images/examples_cd_ks_cifar10_13_1.png
../_images/examples_cd_ks_cifar10_13_2.png
../_images/examples_cd_ks_cifar10_13_3.png
../_images/examples_cd_ks_cifar10_13_4.png

We can also verify that the performance of a classification model on CIFAR-10 drops significantly on this perturbed dataset:

[9]:
dataset = 'cifar10'
model = 'resnet32'
clf = fetch_tf_model(dataset, model)
acc = clf.evaluate(scale_by_instance(X_test), y_test, batch_size=128, verbose=0)[1]
print('Test set accuracy:')
print('Original {:.4f}'.format(acc))
clf_accuracy = {'original': acc}
for _ in range(len(corruption)):
    acc = clf.evaluate(scale_by_instance(X_c[_]), y_test, batch_size=128, verbose=0)[1]
    clf_accuracy[corruption[_]] = acc
    print('{} {:.4f}'.format(corruption[_], acc))
Test set accuracy:
Original 0.9278
gaussian_noise 0.2208
motion_blur 0.6339
brightness 0.8913
pixelate 0.3666

Given the drop in performance, it is important that we detect the harmful data drift!

Detect drift

We are trying to detect data drift on high-dimensional (32x32x3) data using an aggregation of univariate K-S tests. It therefore makes sense to apply dimensionality reduction first. Some dimensionality reduction methods also used in Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift are readily available: UAE (Untrained AutoEncoder), BBSDs (black-box shift detection using the classifier’s softmax outputs) and PCA.

Untrained AutoEncoder

First we try UAE:

[10]:
tf.random.set_seed(0)

# define encoder
encoding_dim = 32
encoder_net = tf.keras.Sequential(
  [
      InputLayer(input_shape=(32, 32, 3)),
      Conv2D(64, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(128, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(512, 4, strides=2, padding='same', activation=tf.nn.relu),
      Flatten(),
      Dense(encoding_dim,)
  ]
)

# initialise drift detector
p_val = .05
cd = KSDrift(
    p_val=p_val,        # p-value for K-S test
    X_ref=X_ref,       # test against original test set
    preprocess_fn=uae,  # UAE for dimensionality reduction
    preprocess_kwargs={'encoder_net': encoder_net, 'batch_size': 128},
    alternative='two-sided'  # other options: 'less', 'greater'
)

# we can also save/load an initialised detector
filepath = 'my_path'  # change to directory where detector is saved
save_detector(cd, filepath)
cd = load_detector(filepath)

The p-value used by the detector for the multivariate data with encoding_dim features is equal to p_val / encoding_dim because of the Bonferroni correction.

[11]:
assert cd.p_val / cd.n_features == p_val / encoding_dim

Let’s check whether the detector thinks drift occurred within the original test set:

[12]:
preds_h0 = cd.predict(X_h0, return_p_val=True)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
Drift? No!

As expected, no drift occurred. We can also inspect the p-values for each univariate K-S test by (encoded) feature before the multivariate correction. Most of them are well above the \(0.05\) threshold:

[13]:
print(preds_h0['data']['p_val'])
[0.94146556 0.14195988 0.6440195  0.05512931 0.37907234 0.25943416
 0.87724036 0.48063537 0.11774229 0.677735   0.48063537 0.7442197
 0.08356539 0.14860499 0.31536394 0.31536394 0.61036026 0.36571503
 0.80732274 0.22020556 0.249175   0.46531922 0.11774229 0.45025542
 0.25943416 0.5936282  0.5604951  0.9571862  0.8642828  0.23921937
 0.90134364 0.6945301 ]

Let’s now check the predictions on the perturbed data:

[14]:
for x, c in zip(X_c, corruption):
    preds = cd.predict(x, return_p_val=True)
    print(f'Corruption type: {c}')
    print('Drift? {}'.format(labels[preds['data']['is_drift']]))
    print('Feature-wise p-values:')
    print(preds['data']['p_val'])
    print('')
Corruption type: gaussian_noise
Drift? Yes!
Feature-wise p-values:
[4.95750410e-03 7.34658632e-03 5.52912503e-02 1.03268398e-08
 3.38559955e-01 8.17310661e-02 6.79804944e-03 2.71271050e-01
 1.55009609e-02 1.03484383e-02 1.82668434e-03 1.05715834e-01
 5.14261274e-08 1.42579767e-10 1.07582469e-04 1.40577822e-06
 2.90366083e-01 3.92065112e-06 3.53773794e-04 1.99041501e-01
 1.90719927e-03 6.04502577e-03 4.95599561e-10 6.04502577e-03
 5.12230098e-01 5.46741539e-05 3.83900553e-01 1.14905804e-01
 2.80400272e-02 8.04118258e-07 3.30103422e-03 2.05864832e-02]

Corruption type: motion_blur
Drift? Yes!
Feature-wise p-values:
[3.5206401e-07 1.1811157e-01 5.1718007e-04 3.8903630e-03 6.3595712e-01
 5.4199551e-04 3.7114436e-04 1.6068090e-02 1.9909975e-03 1.2925932e-02
 4.6739896e-09 7.1598392e-04 3.8931589e-04 6.6549936e-21 1.4105450e-07
 2.0635051e-05 7.2645240e-02 7.4932752e-05 1.0571583e-01 1.1322660e-04
 2.2843637e-04 1.5437353e-01 6.5380293e-03 3.2056447e-02 3.1007592e-02
 9.2208997e-05 6.6406779e-02 2.0782007e-03 1.1780208e-03 8.4564666e-10
 9.8475325e-04 3.5377379e-04]

Corruption type: brightness
Drift? Yes!
Feature-wise p-values:
[0.0000000e+00 0.0000000e+00 4.9483204e-29 0.0000000e+00 0.0000000e+00
 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
 0.0000000e+00 0.0000000e+00 0.0000000e+00 2.7021131e-21 4.0524192e-38
 0.0000000e+00 0.0000000e+00 0.0000000e+00 9.2606446e-34 0.0000000e+00
 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
 3.9701583e-05 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
 0.0000000e+00 0.0000000e+00]

Corruption type: pixelate
Drift? Yes!
Feature-wise p-values:
[3.38559955e-01 2.09074125e-01 1.99041501e-01 8.56629852e-03
 5.97174764e-01 9.44178849e-02 1.66843131e-01 5.21439314e-01
 2.44745538e-02 4.58541155e-01 1.96971247e-04 1.35259181e-01
 1.67503790e-03 1.82668434e-03 1.14905804e-01 9.17565823e-02
 8.31667542e-01 6.25065863e-02 6.64067790e-02 9.24271531e-03
 2.41498753e-01 6.94294155e-01 1.08709060e-01 8.41466635e-02
 3.45860541e-01 1.38920277e-01 3.68379533e-01 3.03615063e-01
 3.76089692e-01 9.24271531e-03 5.03090501e-01 9.96722654e-03]

BBSDs

For BBSDs, we use the classifier’s softmax outputs for black-box shift detection. This method is based on Detecting and Correcting for Label Shift with Black Box Predictors. The ResNet classifier is trained on data standardised by instance so we need to rescale the data.

[15]:
X_train = scale_by_instance(X_train)
X_test = scale_by_instance(X_test)
for i in range(n_corr):
    X_c[i] = scale_by_instance(X_c[i])
X_ref = scale_by_instance(X_ref)
X_h0 = scale_by_instance(X_h0)

Initialisation of the drift detector. Here we use the output of the softmax layer to detect the drift, but other hidden layers can be extracted as well by setting ‘layer’ to the index of the desired hidden layer in the model:

[16]:
p_val = .05
cd = KSDrift(
    p_val=p_val,
    X_ref=X_ref,
    preprocess_fn=hidden_output,
    preprocess_kwargs={'model': clf, 'layer': -1, 'batch_size': 128}  # use output softmax layer
)

Again we can see that the p-value used by the detector for the multivariate data with 10 features (number of CIFAR-10 classes) is equal to p_val / 10 because of the Bonferroni correction.

[17]:
assert cd.p_val / cd.n_features == p_val / 10

There is no drift on the original held out test set:

[18]:
preds_h0 = cd.predict(X_h0)
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print(preds_h0['data']['p_val'])
Drift? No!
[0.11774229 0.52796143 0.19387017 0.20236294 0.496191   0.72781175
 0.12345381 0.420929   0.8367454  0.7604178 ]

We compare this with the perturbed data:

[19]:
for x, c in zip(X_c, corruption):
    preds = cd.predict(x)
    print(f'Corruption type: {c}')
    print('Drift? {}'.format(labels[preds['data']['is_drift']]))
    print('Feature-wise p-values:')
    print(preds['data']['p_val'])
    print('')
Corruption type: gaussian_noise
Drift? Yes!
Feature-wise p-values:
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

Corruption type: motion_blur
Drift? Yes!
Feature-wise p-values:
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

Corruption type: brightness
Drift? Yes!
Feature-wise p-values:
[0.0000000e+00 4.2024049e-15 2.8963613e-33 4.8499879e-07 2.3718185e-15
 1.2473309e-05 2.9714003e-30 1.0611427e-09 4.6048109e-12 4.1857830e-17]

Corruption type: pixelate
Drift? Yes!
Feature-wise p-values:
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

Label drift

We can also check what happens when we introduce class imbalances between the reference data X_ref and the tested data X_imb. The reference data will use \(75\)% of the instances of the first 5 classes and only \(25\)% of the last 5. The data used for drift testing then uses respectively \(25\)% and \(75\)% of the test instances for the first and last 5 classes.

[20]:
np.random.seed(0)
# get index for each class in the test set
num_classes = len(np.unique(y_test))
idx_by_class = [np.where(y_test == c)[0] for c in range(num_classes)]
# sample imbalanced data for different classes for X_ref and X_imb
perc_ref = .75
perc_ref_by_class = [perc_ref if c < 5 else 1 - perc_ref for c in range(num_classes)]
n_by_class = n_test // num_classes
X_ref = []
X_imb, y_imb = [], []
for _ in range(num_classes):
    idx_class_ref = np.random.choice(n_by_class, size=int(perc_ref_by_class[_] * n_by_class), replace=False)
    idx_ref = idx_by_class[_][idx_class_ref]
    idx_class_imb = np.delete(np.arange(n_by_class), idx_class_ref, axis=0)
    idx_imb = idx_by_class[_][idx_class_imb]
    assert idx_ref != idx_imb
    X_ref.append(X_test[idx_ref])
    X_imb.append(X_test[idx_imb])
    y_imb.append(y_test[idx_imb])
X_ref = np.concatenate(X_ref)
X_imb = np.concatenate(X_imb)
y_imb = np.concatenate(y_imb)
print(X_ref.shape, X_imb.shape, y_imb.shape)
(5000, 32, 32, 3) (5000, 32, 32, 3) (5000,)
/home/avl/anaconda3/envs/detect/lib/python3.7/site-packages/ipykernel_launcher.py:16: DeprecationWarning: elementwise comparison failed; this will raise an error in the future.
  app.launch_new_instance()

Update reference dataset for the detector and make predictions:

[21]:
cd.X_ref = X_ref
[22]:
preds_imb = cd.predict(X_imb)
print('Drift? {}'.format(labels[preds_imb['data']['is_drift']]))
print(preds_imb['data']['p_val'])
Drift? Yes!
[6.10830360e-20 1.32319470e-20 7.62410424e-29 1.05537245e-17
 7.68910424e-23 1.57479264e-15 5.77457112e-19 1.94419707e-20
 5.02102509e-21 4.13147353e-21]

Update reference data

So far we have kept the reference data the same throughout the experiments. It is possible however that we want to test a new batch against the last N instances or against a batch of instances of fixed size where we give each instance we have seen up until now the same chance of being in the reference batch (reservoir sampling). The update_X_ref argument allows you to change the reference data update rule. It is a Dict which takes as key the update rule (‘last’ for last N instances or ‘reservoir_sampling’) and as value the batch size N of the reference data. You can also save the detector after the prediction calls to save the updated reference data.

[23]:
N = 7500
cd = KSDrift(
    p_val=.05,
    X_ref=X_ref,
    update_X_ref={'reservoir_sampling': N},
    preprocess_fn=hidden_output,
    preprocess_kwargs={'model': clf, 'layer': -1, 'batch_size': 128}
)

The reference data is now updated with each predict call. Say we start with our imbalanced reference set and make a prediction on the remaining test set data X_imb, then the drift detector will figure out data drift has occurred.

[24]:
preds_imb = cd.predict(X_imb)
print('Drift? {}'.format(labels[preds_imb['data']['is_drift']]))
Drift? Yes!

We can now see that the reference data consists of N instances, obtained through reservoir sampling.

[25]:
assert cd.X_ref.shape[0] == N

We then draw a random sample from the training set and compare it with the updated reference data. This still highlights that there is data drift but will update the reference data again:

[26]:
np.random.seed(0)
perc_train = .5
n_train = X_train.shape[0]
idx_train = np.random.choice(n_train, size=int(perc_train * n_train), replace=False)
[27]:
preds_train = cd.predict(X_train[idx_train])
print('Drift? {}'.format(labels[preds_train['data']['is_drift']]))
Drift? Yes!

When we draw a new sample from the training set, it highlights that it is not drifting anymore against the reservoir in X_ref.

[28]:
np.random.seed(1)
perc_train = .1
idx_train = np.random.choice(n_train, size=int(perc_train * n_train), replace=False)
preds_train = cd.predict(X_train[idx_train])
print('Drift? {}'.format(labels[preds_train['data']['is_drift']]))
Drift? No!

Multivariate correction mechanism

Instead of the Bonferroni correction for multivariate data, we can also use the less conservative False Discovery Rate (FDR) correction. See here or here for nice explanations. While the Bonferroni correction controls the probability of at least one false positive, the FDR correction controls for an expected amount of false positives. The p_val argument at initialisation time can be interpreted as the acceptable q-value when the FDR correction is applied.

[29]:
cd = KSDrift(
    p_val=.05,
    correction='fdr',
    X_ref=X_ref,
    preprocess_fn=hidden_output,
    preprocess_kwargs={'model': clf, 'layer': -1, 'batch_size': 128}
)

preds_imb = cd.predict(X_imb)
print('Drift? {}'.format(labels[preds_imb['data']['is_drift']]))
Drift? Yes!

Train vs. test set drift

Let us do a last check and see whether there is data drift between the original train and test sets. We will use both the BBSDs and Untrained AutoEncoder preprocessing techniques. We start with BBSDs:

[30]:
cd = KSDrift(
    p_val=.05,
    X_ref=X_train,
    preprocess_fn=hidden_output,
    preprocess_kwargs={'model': clf, 'layer': -1, 'batch_size': 128}
)

preds_test = cd.predict(X_test)
print('Drift? {}'.format(labels[preds_test['data']['is_drift']]))
Drift? Yes!

So when we use the softmax output of the classification model, we pick up drift. It is important to notice that this is drift with respect to the classifier output. This means that the detector picked up a difference between the distributions of the softmax probabilities of the model on the train and test sets.

When we use the UAE to preprocess the data we do not care about the distribution of the labels being predicted and only take random projections of the input on the latent dimension into account.

[31]:
cd = KSDrift(
    p_val=.05,
    X_ref=X_train,
    preprocess_fn=uae,
    preprocess_kwargs={'encoder_net': encoder_net, 'batch_size': 128}
)

preds_test = cd.predict(X_test)
print('Drift? {}'.format(labels[preds_test['data']['is_drift']]))
Drift? No!

The UAE detector does not pick up drift. This means that the input data of the classification model does not seem to drift between the train and test set, but the model output does!

Adversarial autoencoder as a malicious drift detector

We can leverage the adversarial scores obtained from an adversarial autoencoder trained on normal data and transform it into a data drift detector. The score function of the adversarial autoencoder becomes the preprocessing function for the drift detector. The K-S test is then a simple univariate test on the adversarial scores. Importantly, an adversarial drift detector flags malicious data drift. We can fetch the pretrained adversarial detector from a Google Cloud Bucket or train one from scratch:

[32]:
load_pretrained = True
[33]:
filepath = 'my_path'  # change to (absolute) directory where model is downloaded
if load_pretrained:
    detector_type = 'adversarial'
    detector_name = 'base'
    ad = fetch_detector(filepath, detector_type, dataset, detector_name, model=model)
    filepath = os.path.join(filepath, detector_name)
else:  # train detector from scratch
    # define encoder and decoder networks
    encoder_net = tf.keras.Sequential(
            [
                InputLayer(input_shape=(32, 32, 3)),
                Conv2D(32, 4, strides=2, padding='same',
                       activation=tf.nn.relu, kernel_regularizer=l1(1e-5)),
                Conv2D(64, 4, strides=2, padding='same',
                       activation=tf.nn.relu, kernel_regularizer=l1(1e-5)),
                Conv2D(256, 4, strides=2, padding='same',
                       activation=tf.nn.relu, kernel_regularizer=l1(1e-5)),
                Flatten(),
                Dense(40)
            ]
        )

    decoder_net = tf.keras.Sequential(
        [
                InputLayer(input_shape=(40,)),
                Dense(4 * 4 * 128, activation=tf.nn.relu),
                Reshape(target_shape=(4, 4, 128)),
                Conv2DTranspose(256, 4, strides=2, padding='same',
                                activation=tf.nn.relu, kernel_regularizer=l1(1e-5)),
                Conv2DTranspose(64, 4, strides=2, padding='same',
                                activation=tf.nn.relu, kernel_regularizer=l1(1e-5)),
                Conv2DTranspose(3, 4, strides=2, padding='same',
                                activation=None, kernel_regularizer=l1(1e-5))
            ]
        )

    # initialise and train detector
    ad = AdversarialAE(
        encoder_net=encoder_net,
        decoder_net=decoder_net,
        model=clf
    )
    ad.fit(X_train, epochs=50, batch_size=128, verbose=True)

    # save the trained adversarial detector
    save_detector(ad, filepath)
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.
WARNING:tensorflow:Error in loading the saved optimizer state. As a result, your model is starting with a freshly initialized optimizer.
WARNING:tensorflow:Error in loading the saved optimizer state. As a result, your model is starting with a freshly initialized optimizer.
WARNING:alibi_detect.ad.adversarialae:No threshold level set. Need to infer threshold using `infer_threshold`.

Initialise the drift detector:

[34]:
np.random.seed(0)
idx = np.random.choice(n_test, size=n_test // 2, replace=False)
X_ref = scale_by_instance(X_test[idx])

cd = KSDrift(
    p_val=.05,
    X_ref=X_ref,
    preprocess_fn=ad.score,
    preprocess_kwargs={'batch_size': 128}
)

Make drift predictions on the original test set and corrupted data:

[35]:
clf_accuracy['h0'] = clf.evaluate(X_h0, y_h0, batch_size=128, verbose=0)[1]
preds_h0 = cd.predict(X_h0)
print('H0: Accuracy {:.4f} -- Drift? {}'.format(
    clf_accuracy['h0'], labels[preds_h0['data']['is_drift']]))
clf_accuracy['imb'] = clf.evaluate(X_imb, y_imb, batch_size=128, verbose=0)[1]
preds_imb = cd.predict(X_imb)
print('imbalance: Accuracy {:.4f} -- Drift? {}'.format(
    clf_accuracy['imb'], labels[preds_imb['data']['is_drift']]))
for x, c in zip(X_c, corruption):
    preds = cd.predict(x)
    print('{}: Accuracy {:.4f} -- Drift? {}'.format(
        c, clf_accuracy[c],labels[preds['data']['is_drift']]))
H0: Accuracy 0.9286 -- Drift? No!
imbalance: Accuracy 0.9282 -- Drift? No!
gaussian_noise: Accuracy 0.2208 -- Drift? Yes!
motion_blur: Accuracy 0.6339 -- Drift? Yes!
brightness: Accuracy 0.8913 -- Drift? Yes!
pixelate: Accuracy 0.3666 -- Drift? Yes!

While X_imb clearly exhibits input data drift due to the introduced class imbalances, it is not flagged by the adversarial drift detector since the performance of the classifier is not affected and the drift is not malicious. We can visualise this by plotting the adversarial scores together with the harmfulness of the data corruption as reflected by the drop in classifier accuracy:

[36]:
adv_scores = {}
score = ad.score(X_ref, batch_size=128)
adv_scores['original'] = {'mean': score.mean(), 'std': score.std()}
score = ad.score(X_h0, batch_size=128)
adv_scores['h0'] = {'mean': score.mean(), 'std': score.std()}
score = ad.score(X_imb, batch_size=128)
adv_scores['imb'] = {'mean': score.mean(), 'std': score.std()}

for x, c in zip(X_c, corruption):
    score_x = ad.score(x, batch_size=128)
    adv_scores[c] = {'mean': score_x.mean(), 'std': score_x.std()}
[38]:
mu = [v['mean'] for _, v in adv_scores.items()]
stdev = [v['std'] for _, v in adv_scores.items()]
xlabels = list(adv_scores.keys())
acc = [clf_accuracy[label] for label in xlabels]
xticks = np.arange(len(mu))

width = .35

fig, ax = plt.subplots()
ax2 = ax.twinx()

p1 = ax.bar(xticks, mu, width, yerr=stdev, capsize=2)
color = 'tab:red'
p2 = ax2.bar(xticks + width, acc, width, color=color)

ax.set_title('Adversarial Scores and Accuracy by Corruption Type')
ax.set_xticks(xticks + width / 2)
ax.set_xticklabels(xlabels, rotation=45)
ax.legend((p1[0], p2[0]), ('Score', 'Accuracy'), loc='upper right', ncol=2)
ax.set_ylabel('Adversarial Score')

color = 'tab:red'
ax2.set_ylabel('Accuracy')
ax2.set_ylim((-.26,1.2))
ax.set_ylim((-2,9))

plt.show()
../_images/examples_cd_ks_cifar10_68_0.png

We can therefore use the scores of the detector itself to quantify the harmfulness of the drift! We can generalise this to all the corruptions at each severity level in CIFAR-10-C:

[39]:
def accuracy(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return (y_true == y_pred).astype(int).sum() / y_true.shape[0]
[40]:
severities = [1, 2, 3, 4, 5]

score_drift = {
    1: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},
    2: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},
    3: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},
    4: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},
    5: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},
}

y_pred = predict_batch(clf, X_test, batch_size=256, return_class=True)
score_x = ad.score(X_test, batch_size=256)

for s in severities:
    print('\nSeverity: {} of {}'.format(s, len(severities)))

    print('Loading corrupted dataset...')
    X_corr, y_corr = fetch_cifar10c(corruption=corruptions, severity=s, return_X_y=True)
    X_corr = X_corr.astype('float32')

    print('Preprocess data...')
    X_corr = scale_by_instance(X_corr)

    print('Make predictions on corrupted dataset...')
    y_pred_corr = predict_batch(clf, X_corr, batch_size=256, return_class=True)

    print('Compute adversarial scores on corrupted dataset...')
    score_corr = ad.score(X_corr, batch_size=256)

    print('Get labels for malicious corruptions...')
    labels_corr = np.zeros(score_corr.shape[0])
    repeat = y_corr.shape[0] // y_test.shape[0]
    y_pred_repeat = np.tile(y_pred, (repeat,))
    # malicious/harmful corruption: original prediction correct but
    # prediction on corrupted data incorrect
    idx_orig_right = np.where(y_pred_repeat == y_corr)[0]
    idx_corr_wrong = np.where(y_pred_corr != y_corr)[0]
    idx_harmful = np.intersect1d(idx_orig_right, idx_corr_wrong)
    labels_corr[idx_harmful] = 1
    labels = np.concatenate([np.zeros(X_test.shape[0]), labels_corr]).astype(int)
    # harmless corruption: original prediction correct and prediction
    # on corrupted data correct
    idx_corr_right = np.where(y_pred_corr == y_corr)[0]
    idx_harmless = np.intersect1d(idx_orig_right, idx_corr_right)

    score_drift[s]['all'] = score_corr
    score_drift[s]['harm'] = score_corr[idx_harmful]
    score_drift[s]['noharm'] = score_corr[idx_harmless]
    score_drift[s]['acc'] = accuracy(y_corr, y_pred_corr)
Severity: 1 of 5
Loading corrupted dataset...
Preprocess data...
Make predictions on corrupted dataset...
Compute adversarial scores on corrupted dataset...
Get labels for malicious corruptions...

Severity: 2 of 5
Loading corrupted dataset...
Preprocess data...
Make predictions on corrupted dataset...
Compute adversarial scores on corrupted dataset...
Get labels for malicious corruptions...

Severity: 3 of 5
Loading corrupted dataset...
Preprocess data...
Make predictions on corrupted dataset...
Compute adversarial scores on corrupted dataset...
Get labels for malicious corruptions...

Severity: 4 of 5
Loading corrupted dataset...
Preprocess data...
Make predictions on corrupted dataset...
Compute adversarial scores on corrupted dataset...
Get labels for malicious corruptions...

Severity: 5 of 5
Loading corrupted dataset...
Preprocess data...
Make predictions on corrupted dataset...
Compute adversarial scores on corrupted dataset...
Get labels for malicious corruptions...

We now compute mean scores and standard deviations per severity level and plot the results. The plot shows the mean adversarial scores (lhs) and ResNet-32 accuracies (rhs) for increasing data corruption severity levels. Level 0 corresponds to the original test set. Harmful scores are scores from instances which have been flipped from the correct to an incorrect prediction because of the corruption. Not harmful means that the prediction was unchanged after the corruption.

[41]:
mu_noharm, std_noharm = [], []
mu_harm, std_harm = [], []
acc = [clf_accuracy['original']]
for k, v in score_drift.items():
    mu_noharm.append(v['noharm'].mean())
    std_noharm.append(v['noharm'].std())
    mu_harm.append(v['harm'].mean())
    std_harm.append(v['harm'].std())
    acc.append(v['acc'])
[42]:
plot_labels = ['0', '1', '2', '3', '4', '5']

N = 6
ind = np.arange(N)
width = .35

fig_bar_cd, ax = plt.subplots()
ax2 = ax.twinx()

p0 = ax.bar(ind[0], score_x.mean(), yerr=score_x.std(), capsize=2)
p1 = ax.bar(ind[1:], mu_noharm, width, yerr=std_noharm, capsize=2)
p2 = ax.bar(ind[1:] + width, mu_harm, width, yerr=std_harm, capsize=2)

ax.set_title('Adversarial Scores and Accuracy by Corruption Severity')
ax.set_xticks(ind + width / 2)
ax.set_xticklabels(plot_labels)
ax.set_ylim((-1,6))
ax.legend((p1[0], p2[0]), ('Not Harmful', 'Harmful'), loc='upper right', ncol=2)
ax.set_ylabel('Score')
ax.set_xlabel('Corruption Severity')

color = 'tab:red'
ax2.set_ylabel('Accuracy', color=color)
ax2.plot(acc, color=color)
ax2.tick_params(axis='y', labelcolor=color)

plt.show()
../_images/examples_cd_ks_cifar10_74_0.png