import logging
import os
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union, cast
import numpy as np
from tqdm import tqdm
from alibi.api.defaults import DEFAULT_DATA_CFRL, DEFAULT_META_CFRL
from alibi.api.interfaces import Explainer, Explanation, FitMixin
from alibi.explainers.backends.cfrl_base import (generate_empty_condition,
get_classification_reward,
get_hard_distribution,
identity_function)
from alibi.utils.frameworks import Framework, has_pytorch, has_tensorflow
if TYPE_CHECKING:
import tensorflow
import torch
if has_pytorch:
# import pytorch backend
from alibi.explainers.backends.pytorch import cfrl_base as pytorch_base_backend
if has_tensorflow:
# import tensorflow backend
from alibi.explainers.backends.tensorflow import cfrl_base as tensorflow_base_backend
# define logger
logger = logging.getLogger(__name__)
[docs]
class NormalActionNoise:
""" Normal noise generator. """
[docs]
def __init__(self, mu: float, sigma: float) -> None:
"""
Constructor.
Parameters
----------
mu
Mean of the normal noise.
sigma
Standard deviation of the noise.
"""
self.mu = mu
self.sigma = sigma
[docs]
def __call__(self, shape: Tuple[int, ...]) -> np.ndarray:
"""
Generates normal noise with the appropriate mean and standard deviation.
Parameters
----------
shape
Shape of the array to be generated
Returns
-------
Normal noise with the appropriate mean, standard deviation and shape.
"""
return self.mu + self.sigma * np.random.randn(*shape)
def __repr__(self) -> str:
return 'NormalActionNoise(mu={}, sigma={})'.format(self.mu, self.sigma)
[docs]
class ReplayBuffer:
"""
Circular experience replay buffer for `CounterfactualRL` (DDPG). When the buffer is filled, then the oldest
experience is replaced by the new one (FIFO). The experience batch size is kept constant and inferred when
the first batch of data is stored. Allowing flexible batch size can generate `tensorflow` warning due to
the `tf.function` retracing, which can lead to a drop in performance.
"""
X: np.ndarray #: Inputs buffer.
Y_m: np.ndarray #: Model's prediction buffer.
Y_t: np.ndarray #: Counterfactual targets buffer.
Z: np.ndarray #: Input embedding buffer.
Z_cf_tilde: np.ndarray #: Noised counterfactual embedding buffer.
R_tilde: np.ndarray #: Noise counterfactual rewards buffer.
[docs]
def __init__(self, size: int = 1000) -> None:
"""
Constructor.
Parameters
----------
size
Dimension of the buffer in batch size. This that the total memory allocated is proportional with the
`size x batch_size`, where `batch_size` is inferred from the first array to be stored.
"""
self.idx = 0 # cursor for the buffer
self.len = 0 # current length of the buffer
self.size = size # buffer's maximum capacity
self.batch_size = 0 # batch size (inferred during `append`)
self.C: Optional[np.ndarray] = None # buffer for the conditional tensor
[docs]
def append(self,
X: np.ndarray,
Y_m: np.ndarray,
Y_t: np.ndarray,
Z: np.ndarray,
Z_cf_tilde: np.ndarray,
C: Optional[np.ndarray],
R_tilde: np.ndarray,
**kwargs) -> None:
"""
Adds experience to the replay buffer. When the buffer is filled, then the oldest experience is replaced
by the new one (FIFO).
Parameters
----------
X
Input array.
Y_m
Model's prediction class of `X`.
Y_t
Counterfactual target class.
Z
Input's embedding.
Z_cf_tilde
Noised counterfactual embedding.
C
Conditional array.
R_tilde
Noised counterfactual reward array.
**kwargs
Other arguments. Not used.
"""
# Initialize the buffers.
if not hasattr(self, 'X'):
self.batch_size = X.shape[0]
# Allocate memory.
self.X = np.zeros((self.size * self.batch_size, *X.shape[1:]), dtype=np.float32)
self.Y_m = np.zeros((self.size * self.batch_size, *Y_m.shape[1:]), dtype=np.float32)
self.Y_t = np.zeros((self.size * self.batch_size, *Y_t.shape[1:]), dtype=np.float32)
self.Z = np.zeros((self.size * self.batch_size, *Z.shape[1:]), dtype=np.float32)
self.Z_cf_tilde = np.zeros((self.size * self.batch_size, *Z_cf_tilde.shape[1:]), dtype=np.float32)
self.R_tilde = np.zeros((self.size * self.batch_size, *R_tilde.shape[1:]), dtype=np.float32)
# Conditional tensor can be `None` when no condition is included. If it is not `None`, allocate memory.
if C is not None:
self.C = np.zeros((self.size * self.batch_size, *C.shape[1:]), dtype=np.float32)
# Increase the length of the buffer if not full.
if self.len < self.size:
self.len += 1
# Compute the first position where to add most recent experience.
start = self.batch_size * self.idx
# Add new data / replace old experience (note that a full batch is added at once).
self.X[start:start + self.batch_size] = X
self.Y_m[start:start + self.batch_size] = Y_m
self.Y_t[start:start + self.batch_size] = Y_t
self.Z[start:start + self.batch_size] = Z
self.Z_cf_tilde[start:start + self.batch_size] = Z_cf_tilde
self.R_tilde[start:start + self.batch_size] = R_tilde
if C is not None:
self.C = cast(np.ndarray, self.C) # helping mypy out as self.C cannot be None at this point
self.C[start:start + self.batch_size] = C
# Compute the next index. Not that if the buffer reached its maximum capacity, for the next iteration
# we start replacing old batches.
self.idx = (self.idx + 1) % self.size
[docs]
def sample(self) -> Dict[str, Optional[np.ndarray]]:
"""
Sample a batch of experience form the replay buffer.
Returns
-------
A batch experience. For a description of the keys and values returned, see parameter descriptions \
in :py:meth:`alibi.explainers.cfrl_base.ReplayBuffer.append` method. The batch size returned is the same \
as the one passed in the :py:meth:`alibi.explainers.cfrl_base.ReplayBuffer.append`.
"""
# Generate random indices to be sampled.
rand_idx = np.random.randint(low=0, high=self.len * self.batch_size, size=(self.batch_size,))
# Extract data form buffers.
X = self.X[rand_idx] # input array
Y_m = self.Y_m[rand_idx] # model's prediction
Y_t = self.Y_t[rand_idx] # counterfactual target
Z = self.Z[rand_idx] # input embedding
Z_cf_tilde = self.Z_cf_tilde[rand_idx] # noised counterfactual embedding
C = self.C[rand_idx] if (self.C is not None) else None # conditional array if exists
R_tilde = self.R_tilde[rand_idx] # noised counterfactual reward
return {
"X": X,
"Y_m": Y_m,
"Y_t": Y_t,
"Z": Z,
"Z_cf_tilde": Z_cf_tilde,
"C": C,
"R_tilde": R_tilde
}
DEFAULT_BASE_PARAMS = {
"act_noise": 0.1,
"act_low": -1.0,
"act_high": 1.0,
"replay_buffer_size": 1000,
"batch_size": 100,
"num_workers": 4,
"shuffle": True,
"exploration_steps": 100,
"update_every": 1,
"update_after": 10,
"train_steps": 100000,
"backend": "tensorflow",
"encoder_preprocessor": identity_function,
"decoder_inv_preprocessor": identity_function,
"reward_func": get_classification_reward,
"postprocessing_funcs": [],
"conditional_func": generate_empty_condition,
"callbacks": [],
"actor": None,
"critic": None,
"optimizer_actor": None,
"optimizer_critic": None,
"lr_actor": 1e-3,
"lr_critic": 1e-3,
"actor_hidden_dim": 256,
"critic_hidden_dim": 256,
}
"""
Default Counterfactual with Reinforcement Learning parameters.
- ``'act_noise'`` : ``float`` - standard deviation for the normal noise added to the actor for exploration.
- ``'act_low'`` : ``float`` - minimum action value. Each action component takes values between \
`[act_low, act_high]`.
- ``'act_high'`` : ``float`` - maximum action value. Each action component takes values between \
`[act_low, act_high]`.
- ``'replay_buffer_size'`` : ``int`` - dimension of the replay buffer in `batch_size` units. The total memory \
allocated is proportional with the `size x batch_size`.
- ``'batch_size'`` : ``int`` - training batch size.
- ``'num_workers'`` : ``int`` - number of workers used by the data loader if ``'pytorch'`` backend is selected.
- ``'shuffle'`` : ``bool`` - whether to shuffle the datasets every epoch.
- ``'exploration_steps'`` : ``int`` - number of exploration steps. For the first `exploration_steps`, the \
counterfactual embedding coordinates are sampled uniformly at random from the interval `[act_low, act_high]`.
- ``'update_every'`` : ``int`` - number of steps that should elapse between gradient updates. Regardless of the \
waiting steps, the ratio of waiting steps to gradient steps is locked to 1.
- ``'update_after'`` : ``int`` - number of steps to wait before start updating the actor and critic. This ensures \
that the replay buffers is full enough for useful updates.
- ``'backend'`` : ``str`` - backend to be used: ``'tensorflow'`` | ``'pytorch'``. Default ``'tensorflow'``.
- ``'train_steps'`` : ``int`` - number of train steps.
- ``'encoder_preprocessor'`` : ``Callable`` - encoder/auto-encoder data preprocessors. Transforms the input data \
into the format expected by the auto-encoder. By default, the identity function.
- ``'decoder_inv_preprocessor'`` : ``Callable`` - decoder/auto-encoder data inverse preprocessor. Transforms data \
from the auto-encoder output format to the original input format. Before calling the prediction function, the \
data is inverse preprocessed to match the original input format. By default, the identity function.
- ``'reward_func'`` : ``Callable`` - element-wise reward function. By default, considers classification task and \
checks if the counterfactual prediction label matches the target label. Note that this is element-wise, so a \
tensor is expected to be returned.
- ``'postprocessing_funcs'`` : ``List[Postprocessing]`` - list of post-processing functions. The function are \
applied in the order, from low to high index. Non-differentiable post-processing can be applied. The function \
expects as arguments `X_cf` - the counterfactual instance, `X` - the original input instance and `C` - the \
conditional vector, and returns the post-processed counterfactual instance `X_cf_pp` which is passed as `X_cf` \
for the following functions. By default, no post-processing is applied (empty list).
- ``'conditional_func'`` : ``Callable`` - generates a conditional vector given a pre-processed input instance. By \
default, the function returns ``None`` which is equivalent to no conditioning.
- ``'callbacks'`` : ``List[Callback]`` - list of callback functions applied at the end of each training step.
- ``'actor'`` : ``Optional[Union[tensorflow.keras.Model, torch.nn.Module]]`` - actor network.
- ``'critic;`` : ``Optional[Union[tensorflow.keras.Model, torch.nn.Module]]`` - critic network.
- ``'optimizer_actor'`` : ``Optional[Union[tensorflow.keras.optimizers.Optimizer, torch.optim.Optimizer]]`` - \
actor optimizer.
- ``'optimizer_critic'`` : ``Optional[Union[tensorflow.keras.optimizer.Optimizer, torch.optim.Optimizer]]`` - \
critic optimizer.
- ``'lr_actor'`` : ``float`` - actor learning rate.
- ``'lr_critic'`` : ``float`` - critic learning rate.
- ``'actor_hidden_dim'`` : ``int`` - actor hidden layer dimension.
- ``'critic_hidden_dim'`` : ``int`` - critic hidden layer dimension.
"""
_PARAM_TYPES = {
"primitives": [
"act_noise", "act_low", "act_high", "replay_buffer_size", "batch_size", "num_workers", "shuffle",
"exploration_steps", "update_every", "update_after", "train_steps", "backend", "actor_hidden_dim",
"critic_hidden_dim",
],
"complex": [
"encoder_preprocessor", "decoder_inv_preprocessor", "reward_func", "postprocessing_funcs", "conditional_func",
"callbacks", "actor", "critic", "optimizer_actor", "optimizer_critic", "encoder", "decoder", "predictor",
"sparsity_loss", "consistency_loss",
]
}
"""
Parameter types for serialization
- ``'primitives'`` : List[str] - list of parameters having primitive data types.
- ``'complex'`` : List[str] - list of parameters having complex data types (e.g., functions, models,\
optimizers etc.).
"""
[docs]
class CounterfactualRL(Explainer, FitMixin):
""" Counterfactual Reinforcement Learning. """
[docs]
def __init__(self,
predictor: Callable[[np.ndarray], np.ndarray],
encoder: 'Union[tensorflow.keras.Model, torch.nn.Module]',
decoder: 'Union[tensorflow.keras.Model, torch.nn.Module]',
coeff_sparsity: float,
coeff_consistency: float,
latent_dim: Optional[int] = None,
backend: str = "tensorflow",
seed: int = 0,
**kwargs):
"""
Constructor.
Parameters
----------
predictor
A callable that takes a `numpy` array of `N` data points as inputs and returns `N` outputs. For
classification task, the second dimension of the output should match the number of classes. Thus, the
output can be either a soft label distribution or a hard label distribution (i.e. one-hot encoding)
without affecting the performance since `argmax` is applied to the predictor's output.
encoder
Pretrained encoder network.
decoder
Pretrained decoder network.
coeff_sparsity
Sparsity loss coefficient.
coeff_consistency
Consistency loss coefficient.
latent_dim
Auto-encoder latent dimension. Can be omitted if the actor network is user specified.
backend
Deep learning backend: ``'tensorflow'`` | ``'pytorch'``. Default ``'tensorflow'``.
seed
Seed for reproducibility. The results are not reproducible for ``'tensorflow'`` backend.
**kwargs
Used to replace any default parameter from :py:data:`alibi.explainers.cfrl_base.DEFAULT_BASE_PARAMS`.
"""
super().__init__(meta=deepcopy(DEFAULT_META_CFRL))
# Clean backend flag.
backend = backend.strip().lower()
# Verify backend installed
CounterfactualRL._verify_backend(backend)
# Select backend.
self.backend = self._select_backend(backend, **kwargs)
# Set seed for reproducibility.
self.backend.set_seed(seed)
# Validate arguments.
self.params = self._validate_kwargs(predictor=predictor,
encoder=encoder,
decoder=decoder,
latent_dim=latent_dim,
coeff_sparsity=coeff_sparsity,
coeff_consistency=coeff_consistency,
backend=backend,
seed=seed,
**kwargs)
# If pytorch backend, the if GPU available, send everything to GPU
if self.params["backend"] == Framework.PYTORCH:
from alibi.explainers.backends.pytorch.cfrl_base import get_device
self.params.update({"device": get_device()})
# Send encoder and decoder to device.
self.params["encoder"].to(self.params["device"])
self.params["decoder"].to(self.params["device"])
# Sent actor and critic to device.
self.params["actor"].to(self.params["device"])
self.params["critic"].to(self.params["device"])
# Update meta-data with all parameters passed (correct and incorrect).
self.meta["params"].update(CounterfactualRL._serialize_params(self.params))
@staticmethod
def _serialize_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""
Parameter serialization. The function replaces object by human-readable representation.
Parameters
----------
params
Dictionary of parameters to be serialized.
Returns
-------
Human-readable replacement of data.
"""
meta = dict()
for param, value in params.items():
if param in _PARAM_TYPES["primitives"]:
# primitive types are passed as they are
meta.update({param: value})
elif param in _PARAM_TYPES["complex"]:
if isinstance(value, list):
# each complex element in the list is serialized by replacing it with a name
meta.update({param: [CounterfactualRL._get_name(v) for v in value]})
else:
# complex element is serialized by replacing it with a name
meta.update({param: CounterfactualRL._get_name(value)})
else:
# Unknown parameters are passed as they are. TODO: think of a better way to handle this.
meta.update({param: value})
return meta
@staticmethod
def _get_name(a: Any) -> str:
"""
Constructs a name for the given object. If the object has as built-in name, the name is return.
If the object has a built-in class name, the name of the class is returned. Otherwise ``'unknown'`` is returned.
Parameters
----------
a
Object to give the name for.
Returns
-------
Name of the object.
"""
if hasattr(a, "__name__"):
return a.__name__
if hasattr(a, "__class__"):
return str(a.__class__)
return "unknown"
@staticmethod
def _verify_backend(backend: str):
"""
Verifies if the backend is supported.
Parameters
----------
backend
Backend to be checked.
"""
# Check if pytorch/tensorflow backend supported.
if (backend == Framework.PYTORCH and not has_pytorch) or \
(backend == Framework.TENSORFLOW and not has_tensorflow):
raise ImportError(f'{backend} not installed. Cannot initialize and run the CounterfactualRL'
f' with {backend} backend.')
# Allow only pytorch and tensorflow.
elif backend not in [Framework.PYTORCH, Framework.TENSORFLOW]:
raise NotImplementedError(f'{backend} not implemented. Use `tensorflow` or `pytorch` instead.')
def _select_backend(self, backend: str, **kwargs):
"""
Selects the backend according to the `backend` flag.
Parameters
---------
backend
Deep learning backend: ``'tensorflow'`` | ``'pytorch'``. Default `tensorflow`.
**kwargs
Other arguments. Not used.
"""
return tensorflow_base_backend if backend == "tensorflow" else pytorch_base_backend
def _validate_kwargs(self,
predictor: Callable,
encoder: 'Union[tensorflow.keras.Model, torch.nn.Module]',
decoder: 'Union[tensorflow.keras.Model, torch.nn.Module]',
latent_dim: Optional[int],
coeff_sparsity: float,
coeff_consistency: float,
backend: str,
seed: int,
**kwargs):
"""
Validates arguments.
Parameters
----------
predictor.
A callable that takes a `numpy` array of `N` data points as inputs and returns `N` outputs.
encoder
Pretrained encoder network.
decoder
Pretrained decoder network.
latent_dim
Auto-encoder latent dimension.
coeff_sparsity
Sparsity loss coefficient.
coeff_consistency
Consistency loss coefficient.
backend
Deep learning backend: ``'tensorflow'`` | ``'pytorch'``.
**kwargs
Other arguments.
"""
# Copy default parameters.
params = deepcopy(DEFAULT_BASE_PARAMS)
# Update parameters with mandatory arguments
params.update({
"encoder": encoder,
"decoder": decoder,
"latent_dim": latent_dim,
"predictor": predictor,
"coeff_sparsity": coeff_sparsity,
"coeff_consistency": coeff_consistency,
"backend": backend,
"seed": seed,
})
# Add actor if not user-specified.
not_specified = {"actor": False, "critic": False}
if "actor" not in kwargs:
not_specified["actor"] = True
params["actor"] = self.backend.get_actor(hidden_dim=params["actor_hidden_dim"],
output_dim=params["latent_dim"])
if "critic" not in kwargs:
not_specified["critic"] = True
params["critic"] = self.backend.get_critic(hidden_dim=params["critic_hidden_dim"])
# Add optimizers if not user-specified.
optimizers = ["optimizer_actor", "optimizer_critic"]
for optim in optimizers:
# extract model in question
model_name = optim.split("_")[1]
model = params[model_name]
lr = params["lr_" + model_name]
# If the optimizer is user-specified, just update the params
if optim in kwargs:
params.update({optim: kwargs[optim]})
if self.params["backend"] == Framework.PYTORCH and not_specified[model_name]:
raise ValueError(f"Can not specify {optim} when {model_name} not specified for pytorch backend.")
# If the optimizer is not user-specified, it need to be initialized. The initialization is backend specific.
elif params['backend'] == Framework.TENSORFLOW:
params.update({optim: self.backend.get_optimizer(lr=lr)})
else:
params.update({optim: self.backend.get_optimizer(model=model, lr=lr)})
# Add sparsity loss if not user-specified.
params["sparsity_loss"] = self.backend.sparsity_loss if "sparsity_loss" not in kwargs \
else kwargs["sparsity_loss"]
# Add consistency loss if not user-specified.
params["consistency_loss"] = self.backend.consistency_loss if "consistency_loss" not in kwargs \
else kwargs["consistency_loss"]
# Validate arguments.
allowed_keys = set(params.keys())
provided_keys = set(kwargs.keys())
common_keys = allowed_keys & provided_keys
# Check if some provided keys are incorrect
if len(common_keys) < len(provided_keys):
incorrect_keys = ", ".join(provided_keys - common_keys)
logger.warning("The following keys are incorrect: " + incorrect_keys)
# Update default parameters and all parameters
params.update({key: kwargs[key] for key in common_keys})
return params
[docs]
@classmethod
def load(cls, path: Union[str, os.PathLike], predictor: Any) -> "Explainer":
return super().load(path, predictor)
[docs]
def reset_predictor(self, predictor: Any) -> None:
"""
Resets the predictor.
Parameters
----------
predictor
New predictor.
"""
self.params["predictor"] = predictor
self.meta["params"].update(CounterfactualRL._serialize_params(self.params))
[docs]
def save(self, path: Union[str, os.PathLike]) -> None:
super().save(path)
[docs]
def fit(self, X: np.ndarray) -> "Explainer":
"""
Fit the model agnostic counterfactual generator.
Parameters
----------
X
Training data array.
Returns
-------
self
The explainer itself.
"""
# Define boolean flag for initializing actor and critic network for Tensorflow backend.
initialize_actor_critic = False
# Define replay buffer (this will deal only with numpy arrays).
replay_buff = ReplayBuffer(size=self.params["replay_buffer_size"])
# Define noise variable.
noise = NormalActionNoise(mu=0, sigma=self.params["act_noise"])
# Define data generator.
data_generator = self.backend.data_generator(X=X, **self.params)
data_iter = iter(data_generator)
for step in tqdm(range(self.params["train_steps"])):
# Sample training data.
try:
data = next(data_iter)
except StopIteration:
if hasattr(data_generator, "on_epoch_end"):
# This is just for tensorflow backend.
data_generator.on_epoch_end()
data_iter = iter(data_generator)
data = next(data_iter)
# Add None condition if condition does not exist.
if "C" not in data:
data["C"] = None
# Compute input embedding.
Z = self.backend.encode(X=data["X"], **self.params)
data.update({"Z": Z})
# Compute counterfactual embedding.
Z_cf = self.backend.generate_cf(**data, **self.params)
data.update({"Z_cf": Z_cf})
# Add noise to the counterfactual embedding.
Z_cf_tilde = self.backend.add_noise(noise=noise, step=step, **data, **self.params)
data.update({"Z_cf_tilde": Z_cf_tilde})
# Decode noised counterfactual and apply postprocessing step to X_cf_tilde.
X_cf_tilde = self.backend.decode(Z=data["Z_cf_tilde"], **self.params)
for pp_func in self.params["postprocessing_funcs"]:
# Post-process noised counterfactual.
X_cf_tilde = pp_func(self.backend.to_numpy(X_cf_tilde),
self.backend.to_numpy(data["X"]),
self.backend.to_numpy(data["C"]))
data.update({"X_cf_tilde": X_cf_tilde})
# Compute model's prediction on the noised counterfactual
X_cf_tilde = self.params["decoder_inv_preprocessor"](self.backend.to_numpy(data["X_cf_tilde"]))
Y_m_cf_tilde = self.params["predictor"](X_cf_tilde)
# Compute reward.
R_tilde = self.params["reward_func"](self.backend.to_numpy(Y_m_cf_tilde),
self.backend.to_numpy(data["Y_t"]))
data.update({"R_tilde": R_tilde, "Y_m_cf_tilde": Y_m_cf_tilde})
# Store experience in the replay buffer.
data = {key: self.backend.to_numpy(data[key]) for key in data.keys()}
replay_buff.append(**data)
if step % self.params['update_every'] == 0 and step > self.params["update_after"]:
for i in range(self.params['update_every']):
# Sample batch of experience form the replay buffer.
sample = replay_buff.sample()
# Initialize actor and critic. This is required for tensorflow in order to reinitialize the
# explainer object and call fit multiple times. If the models are not reinitialized, the
# error: "tf.function-decorated function tried to create variables on non-first call" is raised.
# This is due to @tf.function and building the model for the first time in a compiled function
if not initialize_actor_critic and self.params["backend"] == Framework.TENSORFLOW:
self.backend.initialize_actor_critic(**sample, **self.params)
self.backend.initialize_optimizers(**sample, **self.params)
initialize_actor_critic = True
if "C" not in sample:
sample["C"] = None
# Decode counterfactual. This procedure has to be done here and not in the experience loop
# since the actor is updating but old experience is used. Thus, the decoding of the counterfactual
# will not correspond to the latest actor network. Remember that the counterfactual is used
# for the consistency loss. The counterfactual generation is performed here due to @tf.function
# which does not allow all post-processing functions.
Z_cf = self.backend.generate_cf(Z=self.backend.to_tensor(sample["Z"], **self.params),
Y_m=self.backend.to_tensor(sample["Y_m"], **self.params),
Y_t=self.backend.to_tensor(sample["Y_t"], **self.params),
C=self.backend.to_tensor(sample["C"], **self.params),
**self.params)
X_cf = self.backend.decode(Z=Z_cf, **self.params)
for pp_func in self.params["postprocessing_funcs"]:
# Post-process counterfactual.
X_cf = pp_func(self.backend.to_numpy(X_cf),
self.backend.to_numpy(sample["X"]),
self.backend.to_numpy(sample["C"]))
# Add counterfactual instance to the sample to be used in the update function for consistency loss
sample.update({"Z_cf": self.backend.to_numpy(Z_cf),
"X_cf": self.backend.to_numpy(X_cf)})
# Update critic by one-step gradient descent.
losses = self.backend.update_actor_critic(**sample, **self.params)
# Convert all losses from tensors to numpy arrays.
losses = {key: self.backend.to_numpy(losses[key]).item() for key in losses.keys()}
# Call all callbacks.
for callback in self.params["callbacks"]:
callback(step=step, update=i, model=self, sample=sample, losses=losses)
return self
@staticmethod
def _validate_target(Y_t: Optional[np.ndarray]):
"""
Validate the targets by checking the dimensions.
Parameters
----------
Y_t
Targets to be checked.
"""
if Y_t is None:
raise ValueError("Target can not be `None`.")
if len(Y_t.shape) not in [1, 2]:
raise ValueError(f"Target shape should be at least 1 and at most 2. Found {len(Y_t.shape)} instead.")
@staticmethod
def _validate_condition(C: Optional[np.ndarray]):
"""
Validate condition vector.
Parameters
----------
C
Condition vector.
"""
if (C is not None) and len(C.shape) != 2:
raise ValueError(f"Condition vector shape should be 2. Found {len(C.shape)} instead.")
@staticmethod
def _is_classification(pred: np.ndarray) -> bool:
"""
Check if the prediction task is classification by looking at the model's prediction shape.
Parameters
----------
pred
Model's prediction.
Returns
-------
``True`` if the prediction has shape of 2 and the second dimension bigger grater than 1. ``False`` otherwise.
"""
return len(pred.shape) == 2 and pred.shape[1] > 1
[docs]
def explain(self, # type: ignore[override]
X: np.ndarray,
Y_t: np.ndarray,
C: Optional[np.ndarray] = None,
batch_size: int = 100) -> Explanation:
"""
Explains an input instance
Parameters
----------
X
Instances to be explained.
Y_t
Counterfactual targets.
C
Conditional vectors. If ``None``, it means that no conditioning was used during training (i.e. the
`conditional_func` returns ``None``).
batch_size
Batch size to be used when generating counterfactuals.
Returns
-------
explanation
`Explanation` object containing the counterfactual with additional metadata as attributes. \
See usage at `CFRL examples`_ for details.
.. _CFRL examples:
https://docs.seldon.io/projects/alibi/en/stable/methods/CFRL.html
"""
# General validation.
self._validate_target(Y_t)
self._validate_condition(C)
# Check the number of target labels.
if Y_t.shape[0] != 1 and Y_t.shape[0] != X.shape[0]:
raise ValueError("The number target labels should be 1 or equals the number of samples in X.")
# Check the number of conditional vectors
if (C is not None) and C.shape[0] != 1 and C.shape[0] != X.shape[0]:
raise ValueError("The number of conditional vectors should be 1 or equals the number if samples in X.")
# Transform target into a 2D array.
Y_t = Y_t.reshape(Y_t.shape[0], -1)
# Repeat the same label to match the number of input instances.
if Y_t.shape[0] == 1:
Y_t = np.tile(Y_t, (X.shape[0], 1))
# Repeat the same conditional vectors to match the number of input instances.
if C is not None:
C = np.tile(C, (X.shape[0], 1))
# Perform prediction in mini-batches.
n_minibatch = int(np.ceil(X.shape[0] / batch_size))
all_results: Dict[str, Optional[np.ndarray]] = {}
for i in tqdm(range(n_minibatch)):
istart, istop = i * batch_size, min((i + 1) * batch_size, X.shape[0])
results = self._compute_counterfactual(X=X[istart:istop],
Y_t=Y_t[istart:istop],
C=C[istart:istop] if (C is not None) else C)
# Initialize the dict.
if not all_results:
all_results = results
continue
# Append the new batch off results.
for key in all_results:
all_results_val, results_val = all_results[key], results[key]
if all_results_val is not None and results_val is not None:
all_results[key] = np.concatenate([all_results_val, results_val], axis=0)
# see https://github.com/python/mypy/issues/5382 for the type ignore
return self._build_explanation(**all_results) # type: ignore[arg-type]
def _compute_counterfactual(self,
X: np.ndarray,
Y_t: np.ndarray,
C: Optional[np.ndarray] = None) -> Dict[str, Optional[np.ndarray]]: # TODO: TypedDict
"""
Compute counterfactual instance for a given input, target and condition vector.
Parameters
----------
X
Instances to be explained.
Y_t
Counterfactual targets.
C
Conditional vector. If ``None``, it means that no conditioning was used during training (i.e. the
`conditional_func` returns ``None``).
Returns
-------
Dictionary containing the input instances in the original format, input classification labels,
counterfactual instances in the original format, counterfactual classification labels, target labels,
conditional vectors.
"""
# Save original input for later usage.
X_orig = X
# Compute models prediction.
Y_m = self.params["predictor"](X_orig)
# Check if the prediction task is classification. Please refer to
# `alibi.explainers.backends.cfrl_base.CounterfactualRLDataset` for a justification.
if self._is_classification(pred=Y_m):
Y_m = get_hard_distribution(Y=Y_m, num_classes=Y_m.shape[1])
Y_t = get_hard_distribution(Y=Y_t, num_classes=Y_m.shape[1])
else:
Y_m = Y_m.reshape(-1, 1)
Y_t = Y_t.reshape(-1, 1)
# Apply autoencoder preprocessing step.
X = self.params["encoder_preprocessor"](X_orig)
# Convert to tensors.
X = self.backend.to_tensor(X, **self.params)
Y_m = self.backend.to_tensor(Y_m, **self.params)
Y_t = self.backend.to_tensor(Y_t, **self.params)
C = self.backend.to_tensor(C, **self.params)
# Encode instance.
Z = self.backend.encode(X, **self.params)
# Generate counterfactual embedding.
Z_cf = self.backend.generate_cf(Z, Y_m, Y_t, C, **self.params)
# Decode counterfactual.
X_cf = self.backend.decode(Z_cf, **self.params)
# Convert to numpy for postprocessing
X_cf = self.backend.to_numpy(X_cf)
X = self.backend.to_numpy(X)
C = self.backend.to_numpy(C)
# Apply postprocessing functions.
for pp_func in self.params["postprocessing_funcs"]:
X_cf = pp_func(X_cf, X, C)
# Apply inverse autoencoder pre-processor.
X_cf = self.params["decoder_inv_preprocessor"](X_cf)
# Classify counterfactual instances.
Y_m_cf = self.params["predictor"](X_cf)
# Convert tensors to numpy.
Y_m = self.backend.to_numpy(Y_m)
Y_t = self.backend.to_numpy(Y_t)
# If the prediction is a classification task.
if self._is_classification(pred=Y_m):
Y_m = np.argmax(Y_m, axis=1)
Y_t = np.argmax(Y_t, axis=1)
Y_m_cf = np.argmax(Y_m_cf, axis=1)
return {
"X": X_orig, # input instances
"Y_m": Y_m, # input classification labels
"X_cf": X_cf, # counterfactual instances
"Y_m_cf": Y_m_cf, # counterfactual classification labels
"Y_t": Y_t, # target labels
"C": C # conditional vectors
}
def _build_explanation(self,
X: np.ndarray,
Y_m: np.ndarray,
X_cf: np.ndarray,
Y_m_cf: np.ndarray,
Y_t: np.ndarray,
C: Optional[np.ndarray]) -> Explanation:
"""
Builds the explanation of the current object.
Parameters
----------
X
Inputs instance in the original format.
Y_m
Inputs classification labels.
X_cf
Counterfactuals instances in the original format.
Y_m_cf
Counterfactuals classification labels.
Y_t
Target labels.
C
Condition vector. If ``None``, it means that no conditioning was used during training (i.e. the
`conditional_func` returns ``None``).
Returns
-------
`Explanation` object containing the inputs with the corresponding labels, the counterfactuals with the \
corresponding labels, targets and additional metadata.
"""
data = deepcopy(DEFAULT_DATA_CFRL)
# update original input entrance
data["orig"] = {}
data["orig"].update({"X": X, "class": Y_m.reshape(-1, 1)})
# update counterfactual entrance
data["cf"] = {}
data["cf"].update({"X": X_cf, "class": Y_m_cf.reshape(-1, 1)})
# update target and condition
data["target"] = Y_t.reshape(-1, 1)
data["condition"] = C
return Explanation(meta=self.meta, data=data)
[docs]
class Postprocessing(ABC):
[docs]
@abstractmethod
def __call__(self, X_cf: Any, X: np.ndarray, C: Optional[np.ndarray]) -> Any:
"""
Post-processing function
Parameters
----------
X_cf
Counterfactual instance. The datatype depends on the output of the decoder. For example, for an image
dataset, the output is ``np.ndarray``. For a tabular dataset, the output is ``List[np.ndarray]`` where each
element of the list corresponds to a feature. This corresponds to the decoder's output from the
heterogeneous autoencoder (see :py:class:`alibi.models.tensorflow.autoencoder.HeAE` and
:py:class:`alibi.models.pytorch.autoencoder.HeAE`).
X
Input instance.
C
Conditional vector. If ``None``, it means that no conditioning was used during training (i.e. the
`conditional_func` returns ``None``).
Returns
-------
X_cf
Post-processed `X_cf`.
"""
pass
[docs]
class Callback(ABC):
""" Training callback class. """
[docs]
@abstractmethod
def __call__(self,
step: int,
update: int,
model: CounterfactualRL,
sample: Dict[str, np.ndarray],
losses: Dict[str, float]) -> None:
"""
Training callback applied after every training step.
Parameters
-----------
step
Current experience step.
update
Current update step. The ration between the number experience steps and the number of training updates is
bound to 1.
model
CounterfactualRL explainer. All the parameters defined in
:py:data:`alibi.explainers.cfrl_base.DEFAULT_BASE_PARAMS` can be accessed through `model.params`.
sample
Dictionary of samples used for an update which contains
- ``'X'`` : ``np.ndarray`` - input instances.
- ``'Y_m'`` : ``np.ndarray`` - predictor outputs for the input instances.
- ``'Y_t'`` : ``np.ndarray`` - target outputs.
- ``'Z'`` : ``np.ndarray`` - input embeddings.
- ``'Z_cf_tilde'`` : ``np.ndarray`` - noised counterfactual embeddings.
- ``'X_cf_tilde'`` : ``np.ndarray`` - noised counterfactual instances obtained ofter decoding the \
noised counterfactual embeddings `Z_cf_tilde` and apply post-processing functions.
- ``'C'`` : ``Optional[np.ndarray]`` - conditional vector.
- ``'R_tilde'`` : ``np.ndarray`` - reward obtained for the noised counterfactual instances.
- ``'Z_cf'`` : ``np.ndarray`` - counterfactual embeddings.
- ``'X_cf'`` : ``np.ndarray`` - counterfactual instances obtained after decoding the counterfactual \
embeddings `Z_cf` and apply post-processing functions.
losses
Dictionary of losses which contains
- ``'loss_actor'`` : ``Callable`` - actor network loss.
- ``'loss_critic'`` : ``Callable`` - critic network loss.
- ``'sparsity_loss'`` : ``Callable`` - sparsity loss for the \
:py:class:`alibi.explainers.cfrl_base.CounterfactualRL` class.
- ``'sparsity_num_loss'`` : ``Callable`` - numerical features sparsity loss for the \
:py:class:`alibi.explainers.cfrl_tabular.CounterfactualRLTabular` class.
- ``'sparsity_cat_loss'`` : ``Callable`` - categorical features sparsity loss for the \
:py:class:`alibi.explainers.cfrl_tabular.CounterfactualRLTabular` class.
- ``'consistency_loss'`` : ``Callable`` - consistency loss if used.
"""
pass