import numpy as np
import torch
from torch import nn
from . import distance
from typing import Optional, Union, Callable
from alibi_detect.utils.frameworks import Framework
[docs]class GaussianRBF(nn.Module):
[docs] def __init__(
self,
sigma: Optional[torch.Tensor] = None,
init_sigma_fn: Optional[Callable] = None,
trainable: bool = False
) -> None:
"""
Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2). A forward pass takes
a batch of instances x [Nx, features] and y [Ny, features] and returns the kernel
matrix [Nx, Ny].
Parameters
----------
sigma
Bandwidth used for the kernel. Needn't be specified if being inferred or trained.
Can pass multiple values to eval kernel with and then average.
init_sigma_fn
Function used to compute the bandwidth `sigma`. Used when `sigma` is to be inferred.
The function's signature should match :py:func:`~alibi_detect.utils.pytorch.kernels.sigma_median`,
meaning that it should take in the tensors `x`, `y` and `dist` and return `sigma`. If `None`, it is set to
:func:`~alibi_detect.utils.pytorch.kernels.sigma_median`.
trainable
Whether or not to track gradients w.r.t. `sigma` to allow it to be trained.
"""
super().__init__()
init_sigma_fn = sigma_median if init_sigma_fn is None else init_sigma_fn
self.config = {'sigma': sigma, 'trainable': trainable, 'init_sigma_fn': init_sigma_fn}
if sigma is None:
self.log_sigma = nn.Parameter(torch.empty(1), requires_grad=trainable)
self.init_required = True
else:
sigma = sigma.reshape(-1) # [Ns,]
self.log_sigma = nn.Parameter(sigma.log(), requires_grad=trainable)
self.init_required = False
self.init_sigma_fn = init_sigma_fn
self.trainable = trainable
@property
def sigma(self) -> torch.Tensor:
return self.log_sigma.exp()
[docs] def forward(self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor],
infer_sigma: bool = False) -> torch.Tensor:
x, y = torch.as_tensor(x), torch.as_tensor(y)
dist = distance.squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny]
if infer_sigma or self.init_required:
if self.trainable and infer_sigma:
raise ValueError("Gradients cannot be computed w.r.t. an inferred sigma value")
sigma = self.init_sigma_fn(x, y, dist)
with torch.no_grad():
self.log_sigma.copy_(sigma.log().clone())
self.init_required = False
gamma = 1. / (2. * self.sigma ** 2) # [Ns,]
# TODO: do matrix multiplication after all?
kernel_mat = torch.exp(- torch.cat([(g * dist)[None, :, :] for g in gamma], dim=0)) # [Ns, Nx, Ny]
return kernel_mat.mean(dim=0) # [Nx, Ny]
[docs] def get_config(self) -> dict:
"""
Returns a serializable config dict (excluding the input_sigma_fn, which is serialized in alibi_detect.saving).
"""
cfg = self.config.copy()
if isinstance(cfg['sigma'], torch.Tensor):
cfg['sigma'] = cfg['sigma'].detach().cpu().numpy().tolist()
cfg.update({'flavour': Framework.PYTORCH.value})
return cfg
[docs] @classmethod
def from_config(cls, config):
"""
Instantiates a kernel from a config dictionary.
Parameters
----------
config
A kernel config dictionary.
"""
config.pop('flavour')
return cls(**config)
[docs]class DeepKernel(nn.Module):
"""
Computes similarities as k(x,y) = (1-eps)*k_a(proj(x), proj(y)) + eps*k_b(x,y).
A forward pass takes a batch of instances x [Nx, features] and y [Ny, features] and returns
the kernel matrix [Nx, Ny].
Parameters
----------
proj
The projection to be applied to the inputs before applying kernel_a
kernel_a
The kernel to apply to the projected inputs. Defaults to a Gaussian RBF with trainable bandwidth.
kernel_b
The kernel to apply to the raw inputs. Defaults to a Gaussian RBF with trainable bandwidth.
Set to None in order to use only the deep component (i.e. eps=0).
eps
The proportion (in [0,1]) of weight to assign to the kernel applied to raw inputs. This can be
either specified or set to 'trainable'. Only relavent if kernel_b is not None.
"""
def __init__(
self,
proj: nn.Module,
kernel_a: Union[nn.Module, str] = 'rbf',
kernel_b: Optional[Union[nn.Module, str]] = 'rbf',
eps: Union[float, str] = 'trainable'
) -> None:
super().__init__()
self.config = {'proj': proj, 'kernel_a': kernel_a, 'kernel_b': kernel_b, 'eps': eps}
if kernel_a == 'rbf':
kernel_a = GaussianRBF(trainable=True)
if kernel_b == 'rbf':
kernel_b = GaussianRBF(trainable=True)
self.kernel_a = kernel_a
self.kernel_b = kernel_b
self.proj = proj
if kernel_b is not None:
self._init_eps(eps)
def _init_eps(self, eps: Union[float, str]) -> None:
if isinstance(eps, float):
if not 0 < eps < 1:
raise ValueError("eps should be in (0,1)")
self.logit_eps = nn.Parameter(torch.tensor(eps).logit(), requires_grad=False)
elif eps == 'trainable':
self.logit_eps = nn.Parameter(torch.tensor(0.))
else:
raise NotImplementedError("eps should be 'trainable' or a float in (0,1)")
@property
def eps(self) -> torch.Tensor:
return self.logit_eps.sigmoid() if self.kernel_b is not None else torch.tensor(0.)
[docs] def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
similarity = self.kernel_a(self.proj(x), self.proj(y)) # type: ignore[operator]
if self.kernel_b is not None:
similarity = (1-self.eps)*similarity + self.eps*self.kernel_b(x, y) # type: ignore[operator]
return similarity
[docs] def get_config(self) -> dict:
return self.config.copy()
[docs] @classmethod
def from_config(cls, config):
return cls(**config)