Source code for alibi.explainers.backends.tensorflow.cfrl_tabular

"""
This module contains utility functions for the Counterfactual with Reinforcement Learning tabular class (`cfrl_tabular`)
for the Tensorflow backend.
"""
from typing import List, Dict, Union

import numpy as np
import tensorflow as tf

from alibi.explainers.backends.cfrl_tabular import split_ohe, generate_condition  # noqa: F401

# The following methods are included since `alibi.explainers.backends.pytorch.cfrl_tabular` is an extension to the
# `alibi.explainers.backends.pytorch.cfrl_base.py`. In the explainer class `alibi.explainers.cfrl_tabular` the
# access to the backend specific methods is performed through `self.backend` which is of `types.ModuleType`. Since
# some of the methods imported below are common for both data modalities and are access through `self.backend`
# we import them here, without being used explicitly in this module.

from alibi.explainers.backends.tensorflow.cfrl_base import get_actor, get_critic, get_optimizer, data_generator, \
    encode, decode, generate_cf, update_actor_critic, add_noise, to_numpy, to_tensor, set_seed, \
    save_model, load_model, initialize_optimizers, initialize_actor_critic  # noqa: F403, F401


[docs] def sample_differentiable(X_hat_split: List[tf.Tensor], category_map: Dict[int, List[str]]) -> List[tf.Tensor]: """ Samples differentiable reconstruction. Parameters ---------- X_hat_split List of reconstructed columns form the auto-encoder. category_map Dictionary of category mapping. The keys are column indexes and the values are lists containing the possible values for an attribute. Returns ------- Differentiable reconstruction. """ num_attr = len(X_hat_split) - len(category_map) cat_attr = len(category_map) X_out = [] # Pass numerical attributes as they are if num_attr > 0: X_out.append(X_hat_split[0]) # Sample categorical attributes if cat_attr > 0: for head in X_hat_split[-cat_attr:]: out = tf.argmax(head, axis=1) # Transform to one-hot encoding out = tf.one_hot(out, depth=head.shape[1]) proba = tf.nn.softmax(head, axis=1) out = out - tf.stop_gradient(proba) + proba X_out.append(out) return X_out
[docs] def l0_ohe(input: tf.Tensor, target: tf.Tensor, reduction: str = 'none') -> tf.Tensor: """ Computes the L0 loss for a one-hot encoding representation. Parameters ---------- input Input tensor. target Target tensor reduction Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Returns ------- L0 loss. """ # Order matters as the gradient of zeros will still flow if reversed order. Maybe consider clipping a bit higher? eps = 1e-7 / input.shape[1] loss = tf.reduce_sum(tf.maximum(eps + tf.zeros_like(input), target - input), axis=1) if reduction == 'none': return loss if reduction == 'mean': return tf.reduce_mean(loss) if reduction == 'sum': return tf.reduce_sum(loss) raise ValueError(f"Reduction {reduction} not implemented.")
[docs] def l1_loss(input: tf.Tensor, target=tf.Tensor, reduction: str = 'none') -> tf.Tensor: """ Computes the L1 loss. Parameters ---------- input Input tensor. target Target tensor reduction Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Returns ------- L1 loss. """ loss = tf.abs(input - target) if reduction == 'none': return loss if reduction == 'mean': return tf.reduce_mean(loss) if reduction == 'sum': return tf.reduce_sum(loss) raise ValueError(f"Reduction {reduction} not implemented.")
[docs] def sparsity_loss(X_hat_split: List[tf.Tensor], X_ohe: tf.Tensor, category_map: Dict[int, List[str]], weight_num: float = 1.0, weight_cat: float = 1.0): """ Computes heterogeneous sparsity loss. Parameters ---------- X_hat_split List of reconstructed columns form the auto-encoder. X_ohe One-hot encoded representation of the input. category_map Dictionary of category mapping. The keys are column indexes and the values are lists containing the possible values for an attribute. weight_num Numerical loss weight. weight_cat Categorical loss weight. Returns ------- Heterogeneous sparsity loss. """ # Split the input into a list of tensor, where each element corresponds to a network head X_ohe_num_split, X_ohe_cat_split = split_ohe(X_ohe=X_ohe, category_map=category_map) # Sample differentiable output X_ohe_hat_split = sample_differentiable(X_hat_split=X_hat_split, category_map=category_map) # Define numerical and categorical loss num_loss, cat_loss = 0., 0. offset = 0 # Compute numerical loss if len(X_ohe_num_split) > 0: offset = 1 num_loss = tf.reduce_mean(l1_loss(input=X_ohe_hat_split[0], target=X_ohe_num_split[0], reduction='none')) # Compute categorical loss if len(X_ohe_cat_split) > 0: for i in range(len(X_ohe_cat_split)): cat_loss += tf.reduce_mean(l0_ohe(input=X_ohe_hat_split[i + offset], target=X_ohe_cat_split[i], reduction='none')) cat_loss /= len(X_ohe_cat_split) return {"sparsity_num_loss": weight_num * num_loss, "sparsity_cat_loss": weight_cat * cat_loss}
[docs] def consistency_loss(Z_cf_pred: tf.Tensor, Z_cf_tgt: Union[np.ndarray, tf.Tensor], **kwargs): """ Computes heterogeneous consistency loss. Parameters ---------- Z_cf_pred Counterfactual embedding prediction. Z_cf_tgt Counterfactual embedding target. Returns ------- Heterogeneous consistency loss. """ # Compute consistency loss loss = tf.reduce_mean(tf.square(Z_cf_pred - Z_cf_tgt)) return {"consistency_loss": loss}