alibi_detect.models.tensorflow.trainer module

alibi_detect.models.tensorflow.trainer.trainer(model, loss_fn, x_train, y_train=None, dataset=None, optimizer=tensorflow.keras.optimizers.Adam, loss_fn_kwargs=None, preprocess_fn=None, epochs=20, reg_loss_fn=<function <lambda>>, batch_size=64, buffer_size=1024, verbose=True, log_metric=None, callbacks=None)[source]

Train TensorFlow model.

Parameters:
  • model – Model to train.

  • loss_fn – Loss function used for training.

  • x_train – Training data.

  • y_train – Training labels.

  • dataset – Training dataset which returns (x, y).

  • optimizer – Optimizer used for training.

  • loss_fn_kwargs – Kwargs for loss function.

  • preprocess_fn – Preprocessing function applied to each training batch.

  • epochs – Number of training epochs.

  • reg_loss_fn – Allows an additional regularisation term to be defined as reg_loss_fn(model)

  • batch_size – Batch size used for training.

  • buffer_size – Maximum number of elements that will be buffered when prefetching.

  • verbose – Whether to print training progress.

  • log_metric – Additional metrics whose progress will be displayed if verbose equals True.

  • callbacks – Callbacks used during training.