This page was generated from doc/source/methods/adversarialvae.ipynb.


Variational Auto-Encoder Adversarial Detector


The adversarial VAE detector is first trained on a batch of unlabeled, but normal (inlier) data. Unsupervised or semi-supervised training is desirable since labeled data is often scarce. The loss is however different from traditional VAE training and focuses on minimizing the KL-divergence between a classifier’s class prediction probabilities on the original and reconstructed data by the VAE. When an adversarial instance is fed to the VAE, the KL-divergence between the predictions on the adversarial example and the reconstruction is large. The reconstruction does not contain the adversarial artefacts and has a different prediction distribution. As a result, the adversarial instance is flagged. The algorithm works well on tabular and image data.




  • threshold: threshold value above which the instance is flagged as an adversarial instance.

  • latent_dim: latent dimension of the VAE.

  • encoder_net: tf.keras.Sequential instance containing the encoder network. Example:

encoder_net = tf.keras.Sequential(
      InputLayer(input_shape=(32, 32, 3)),
      Conv2D(64, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(128, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2D(512, 4, strides=2, padding='same', activation=tf.nn.relu)
  • decoder_net: tf.keras.Sequential instance containing the decoder network. Example:

decoder_net = tf.keras.Sequential(
      Reshape(target_shape=(4, 4, 128)),
      Conv2DTranspose(256, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2DTranspose(64, 4, strides=2, padding='same', activation=tf.nn.relu),
      Conv2DTranspose(3, 4, strides=2, padding='same', activation='sigmoid')
  • vae: instead of using a separate encoder and decoder, the VAE can also be passed as a tf.keras.Model.

  • model: the classifier as a tf.keras.Model. Example:

inputs = tf.keras.Input(shape=(input_dim,))
outputs = tf.keras.layers.Dense(output_dim, activation=tf.nn.softmax)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
  • samples: number of samples drawn during detection for each instance to detect.

  • beta: weight on the KL-divergence loss term following the \(\beta\)-VAE framework. Default equals 0.

  • data_type: can specify data type added to metadata. E.g. ‘tabular’ or ‘image’.

Initialized outlier detector example:

from import AdversarialVAE

ad = AdversarialVAE(


We then need to train the adversarial detector. The following parameters can be specified:

  • X: training batch as a numpy array of preferably normal data.

  • loss_fn: loss function used for training. Defaults to the custom adversarial loss.

  • w_model: weight on the loss term minimizing the KL-divergence between model prediction probabilities on the original and reconstructed instance. Defaults to 1.

  • w_recon: weight on the elbo loss term. Defaults to 0.

  • optimizer: optimizer used for training. Defaults to Adam with learning rate 1e-3.

  • cov_elbo: dictionary with covariance matrix options in case the elbo loss function is used. Either use the full covariance matrix inferred from X (dict(cov_full=None)), only the variance (dict(cov_diag=None)) or a float representing the same standard deviation for each feature (e.g. dict(sim=.05)) which is the default.

  • epochs: number of training epochs.

  • batch_size: batch size used during training.

  • verbose: boolean whether to print training progress.

  • log_metric: additional metrics whose progress will be displayed if verbose equals True.

It is often hard to find a good threshold value. If we have a batch of normal and outlier data and we know approximately the percentage of normal data in the batch, we can infer a suitable threshold:



We detect adversarial instances by simply calling predict on a batch of instances X. We can also return the instance level adversarial score by setting return_instance_score to True.

The prediction takes the form of a dictionary with meta and data keys. meta contains the detector’s metadata while data is also a dictionary which contains the actual predictions stored in the following keys:

  • is_adversarial: boolean whether instances are above the threshold and therefore adversarial instances. The array is of shape (batch size,).

  • instance_score: contains instance level scores if return_instance_score equals True.

preds = ad.predict(