This page was generated from doc/source/methods/adversarialvae.ipynb.
Variational Auto-Encoder Adversarial Detector¶
The adversarial VAE detector is first trained on a batch of unlabeled, but normal (inlier) data. Unsupervised or semi-supervised training is desirable since labeled data is often scarce. The loss is however different from traditional VAE training and focuses on minimizing the KL-divergence between a classifier’s class prediction probabilities on the original and reconstructed data by the VAE. When an adversarial instance is fed to the VAE, the KL-divergence between the predictions on the adversarial example and the reconstruction is large. The reconstruction does not contain the adversarial artefacts and has a different prediction distribution. As a result, the adversarial instance is flagged. The algorithm works well on tabular and image data.
threshold: threshold value above which the instance is flagged as an adversarial instance.
latent_dim: latent dimension of the VAE.
tf.keras.Sequentialinstance containing the encoder network. Example:
encoder_net = tf.keras.Sequential( [ InputLayer(input_shape=(32, 32, 3)), Conv2D(64, 4, strides=2, padding='same', activation=tf.nn.relu), Conv2D(128, 4, strides=2, padding='same', activation=tf.nn.relu), Conv2D(512, 4, strides=2, padding='same', activation=tf.nn.relu) ])
tf.keras.Sequentialinstance containing the decoder network. Example:
decoder_net = tf.keras.Sequential( [ InputLayer(input_shape=(latent_dim,)), Dense(4*4*128), Reshape(target_shape=(4, 4, 128)), Conv2DTranspose(256, 4, strides=2, padding='same', activation=tf.nn.relu), Conv2DTranspose(64, 4, strides=2, padding='same', activation=tf.nn.relu), Conv2DTranspose(3, 4, strides=2, padding='same', activation='sigmoid') ])
vae: instead of using a separate encoder and decoder, the VAE can also be passed as a
model: the classifier as a
inputs = tf.keras.Input(shape=(input_dim,)) outputs = tf.keras.layers.Dense(output_dim, activation=tf.nn.softmax)(inputs) model = tf.keras.Model(inputs=inputs, outputs=outputs)
samples: number of samples drawn during detection for each instance to detect.
beta: weight on the KL-divergence loss term following the \(\beta\)-VAE framework. Default equals 0.
data_type: can specify data type added to metadata. E.g. ‘tabular’ or ‘image’.
Initialized outlier detector example:
from alibi_detect.ad import AdversarialVAE ad = AdversarialVAE( threshold=0.1, encoder_net=encoder_net, decoder_net=decoder_net, model=model, latent_dim=50, samples=10 )
We then need to train the adversarial detector. The following parameters can be specified:
X: training batch as a numpy array of preferably normal data.
loss_fn: loss function used for training. Defaults to the custom adversarial loss.
w_model: weight on the loss term minimizing the KL-divergence between model prediction probabilities on the original and reconstructed instance. Defaults to 1.
w_recon: weight on the elbo loss term. Defaults to 0.
optimizer: optimizer used for training. Defaults to Adam with learning rate 1e-3.
cov_elbo: dictionary with covariance matrix options in case the elbo loss function is used. Either use the full covariance matrix inferred from X (dict(cov_full=None)), only the variance (dict(cov_diag=None)) or a float representing the same standard deviation for each feature (e.g. dict(sim=.05)) which is the default.
epochs: number of training epochs.
batch_size: batch size used during training.
verbose: boolean whether to print training progress.
log_metric: additional metrics whose progress will be displayed if verbose equals True.
ad.fit( X_train, epochs=5 )
It is often hard to find a good threshold value. If we have a batch of normal and outlier data and we know approximately the percentage of normal data in the batch, we can infer a suitable threshold:
ad.infer_threshold( X, threshold_perc=95 )
We detect adversarial instances by simply calling
predict on a batch of instances
X. We can also return the instance level adversarial score by setting
return_instance_score to True.
The prediction takes the form of a dictionary with
meta contains the detector’s metadata while
data is also a dictionary which contains the actual predictions stored in the following keys:
is_adversarial: boolean whether instances are above the threshold and therefore adversarial instances. The array is of shape (batch size,).
instance_score: contains instance level scores if
preds = ad.predict( X, return_instance_score=True )