import logging
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from typing import Callable, Dict, Optional, Tuple, Union, List
from alibi_detect.cd.base import BaseContextMMDDrift
from alibi_detect.utils.tensorflow.kernels import GaussianRBF
from alibi_detect.utils.warnings import deprecated_alias
from alibi_detect.utils.frameworks import Framework
from alibi_detect.cd._domain_clf import _SVCDomainClf
from tqdm import tqdm
logger = logging.getLogger(__name__)
[docs]class ContextMMDDriftTF(BaseContextMMDDrift):
lams: Optional[Tuple[tf.Tensor, tf.Tensor]]
[docs] @deprecated_alias(preprocess_x_ref='preprocess_at_init')
def __init__(
self,
x_ref: Union[np.ndarray, list],
c_ref: np.ndarray,
p_val: float = .05,
x_ref_preprocessed: bool = False,
preprocess_at_init: bool = True,
update_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
x_kernel: Callable = GaussianRBF,
c_kernel: Callable = GaussianRBF,
n_permutations: int = 1000,
prop_c_held: float = 0.25,
n_folds: int = 5,
batch_size: Optional[int] = 256,
input_shape: Optional[tuple] = None,
data_type: Optional[str] = None,
verbose: bool = False,
) -> None:
"""
A context-aware drift detector based on a conditional analogue of the maximum mean discrepancy (MMD).
Only detects differences between samples that can not be attributed to differences between associated
sets of contexts. p-values are computed using a conditional permutation test.
Parameters
----------
x_ref
Data used as reference distribution.
c_ref
Context for the reference distribution.
p_val
p-value used for the significance of the permutation test.
x_ref_preprocessed
Whether the given reference data `x_ref` has been preprocessed yet. If `x_ref_preprocessed=True`, only
the test data `x` will be preprocessed at prediction time. If `x_ref_preprocessed=False`, the reference
data will also be preprocessed.
preprocess_at_init
Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference
data will be preprocessed at prediction time. Only applies if `x_ref_preprocessed=False`.
update_ref
Reference data can optionally be updated to the last N instances seen by the detector.
The parameter should be passed as a dictionary *{'last': N}*.
preprocess_fn
Function to preprocess the data before computing the data drift metrics.
x_kernel
Kernel defined on the input data, defaults to Gaussian RBF kernel.
c_kernel
Kernel defined on the context data, defaults to Gaussian RBF kernel.
n_permutations
Number of permutations used in the permutation test.
prop_c_held
Proportion of contexts held out to condition on.
n_folds
Number of cross-validation folds used when tuning the regularisation parameters.
batch_size
If not None, then compute batches of MMDs at a time (rather than all at once).
input_shape
Shape of input data.
data_type
Optionally specify the data type (tabular, image or time-series). Added to metadata.
verbose
Whether or not to print progress during configuration.
"""
super().__init__(
x_ref=x_ref,
c_ref=c_ref,
p_val=p_val,
x_ref_preprocessed=x_ref_preprocessed,
preprocess_at_init=preprocess_at_init,
update_ref=update_ref,
preprocess_fn=preprocess_fn,
x_kernel=x_kernel,
c_kernel=c_kernel,
n_permutations=n_permutations,
prop_c_held=prop_c_held,
n_folds=n_folds,
batch_size=batch_size,
input_shape=input_shape,
data_type=data_type,
verbose=verbose
)
self.meta.update({'backend': Framework.TENSORFLOW.value})
# initialize kernel
self.x_kernel = x_kernel(init_sigma_fn=_sigma_median_diag) if x_kernel == GaussianRBF else x_kernel
self.c_kernel = c_kernel(init_sigma_fn=_sigma_median_diag) if c_kernel == GaussianRBF else c_kernel
# Initialize classifier (hardcoded for now)
self.clf = _SVCDomainClf(self.c_kernel)
[docs] def score(self, # type: ignore[override]
x: Union[np.ndarray, list], c: np.ndarray) -> Tuple[float, float, float, Tuple]:
"""
Compute the MMD based conditional test statistic, and perform a conditional permutation test to obtain a
p-value representing the test statistic's extremity under the null hypothesis.
Parameters
----------
x
Batch of instances.
c
Context associated with batch of instances.
Returns
-------
p-value obtained from the conditional permutation test, the conditional MMD test statistic, the test \
statistic threshold above which drift is flagged, and a tuple containing the coupling matrices \
(W_{ref,ref}, W_{test,test}, W_{ref,test}).
"""
x_ref, x = self.preprocess(x)
# Hold out a portion of contexts for conditioning on
n, n_held = len(c), int(len(c)*self.prop_c_held)
inds_held = np.random.choice(n, n_held, replace=False)
inds_test = np.setdiff1d(np.arange(n), inds_held)
c_held = c[inds_held]
c, x = c[inds_test], x[inds_test]
n_ref, n_test = len(x_ref), len(x)
bools = tf.concat([tf.zeros(n_ref), tf.ones(n_test)], axis=0)
# Compute kernel matrices
x_all = tf.concat([x_ref, x], axis=0)
c_all = tf.concat([self.c_ref, c], axis=0)
K = self.x_kernel(x_all, x_all)
L = self.c_kernel(c_all, c_all)
L_held = self.c_kernel(c_held, c_all)
# Fit and calibrate the domain classifier
c_all_np, bools_np = c_all.numpy(), bools.numpy()
self.clf.fit(c_all_np, bools_np)
self.clf.calibrate(c_all_np, bools_np)
# Obtain n_permutations conditional reassignments
prop_scores = self.clf.predict(c_all_np)
self.redrawn_bools = [tfp.distributions.Bernoulli(probs=prop_scores).sample()
for _ in range(self.n_permutations)]
iters = tqdm(self.redrawn_bools, total=self.n_permutations) if self.verbose else self.redrawn_bools
# Compute test stat on original and reassigned data
stat, coupling_xx, coupling_yy, coupling_xy = self._cmmd(K, L, bools, L_held=L_held)
permuted_stats = tf.stack([self._cmmd(K, L, perm_bools, L_held=L_held)[0] for perm_bools in iters])
# Compute p-value
p_val = tf.reduce_mean(tf.cast(stat <= permuted_stats, float))
coupling = (coupling_xx.numpy(), coupling_yy.numpy(), coupling_xy.numpy())
# compute distance threshold
idx_threshold = int(self.p_val * len(permuted_stats))
distance_threshold = np.sort(permuted_stats)[::-1][idx_threshold]
return p_val.numpy().item(), stat.numpy().item(), distance_threshold, coupling
def _cmmd(self, K: tf.Tensor, L: tf.Tensor, bools: tf.Tensor, L_held: tf.Tensor = None) \
-> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
"""
Private method to compute the MMD-ADiTT test statistic.
"""
# Get ref/test indices
idx_0, idx_1 = np.where(bools == 0)[0], np.where(bools == 1)[0]
n_ref, n_test = len(idx_0), len(idx_1)
# Form kernel matrices
L_0, L_1 = tf.gather(tf.gather(L, idx_0), idx_0, axis=1), tf.gather(tf.gather(L, idx_1), idx_1, axis=1)
K_0, K_1 = tf.gather(tf.gather(K, idx_0), idx_0, axis=1), tf.gather(tf.gather(K, idx_1), idx_1, axis=1)
# Avoid using tf.gather_nd since this would require [n_fef, n_ref, 2] and [n_test, n_test, 2] idx tensors
# Initialise regularisation parameters
# Implemented only for first _cmmd call which corresponds to original window assignment
if self.lams is None:
possible_lams = tf.convert_to_tensor([2**(-i) for i in range(20)], dtype=tf.float64)
lam_0 = self._pick_lam(possible_lams, K_0, L_0, n_folds=self.n_folds)
lam_1 = self._pick_lam(possible_lams, K_1, L_1, n_folds=self.n_folds)
self.lams = (lam_0, lam_1)
# Compute stat
L_0_inv = tf.linalg.inv(L_0 + n_ref*self.lams[0]*tf.eye(int(n_ref)))
L_1_inv = tf.linalg.inv(L_1 + n_test*self.lams[1]*tf.eye(int(n_test)))
A_0 = tf.gather(L_held, idx_0, axis=1) @ L_0_inv
A_1 = tf.gather(L_held, idx_1, axis=1) @ L_1_inv
# Allow batches of MMDs to be computed at a time (rather than all)
if self.batch_size is not None:
bs = self.batch_size
coupling_xx = tf.reduce_mean(tf.stack([tf.reduce_mean(tf.einsum('ij,ik->ijk', A_0_i, A_0_i), axis=0)
for A_0_i in tf.split(A_0, _split_chunks(len(A_0), bs))]), axis=0)
coupling_yy = tf.reduce_mean(tf.stack([tf.reduce_mean(tf.einsum('ij,ik->ijk', A_1_i, A_1_i), axis=0)
for A_1_i in tf.split(A_1, _split_chunks(len(A_1), bs))]), axis=0)
coupling_xy = tf.reduce_mean(tf.stack([
tf.reduce_mean(tf.einsum('ij,ik->ijk', A_0_i, A_1_i), axis=0)
for A_0_i, A_1_i in zip(tf.split(A_0, _split_chunks(len(A_0), bs)),
tf.split(A_1, _split_chunks(len(A_1), bs)))
]), axis=0)
else:
coupling_xx = tf.reduce_mean(tf.einsum('ij,ik->ijk', A_0, A_0), axis=0)
coupling_yy = tf.reduce_mean(tf.einsum('ij,ik->ijk', A_1, A_1), axis=0)
coupling_xy = tf.reduce_mean(tf.einsum('ij,ik->ijk', A_0, A_1), axis=0)
sim_xx = tf.reduce_sum(tf.gather(tf.gather(K, idx_0), idx_0, axis=1)*coupling_xx)
sim_yy = tf.reduce_sum(tf.gather(tf.gather(K, idx_1), idx_1, axis=1)*coupling_yy)
sim_xy = tf.reduce_sum(tf.gather(tf.gather(K, idx_0), idx_1, axis=1)*coupling_xy)
stat = sim_xx + sim_yy - 2*sim_xy
return stat, coupling_xx, coupling_yy, coupling_xy
def _pick_lam(self, lams: tf.Tensor, K: tf.Tensor, L: tf.Tensor, n_folds: int = 5) -> tf.Tensor:
"""
The conditional mean embedding is estimated as the solution of a regularised regression problem.
This private method function uses cross validation to select the regularisation parameter that
minimises squared error on the out-of-fold instances. The error is a distance in the RKHS and is
therefore an MMD-like quantity itself.
"""
n = len(L)
fold_size = n // n_folds
K, L = tf.cast(K, tf.float64), tf.cast(K, tf.float64)
perm = tf.random.shuffle(range(n))
K, L = tf.gather(tf.gather(K, perm), perm, axis=1), tf.gather(tf.gather(L, perm), perm, axis=1)
losses = tf.zeros_like(lams, dtype=tf.float64)
for fold in range(n_folds):
inds_oof = np.arange(n)[(fold*fold_size):((fold+1)*fold_size)]
inds_if = np.setdiff1d(np.arange(n), inds_oof)
K_if = tf.gather(tf.gather(K, inds_if), inds_if, axis=1)
L_if = tf.gather(tf.gather(L, inds_if), inds_if, axis=1)
n_if = len(K_if)
L_inv_lams = tf.stack(
[tf.linalg.inv(L_if + n_if*lam*tf.eye(n_if, dtype=tf.float64)) for lam in lams]) # n_lam x n_if x n_if
KW = tf.einsum('ij,ljk->lik', K_if, L_inv_lams)
lW = tf.einsum('ij,ljk->lik', tf.gather(tf.gather(L, inds_oof), inds_if, axis=1), L_inv_lams)
lWKW = tf.einsum('lij,ljk->lik', lW, KW)
lWKWl = tf.einsum('lkj,jk->lk', lWKW, tf.gather(tf.gather(L, inds_if), inds_oof, axis=1)) # n_lam x n_oof
lWk = tf.einsum('lij,ji->li', lW, tf.gather(tf.gather(K, inds_if), inds_oof, axis=1)) # n_lam x n_oof
kxx = tf.ones_like(lWk) * tf.reduce_max(K)
losses += tf.reduce_sum(lWKWl + kxx - 2*lWk, axis=-1)
return tf.cast(lams[tf.argmin(losses)], tf.float32)
def _split_chunks(n: int, p: int) -> List[int]:
"""
Private function to calculate chunk sizes for tf.split(). An array/tensor of length n is aimed to be split into p
number of chunks of roughly equivalent size.
Parameters
----------
n
Size of array/tensor to be split.
p
Number of chunks.
Returns
-------
List containing the size of each chunk.
"""
if p >= n:
chunks = [n]
else:
chunks = [n // p + 1] * (n % p) + [n // p] * (p - n % p)
return chunks
def _sigma_median_diag(x: tf.Tensor, y: tf.Tensor, dist: tf.Tensor) -> tf.Tensor:
"""
Private version of the bandwidth estimation function :py:func:`~alibi_detect.utils.tensorflow.kernels.sigma_median`,
with the +n (and -1) term excluded to account for the diagonal of the kernel matrix.
Parameters
----------
x
Tensor of instances with dimension [Nx, features].
y
Tensor of instances with dimension [Ny, features].
dist
Tensor with dimensions [Nx, Ny], containing the pairwise distances between `x` and `y`.
Returns
-------
The computed bandwidth, `sigma`.
"""
n_median = tf.math.reduce_prod(dist.shape) // 2
sigma = tf.expand_dims((.5 * tf.sort(tf.reshape(dist, (-1,)))[n_median]) ** .5, axis=0)
return sigma