alibi_detect.cd.context_aware module
- class alibi_detect.cd.context_aware.ContextMMDDrift(x_ref, c_ref, backend='tensorflow', p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_ref=None, preprocess_fn=None, x_kernel=None, c_kernel=None, n_permutations=1000, prop_c_held=0.25, n_folds=5, batch_size=256, device=None, input_shape=None, data_type=None, verbose=False)[source]
Bases:
DriftConfigMixin
- __init__(x_ref, c_ref, backend='tensorflow', p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_ref=None, preprocess_fn=None, x_kernel=None, c_kernel=None, n_permutations=1000, prop_c_held=0.25, n_folds=5, batch_size=256, device=None, input_shape=None, data_type=None, verbose=False)[source]
A context-aware drift detector based on a conditional analogue of the maximum mean discrepancy (MMD). Only detects differences between samples that can not be attributed to differences between associated sets of contexts. p-values are computed using a conditional permutation test.
- Parameters:
x_ref (
Union
[ndarray
,list
]) – Data used as reference distribution.c_ref (
ndarray
) – Context for the reference distribution.backend (
str
) – Backend used for the MMD implementation.p_val (
float
) – p-value used for the significance of the permutation test.x_ref_preprocessed (
bool
) – Whether the given reference data x_ref has been preprocessed yet. If x_ref_preprocessed=True, only the test data x will be preprocessed at prediction time. If x_ref_preprocessed=False, the reference data will also be preprocessed.preprocess_at_init (
bool
) – Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference data will be preprocessed at prediction time. Only applies if x_ref_preprocessed=False.update_ref (
Optional
[Dict
[str
,int
]]) – Reference data can optionally be updated to the last N instances seen by the detector. The parameter should be passed as a dictionary {‘last’: N}.preprocess_fn (
Optional
[Callable
]) – Function to preprocess the data before computing the data drift metrics.x_kernel (
Callable
) – Kernel defined on the input data, defaults to Gaussian RBF kernel.c_kernel (
Callable
) – Kernel defined on the context data, defaults to Gaussian RBF kernel.n_permutations (
int
) – Number of permutations used in the permutation test.prop_c_held (
float
) – Proportion of contexts held out to condition on.n_folds (
int
) – Number of cross-validation folds used when tuning the regularisation parameters.batch_size (
Optional
[int
]) – If not None, then compute batches of MMDs at a time (rather than all at once).device (
Union
[Literal
[‘cuda’, ‘gpu’, ‘cpu’], torch.device,None
]) – Device type used. The default tries to use the GPU and falls back on CPU if needed. Can be specified by passing either'cuda'
,'gpu'
,'cpu'
or an instance oftorch.device
. Only relevant for ‘pytorch’ backend.data_type (
Optional
[str
]) – Optionally specify the data type (tabular, image or time-series). Added to metadata.verbose (
bool
) – Whether to print progress messages.
- predict(x, c, return_p_val=True, return_distance=True, return_coupling=False)[source]
Predict whether a batch of data has drifted from the reference data, given the provided context.
- Parameters:
c (
ndarray
) – Context associated with batch of instances.return_p_val (
bool
) – Whether to return the p-value of the permutation test.return_distance (
bool
) – Whether to return the conditional MMD test statistic between the new batch and reference data.return_coupling (
bool
) – Whether to return the coupling matrices.
- Return type:
- Returns:
Dictionary containing
'meta'
and'data'
dictionaries. –'meta'
has the model’s metadata.'data'
contains the drift prediction and optionally the p-value, threshold, conditional MMD test statistic and coupling matrices.
- score(x, c)[source]
Compute the MMD based conditional test statistic, and perform a conditional permutation test to obtain a p-value representing the test statistic’s extremity under the null hypothesis.
- Parameters:
- Return type:
- Returns:
p-value obtained from the conditional permutation test, the conditional MMD test statistic, the test statistic threshold above which drift is flagged, and a tuple containing the coupling matrices \((W_{ref,ref}, W_{test,test}, W_{ref,test})\).