import logging
import numpy as np
import torch
from typing import Callable, Dict, Optional, Tuple, Union
from alibi_detect.cd.base import BaseContextMMDDrift
from alibi_detect.utils.pytorch import get_device
from alibi_detect.utils.pytorch.kernels import GaussianRBF
from alibi_detect.utils.warnings import deprecated_alias
from alibi_detect.utils.frameworks import Framework
from alibi_detect.cd._domain_clf import _SVCDomainClf
from alibi_detect.utils._types import TorchDeviceType
from tqdm import tqdm
logger = logging.getLogger(__name__)
[docs]class ContextMMDDriftTorch(BaseContextMMDDrift):
lams: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
[docs] @deprecated_alias(preprocess_x_ref='preprocess_at_init')
def __init__(
self,
x_ref: Union[np.ndarray, list],
c_ref: np.ndarray,
p_val: float = .05,
x_ref_preprocessed: bool = False,
preprocess_at_init: bool = True,
update_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
x_kernel: Callable = GaussianRBF,
c_kernel: Callable = GaussianRBF,
n_permutations: int = 1000,
prop_c_held: float = 0.25,
n_folds: int = 5,
batch_size: Optional[int] = 256,
device: TorchDeviceType = None,
input_shape: Optional[tuple] = None,
data_type: Optional[str] = None,
verbose: bool = False,
) -> None:
"""
A context-aware drift detector based on a conditional analogue of the maximum mean discrepancy (MMD).
Only detects differences between samples that can not be attributed to differences between associated
sets of contexts. p-values are computed using a conditional permutation test.
Parameters
----------
x_ref
Data used as reference distribution.
c_ref
Context for the reference distribution.
p_val
p-value used for the significance of the permutation test.
x_ref_preprocessed
Whether the given reference data `x_ref` has been preprocessed yet. If `x_ref_preprocessed=True`, only
the test data `x` will be preprocessed at prediction time. If `x_ref_preprocessed=False`, the reference
data will also be preprocessed.
preprocess_at_init
Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference
data will be preprocessed at prediction time. Only applies if `x_ref_preprocessed=False`.
update_ref
Reference data can optionally be updated to the last N instances seen by the detector.
The parameter should be passed as a dictionary *{'last': N}*.
preprocess_fn
Function to preprocess the data before computing the data drift metrics.
x_kernel
Kernel defined on the input data, defaults to Gaussian RBF kernel.
c_kernel
Kernel defined on the context data, defaults to Gaussian RBF kernel.
n_permutations
Number of permutations used in the permutation test.
prop_c_held
Proportion of contexts held out to condition on.
n_folds
Number of cross-validation folds used when tuning the regularisation parameters.
batch_size
If not None, then compute batches of MMDs at a time (rather than all at once).
device
Device type used. The default tries to use the GPU and falls back on CPU if needed.
Can be specified by passing either ``'cuda'``, ``'gpu'``, ``'cpu'`` or an instance of
``torch.device``. Only relevant for 'pytorch' backend.
input_shape
Shape of input data.
data_type
Optionally specify the data type (tabular, image or time-series). Added to metadata.
verbose
Whether or not to print progress during configuration.
"""
super().__init__(
x_ref=x_ref,
c_ref=c_ref,
p_val=p_val,
x_ref_preprocessed=x_ref_preprocessed,
preprocess_at_init=preprocess_at_init,
update_ref=update_ref,
preprocess_fn=preprocess_fn,
x_kernel=x_kernel,
c_kernel=c_kernel,
n_permutations=n_permutations,
prop_c_held=prop_c_held,
n_folds=n_folds,
batch_size=batch_size,
input_shape=input_shape,
data_type=data_type,
verbose=verbose,
)
self.meta.update({'backend': Framework.PYTORCH.value})
# set device
self.device = get_device(device)
# initialize kernel
self.x_kernel = x_kernel(init_sigma_fn=_sigma_median_diag) if x_kernel == GaussianRBF else x_kernel
self.c_kernel = c_kernel(init_sigma_fn=_sigma_median_diag) if c_kernel == GaussianRBF else c_kernel
# Initialize classifier (hardcoded for now)
self.clf = _SVCDomainClf(self.c_kernel)
[docs] def score(self, # type: ignore[override]
x: Union[np.ndarray, list], c: np.ndarray) -> Tuple[float, float, float, Tuple]:
"""
Compute the MMD based conditional test statistic, and perform a conditional permutation test to obtain a
p-value representing the test statistic's extremity under the null hypothesis.
Parameters
----------
x
Batch of instances.
c
Context associated with batch of instances.
Returns
-------
p-value obtained from the conditional permutation test, the conditional MMD test statistic, the test \
statistic threshold above which drift is flagged, and a tuple containing the coupling matrices \
(W_{ref,ref}, W_{test,test}, W_{ref,test}).
"""
x_ref, x = self.preprocess(x)
x_ref = torch.from_numpy(x_ref).to(self.device) # type: ignore[assignment]
c_ref = torch.from_numpy(self.c_ref).to(self.device)
# Hold out a portion of contexts for conditioning on
n, n_held = len(c), int(len(c)*self.prop_c_held)
inds_held = np.random.choice(n, n_held, replace=False)
inds_test = np.setdiff1d(np.arange(n), inds_held)
c_held = torch.as_tensor(c[inds_held]).to(self.device)
c = torch.as_tensor(c[inds_test]).to(self.device) # type: ignore[assignment]
x = torch.as_tensor(x[inds_test]).to(self.device) # type: ignore[assignment]
n_ref, n_test = len(x_ref), len(x)
bools = torch.cat([torch.zeros(n_ref), torch.ones(n_test)]).to(self.device)
# Compute kernel matrices
x_all = torch.cat([x_ref, x], dim=0) # type: ignore[list-item]
c_all = torch.cat([c_ref, c], dim=0) # type: ignore[list-item]
K = self.x_kernel(x_all, x_all)
L = self.c_kernel(c_all, c_all)
L_held = self.c_kernel(c_held, c_all)
# Fit and calibrate the domain classifier
c_all_np, bools_np = c_all.cpu().numpy(), bools.cpu().numpy()
self.clf.fit(c_all_np, bools_np)
self.clf.calibrate(c_all_np, bools_np)
# Obtain n_permutations conditional reassignments
prop_scores = torch.as_tensor(self.clf.predict(c_all_np))
self.redrawn_bools = [torch.bernoulli(prop_scores) for _ in range(self.n_permutations)]
iters = tqdm(self.redrawn_bools, total=self.n_permutations) if self.verbose else self.redrawn_bools
# Compute test stat on original and reassigned data
stat, coupling_xx, coupling_yy, coupling_xy = self._cmmd(K, L, bools, L_held=L_held)
permuted_stats = torch.stack([self._cmmd(K, L, perm_bools, L_held=L_held)[0] for perm_bools in iters])
# Compute p-value
p_val = (stat <= permuted_stats).float().mean()
coupling = (coupling_xx.numpy(), coupling_yy.numpy(), coupling_xy.numpy())
# compute distance threshold
idx_threshold = int(self.p_val * len(permuted_stats))
distance_threshold = torch.sort(permuted_stats, descending=True).values[idx_threshold]
return p_val.numpy().item(), stat.numpy().item(), distance_threshold.numpy(), coupling
def _cmmd(self, K: torch.Tensor, L: torch.Tensor, bools: torch.Tensor, L_held: torch.Tensor = None) \
-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Private method to compute the MMD-ADiTT test statistic.
"""
# Get ref/test indices
idx_0, idx_1 = torch.where(bools == 0)[0], torch.where(bools == 1)[0]
n_ref, n_test = len(idx_0), len(idx_1)
# Form kernel matrices
L_0, L_1 = L[idx_0][:, idx_0], L[idx_1][:, idx_1]
K_0, K_1 = K[idx_0][:, idx_0], K[idx_1][:, idx_1]
# Initialise regularisation parameters
# Implemented only for first _cmmd call which corresponds to original window assignment
if self.lams is None:
possible_lams = torch.tensor([2**(-i) for i in range(20)]).to(K.device)
lam_0 = self._pick_lam(possible_lams, K_0, L_0, n_folds=self.n_folds)
lam_1 = self._pick_lam(possible_lams, K_1, L_1, n_folds=self.n_folds)
self.lams = (lam_0, lam_1)
# Compute stat
L_0_inv = torch.linalg.inv(L_0 + n_ref*self.lams[0]*torch.eye(int(n_ref)).to(L_0.device))
L_1_inv = torch.linalg.inv(L_1 + n_test*self.lams[1]*torch.eye(int(n_test)).to(L_1.device))
A_0 = L_held[:, idx_0] @ L_0_inv
A_1 = L_held[:, idx_1] @ L_1_inv
# Allow batches of MMDs to be computed at a time (rather than all)
if self.batch_size is not None:
bs = self.batch_size
coupling_xx = torch.stack([torch.einsum('ij,ik->ijk', A_0_i, A_0_i).mean(0)
for A_0_i in A_0.split(bs)]).mean(0)
coupling_yy = torch.stack([torch.einsum('ij,ik->ijk', A_1_i, A_1_i).mean(0)
for A_1_i in A_1.split(bs)]).mean(0)
coupling_xy = torch.stack([
torch.einsum('ij,ik->ijk', A_0_i, A_1_i).mean(0) for A_0_i, A_1_i in zip(A_0.split(bs), A_1.split(bs))
]).mean(0)
else:
coupling_xx = torch.einsum('ij,ik->ijk', A_0, A_0).mean(0)
coupling_yy = torch.einsum('ij,ik->ijk', A_1, A_1).mean(0)
coupling_xy = torch.einsum('ij,ik->ijk', A_0, A_1).mean(0)
sim_xx = (K[idx_0][:, idx_0]*coupling_xx).sum()
sim_yy = (K[idx_1][:, idx_1]*coupling_yy).sum()
sim_xy = (K[idx_0][:, idx_1]*coupling_xy).sum()
stat = sim_xx + sim_yy - 2*sim_xy
return stat.cpu(), coupling_xx.cpu(), coupling_yy.cpu(), coupling_xy.cpu()
def _pick_lam(self, lams: torch.Tensor, K: torch.Tensor, L: torch.Tensor, n_folds: int = 5) -> torch.Tensor:
"""
The conditional mean embedding is estimated as the solution of a regularised regression problem.
This private method function uses cross validation to select the regularisation parameter that
minimises squared error on the out-of-fold instances. The error is a distance in the RKHS and is
therefore an MMD-like quantity itself.
"""
n = len(L)
fold_size = n // n_folds
K, L = K.type(torch.float64), L.type(torch.float64)
perm = torch.randperm(n)
K, L = K[perm][:, perm], L[perm][:, perm]
losses = torch.zeros_like(lams, dtype=torch.float).to(K.device)
for fold in range(n_folds):
inds_oof = list(np.arange(n)[(fold*fold_size):((fold+1)*fold_size)])
inds_if = list(np.setdiff1d(np.arange(n), inds_oof))
K_if, L_if = K[inds_if][:, inds_if], L[inds_if][:, inds_if]
n_if = len(K_if)
L_inv_lams = torch.stack(
[torch.linalg.inv(L_if + n_if*lam*torch.eye(n_if).to(L.device)) for lam in lams]) # n_lam x n_if x n_if
KW = torch.einsum('ij,ljk->lik', K_if, L_inv_lams)
lW = torch.einsum('ij,ljk->lik', L[inds_oof][:, inds_if], L_inv_lams)
lWKW = torch.einsum('lij,ljk->lik', lW, KW)
lWKWl = torch.einsum('lkj,jk->lk', lWKW, L[inds_if][:, inds_oof]) # n_lam x n_oof
lWk = torch.einsum('lij,ji->li', lW, K[inds_if][:, inds_oof]) # n_lam x n_oof
kxx = torch.ones_like(lWk).to(lWk.device) * torch.max(K)
losses += (lWKWl + kxx - 2*lWk).sum(-1)
return lams[torch.argmin(losses)]
def _sigma_median_diag(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.Tensor:
"""
Private version of the bandwidth estimation function :py:func:`~alibi_detect.utils.pytorch.kernels.sigma_median`,
with the +n (and -1) term excluded to account for the diagonal of the kernel matrix.
Parameters
----------
x
Tensor of instances with dimension [Nx, features].
y
Tensor of instances with dimension [Ny, features].
dist
Tensor with dimensions [Nx, Ny], containing the pairwise distances between `x` and `y`.
Returns
-------
The computed bandwidth, `sigma`.
"""
n_median = np.prod(dist.shape) // 2
sigma = (.5 * dist.flatten().sort().values[int(n_median)].unsqueeze(dim=-1)) ** .5
return sigma