alibi_detect.models.pytorch.trainer module

alibi_detect.models.pytorch.trainer.trainer(model, loss_fn, dataloader, device, optimizer=torch.optim.Adam, learning_rate=0.001, preprocess_fn=None, epochs=20, verbose=1)[source]

Train PyTorch model.

  • model (Union[Module, Sequential]) – Model to train.

  • loss_fn (Callable) – Loss function used for training.

  • dataloader (DataLoader) – PyTorch dataloader.

  • optimizer (Callable) – Optimizer used for training.

  • learning_rate (float) – Optimizer’s learning rate.

  • preprocess_fn (Optional[Callable]) – Preprocessing function applied to each training batch.

  • epochs (int) – Number of training epochs.

  • verbose (int) – Whether to print training progress.

Return type