alibi_detect.cd.pytorch.context_aware module

class alibi_detect.cd.pytorch.context_aware.ContextMMDDriftTorch(x_ref, c_ref, p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_ref=None, preprocess_fn=None, x_kernel=<class 'alibi_detect.utils.pytorch.kernels.GaussianRBF'>, c_kernel=<class 'alibi_detect.utils.pytorch.kernels.GaussianRBF'>, n_permutations=1000, prop_c_held=0.25, n_folds=5, batch_size=256, device=None, input_shape=None, data_type=None, verbose=False)[source]

Bases: BaseContextMMDDrift

__init__(x_ref, c_ref, p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_ref=None, preprocess_fn=None, x_kernel=<class 'alibi_detect.utils.pytorch.kernels.GaussianRBF'>, c_kernel=<class 'alibi_detect.utils.pytorch.kernels.GaussianRBF'>, n_permutations=1000, prop_c_held=0.25, n_folds=5, batch_size=256, device=None, input_shape=None, data_type=None, verbose=False)[source]

A context-aware drift detector based on a conditional analogue of the maximum mean discrepancy (MMD). Only detects differences between samples that can not be attributed to differences between associated sets of contexts. p-values are computed using a conditional permutation test.

Parameters:
  • x_ref (Union[ndarray, list]) – Data used as reference distribution.

  • c_ref (ndarray) – Context for the reference distribution.

  • p_val (float) – p-value used for the significance of the permutation test.

  • x_ref_preprocessed (bool) – Whether the given reference data x_ref has been preprocessed yet. If x_ref_preprocessed=True, only the test data x will be preprocessed at prediction time. If x_ref_preprocessed=False, the reference data will also be preprocessed.

  • preprocess_at_init (bool) – Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference data will be preprocessed at prediction time. Only applies if x_ref_preprocessed=False.

  • update_ref (Optional[Dict[str, int]]) – Reference data can optionally be updated to the last N instances seen by the detector. The parameter should be passed as a dictionary {‘last’: N}.

  • preprocess_fn (Optional[Callable]) – Function to preprocess the data before computing the data drift metrics.

  • x_kernel (Callable) – Kernel defined on the input data, defaults to Gaussian RBF kernel.

  • c_kernel (Callable) – Kernel defined on the context data, defaults to Gaussian RBF kernel.

  • n_permutations (int) – Number of permutations used in the permutation test.

  • prop_c_held (float) – Proportion of contexts held out to condition on.

  • n_folds (int) – Number of cross-validation folds used when tuning the regularisation parameters.

  • batch_size (Optional[int]) – If not None, then compute batches of MMDs at a time (rather than all at once).

  • device (Union[Literal[‘cuda’, ‘gpu’, ‘cpu’], device, None]) – Device type used. The default tries to use the GPU and falls back on CPU if needed. Can be specified by passing either 'cuda', 'gpu', 'cpu' or an instance of torch.device. Only relevant for ‘pytorch’ backend.

  • input_shape (Optional[tuple]) – Shape of input data.

  • data_type (Optional[str]) – Optionally specify the data type (tabular, image or time-series). Added to metadata.

  • verbose (bool) – Whether or not to print progress during configuration.

lams: Tuple[torch.Tensor, torch.Tensor] | None = None
score(x, c)[source]

Compute the MMD based conditional test statistic, and perform a conditional permutation test to obtain a p-value representing the test statistic’s extremity under the null hypothesis.

Parameters:
  • x (Union[ndarray, list]) – Batch of instances.

  • c (ndarray) – Context associated with batch of instances.

Return type:

Tuple[float, float, float, Tuple]

Returns:

p-value obtained from the conditional permutation test, the conditional MMD test statistic, the test statistic threshold above which drift is flagged, and a tuple containing the coupling matrices (W_{ref,ref}, W_{test,test}, W_{ref,test}).