alibi_detect.models.pytorch.trainer module
- alibi_detect.models.pytorch.trainer.trainer(model, loss_fn, dataloader, device, optimizer=torch.optim.Adam, learning_rate=0.001, preprocess_fn=None, epochs=20, reg_loss_fn=<function <lambda>>, verbose=1)[source]
Train PyTorch model.
- Parameters:
model (
Union
[Module
,Sequential
]) – Model to train.loss_fn (
Callable
) – Loss function used for training.dataloader (
DataLoader
) – PyTorch dataloader.device (
device
) – Device used for training.optimizer (
Callable
) – Optimizer used for training.learning_rate (
float
) – Optimizer’s learning rate.preprocess_fn (
Optional
[Callable
]) – Preprocessing function applied to each training batch.epochs (
int
) – Number of training epochs.reg_loss_fn (
Callable
) – The regularisation term reg_loss_fn(model) is added to the loss function being optimized.verbose (
int
) – Whether to print training progress.
- Return type: