alibi_detect.cd.context_aware module

class alibi_detect.cd.context_aware.ContextMMDDrift(x_ref, c_ref, backend='tensorflow', p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_ref=None, preprocess_fn=None, x_kernel=None, c_kernel=None, n_permutations=1000, prop_c_held=0.25, n_folds=5, batch_size=256, device=None, input_shape=None, data_type=None, verbose=False)[source]

Bases: DriftConfigMixin

__init__(x_ref, c_ref, backend='tensorflow', p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_ref=None, preprocess_fn=None, x_kernel=None, c_kernel=None, n_permutations=1000, prop_c_held=0.25, n_folds=5, batch_size=256, device=None, input_shape=None, data_type=None, verbose=False)[source]

A context-aware drift detector based on a conditional analogue of the maximum mean discrepancy (MMD). Only detects differences between samples that can not be attributed to differences between associated sets of contexts. p-values are computed using a conditional permutation test.

Parameters:
  • x_ref (Union[ndarray, list]) – Data used as reference distribution.

  • c_ref (ndarray) – Context for the reference distribution.

  • backend (str) – Backend used for the MMD implementation.

  • p_val (float) – p-value used for the significance of the permutation test.

  • x_ref_preprocessed (bool) – Whether the given reference data x_ref has been preprocessed yet. If x_ref_preprocessed=True, only the test data x will be preprocessed at prediction time. If x_ref_preprocessed=False, the reference data will also be preprocessed.

  • preprocess_at_init (bool) – Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference data will be preprocessed at prediction time. Only applies if x_ref_preprocessed=False.

  • update_ref (Optional[Dict[str, int]]) – Reference data can optionally be updated to the last N instances seen by the detector. The parameter should be passed as a dictionary {‘last’: N}.

  • preprocess_fn (Optional[Callable]) – Function to preprocess the data before computing the data drift metrics.

  • x_kernel (Callable) – Kernel defined on the input data, defaults to Gaussian RBF kernel.

  • c_kernel (Callable) – Kernel defined on the context data, defaults to Gaussian RBF kernel.

  • n_permutations (int) – Number of permutations used in the permutation test.

  • prop_c_held (float) – Proportion of contexts held out to condition on.

  • n_folds (int) – Number of cross-validation folds used when tuning the regularisation parameters.

  • batch_size (Optional[int]) – If not None, then compute batches of MMDs at a time (rather than all at once).

  • device (Union[Literal[‘cuda’, ‘gpu’, ‘cpu’], torch.device, None]) – Device type used. The default tries to use the GPU and falls back on CPU if needed. Can be specified by passing either 'cuda', 'gpu', 'cpu' or an instance of torch.device. Only relevant for ‘pytorch’ backend.

  • input_shape (Optional[tuple]) – Shape of input data.

  • data_type (Optional[str]) – Optionally specify the data type (tabular, image or time-series). Added to metadata.

  • verbose (bool) – Whether to print progress messages.

predict(x, c, return_p_val=True, return_distance=True, return_coupling=False)[source]

Predict whether a batch of data has drifted from the reference data, given the provided context.

Parameters:
  • x (Union[ndarray, list]) – Batch of instances.

  • c (ndarray) – Context associated with batch of instances.

  • return_p_val (bool) – Whether to return the p-value of the permutation test.

  • return_distance (bool) – Whether to return the conditional MMD test statistic between the new batch and reference data.

  • return_coupling (bool) – Whether to return the coupling matrices.

Return type:

Dict[Dict[str, str], Dict[str, Union[int, float]]]

Returns:

Dictionary containing 'meta' and 'data' dictionaries. –

  • 'meta' has the model’s metadata.

  • 'data' contains the drift prediction and optionally the p-value, threshold, conditional MMD test statistic and coupling matrices.

score(x, c)[source]

Compute the MMD based conditional test statistic, and perform a conditional permutation test to obtain a p-value representing the test statistic’s extremity under the null hypothesis.

Parameters:
  • x (Union[ndarray, list]) – Batch of instances.

  • c (ndarray) – Context associated with batch of instances.

Return type:

Tuple[float, float, float, Tuple]

Returns:

p-value obtained from the conditional permutation test, the conditional MMD test statistic, the test statistic threshold above which drift is flagged, and a tuple containing the coupling matrices \((W_{ref,ref}, W_{test,test}, W_{ref,test})\).