This page was generated from doc/source/methods/mmddrift.ipynb.


Maximum Mean Discrepancy


The Maximum Mean Discrepancy (MMD) detector is a kernel-based method for multivariate 2 sample testing. The MMD is a distance-based measure between 2 distributions p and q based on the mean embeddings \(\mu_{p}\) and \(\mu_{q}\) in a reproducing kernel Hilbert space \(F\):

\begin{align} MMD(F, p, q) & = || \mu_{p} - \mu_{q} ||^2_{F} \\ \end{align}

We can compute unbiased estimates of \(MMD^2\) from the samples of the 2 distributions after applying the kernel trick. We use by default a radial basis function kernel, but users are free to pass their own kernel of preference to the detector. We obtain a \(p\)-value via a permutation test on the values of \(MMD^2\).

For high-dimensional data, we typically want to reduce the dimensionality before computing the permutation test. Following suggestions in Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift, we incorporate Untrained AutoEncoders (UAE), black-box shift detection using the classifier’s softmax outputs (BBSDs) and PCA as out-of-the box preprocessing methods. Preprocessing methods which do not rely on the classifier will usually pick up drift in the input data, while BBSDs focuses on label shift.




  • p_val: p-value used for significance of the permutation test.

  • X_ref: Data used as reference distribution.

  • update_X_ref: Reference data can optionally be updated to the last N instances seen by the detector or via reservoir sampling with size N. For the former, the parameter equals {‘last’: N} while for reservoir sampling {‘reservoir_sampling’: N} is passed.

  • preprocess_fn: Function to preprocess the data before computing the data drift metrics. Typically a dimensionality reduction technique. The out-of-the box methods UAE, BBSDs and PCA are illustrated in the example notebook.

  • preprocess_kwargs: Keyword arguments for preprocess_fn. Again see the notebook for concrete examples.

  • kernel: Kernel function used when computing the MMD. Defaults to a Gaussian kernel.

  • kernel_kwargs: Keyword arguments for the kernel function. For the Gaussian kernel this is the kernel bandwidth sigma. We can also sum over a number of different kernel bandwidths. sigma then becomes an array with different values. If sigma is not specified, the detector will infer it by computing the pairwise distances between each of the instances in the 2 samples and set sigma to the median distance.

  • n_permutations: Number of permutations used in the permutation test.

  • chunk_size: Used to optionally compute the MMD between the 2 samples in chunks using dask to avoid potential out-of-memory errors. In terms of speed, the optimal chunk size is application and hardware dependent, so it is often worth to test a few different values, including None. None means that the computation is done in-memory in NumPy.

  • data_type: can specify data type added to metadata. E.g. ‘tabular’ or ‘image’.

Initialized drift detector example:

from import MMDDrift
from import uae  # Untrained AutoEncoder

encoder_net = tf.keras.Sequential(
      InputLayer(input_shape=(32, 32, 3)),
      Conv2D(64, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(128, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(512, 4, strides=2, padding='same', activation=tf.nn.relu),

cd = MMDDrift(
    preprocess_kwargs={'encoder_net': encoder_net, 'batch_size': 128},
    kernel_kwargs={'sigma': np.array([.5, 1., 5.])},

Detect Drift

We detect data drift by simply calling predict on a batch of instances X. We can return the p-value of the permutation test by setting return_p_val to True.

The prediction takes the form of a dictionary with meta and data keys. meta contains the detector’s metadata while data is also a dictionary which contains the actual predictions stored in the following keys:

  • is_drift: 1 if the sample tested has drifted from the reference data and 0 otherwise.

  • p_val: contains the p-value if return_p_val equals True.

preds_drift = cd.predict(X, return_p_val=True)