alibi_detect.models.tensorflow.trainer module
- alibi_detect.models.tensorflow.trainer.trainer(model, loss_fn, x_train, y_train=None, dataset=None, optimizer=tensorflow.keras.optimizers.Adam, loss_fn_kwargs=None, preprocess_fn=None, epochs=20, reg_loss_fn=<function <lambda>>, batch_size=64, buffer_size=1024, verbose=True, log_metric=None, callbacks=None)[source]
Train TensorFlow model.
- Parameters:
model – Model to train.
loss_fn – Loss function used for training.
x_train – Training data.
y_train – Training labels.
dataset – Training dataset which returns (x, y).
optimizer – Optimizer used for training.
loss_fn_kwargs – Kwargs for loss function.
preprocess_fn – Preprocessing function applied to each training batch.
epochs – Number of training epochs.
reg_loss_fn – Allows an additional regularisation term to be defined as reg_loss_fn(model)
batch_size – Batch size used for training.
buffer_size – Maximum number of elements that will be buffered when prefetching.
verbose – Whether to print training progress.
log_metric – Additional metrics whose progress will be displayed if verbose equals True.
callbacks – Callbacks used during training.