alibi.explainers.backends.pytorch.cfrl_base module
This module contains utility functions for the Counterfactual with Reinforcement Learning base class,
alibi.explainers.cfrl_base
for the Pytorch backend.
- class alibi.explainers.backends.pytorch.cfrl_base.PtCounterfactualRLDataset(X, preprocessor, predictor, conditional_func, batch_size)[source]
Bases:
CounterfactualRLDataset
,Dataset
Pytorch backend datasets.
- __init__(X, preprocessor, predictor, conditional_func, batch_size)[source]
Constructor.
- Parameters:
X (
ndarray
) – Array of input instances. The input should NOT be preprocessed as it will be preprocessed when calling the preprocessor function.preprocessor (
Callable
) – Preprocessor function. This function correspond to the preprocessing steps applied to the auto-encoder model.predictor (
Callable
) – Prediction function. The classifier function should expect the input in the original format and preprocess it internally in the predictor if necessary.conditional_func (
Callable
) – Conditional function generator. Given an preprocessed input array, the functions generates a conditional array.batch_size (
int
) – Dimension of the batch used during training. The same batch size is used to infer the classification labels of the input dataset.
- alibi.explainers.backends.pytorch.cfrl_base.add_noise(Z_cf, noise, act_low, act_high, step, exploration_steps, device, **kwargs)[source]
Add noise to the counterfactual embedding.
- Parameters:
Z_cf (
Tensor
) – Counterfactual embedding.noise (
NormalActionNoise
) – Noise generator object.act_low (
float
) – Action lower bound.act_high (
float
) – Action upper bound.step (
int
) – Training step.exploration_steps (
int
) – Number of exploration steps. For the first exploration_steps, the noised counterfactual embedding is sampled uniformly at random.device (
device
) – Device to send data to.
- Return type:
Tensor
- Returns:
Z_cf_tilde – Noised counterfactual embedding.
- alibi.explainers.backends.pytorch.cfrl_base.consistency_loss(Z_cf_pred, Z_cf_tgt)[source]
Default 0 consistency loss.
- Parameters:
Z_cf_pred (
Tensor
) – Counterfactual embedding prediction.Z_cf_tgt (
Tensor
) – Counterfactual embedding target.
- Returns:
0 consistency loss.
- alibi.explainers.backends.pytorch.cfrl_base.data_generator(X, encoder_preprocessor, predictor, conditional_func, batch_size, shuffle, num_workers, **kwargs)[source]
Constructs a tensorflow data generator.
- Parameters:
X (
ndarray
) – Array of input instances. The input should NOT be preprocessed as it will be preprocessed when calling the preprocessor function.encoder_preprocessor (
Callable
) – Preprocessor function. This function correspond to the preprocessing steps applied to the encoder/auto-encoder model.predictor (
Callable
) – Prediction function. The classifier function should expect the input in the original format and preprocess it internally in the predictor if necessary.conditional_func (
Callable
) – Conditional function generator. Given an preprocessed input array, the functions generates a conditional array.batch_size (
int
) – Dimension of the batch used during training. The same batch size is used to infer the classification labels of the input dataset.shuffle (
bool
) – Whether to shuffle the dataset each epoch.True
by default.num_workers (
int
) – Number of worker processes to be created.**kwargs – Other arguments. Not used.
- alibi.explainers.backends.pytorch.cfrl_base.decode(Z, decoder, device, **kwargs)
Decodes an embedding tensor.
- Parameters:
Z (
Tensor
) – Embedding tensor to be decoded.decoder (
Module
) – Pretrained decoder network.device (
device
) – Device to sent data to.
- Returns:
Embedding tensor decoding.
- alibi.explainers.backends.pytorch.cfrl_base.encode(X, encoder, device, **kwargs)
Encodes the input tensor.
- Parameters:
X (
Tensor
) – Input to be encoded.encoder (
Module
) – Pretrained encoder network.device (
device
) – Device to send data to.
- Returns:
Input encoding.
- alibi.explainers.backends.pytorch.cfrl_base.generate_cf(Z, Y_m, Y_t, C, encoder, decoder, actor, device, **kwargs)
Generates counterfactual embedding.
- Parameters:
Z (
Tensor
) – Input embedding tensor.Y_m (
Tensor
) – Input classification label.Y_t (
Tensor
) – Target counterfactual classification label.C (
Optional
[Tensor
]) – Conditional tensor.encoder (
Module
) – Pretrained encoder network.decoder (
Module
) – Pretrained decoder network.actor (
Module
) – Actor network. The model generates the counterfactual embedding.device (
device
) – Device object to be used.
- Return type:
Tensor
- Returns:
Z_cf – Counterfactual embedding.
- alibi.explainers.backends.pytorch.cfrl_base.get_actor(hidden_dim, output_dim)[source]
Constructs the actor network.
- alibi.explainers.backends.pytorch.cfrl_base.get_critic(hidden_dim)[source]
Constructs the critic network.
- Parameters:
hidden_dim (
int
) – Critic’s hidden dimension.- Return type:
Module
- Returns:
Critic network.
- alibi.explainers.backends.pytorch.cfrl_base.get_device()[source]
Checks if cuda is available. If available, use cuda by default, else use cpu.
- Return type:
device
- Returns:
Device to be used.
- alibi.explainers.backends.pytorch.cfrl_base.get_optimizer(model, lr=0.001)[source]
Constructs default Adam optimizer.
- Return type:
Optimizer
- Returns:
Default optimizer.
- alibi.explainers.backends.pytorch.cfrl_base.load_model(path)[source]
Loads a model and its optimizer.
- alibi.explainers.backends.pytorch.cfrl_base.save_model(path, model)[source]
Saves a model and its optimizer.
- alibi.explainers.backends.pytorch.cfrl_base.set_seed(seed=13)[source]
Sets a seed to ensure reproducibility.
- Parameters:
seed (
int
) – Seed to be set.
- alibi.explainers.backends.pytorch.cfrl_base.sparsity_loss(X_hat_cf, X)[source]
Default L1 sparsity loss.
- alibi.explainers.backends.pytorch.cfrl_base.to_numpy(X)[source]
Converts given tensor to numpy array.
- alibi.explainers.backends.pytorch.cfrl_base.to_tensor(X, device, **kwargs)[source]
Converts tensor to torch.Tensor
- Return type:
Optional
[Tensor
]- Returns:
torch.Tensor conversion.
- alibi.explainers.backends.pytorch.cfrl_base.update_actor_critic(encoder, decoder, critic, actor, optimizer_critic, optimizer_actor, sparsity_loss, consistency_loss, coeff_sparsity, coeff_consistency, X, X_cf, Z, Z_cf_tilde, Y_m, Y_t, C, R_tilde, device, **kwargs)[source]
Training step. Updates actor and critic networks including additional losses.
- Parameters:
encoder (
Module
) – Pretrained encoder network.decoder (
Module
) – Pretrained decoder network.critic (
Module
) – Critic network.actor (
Module
) – Actor network.optimizer_critic (
Optimizer
) – Critic’s optimizer.optimizer_actor (
Optimizer
) – Actor’s optimizer.sparsity_loss (
Callable
) – Sparsity loss function.consistency_loss (
Callable
) – Consistency loss function.coeff_sparsity (
float
) – Sparsity loss coefficient.coeff_consistency (
float
) – Consistency loss coefficientX (
ndarray
) – Input array.X_cf (
ndarray
) – Counterfactual array.Z (
ndarray
) – Input embedding.Z_cf_tilde (
ndarray
) – Noised counterfactual embedding.Y_m (
ndarray
) – Input classification label.Y_t (
ndarray
) – Target counterfactual classification label.C (
Optional
[ndarray
]) – Conditional tensor.R_tilde (
ndarray
) – Noised counterfactual reward.device (
device
) – Torch device object.**kwargs – Other arguments. Not used.
- Returns:
Dictionary of losses.