This page was generated from doc/source/methods/mmddrift.ipynb.

# Maximum Mean Discrepancy¶

## Overview¶

The Maximum Mean Discrepancy (MMD) detector is a kernel-based method for multivariate 2 sample testing. The MMD is a distance-based measure between 2 distributions *p* and *q* based on the mean embeddings \(\mu_{p}\) and \(\mu_{q}\) in a reproducing kernel Hilbert space \(F\):

We can compute unbiased estimates of \(MMD^2\) from the samples of the 2 distributions after applying the kernel trick. We use by default a radial basis function kernel, but users are free to pass their own kernel of preference to the detector. We obtain a \(p\)-value via a permutation test on the values of \(MMD^2\).

For high-dimensional data, we typically want to reduce the dimensionality before computing the permutation test. Following suggestions in Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift, we incorporate Untrained AutoEncoders (UAE), black-box shift detection using the classifier’s softmax outputs (BBSDs) and PCA as out-of-the box preprocessing methods. Preprocessing methods which do not rely on the classifier will usually pick up drift in the input data, while BBSDs focuses on label shift.

## Usage¶

### Initialize¶

Parameters:

`p_val`

: p-value used for significance of the permutation test.`X_ref`

: Data used as reference distribution.`update_X_ref`

: Reference data can optionally be updated to the last N instances seen by the detector or via reservoir sampling with size N. For the former, the parameter equals*{‘last’: N}*while for reservoir sampling*{‘reservoir_sampling’: N}*is passed.`preprocess_fn`

: Function to preprocess the data before computing the data drift metrics. Typically a dimensionality reduction technique. The out-of-the box methods UAE, BBSDs and PCA are illustrated in the example notebook.`preprocess_kwargs`

: Keyword arguments for`preprocess_fn`

. Again see the notebook for concrete examples.`kernel`

: Kernel function used when computing the MMD. Defaults to a Gaussian kernel.`kernel_kwargs`

: Keyword arguments for the kernel function. For the Gaussian kernel this is the kernel bandwidth`sigma`

. We can also sum over a number of different kernel bandwidths.`sigma`

then becomes an array with different values. If`sigma`

is not specified, the detector will infer it by computing the pairwise distances between each of the instances in the 2 samples and set`sigma`

to the median distance.`n_permutations`

: Number of permutations used in the permutation test.`chunk_size`

: Used to optionally compute the MMD between the 2 samples in chunks using dask to avoid potential out-of-memory errors. In terms of speed, the optimal chunk size is application and hardware dependent, so it is often worth to test a few different values, including*None*.*None*means that the computation is done in-memory in NumPy.`data_type`

: can specify data type added to metadata. E.g.*‘tabular’*or*‘image’*.

Initialized drift detector example:

```
from alibi_detect.cd import MMDDrift
from alibi_detect.cd.preprocess import uae # Untrained AutoEncoder
encoder_net = tf.keras.Sequential(
[
InputLayer(input_shape=(32, 32, 3)),
Conv2D(64, 4, strides=2, padding='same', activation=tf.nn.relu),
Conv2D(128, 4, strides=2, padding='same', activation=tf.nn.relu),
Conv2D(512, 4, strides=2, padding='same', activation=tf.nn.relu),
Flatten(),
Dense(32,)
]
)
cd = MMDDrift(
p_val=.05,
X_ref=X_ref,
preprocess_fn=uae,
preprocess_kwargs={'encoder_net': encoder_net, 'batch_size': 128},
kernel=gaussian_kernel,
kernel_kwargs={'sigma': np.array([.5, 1., 5.])},
chunk_size=1000,
n_permutations=1000
)
```

### Detect Drift¶

We detect data drift by simply calling `predict`

on a batch of instances `X`

. We can return the p-value of the permutation test by setting `return_p_val`

to *True*.

The prediction takes the form of a dictionary with `meta`

and `data`

keys. `meta`

contains the detector’s metadata while `data`

is also a dictionary which contains the actual predictions stored in the following keys:

`is_drift`

: 1 if the sample tested has drifted from the reference data and 0 otherwise.`p_val`

: contains the p-value if`return_p_val`

equals*True*.

```
preds_drift = cd.predict(X, return_p_val=True)
```