Source code for alibi_detect.models.tensorflow.trainer

from functools import partial
import numpy as np
import tensorflow as tf
from typing import Callable, Tuple


[docs] def trainer( model: tf.keras.Model, loss_fn: tf.keras.losses, x_train: np.ndarray, y_train: np.ndarray = None, dataset: tf.keras.utils.Sequence = None, optimizer: tf.keras.optimizers = tf.keras.optimizers.Adam, loss_fn_kwargs: dict = None, preprocess_fn: Callable = None, epochs: int = 20, reg_loss_fn: Callable = (lambda model: 0), batch_size: int = 64, buffer_size: int = 1024, verbose: bool = True, log_metric: Tuple[str, "tf.keras.metrics"] = None, callbacks: tf.keras.callbacks = None ) -> None: """ Train TensorFlow model. Parameters ---------- model Model to train. loss_fn Loss function used for training. x_train Training data. y_train Training labels. dataset Training dataset which returns (x, y). optimizer Optimizer used for training. loss_fn_kwargs Kwargs for loss function. preprocess_fn Preprocessing function applied to each training batch. epochs Number of training epochs. reg_loss_fn Allows an additional regularisation term to be defined as reg_loss_fn(model) batch_size Batch size used for training. buffer_size Maximum number of elements that will be buffered when prefetching. verbose Whether to print training progress. log_metric Additional metrics whose progress will be displayed if verbose equals True. callbacks Callbacks used during training. """ optimizer = optimizer() if isinstance(optimizer, type) else optimizer return_xy = False if not isinstance(dataset, tf.keras.utils.Sequence) and y_train is None else True if not isinstance(dataset, tf.keras.utils.Sequence): # create dataset train_data = x_train if y_train is None else (x_train, y_train) dataset = tf.data.Dataset.from_tensor_slices(train_data) dataset = dataset.shuffle(buffer_size=buffer_size).batch(batch_size) n_minibatch = len(dataset) if loss_fn_kwargs: loss_fn = partial(loss_fn, **loss_fn_kwargs) # iterate over epochs for epoch in range(epochs): if verbose: pbar = tf.keras.utils.Progbar(n_minibatch, 1) if hasattr(dataset, 'on_epoch_end'): dataset.on_epoch_end() loss_val_ma = 0. for step, data in enumerate(dataset): x, y = data if return_xy else (data, None) if isinstance(preprocess_fn, Callable): # type: ignore x = preprocess_fn(x) with tf.GradientTape() as tape: y_hat = model(x) y = x if y is None else y if isinstance(loss_fn, Callable): # type: ignore args = [y, y_hat] if tf.is_tensor(y_hat) else [y] + list(y_hat) loss = loss_fn(*args) else: loss = 0. if model.losses: # additional model losses loss += sum(model.losses) loss += reg_loss_fn(model) # alternative way they might be specified grads = tape.gradient(loss, model.trainable_weights) optimizer.apply_gradients(zip(grads, model.trainable_weights)) if verbose: loss_val = loss.numpy() if loss_val.shape: if loss_val.shape[0] != batch_size: if len(loss_val.shape) == 1: shape = (batch_size - loss_val.shape[0], ) elif len(loss_val.shape) == 2: shape = (batch_size - loss_val.shape[0], loss_val.shape[1]) # type: ignore add_mean = np.ones(shape) * loss_val.mean() loss_val = np.r_[loss_val, add_mean] loss_val_ma = loss_val_ma + (loss_val - loss_val_ma) / (step + 1) pbar_values = [('loss_ma', loss_val_ma)] if log_metric is not None: log_metric[1](y, y_hat) pbar_values.append((log_metric[0], log_metric[1].result().numpy())) pbar.add(1, values=pbar_values)