import logging
import torch
from torch import nn
import numpy as np
from typing import Callable, List, Tuple, Optional, Union
logger = logging.getLogger(__name__)
[docs]@torch.jit.script
def squared_pairwise_distance(x: torch.Tensor, y: torch.Tensor, a_min: float = 1e-30) -> torch.Tensor:
"""
PyTorch pairwise squared Euclidean distance between samples x and y.
Parameters
----------
x
Batch of instances of shape [Nx, features].
y
Batch of instances of shape [Ny, features].
a_min
Lower bound to clip distance values.
Returns
-------
Pairwise squared Euclidean distance [Nx, Ny].
"""
x2 = x.pow(2).sum(dim=-1, keepdim=True)
y2 = y.pow(2).sum(dim=-1, keepdim=True)
dist = torch.addmm(y2.transpose(-2, -1), x, y.transpose(-2, -1), alpha=-2).add_(x2)
return dist.clamp_min_(a_min)
[docs]def batch_compute_kernel_matrix(
x: Union[list, np.ndarray, torch.Tensor],
y: Union[list, np.ndarray, torch.Tensor],
kernel: Union[nn.Module, nn.Sequential],
device: torch.device = None,
batch_size: int = int(1e10),
preprocess_fn: Callable[..., torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute the kernel matrix between x and y by filling in blocks of size
batch_size x batch_size at a time.
Parameters
----------
x
Reference set.
y
Test set.
kernel
PyTorch module.
device
Device type used. The default None tries to use the GPU and falls back on CPU if needed.
Can be specified by passing either torch.device('cuda') or torch.device('cpu').
batch_size
Batch size used during prediction.
preprocess_fn
Optional preprocessing function for each batch.
Returns
-------
Kernel matrix in the form of a torch tensor
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if type(x) != type(y): # noqa: E721
raise ValueError("x and y should be of the same type")
if isinstance(x, np.ndarray):
x, y = torch.from_numpy(x), torch.from_numpy(y)
n_x, n_y = len(x), len(y)
n_batch_x, n_batch_y = int(np.ceil(n_x / batch_size)), int(np.ceil(n_y / batch_size))
with torch.no_grad():
k_is: List[torch.Tensor] = []
for i in range(n_batch_x):
istart, istop = i * batch_size, min((i + 1) * batch_size, n_x)
x_batch = x[istart:istop]
if preprocess_fn is not None:
x_batch = preprocess_fn(x_batch)
x_batch = x_batch.to(device) # type: ignore
k_ijs: List[torch.Tensor] = []
for j in range(n_batch_y):
jstart, jstop = j * batch_size, min((j + 1) * batch_size, n_y)
y_batch = y[jstart:jstop]
if preprocess_fn is not None:
y_batch = preprocess_fn(y_batch)
y_batch = y_batch.to(device) # type: ignore
k_ijs.append(kernel(x_batch, y_batch).cpu())
k_is.append(torch.cat(k_ijs, 1))
k_mat = torch.cat(k_is, 0)
return k_mat
[docs]def mmd2_from_kernel_matrix(kernel_mat: torch.Tensor, m: int, permute: bool = False,
zero_diag: bool = True) -> torch.Tensor:
"""
Compute maximum mean discrepancy (MMD^2) between 2 samples x and y from the
full kernel matrix between the samples.
Parameters
----------
kernel_mat
Kernel matrix between samples x and y.
m
Number of instances in y.
permute
Whether to permute the row indices. Used for permutation tests.
zero_diag
Whether to zero out the diagonal of the kernel matrix.
Returns
-------
MMD^2 between the samples from the kernel matrix.
"""
n = kernel_mat.shape[0] - m
if zero_diag:
kernel_mat = kernel_mat - torch.diag(kernel_mat.diag())
if permute:
idx = torch.randperm(kernel_mat.shape[0])
kernel_mat = kernel_mat[idx][:, idx]
k_xx, k_yy, k_xy = kernel_mat[:-m, :-m], kernel_mat[-m:, -m:], kernel_mat[-m:, :-m]
c_xx, c_yy = 1 / (n * (n - 1)), 1 / (m * (m - 1))
mmd2 = c_xx * k_xx.sum() + c_yy * k_yy.sum() - 2. * k_xy.mean()
return mmd2
[docs]def mmd2(x: torch.Tensor, y: torch.Tensor, kernel: Callable) -> float:
"""
Compute MMD^2 between 2 samples.
Parameters
----------
x
Batch of instances of shape [Nx, features].
y
Batch of instances of shape [Ny, features].
kernel
Kernel function.
Returns
-------
MMD^2 between the samples x and y.
"""
n, m = x.shape[0], y.shape[0]
c_xx, c_yy = 1 / (n * (n - 1)), 1 / (m * (m - 1))
k_xx, k_yy, k_xy = kernel(x, x), kernel(y, y), kernel(x, y)
return c_xx * (k_xx.sum() - k_xx.trace()) + c_yy * (k_yy.sum() - k_yy.trace()) - 2. * k_xy.mean()
[docs]def permed_lsdds(
k_all_c: torch.Tensor,
x_perms: List[torch.Tensor],
y_perms: List[torch.Tensor],
H: torch.Tensor,
H_lam_inv: Optional[torch.Tensor] = None,
lam_rd_max: float = 0.2,
return_unpermed: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Compute LSDD estimates from kernel matrix across various ref and test window samples
Parameters
----------
k_all_c
Kernel matrix of similarities between all samples and the kernel centers.
x_perms
List of B reference window index vectors
y_perms
List of B test window index vectors
H
Special (scaled) kernel matrix of similarities between kernel centers
H_lam_inv
Function of H corresponding to a particular regulariation parameter lambda.
See Eqn 11 of Bu et al. (2017)
lam_rd_max
The maximum relative difference between two estimates of LSDD that the regularization parameter
lambda is allowed to cause. Defaults to 0.2. Only relavent if H_lam_inv is not supplied.
return_unpermed
Whether or not to return value corresponding to unpermed order defined by k_all_c
Returns
-------
Vector of B LSDD estimates for each permutation, H_lam_inv which may have been inferred, and optionally \
the unpermed LSDD estimate.
"""
# Compute (for each bootstrap) the average distance to each kernel center (Eqn 7)
k_xc_perms = torch.stack([k_all_c[x_inds] for x_inds in x_perms], 0)
k_yc_perms = torch.stack([k_all_c[y_inds] for y_inds in y_perms], 0)
h_perms = k_xc_perms.mean(1) - k_yc_perms.mean(1)
if H_lam_inv is None:
# We perform the initialisation for multiple candidate lambda values and pick the largest
# one for which the relative difference (RD) between two difference estimates is below lambda_rd_max.
# See Appendix A
candidate_lambdas = [1/(4**i) for i in range(10)] # TODO: More principled selection
H_plus_lams = torch.stack(
[H+torch.eye(H.shape[0], device=H.device)*can_lam for can_lam in candidate_lambdas], 0
)
H_plus_lam_invs = torch.inverse(H_plus_lams)
H_plus_lam_invs = H_plus_lam_invs.permute(1, 2, 0) # put lambdas in final axis
omegas = torch.einsum('jkl,bk->bjl', H_plus_lam_invs, h_perms) # (Eqn 8)
h_omegas = torch.einsum('bj,bjl->bl', h_perms, omegas)
omega_H_omegas = torch.einsum('bkl,bkl->bl', torch.einsum('bjl,jk->bkl', omegas, H), omegas)
rds = (1 - (omega_H_omegas/h_omegas)).mean(0)
less_than_rd_inds = (rds < lam_rd_max).nonzero()
if len(less_than_rd_inds) == 0:
repeats = k_all_c.shape[0] - torch.unique(k_all_c, dim=0).shape[0]
if repeats > 0:
msg = "Too many repeat instances for LSDD-based detection. \
Try using MMD-based detection instead"
else:
msg = "Unknown error. Try using MMD-based detection instead"
raise ValueError(msg)
lam_index = less_than_rd_inds[0]
lam = candidate_lambdas[lam_index]
logger.info(f"Using lambda value of {lam:.2g} with RD of {float(rds[lam_index]):.2g}")
H_plus_lam_inv = H_plus_lam_invs[:, :, lam_index.item()]
H_lam_inv = 2*H_plus_lam_inv - (H_plus_lam_inv.transpose(0, 1) @ H @ H_plus_lam_inv) # (below Eqn 11)
# Now to compute an LSDD estimate for each permutation
lsdd_perms = (h_perms * (H_lam_inv @ h_perms.transpose(0, 1)).transpose(0, 1)).sum(-1) # (Eqn 11)
if return_unpermed:
n_x = x_perms[0].shape[0]
h = k_all_c[:n_x].mean(0) - k_all_c[n_x:].mean(0)
lsdd_unpermed = (h[None, :] * (H_lam_inv @ h[:, None]).transpose(0, 1)).sum()
return lsdd_perms, H_lam_inv, lsdd_unpermed
else:
return lsdd_perms, H_lam_inv