Source code for alibi_detect.models.tensorflow.losses

from typing import Optional

import tensorflow as tf
from tensorflow.keras.layers import Flatten
from tensorflow.keras.losses import kld, categorical_crossentropy
import tensorflow_probability as tfp
from alibi_detect.models.tensorflow.gmm import gmm_params, gmm_energy


[docs] def elbo(y_true: tf.Tensor, y_pred: tf.Tensor, cov_full: Optional[tf.Tensor] = None, cov_diag: Optional[tf.Tensor] = None, sim: Optional[float] = None ) -> tf.Tensor: """ Compute ELBO loss. The covariance matrix can be specified by passing the full covariance matrix, the matrix diagonal, or a scale identity multiplier. Only one of these should be specified. If none are specified, the identity matrix is used. Parameters ---------- y_true Labels. y_pred Predictions. cov_full Full covariance matrix. cov_diag Diagonal (variance) of covariance matrix. sim Scale identity multiplier. Returns ------- ELBO loss value. Example ------- >>> import tensorflow as tf >>> from alibi_detect.models.tensorflow.losses import elbo >>> y_true = tf.constant([[0.0, 1.0], [1.0, 0.0]]) >>> y_pred = tf.constant([[0.1, 0.9], [0.8, 0.2]]) >>> # Specifying scale identity multiplier >>> elbo(y_true, y_pred, sim=1.0) >>> # Specifying covariance matrix diagonal >>> elbo(y_true, y_pred, cov_diag=tf.ones(2)) >>> # Specifying full covariance matrix >>> elbo(y_true, y_pred, cov_full=tf.eye(2)) """ if len([x for x in [cov_full, cov_diag, sim] if x is not None]) > 1: raise ValueError('Only one of cov_full, cov_diag or sim should be specified.') y_pred_flat = Flatten()(y_pred) if isinstance(cov_full, tf.Tensor): y_mn = tfp.distributions.MultivariateNormalFullCovariance(y_pred_flat, covariance_matrix=cov_full) else: if sim: cov_diag = sim * tf.ones(y_pred_flat.shape[-1]) y_mn = tfp.distributions.MultivariateNormalDiag(y_pred_flat, scale_diag=cov_diag) loss = -tf.reduce_mean(y_mn.log_prob(Flatten()(y_true))) return loss
[docs] def loss_aegmm(x_true: tf.Tensor, x_pred: tf.Tensor, z: tf.Tensor, gamma: tf.Tensor, w_energy: float = .1, w_cov_diag: float = .005 ) -> tf.Tensor: """ Loss function used for OutlierAEGMM. Parameters ---------- x_true Batch of instances. x_pred Batch of reconstructed instances by the autoencoder. z Latent space values. gamma Membership prediction for mixture model components. w_energy Weight on sample energy loss term. w_cov_diag Weight on covariance regularizing loss term. Returns ------- Loss value. """ recon_loss = tf.reduce_mean((x_true - x_pred) ** 2) phi, mu, cov, L, log_det_cov = gmm_params(z, gamma) sample_energy, cov_diag = gmm_energy(z, phi, mu, cov, L, log_det_cov, return_mean=True) loss = recon_loss + w_energy * sample_energy + w_cov_diag * cov_diag return loss
[docs] def loss_vaegmm(x_true: tf.Tensor, x_pred: tf.Tensor, z: tf.Tensor, gamma: tf.Tensor, w_recon: float = 1e-7, w_energy: float = .1, w_cov_diag: float = .005, cov_full: tf.Tensor = None, cov_diag: tf.Tensor = None, sim: float = .05 ) -> tf.Tensor: """ Loss function used for OutlierVAEGMM. Parameters ---------- x_true Batch of instances. x_pred Batch of reconstructed instances by the variational autoencoder. z Latent space values. gamma Membership prediction for mixture model components. w_recon Weight on elbo loss term. w_energy Weight on sample energy loss term. w_cov_diag Weight on covariance regularizing loss term. cov_full Full covariance matrix. cov_diag Diagonal (variance) of covariance matrix. sim Scale identity multiplier. Returns ------- Loss value. """ recon_loss = elbo(x_true, x_pred, cov_full=cov_full, cov_diag=cov_diag, sim=sim) phi, mu, cov, L, log_det_cov = gmm_params(z, gamma) sample_energy, cov_diag = gmm_energy(z, phi, mu, cov, L, log_det_cov) loss = w_recon * recon_loss + w_energy * sample_energy + w_cov_diag * cov_diag return loss
[docs] def loss_adv_ae(x_true: tf.Tensor, x_pred: tf.Tensor, model: tf.keras.Model = None, model_hl: list = None, w_model: float = 1., w_recon: float = 0., w_model_hl: list = None, temperature: float = 1. ) -> tf.Tensor: """ Loss function used for AdversarialAE. Parameters ---------- x_true Batch of instances. x_pred Batch of reconstructed instances by the autoencoder. model A trained tf.keras model with frozen layers (layers.trainable = False). model_hl List with tf.keras models used to extract feature maps and make predictions on hidden layers. w_model Weight on model prediction loss term. w_recon Weight on MSE reconstruction error loss term. w_model_hl Weights assigned to the loss of each model in model_hl. temperature Temperature used for model prediction scaling. Temperature <1 sharpens the prediction probability distribution. Returns ------- Loss value. """ y_true = model(x_true) y_pred = model(x_pred) # apply temperature scaling if temperature != 1.: y_true = y_true ** (1 / temperature) y_true = y_true / tf.reshape(tf.reduce_sum(y_true, axis=-1), (-1, 1)) # compute K-L divergence loss loss_kld = kld(y_true, y_pred) std_kld = tf.math.reduce_std(loss_kld) loss = tf.reduce_mean(loss_kld) # add loss from optional K-L divergences extracted from hidden layers if isinstance(model_hl, list): if w_model_hl is None: w_model_hl = list(tf.ones(len(model_hl))) for m, w in zip(model_hl, w_model_hl): h_true = m(x_true) h_pred = m(x_pred) loss_kld_hl = tf.reduce_mean(kld(h_true, h_pred)) loss += tf.constant(w) * loss_kld_hl loss *= w_model # add optional reconstruction loss if w_recon > 0.: loss_recon = (x_true - x_pred) ** 2 std_recon = tf.math.reduce_std(loss_recon) w_scale = std_kld / (std_recon + 1e-10) loss_recon = w_recon * w_scale * tf.reduce_mean(loss_recon) loss += loss_recon return loss else: return loss
[docs] def loss_distillation(x_true: tf.Tensor, y_pred: tf.Tensor, model: tf.keras.Model = None, loss_type: str = 'kld', temperature: float = 1., ) -> tf.Tensor: """ Loss function used for Model Distillation. Parameters ---------- x_true Batch of data points. y_pred Batch of prediction from the distilled model. model tf.keras model. loss_type Type of loss for distillation. Supported 'kld', 'xent. temperature Temperature used for model prediction scaling. Temperature <1 sharpens the prediction probability distribution. Returns ------- Loss value. """ y_true = model(x_true) # apply temperature scaling if temperature != 1.: y_true = y_true ** (1 / temperature) y_true = y_true / tf.reshape(tf.reduce_sum(y_true, axis=-1), (-1, 1)) if loss_type == 'kld': loss_dist = kld(y_true, y_pred) elif loss_type == 'xent': loss_dist = categorical_crossentropy(y_true, y_pred, from_logits=False) else: raise NotImplementedError # compute K-L divergence loss loss = tf.reduce_mean(loss_dist) return loss