This page was generated from doc/source/cd/methods/learnedkerneldrift.ipynb.


Learned Kernel


The learned-kernel drift detector (Liu et al., 2020) is an extension of the Maximum Mean Discrepancy drift detector where the kernel used to define the MMD is trained using a portion of the data to maximise an estimate of the resulting test power. Once the kernel has been learned a permutation test is performed in the usual way on the value of the MMD.

This method is closely related to the classifier drift detector which trains a classifier to discriminate between instances from the reference window and instances from the test window. The difference here is that we train a kernel to output high similarity on instances from the same window and low similarity between instances from different windows. If this is possible in a generalisable manner then drift must have occured.

As with the classifier-based approach, we should specify the proportion of data to use for training and testing respectively as well as training arguments such as the learning rate and batch size. Note that a new kernel is trained for each test set that is passed for detection.




  • x_ref: Data used as reference distribution.

  • kernel: A differentiable TensorFlow or PyTorch module that takes two instances as input and returns a scalar notion of similarity as output.

Keyword arguments:

  • backend: Specify the backend (tensorflow or pytorch). This depends on the framework of the kernel. Defaults to tensorflow.

  • p_val: p-value threshold used for the significance of the test.

  • preprocess_x_ref: Whether to already apply the (optional) preprocessing step to the reference data at initialization and store the preprocessed data. Dependent on the preprocessing step, this can reduce the computation time for the predict step significantly, especially when the reference dataset is large. Defaults to True. It is possible that it needs to be set to False if the preprocessing step requires statistics from both the reference and test data, such as the mean or standard deviation.

  • update_x_ref: Reference data can optionally be updated to the last N instances seen by the detector or via reservoir sampling with size N. For the former, the parameter equals {‘last’: N} while for reservoir sampling {‘reservoir_sampling’: N} is passed. If the input data type is of type List[Any] then update_x_ref needs to be set to None and the reference set remains fixed.

  • preprocess_fn: Function to preprocess the data before computing the data drift metrics.

  • n_permutations: The number of permutations to use in the permutation test once the MMD has been computed.

  • var_reg: Constant added to the estimated variance of the MMD for stability.

  • reg_loss_fn: The regularisation term reg_loss_fn(kernel) is added to the loss function being optimized.

  • train_size: Optional fraction (float between 0 and 1) of the dataset used to train the classifier. The drift is detected on 1 - train_size. Cannot be used in combination with n_folds.

  • retrain_from_scratch: Whether the kernel should be retrained from scratch for each set of test data or whether it should instead continue training from where it left off on the previous set.

  • optimizer: Optimizer used during training of the kernel. From torch.optim for PyTorch and tf.keras.optimizers for TensorFlow.

  • learning_rate: Learning rate for the optimizer.

  • batch_size: Batch size used during training of the kernel.

  • preprocess_batch_fn: Optional batch preprocessing function. For example to convert a list of generic objects to a tensor which can be processed by the kernel.

  • epochs: Number of training epochs for the kernel.

  • verbose: Verbosity level during the training of the kernel. 0 is silent and 1 prints a progress bar.

  • train_kwargs: Optional additional kwargs for the built-in TensorFlow (from alibi_detect.models.tensorflow import trainer) or PyTorch (from alibi_detect.models.pytorch import trainer) trainer functions.

  • dataset: Dataset object used during training of the kernel. Defaults to (an instance of for the PyTorch backend and (an instance of tf.keras.utils.Sequence) for the TensorFlow backend. For PyTorch, the dataset should only take the windows x_ref and x_test as input, so when e.g. TorchDataset is passed to the detector at initialisation, during training TorchDataset(x_ref, x_test) is used. For TensorFlow, the dataset is an instance of tf.keras.utils.Sequence, so when e.g. TFDataset is passed to the detector at initialisation, during training TFDataset(x_ref, x_test, batch_size=batch_size, shuffle=True) is used. x_ref and x_test can be of type np.ndarray or List[Any].

  • data_type: Optionally specify the data type (e.g. tabular, image or time-series). Added to metadata.

Additional PyTorch keyword arguments:

  • device: cuda or gpu to use the GPU and cpu for the CPU. If the device is not specified, the detector will try to leverage the GPU if possible and otherwise fall back on CPU.

  • dataloader: Dataloader object used during training of the kernel. Defaults to The dataloader is not initialized yet, this is done during init off the detector using the batch_size. Custom dataloaders can be passed as well, e.g. for graph data we can use

Defining the kernel

Any differentiable Pytorch or TensorFlow module that takes as input two instances and outputs a scalar (representing similarity) can be used as the kernel for this drift detector. However, in order to ensure that MMD=0 implies no-drift the kernel should satify a characteristic property. This can be guarenteed by defining a kernel as

\[k(x,y)=(1-\epsilon)*k_a(\Phi(x), \Phi(y)) + \epsilon*k_b(x,y),\]

where \(\Phi\) is a learnable projection, \(k_a\) and \(k_b\) are simple characteristic kernels (such as a Gaussian RBF), and \(\epsilon>0\) is a small constant. By letting \(\Phi\) be very flexible we can learn powerful kernels in this manner.

This is easily implemented using the DeepKernel class provided in alibi_detect. We demonstrate below how we might define a convolutional kernel for images using Pytorch. By default GaussianRBF kernels are used for \(k_a\) and \(k_b\) and here we specify \(\epsilon=0.01\), but we could alternatively set eps='trainable'.

from torch import nn
from alibi_detect.utils.pytorch.kernels import DeepKernel

# define the projection phi
proj = nn.Sequential(
    nn.Conv2d(3, 8, 4, stride=2, padding=0),
    nn.Conv2d(8, 16, 4, stride=2, padding=0),
    nn.Conv2d(16, 32, 4, stride=2, padding=0),

# define the kernel
kernel = DeepKernel(proj, eps=0.01)

Instantiating the detector

Instantiating the detector is then as simple as passing the reference data and the kernel as follows:

# instantiate the detector
from import LearnedKernelDrift

cd = LearnedKernelDrift(x_ref, kernel, backend='pytorch', p_val=.05, epochs=10, batch_size=32)

We could have alternatively defined the kernel and instantiated the detector using TensorFlow:

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Flatten, Input
from alibi_detect.utils.tensorflow.kernels import DeepKernel

# define the projection phi
proj = tf.keras.Sequential(
      Input(shape=(32, 32, 3)),
      Conv2D(8, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(16, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(32, 4, strides=2, padding='same', activation=tf.nn.relu),

# define the kernel
kernel = DeepKernel(proj, eps=0.01)

# instantiate the detector
cd = LearnedKernelDrift(x_ref, kernel, backend='tensorflow', p_val=.05, epochs=10, batch_size=32)

Detect Drift

We detect data drift by simply calling predict on a batch of instances x. return_p_val equal to True will also return the p-value of the test, return_distance equal to True will return a notion of strength of the drift and return_kernel equals True will also return the trained kernel.

The prediction takes the form of a dictionary with meta and data keys. meta contains the detector’s metadata while data is also a dictionary which contains the actual predictions stored in the following keys:

  • is_drift: 1 if the sample tested has drifted from the reference data and 0 otherwise.

  • threshold: the user-defined p-value threshold defining the significance of the test

  • p_val: the p-value of the test if return_p_val equals True.

  • distance: MMD^2 metric between the reference data and the new batch if return_distance equals True.

  • distance_threshold: MMD^2 metric value from the permutation test which corresponds to the the p-value threshold if return_distance equals True.

  • kernel: The trained kernel if return_kernel equals True.

preds = cd.predict(X, return_p_val=True, return_distance=True)