alibi_detect.cd.pytorch.learned_kernel module
- class alibi_detect.cd.pytorch.learned_kernel.LearnedKernelDriftTorch(x_ref, kernel, p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_x_ref=None, preprocess_fn=None, n_permutations=100, var_reg=1e-05, reg_loss_fn=<function LearnedKernelDriftTorch.<lambda>>, train_size=0.75, retrain_from_scratch=True, optimizer=torch.optim.Adam, learning_rate=0.001, batch_size=32, batch_size_predict=32, preprocess_batch_fn=None, epochs=3, num_workers=0, verbose=0, train_kwargs=None, device=None, dataset=<class 'alibi_detect.utils.pytorch.data.TorchDataset'>, dataloader=torch.utils.data.DataLoader, input_shape=None, data_type=None)[source]
Bases:
BaseLearnedKernelDrift
- class JHat(kernel, var_reg)[source]
Bases:
Module
A module that wraps around the kernel. When passed a batch of reference and batch of test instances it returns an estimate of a correlate of test power. Equation 4 of https://arxiv.org/abs/2002.09116
- __init__(x_ref, kernel, p_val=0.05, x_ref_preprocessed=False, preprocess_at_init=True, update_x_ref=None, preprocess_fn=None, n_permutations=100, var_reg=1e-05, reg_loss_fn=<function LearnedKernelDriftTorch.<lambda>>, train_size=0.75, retrain_from_scratch=True, optimizer=torch.optim.Adam, learning_rate=0.001, batch_size=32, batch_size_predict=32, preprocess_batch_fn=None, epochs=3, num_workers=0, verbose=0, train_kwargs=None, device=None, dataset=<class 'alibi_detect.utils.pytorch.data.TorchDataset'>, dataloader=torch.utils.data.DataLoader, input_shape=None, data_type=None)[source]
Maximum Mean Discrepancy (MMD) data drift detector where the kernel is trained to maximise an estimate of the test power. The kernel is trained on a split of the reference and test instances and then the MMD is evaluated on held out instances and a permutation test is performed.
For details see Liu et al (2020): Learning Deep Kernels for Non-Parametric Two-Sample Tests (https://arxiv.org/abs/2002.09116)
- Parameters:
x_ref (
Union
[ndarray
,list
]) – Data used as reference distribution.kernel (
Union
[Module
,Sequential
]) – Trainable PyTorch module that returns a similarity between two instances.p_val (
float
) – p-value used for the significance of the test.x_ref_preprocessed (
bool
) – Whether the given reference data x_ref has been preprocessed yet. If x_ref_preprocessed=True, only the test data x will be preprocessed at prediction time. If x_ref_preprocessed=False, the reference data will also be preprocessed.preprocess_at_init (
bool
) – Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference data will be preprocessed at prediction time. Only applies if x_ref_preprocessed=False.update_x_ref (
Optional
[Dict
[str
,int
]]) – Reference data can optionally be updated to the last n instances seen by the detector or via reservoir sampling with size n. For the former, the parameter equals {‘last’: n} while for reservoir sampling {‘reservoir_sampling’: n} is passed.preprocess_fn (
Optional
[Callable
]) – Function to preprocess the data before applying the kernel.n_permutations (
int
) – The number of permutations to use in the permutation test once the MMD has been computed.var_reg (
float
) – Constant added to the estimated variance of the MMD for stability.reg_loss_fn (
Callable
) – The regularisation term reg_loss_fn(kernel) is added to the loss function being optimized.train_size (
Optional
[float
]) – Optional fraction (float between 0 and 1) of the dataset used to train the kernel. The drift is detected on 1 - train_size.retrain_from_scratch (
bool
) – Whether the kernel should be retrained from scratch for each set of test data or whether it should instead continue training from where it left off on the previous set.optimizer (
Optimizer
) – Optimizer used during training of the kernel.learning_rate (
float
) – Learning rate used by optimizer.batch_size (
int
) – Batch size used during training of the kernel.batch_size_predict (
int
) – Batch size used for the trained drift detector predictions.preprocess_batch_fn (
Optional
[Callable
]) – Optional batch preprocessing function. For example to convert a list of objects to a batch which can be processed by the kernel.epochs (
int
) – Number of training epochs for the kernel. Corresponds to the smaller of the reference and test sets.num_workers (
int
) – Number of workers for the dataloader. The default (num_workers=0) means multi-process data loading is disabled. Setting num_workers>0 may be unreliable on Windows.verbose (
int
) – Verbosity level during the training of the kernel. 0 is silent, 1 a progress bar.train_kwargs (
Optional
[dict
]) – Optional additional kwargs when training the kernel.device (
Union
[Literal
[‘cuda’, ‘gpu’, ‘cpu’],device
,None
]) – Device type used. The default tries to use the GPU and falls back on CPU if needed. Can be specified by passing either'cuda'
,'gpu'
,'cpu'
or an instance oftorch.device
. Only relevant for ‘pytorch’ backend.dataset (
Callable
) – Dataset object used during training.dataloader (
Callable
) – Dataloader object used during training. Only relevant for ‘pytorch’ backend.data_type (
Optional
[str
]) – Optionally specify the data type (tabular, image or time-series). Added to metadata.
- score(x)[source]
Compute the p-value resulting from a permutation test using the maximum mean discrepancy as a distance measure between the reference data and the data to be tested. The kernel used within the MMD is first trained to maximise an estimate of the resulting test power.
- static trainer(j_hat, dataloaders, device, optimizer=torch.optim.Adam, learning_rate=0.001, preprocess_fn=None, epochs=20, reg_loss_fn=<function LearnedKernelDriftTorch.<lambda>>, verbose=1)[source]
Train the kernel to maximise an estimate of test power using minibatch gradient descent.
- Return type: