This page was generated from doc/source/methods/classifierdrift.ipynb.




The classifier-based drift detector Lopez-Paz and Oquab, 2017 simply tries to correctly classify instances from the reference data vs. the test set. If the classifier does not manage to significantly distinguish the reference data from the test set according to a chosen metric (defaults to the classifier accuracy), then no drift occurs. If it can, the test set is different from the reference data and drift is flagged. To leverage all the available reference and test data, stratified cross-validation can be applied and the out-of-fold predictions are used to compute the drift metric. Note that a new classifier is trained for each test set or even each fold within the test set.




  • threshold: Threshold for the drift metric (default is accuracy). Values above the threshold are classified as drift.

  • model: Classification model used for drift detection.

  • X_ref: Data used as reference distribution.

  • preprocess_X_ref: Whether to already preprocess and store the reference data using the preprocess_fn. Typically set to True since it can reduce the time to detect drift during the predict call. It is possible that it needs to be set to False if the preprocessing step requires statistics from both the reference and test data, such as the mean or standard deviation.

  • update_X_ref: Reference data can optionally be updated to the last N instances seen by the detector or via reservoir sampling with size N. For the former, the parameter equals {‘last’: N} while for reservoir sampling {‘reservoir_sampling’: N} is passed.

  • preprocess_fn: Function to preprocess the data before computing the data drift metrics.

  • preprocess_kwargs: Keyword arguments for preprocess_fn.

  • metric_fn: Function computing the drift metric. Takes y_true and y_pred as input and returns a float: metric_fn(y_true, y_pred). Defaults to accuracy.

  • metric_name: Optional name for the metric_fn used in the return dict. Defaults to metric_fn.__name__.

  • train_size: Optional fraction (float between 0 and 1) of the dataset used to train the classifier. The drift is detected on 1 - train_size. Cannot be used in combination with n_folds.

  • n_folds: Optional number of stratified folds used for training. The metric is then calculated on all the out-of-fold predictions. This allows to leverage all the reference and test data for drift detection at the expense of longer computation. If both train_size and n_folds are specified, n_folds is prioritized.

  • seed: Optional random seed for fold selection.

  • optimizer: Optimizer used during training of the classifier.

  • compile_kwargs: Optional additional kwargs form model.compile() when compiling the classifier.

  • batch_size: Batch size used during training of the classifier.

  • epochs: Number of training epochs for the classifier. Applies to each fold if n_folds is specified.

  • verbose: Verbosity level during the training of the classifier. 0 is silent, 1 a progress bar and 2 prints the statistics after each epoch.

  • fit_kwargs: Optional additional kwargs for when fitting the classifier.

  • data_type: Optionally specify the data type (e.g. tabular, image or time-series). Added to metadata.

Initialized drift detector example:

from import ClassifierDrift

model = tf.keras.Sequential(
      Input(shape=(32, 32, 3)),
      Conv2D(8, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(16, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(32, 4, strides=2, padding='same', activation=tf.nn.relu),
      Dense(2, activation='softmax')

cd = ClassifierDrift(threshold=.55, model=model, X_ref=X_ref, n_folds=5, epochs=2)

Detect Drift

We detect data drift by simply calling predict on a batch of instances X. return_metric equal to True will also return the drift metric (e.g. accuracy) and the threshold used by the detector.

The prediction takes the form of a dictionary with meta and data keys. meta contains the detector’s metadata while data is also a dictionary which contains the actual predictions stored in the following keys:

  • is_drift: 1 if the sample tested has drifted from the reference data and 0 otherwise.

  • threshold: user-defined drift threshold for the chosen drift metric.

  • metric_fn.__name__ or the optional metric_name kwarg value: drift metric value if return_metric equals True.

preds_drift = cd.predict(X, return_metric=True)

Saving and loading

The drift detectors can be saved and loaded in the same way as other detectors:

from alibi_detect.utils.saving import save_detector, load_detector

filepath = 'my_path'
save_detector(cd, filepath)
cd = load_detector(filepath)