import logging
import numpy as np
from typing import Callable, Dict, Optional, Union, Tuple
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, BackendValidator, Framework
from alibi_detect.utils.warnings import deprecated_alias
from alibi_detect.base import DriftConfigMixin
from alibi_detect.utils._types import TorchDeviceType

if has_pytorch:
    from import ContextMMDDriftTorch

if has_tensorflow:
    from import ContextMMDDriftTF

logger = logging.getLogger(__name__)

[docs]class ContextMMDDrift(DriftConfigMixin):
[docs] @deprecated_alias(preprocess_x_ref='preprocess_at_init') def __init__( self, x_ref: Union[np.ndarray, list], c_ref: np.ndarray, backend: str = 'tensorflow', p_val: float = .05, x_ref_preprocessed: bool = False, preprocess_at_init: bool = True, update_ref: Optional[Dict[str, int]] = None, preprocess_fn: Optional[Callable] = None, x_kernel: Callable = None, c_kernel: Callable = None, n_permutations: int = 1000, prop_c_held: float = 0.25, n_folds: int = 5, batch_size: Optional[int] = 256, device: TorchDeviceType = None, input_shape: Optional[tuple] = None, data_type: Optional[str] = None, verbose: bool = False ) -> None: """ A context-aware drift detector based on a conditional analogue of the maximum mean discrepancy (MMD). Only detects differences between samples that can not be attributed to differences between associated sets of contexts. p-values are computed using a conditional permutation test. Parameters ---------- x_ref Data used as reference distribution. c_ref Context for the reference distribution. backend Backend used for the MMD implementation. p_val p-value used for the significance of the permutation test. x_ref_preprocessed Whether the given reference data `x_ref` has been preprocessed yet. If `x_ref_preprocessed=True`, only the test data `x` will be preprocessed at prediction time. If `x_ref_preprocessed=False`, the reference data will also be preprocessed. preprocess_at_init Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference data will be preprocessed at prediction time. Only applies if `x_ref_preprocessed=False`. update_ref Reference data can optionally be updated to the last N instances seen by the detector. The parameter should be passed as a dictionary *{'last': N}*. preprocess_fn Function to preprocess the data before computing the data drift metrics. x_kernel Kernel defined on the input data, defaults to Gaussian RBF kernel. c_kernel Kernel defined on the context data, defaults to Gaussian RBF kernel. n_permutations Number of permutations used in the permutation test. prop_c_held Proportion of contexts held out to condition on. n_folds Number of cross-validation folds used when tuning the regularisation parameters. batch_size If not None, then compute batches of MMDs at a time (rather than all at once). device Device type used. The default tries to use the GPU and falls back on CPU if needed. Can be specified by passing either ``'cuda'``, ``'gpu'``, ``'cpu'`` or an instance of ``torch.device``. Only relevant for 'pytorch' backend. input_shape Shape of input data. data_type Optionally specify the data type (tabular, image or time-series). Added to metadata. verbose Whether to print progress messages. """ super().__init__() # Set config self._set_config(locals()) backend = backend.lower() BackendValidator( backend_options={Framework.TENSORFLOW: [Framework.TENSORFLOW], Framework.PYTORCH: [Framework.PYTORCH]}, construct_name=self.__class__.__name__ ).verify_backend(backend) kwargs = locals() args = [kwargs['x_ref'], kwargs['c_ref']] pop_kwargs = ['self', 'x_ref', 'c_ref', 'backend', '__class__'] [kwargs.pop(k, None) for k in pop_kwargs] if x_kernel is None or c_kernel is None: if backend == Framework.TENSORFLOW: from alibi_detect.utils.tensorflow.kernels import GaussianRBF else: from alibi_detect.utils.pytorch.kernels import GaussianRBF # type: ignore[assignment] if x_kernel is None: kwargs.update({'x_kernel': GaussianRBF}) if c_kernel is None: kwargs.update({'c_kernel': GaussianRBF}) if backend == Framework.TENSORFLOW: kwargs.pop('device', None) self._detector = ContextMMDDriftTF(*args, **kwargs) else: self._detector = ContextMMDDriftTorch(*args, **kwargs) self.meta = self._detector.meta
[docs] def predict(self, x: Union[np.ndarray, list], c: np.ndarray, return_p_val: bool = True, return_distance: bool = True, return_coupling: bool = False) \ -> Dict[Dict[str, str], Dict[str, Union[int, float]]]: """ Predict whether a batch of data has drifted from the reference data, given the provided context. Parameters ---------- x Batch of instances. c Context associated with batch of instances. return_p_val Whether to return the p-value of the permutation test. return_distance Whether to return the conditional MMD test statistic between the new batch and reference data. return_coupling Whether to return the coupling matrices. Returns ------- Dictionary containing ``'meta'`` and ``'data'`` dictionaries. - ``'meta'`` has the model's metadata. - ``'data'`` contains the drift prediction and optionally the p-value, threshold, conditional MMD test \ statistic and coupling matrices. """ return self._detector.predict(x, c, return_p_val, return_distance, return_coupling)
[docs] def score(self, x: Union[np.ndarray, list], c: np.ndarray) -> Tuple[float, float, float, Tuple]: """ Compute the MMD based conditional test statistic, and perform a conditional permutation test to obtain a p-value representing the test statistic's extremity under the null hypothesis. Parameters ---------- x Batch of instances. c Context associated with batch of instances. Returns ------- p-value obtained from the conditional permutation test, the conditional MMD test statistic, the test \ statistic threshold above which drift is flagged, and a tuple containing the coupling matrices \ :math:`(W_{ref,ref}, W_{test,test}, W_{ref,test})`. """ return self._detector.score(x, c)