alibi_detect.cd.tensorflow.context_aware module

class alibi_detect.cd.tensorflow.context_aware.ContextMMDDriftTF(x_ref, c_ref, p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_ref=None, preprocess_fn=None, x_kernel=<class 'alibi_detect.utils.tensorflow.kernels.GaussianRBF'>, c_kernel=<class 'alibi_detect.utils.tensorflow.kernels.GaussianRBF'>, n_permutations=1000, prop_c_held=0.25, n_folds=5, batch_size=256, input_shape=None, data_type=None, verbose=False)[source]

Bases: BaseContextMMDDrift

__init__(x_ref, c_ref, p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_ref=None, preprocess_fn=None, x_kernel=<class 'alibi_detect.utils.tensorflow.kernels.GaussianRBF'>, c_kernel=<class 'alibi_detect.utils.tensorflow.kernels.GaussianRBF'>, n_permutations=1000, prop_c_held=0.25, n_folds=5, batch_size=256, input_shape=None, data_type=None, verbose=False)[source]

A context-aware drift detector based on a conditional analogue of the maximum mean discrepancy (MMD). Only detects differences between samples that can not be attributed to differences between associated sets of contexts. p-values are computed using a conditional permutation test.

Parameters:
  • x_ref (Union[ndarray, list]) – Data used as reference distribution.

  • c_ref (ndarray) – Context for the reference distribution.

  • p_val (float) – p-value used for the significance of the permutation test.

  • x_ref_preprocessed (bool) – Whether the given reference data x_ref has been preprocessed yet. If x_ref_preprocessed=True, only the test data x will be preprocessed at prediction time. If x_ref_preprocessed=False, the reference data will also be preprocessed.

  • preprocess_at_init (bool) – Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference data will be preprocessed at prediction time. Only applies if x_ref_preprocessed=False.

  • update_ref (Optional[Dict[str, int]]) – Reference data can optionally be updated to the last N instances seen by the detector. The parameter should be passed as a dictionary {‘last’: N}.

  • preprocess_fn (Optional[Callable]) – Function to preprocess the data before computing the data drift metrics.

  • x_kernel (Callable) – Kernel defined on the input data, defaults to Gaussian RBF kernel.

  • c_kernel (Callable) – Kernel defined on the context data, defaults to Gaussian RBF kernel.

  • n_permutations (int) – Number of permutations used in the permutation test.

  • prop_c_held (float) – Proportion of contexts held out to condition on.

  • n_folds (int) – Number of cross-validation folds used when tuning the regularisation parameters.

  • batch_size (Optional[int]) – If not None, then compute batches of MMDs at a time (rather than all at once).

  • input_shape (Optional[tuple]) – Shape of input data.

  • data_type (Optional[str]) – Optionally specify the data type (tabular, image or time-series). Added to metadata.

  • verbose (bool) – Whether or not to print progress during configuration.

score(x, c)[source]

Compute the MMD based conditional test statistic, and perform a conditional permutation test to obtain a p-value representing the test statistic’s extremity under the null hypothesis.

Parameters:
  • x (Union[ndarray, list]) – Batch of instances.

  • c (ndarray) – Context associated with batch of instances.

Return type:

Tuple[float, float, float, Tuple]

Returns:

p-value obtained from the conditional permutation test, the conditional MMD test statistic, the test statistic threshold above which drift is flagged, and a tuple containing the coupling matrices (W_{ref,ref}, W_{test,test}, W_{ref,test}).