alibi_detect.cd.mmd module
- class alibi_detect.cd.mmd.MMDDrift(x_ref, backend='tensorflow', p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_x_ref=None, preprocess_fn=None, kernel=None, sigma=None, configure_kernel_from_x_ref=True, n_permutations=100, batch_size_permutations=1000000, device=None, input_shape=None, data_type=None)[source]
Bases:
DriftConfigMixin
- __init__(x_ref, backend='tensorflow', p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_x_ref=None, preprocess_fn=None, kernel=None, sigma=None, configure_kernel_from_x_ref=True, n_permutations=100, batch_size_permutations=1000000, device=None, input_shape=None, data_type=None)[source]
Maximum Mean Discrepancy (MMD) data drift detector using a permutation test.
- Parameters:
x_ref (
Union
[ndarray
,list
]) – Data used as reference distribution.backend (
str
) – Backend used for the MMD implementation.p_val (
float
) – p-value used for the significance of the permutation test.x_ref_preprocessed (
bool
) – Whether the given reference data x_ref has been preprocessed yet. If x_ref_preprocessed=True, only the test data x will be preprocessed at prediction time. If x_ref_preprocessed=False, the reference data will also be preprocessed.preprocess_at_init (
bool
) – Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference data will be preprocessed at prediction time. Only applies if x_ref_preprocessed=False.update_x_ref (
Optional
[Dict
[str
,int
]]) – Reference data can optionally be updated to the last n instances seen by the detector or via reservoir sampling with size n. For the former, the parameter equals {‘last’: n} while for reservoir sampling {‘reservoir_sampling’: n} is passed.preprocess_fn (
Optional
[Callable
]) – Function to preprocess the data before computing the data drift metrics.kernel (
Callable
) – Kernel used for the MMD computation, defaults to Gaussian RBF kernel.sigma (
Optional
[ndarray
]) – Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array. The kernel evaluation is then averaged over those bandwidths.configure_kernel_from_x_ref (
bool
) – Whether to already configure the kernel bandwidth from the reference data.n_permutations (
int
) – Number of permutations used in the permutation test.batch_size_permutations (
int
) – KeOps computes the n_permutations of the MMD^2 statistics in chunks of batch_size_permutations. Only relevant for ‘keops’ backend.device (
Union
[Literal
['cuda'
,'gpu'
,'cpu'
], torch.device,None
]) – Device type used. The default tries to use the GPU and falls back on CPU if needed. Can be specified by passing either'cuda'
,'gpu'
,'cpu'
or an instance oftorch.device
. Only relevant for ‘pytorch’ backend.data_type (
Optional
[str
]) – Optionally specify the data type (tabular, image or time-series). Added to metadata.
- predict(x, return_p_val=True, return_distance=True)[source]
Predict whether a batch of data has drifted from the reference data.
- Parameters:
- Return type:
- Returns:
Dictionary containing
'meta'
and'data'
dictionaries. –'meta'
has the model’s metadata.'data'
contains the drift prediction and optionally the p-value, threshold and MMD metric.