Source code for alibi_detect.utils.tensorflow.kernels

import tensorflow as tf
import numpy as np
from . import distance
from typing import Optional, Union, Callable
from scipy.special import logit
from alibi_detect.utils.frameworks import Framework


[docs]def sigma_median(x: tf.Tensor, y: tf.Tensor, dist: tf.Tensor) -> tf.Tensor: """ Bandwidth estimation using the median heuristic :cite:t:`Gretton2012`. Parameters ---------- x Tensor of instances with dimension [Nx, features]. y Tensor of instances with dimension [Ny, features]. dist Tensor with dimensions [Nx, Ny], containing the pairwise distances between `x` and `y`. Returns ------- The computed bandwidth, `sigma`. """ n = min(x.shape[0], y.shape[0]) n = n if tf.reduce_all(x[:n] == y[:n]) and x.shape == y.shape else 0 n_median = n + (tf.math.reduce_prod(dist.shape) - n) // 2 - 1 sigma = tf.expand_dims((.5 * tf.sort(tf.reshape(dist, (-1,)))[n_median]) ** .5, axis=0) return sigma
[docs]class GaussianRBF(tf.keras.Model):
[docs] def __init__( self, sigma: Optional[tf.Tensor] = None, init_sigma_fn: Optional[Callable] = None, trainable: bool = False ) -> None: """ Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2). A forward pass takes a batch of instances x [Nx, features] and y [Ny, features] and returns the kernel matrix [Nx, Ny]. Parameters ---------- sigma Bandwidth used for the kernel. Needn't be specified if being inferred or trained. Can pass multiple values to eval kernel with and then average. init_sigma_fn Function used to compute the bandwidth `sigma`. Used when `sigma` is to be inferred. The function's signature should match :py:func:`~alibi_detect.utils.tensorflow.kernels.sigma_median`, meaning that it should take in the tensors `x`, `y` and `dist` and return `sigma`. If `None`, it is set to :func:`~alibi_detect.utils.tensorflow.kernels.sigma_median`. trainable Whether or not to track gradients w.r.t. sigma to allow it to be trained. """ super().__init__() init_sigma_fn = sigma_median if init_sigma_fn is None else init_sigma_fn self.config = {'sigma': sigma, 'trainable': trainable, 'init_sigma_fn': init_sigma_fn} if sigma is None: self.log_sigma = tf.Variable(np.empty(1), dtype=tf.keras.backend.floatx(), trainable=trainable) self.init_required = True else: sigma = tf.cast(tf.reshape(sigma, (-1,)), dtype=tf.keras.backend.floatx()) # [Ns,] self.log_sigma = tf.Variable(tf.math.log(sigma), trainable=trainable) self.init_required = False self.init_sigma_fn = init_sigma_fn self.trainable = trainable
@property def sigma(self) -> tf.Tensor: return tf.math.exp(self.log_sigma)
[docs] def call(self, x: tf.Tensor, y: tf.Tensor, infer_sigma: bool = False) -> tf.Tensor: y = tf.cast(y, x.dtype) x, y = tf.reshape(x, (x.shape[0], -1)), tf.reshape(y, (y.shape[0], -1)) # flatten dist = distance.squared_pairwise_distance(x, y) # [Nx, Ny] if infer_sigma or self.init_required: if self.trainable and infer_sigma: raise ValueError("Gradients cannot be computed w.r.t. an inferred sigma value") sigma = self.init_sigma_fn(x, y, dist) self.log_sigma.assign(tf.math.log(sigma)) self.init_required = False gamma = tf.constant(1. / (2. * self.sigma ** 2), dtype=x.dtype) # [Ns,] # TODO: do matrix multiplication after all? kernel_mat = tf.exp(- tf.concat([(g * dist)[None, :, :] for g in gamma], axis=0)) # [Ns, Nx, Ny] return tf.reduce_mean(kernel_mat, axis=0) # [Nx, Ny]
[docs] def get_config(self) -> dict: """ Returns a serializable config dict (excluding the input_sigma_fn, which is serialized in alibi_detect.saving). """ cfg = self.config.copy() if isinstance(cfg['sigma'], tf.Tensor): cfg['sigma'] = cfg['sigma'].numpy().tolist() cfg.update({'flavour': Framework.TENSORFLOW.value}) return cfg
[docs] @classmethod def from_config(cls, config): """ Instantiates a kernel from a config dictionary. Parameters ---------- config A kernel config dictionary. """ config.pop('flavour') return cls(**config)
[docs]class DeepKernel(tf.keras.Model): """ Computes similarities as k(x,y) = (1-eps)*k_a(proj(x), proj(y)) + eps*k_b(x,y). A forward pass takes a batch of instances x [Nx, features] and y [Ny, features] and returns the kernel matrix [Nx, Ny]. Parameters ---------- proj The projection to be applied to the inputs before applying kernel_a kernel_a The kernel to apply to the projected inputs. Defaults to a Gaussian RBF with trainable bandwidth. kernel_b The kernel to apply to the raw inputs. Defaults to a Gaussian RBF with trainable bandwidth. Set to None in order to use only the deep component (i.e. eps=0). eps The proportion (in [0,1]) of weight to assign to the kernel applied to raw inputs. This can be either specified or set to 'trainable'. Only relavent is kernel_b is not None. """ def __init__( self, proj: tf.keras.Model, kernel_a: Union[tf.keras.Model, str] = 'rbf', kernel_b: Optional[Union[tf.keras.Model, str]] = 'rbf', eps: Union[float, str] = 'trainable' ) -> None: super().__init__() self.config = {'proj': proj, 'kernel_a': kernel_a, 'kernel_b': kernel_b, 'eps': eps} if kernel_a == 'rbf': kernel_a = GaussianRBF(trainable=True) if kernel_b == 'rbf': kernel_b = GaussianRBF(trainable=True) self.kernel_a = kernel_a self.kernel_b = kernel_b self.proj = proj if kernel_b is not None: self._init_eps(eps) def _init_eps(self, eps: Union[float, str]) -> None: if isinstance(eps, float): if not 0 < eps < 1: raise ValueError("eps should be in (0,1)") eps = tf.constant(eps) self.logit_eps = tf.Variable(tf.constant(logit(eps)), trainable=False) elif eps == 'trainable': self.logit_eps = tf.Variable(tf.constant(0.)) else: raise NotImplementedError("eps should be 'trainable' or a float in (0,1)") @property def eps(self) -> tf.Tensor: return tf.math.sigmoid(self.logit_eps) if self.kernel_b is not None else tf.constant(0.)
[docs] def call(self, x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: similarity = self.kernel_a(self.proj(x), self.proj(y)) # type: ignore[operator] if self.kernel_b is not None: similarity = (1-self.eps)*similarity + self.eps*self.kernel_b(x, y) # type: ignore[operator] return similarity
[docs] def get_config(self) -> dict: return self.config.copy()
[docs] @classmethod def from_config(cls, config): return cls(**config)