alibi.explainers.backends.tensorflow.cfrl_base module
This module contains utility functions for the Counterfactual with Reinforcement Learning base class,
alibi.explainers.cfrl_base
, for the Tensorflow backend.
- class alibi.explainers.backends.tensorflow.cfrl_base.TfCounterfactualRLDataset(X, preprocessor, predictor, conditional_func, batch_size, shuffle=True)[source]
Bases:
CounterfactualRLDataset
,Sequence
Tensorflow backend datasets.
- __init__(X, preprocessor, predictor, conditional_func, batch_size, shuffle=True)[source]
Constructor.
- Parameters:
X (
ndarray
) – Array of input instances. The input should NOT be preprocessed as it will be preprocessed when calling the preprocessor function.preprocessor (
Callable
) – Preprocessor function. This function correspond to the preprocessing steps applied to the encoder/auto-encoder model.predictor (
Callable
) – Prediction function. The classifier function should expect the input in the original format and preprocess it internally in the predictor if necessary.conditional_func (
Callable
) – Conditional function generator. Given an pre-processed input array, the functions generates a conditional array.batch_size (
int
) – Dimension of the batch used during training. The same batch size is used to infer the classification labels of the input dataset.shuffle (
bool
) – Whether to shuffle the dataset each epoch.True
by default.
- alibi.explainers.backends.tensorflow.cfrl_base.add_noise(Z_cf, noise, act_low, act_high, step, exploration_steps, **kwargs)[source]
Add noise to the counterfactual embedding.
- Parameters:
Z_cf (
Union
[Tensor
,ndarray
]) – Counterfactual embedding.noise (
NormalActionNoise
) – Noise generator object.act_low (
float
) – Noise lower bound.act_high (
float
) – Noise upper bound.step (
int
) – Training step.exploration_steps (
int
) – Number of exploration steps. For the first exploration_steps, the noised counterfactual embedding is sampled uniformly at random.**kwargs – Other arguments. Not used.
- Return type:
Tensor
- Returns:
Z_cf_tilde – Noised counterfactual embedding.
- alibi.explainers.backends.tensorflow.cfrl_base.consistency_loss(Z_cf_pred, Z_cf_tgt)[source]
Default 0 consistency loss.
- Parameters:
Z_cf_pred (
Tensor
) – Counterfactual embedding prediction.Z_cf_tgt (
Tensor
) – Counterfactual embedding target.
- Returns:
0 consistency loss.
- alibi.explainers.backends.tensorflow.cfrl_base.data_generator(X, encoder_preprocessor, predictor, conditional_func, batch_size, shuffle=True, **kwargs)[source]
Constructs a tensorflow data generator.
- Parameters:
X (
ndarray
) – Array of input instances. The input should NOT be preprocessed as it will be preprocessed when calling the preprocessor function.encoder_preprocessor (
Callable
) – Preprocessor function. This function correspond to the preprocessing steps applied to the encoder/auto-encoder model.predictor (
Callable
) – Prediction function. The classifier function should expect the input in the original format and preprocess it internally in the predictor if necessary.conditional_func (
Callable
) – Conditional function generator. Given an preprocessed input array, the functions generates a conditional array.batch_size (
int
) – Dimension of the batch used during training. The same batch size is used to infer the classification labels of the input dataset.shuffle (
bool
) – Whether to shuffle the dataset each epoch.True
by default.**kwargs – Other arguments. Not used.
- alibi.explainers.backends.tensorflow.cfrl_base.decode(Z, decoder, **kwargs)[source]
Decodes an embedding tensor.
- Parameters:
Z (
Union
[Tensor
,ndarray
]) – Embedding tensor to be decoded.decoder (
Model
) – Pretrained decoder network.**kwargs – Other arguments. Not used.
- Returns:
Embedding tensor decoding.
- alibi.explainers.backends.tensorflow.cfrl_base.encode(X, encoder, **kwargs)[source]
Encodes the input tensor.
- Parameters:
X (
Union
[Tensor
,ndarray
]) – Input to be encoded.encoder (
Model
) – Pretrained encoder network.**kwargs – Other arguments. Not used.
- Return type:
Tensor
- Returns:
Input encoding.
- alibi.explainers.backends.tensorflow.cfrl_base.generate_cf(Z, Y_m, Y_t, C, actor, **kwargs)[source]
Generates counterfactual embedding.
- Parameters:
Z (
Union
[ndarray
,Tensor
]) – Input embedding tensor.Y_m (
Union
[ndarray
,Tensor
]) – Input classification label.Y_t (
Union
[ndarray
,Tensor
]) – Target counterfactual classification label.actor (
Model
) – Actor network. The model generates the counterfactual embedding.**kwargs – Other arguments. Not used.
- Return type:
Tensor
- Returns:
Z_cf – Counterfactual embedding.
- alibi.explainers.backends.tensorflow.cfrl_base.get_actor(hidden_dim, output_dim)[source]
Constructs the actor network.
- alibi.explainers.backends.tensorflow.cfrl_base.get_critic(hidden_dim)[source]
Constructs the critic network.
- Parameters:
hidden_dim (
int
) – Critic’s hidden dimension.- Return type:
Layer
- Returns:
Critic network.
- alibi.explainers.backends.tensorflow.cfrl_base.get_optimizer(model=None, lr=0.001)[source]
Constructs default Adam optimizer.
- alibi.explainers.backends.tensorflow.cfrl_base.initialize_actor_critic(actor, critic, Z, Z_cf_tilde, Y_m, Y_t, C, **kwargs)[source]
Initialize actor and critic layers by passing a dummy zero tensor.
- Parameters:
actor – Actor model.
critic – Critic model.
Z – Input embedding.
Z_cf_tilde – Noised counterfactual embedding.
Y_m – Input classification label.
Y_t – Target counterfactual classification label.
C – Conditional tensor.
**kwargs – Other arguments. Not used.
- alibi.explainers.backends.tensorflow.cfrl_base.initialize_optimizer(optimizer, model)[source]
Initializes an optimizer given a model.
- Parameters:
optimizer (
Optimizer
) – Optimizer to be initialized.model (
Model
) – Model to be optimized
- Return type:
- alibi.explainers.backends.tensorflow.cfrl_base.initialize_optimizers(optimizer_actor, optimizer_critic, actor, critic, **kwargs)[source]
Initializes the actor and critic optimizers.
- Parameters:
optimizer_actor – Actor optimizer to be initialized.
optimizer_critic – Critic optimizer to be initialized.
actor – Actor model to be optimized.
critic – Critic model to be optimized.
**kwargs – Other arguments. Not used.
- Return type:
- alibi.explainers.backends.tensorflow.cfrl_base.load_model(path)[source]
Loads a model and its optimizer.
- alibi.explainers.backends.tensorflow.cfrl_base.save_model(path, model)[source]
Saves a model and its optimizer.
- alibi.explainers.backends.tensorflow.cfrl_base.set_seed(seed=13)[source]
Sets a seed to ensure reproducibility. Does NOT ensure reproducibility.
- Parameters:
seed (
int
) – seed to be set
- alibi.explainers.backends.tensorflow.cfrl_base.sparsity_loss(X_hat_cf, X)[source]
Default L1 sparsity loss.
- alibi.explainers.backends.tensorflow.cfrl_base.to_numpy(X)[source]
Converts given tensor to numpy array.
- alibi.explainers.backends.tensorflow.cfrl_base.to_tensor(X, **kwargs)[source]
Converts tensor to tf.Tensor.
- alibi.explainers.backends.tensorflow.cfrl_base.update_actor_critic(encoder, decoder, critic, actor, optimizer_critic, optimizer_actor, sparsity_loss, consistency_loss, coeff_sparsity, coeff_consistency, X, X_cf, Z, Z_cf_tilde, Y_m, Y_t, C, R_tilde, **kwargs)
Training step. Updates actor and critic networks including additional losses.
- Parameters:
encoder (
Model
) – Pretrained encoder network.decoder (
Model
) – Pretrained decoder network.critic (
Model
) – Critic network.actor (
Model
) – Actor network.optimizer_critic (
Optimizer
) – Critic’s optimizer.optimizer_actor (
Optimizer
) – Actor’s optimizer.sparsity_loss (
Callable
) – Sparsity loss function.consistency_loss (
Callable
) – Consistency loss function.coeff_sparsity (
float
) – Sparsity loss coefficient.coeff_consistency (
float
) – Consistency loss coefficientX (
ndarray
) – Input array.X_cf (
ndarray
) – Counterfactual array.Z (
ndarray
) – Input embedding.Z_cf_tilde (
ndarray
) – Noised counterfactual embedding.Y_m (
ndarray
) – Input classification label.Y_t (
ndarray
) – Target counterfactual classification label.C (
Optional
[ndarray
]) – Conditional tensor.R_tilde (
ndarray
) – Noised counterfactual reward.**kwargs – Other arguments. Not used.
- Return type:
- Returns:
Dictionary of losses.