import copy
import logging
from typing import Callable, Optional, Tuple, Union
import numpy as np
import tensorflow.compat.v1 as tf
from alibi.api.defaults import DEFAULT_DATA_CF, DEFAULT_META_CF
from alibi.api.interfaces import Explainer, Explanation
from alibi.utils.gradients import num_grad_batch
logger = logging.getLogger(__name__)
def _define_func(predict_fn: Callable,
pred_class: int,
target_class: Union[str, int] = 'same') -> Tuple[Callable, Union[str, int]]:
# TODO: convert to batchwise function
Define the class-specific prediction function to be used in the optimization.
Classifier prediction function.
Predicted class of the instance to be explained.
Target class of the explanation, one of ``'same'``, ``'other'`` or an integer class.
Class-specific prediction function and the target class used.
if target_class == 'other':
# TODO: need to optimize this
def func(X):
probas = predict_fn(X)
sorted = np.argsort(-probas) # class indices in decreasing order of probability
# take highest probability class different from class predicted for X
if sorted[0, 0] == pred_class:
target_class = sorted[0, 1]
# logger.debug('Target class equals predicted class')
target_class = sorted[0, 0]
# logger.debug('Current best target class: %s', target_class)
return (predict_fn(X)[:, target_class]).reshape(-1, 1)
return func, target_class
elif target_class == 'same':
target_class = pred_class
def func(X): # type: ignore
return (predict_fn(X)[:, target_class]).reshape(-1, 1)
return func, target_class
def CounterFactual(*args, **kwargs):
The class name `CounterFactual` is deprecated, please use `Counterfactual`.
# TODO: remove this function in an upcoming release
warning_msg = 'The class name `CounterFactual` is deprecated, please use `Counterfactual`.'
import warnings
warnings.warn(warning_msg, FutureWarning)
return Counterfactual(*args, **kwargs)
class Counterfactual(Explainer):
def __init__(self,
predict_fn: Union[Callable[[np.ndarray], np.ndarray], tf.keras.Model],
shape: Tuple[int, ...],
distance_fn: str = 'l1',
target_proba: float = 1.0,
target_class: Union[str, int] = 'other',
max_iter: int = 1000,
early_stop: int = 50,
lam_init: float = 1e-1,
max_lam_steps: int = 10,
tol: float = 0.05,
feature_range: Union[Tuple, str] = (-1e10, 1e10),
eps: Union[float, np.ndarray] = 0.01, # feature-wise epsilons
init: str = 'identity',
decay: bool = True,
write_dir: Optional[str] = None,
debug: bool = False,
sess: Optional[tf.Session] = None) -> None:
Initialize counterfactual explanation method based on Wachter et al. (2017)
`tensorflow` model or any other model's prediction function returning class probabilities.
Shape of input data starting with batch size.
Distance function to use in the loss term.
Target probability for the counterfactual to reach.
Target class for the counterfactual to reach, one of ``'other'``, ``'same'`` or an integer denoting
desired class membership for the counterfactual instance.
Maximum number of iterations to run the gradient descent for (inner loop).
Number of steps after which to terminate gradient descent if all or none of found instances are solutions.
Initial regularization constant for the prediction part of the Wachter loss.
Maximum number of times to adjust the regularization constant (outer loop) before terminating the search.
Tolerance for the counterfactual target probability.
Initial learning rate for each outer loop of `lambda`.
Tuple with `min` and `max` ranges to allow for perturbed instances. `Min` and `max` ranges can be `float`
or `numpy` arrays with dimension (1 x nb of features) for feature-wise ranges.
Gradient step sizes used in calculating numerical gradients, defaults to a single value for all
features, but can be passed an array for feature-wise step sizes.
Initialization method for the search of counterfactuals, currently must be ``'identity'``.
Flag to decay learning rate to zero for each outer loop over lambda.
Directory to write `tensorboard` files to.
Flag to write `tensorboard` summaries for debugging.
Optional `tensorflow` session that will be used if passed instead of creating or inferring one internally.
# get params for storage in meta
params = locals()
remove = ['self', 'predict_fn', 'sess', '__class__']
for key in remove:
self.data_shape = shape
self.batch_size = shape[0]
self.target_class = target_class
# options for the optimizer
self.max_iter = max_iter
self.lam_init = lam_init
self.tol = tol
self.max_lam_steps = max_lam_steps
self.early_stop = early_stop
self.eps = eps
self.init = init
self.feature_range = feature_range
self.target_proba_arr = target_proba * np.ones(self.batch_size)
self.debug = debug
# check if the passed object is a model and get session
is_model = isinstance(predict_fn, tf.keras.Model)
model_sess = tf.compat.v1.keras.backend.get_session()
# if session provided, use it
if isinstance(sess, tf.Session):
self.sess = sess
self.sess = model_sess
if is_model: # Keras or TF model
self.model = True
self.predict_fn = predict_fn.predict # type: ignore # array function
self.predict_tn = predict_fn # tensor function
else: # black-box model
self.predict_fn = predict_fn
self.predict_tn = None
self.model = False
self.n_classes = self.predict_fn(np.zeros(shape)).shape[1]
# flag to keep track if explainer is fit or not
self.fitted = False
# set up graph session for optimization (counterfactual search)
with tf.variable_scope('cf_search', reuse=tf.AUTO_REUSE):
# define variables for original and candidate counterfactual instances, target labels and lambda
self.orig = tf.get_variable('original', shape=shape, dtype=tf.float32) = tf.get_variable('counterfactual', shape=shape,
constraint=lambda x: tf.clip_by_value(x, feature_range[0], feature_range[1]))
# the following will be a 1-hot encoding of the target class (as predicted by the model) = tf.get_variable('target', shape=(self.batch_size, self.n_classes), dtype=tf.float32)
# constant target probability and global step variable
self.target_proba = tf.constant(target_proba * np.ones(self.batch_size), dtype=tf.float32,
self.global_step = tf.Variable(0.0, trainable=False, name='global_step')
# lambda hyperparameter - placeholder instead of variable as annealed in first epoch
self.lam = tf.placeholder(tf.float32, shape=(self.batch_size), name='lam')
# define placeholders that will be assigned to relevant variables
self.assign_orig = tf.placeholder(tf.float32, shape, name='assing_orig')
self.assign_cf = tf.placeholder(tf.float32, shape, name='assign_cf')
self.assign_target = tf.placeholder(tf.float32, shape=(self.batch_size, self.n_classes),
# L1 distance and MAD constants
ax_sum = list(np.arange(1, len(self.data_shape)))
if distance_fn == 'l1':
self.dist = tf.reduce_sum(tf.abs( - self.orig), axis=ax_sum, name='l1')
logger.exception('Distance metric %s not supported', distance_fn)
raise ValueError
# distance loss
self.loss_dist = self.lam * self.dist
# prediction loss
if not self.model:
# will need to calculate gradients numerically
self.loss_opt = self.loss_dist
# autograd gradients throughout
self.pred_proba = self.predict_tn(
# 3 cases for target_class
if target_class == 'same':
self.pred_proba_class = tf.reduce_max( * self.pred_proba, 1)
elif target_class == 'other':
self.pred_proba_class = tf.reduce_max((1 - * self.pred_proba, 1)
elif target_class in range(self.n_classes):
# if class is specified, this is known in advance
self.pred_proba_class = tf.reduce_max(tf.one_hot(target_class, self.n_classes, dtype=tf.float32)
* self.pred_proba, 1)
logger.exception('Target class %s unknown', target_class)
raise ValueError
self.loss_pred = tf.square(self.pred_proba_class - self.target_proba)
self.loss_opt = self.loss_pred + self.loss_dist
# optimizer
if decay:
self.learning_rate = tf.train.polynomial_decay(learning_rate_init, self.global_step,
self.max_iter, 0.0, power=1.0)
self.learning_rate = tf.convert_to_tensor(learning_rate_init)
# TODO optional argument to change type, learning rate scheduler
opt = tf.train.AdamOptimizer(self.learning_rate)
# first compute gradients, then apply them
self.compute_grads = opt.compute_gradients(self.loss_opt, var_list=[])
self.grad_ph = tf.placeholder(shape=shape, dtype=tf.float32, name='grad_cf')
grad_and_var = [(self.grad_ph,]
self.apply_grads = opt.apply_gradients(grad_and_var, global_step=self.global_step)
# variables to initialize
self.setup: list = []
self.tf_init = tf.variables_initializer(var_list=tf.global_variables(scope='cf_search'))
# tensorboard
if write_dir is not None:
self.writer = tf.summary.FileWriter(write_dir, tf.get_default_graph())
# return templates
self.instance_dict = dict.fromkeys(['X', 'distance', 'lambda', 'index', 'class', 'proba', 'loss'])
self.return_dict = copy.deepcopy(DEFAULT_DATA_CF)
self.return_dict['all'] = {i: [] for i in range(self.max_lam_steps)}
def _initialize(self, X: np.ndarray) -> np.ndarray:
# TODO initialization strategies ("same", "random", "from_train")
if self.init == 'identity':
X_init = X
logger.debug('Initializing search at the test point X')
raise ValueError('Initialization method should be "identity"')
return X_init
def fit(self,
X: np.ndarray,
y: Optional[np.ndarray]) -> "Counterfactual":
Fit method - currently unused as the counterfactual search is fully unsupervised.
Not used. Included for consistency.
Not used. Included for consistency.
Explainer itself.
# TODO feature ranges, epsilons and MADs
self.fitted = True
return self
def explain(self, X: np.ndarray) -> Explanation:
Explain an instance and return the counterfactual with metadata.
Instance to be explained.
`Explanation` object containing the counterfactual with additional metadata as attributes.
See usage at `Counterfactual examples`_ for details.
.. _Counterfactual examples:
# TODO change init parameters on the fly
if X.shape[0] != 1:
logger.warning('Currently only single instance explanations supported (first dim = 1), '
'but first dim = %s', X.shape[0])
# make a prediction
Y = self.predict_fn(X)
pred_class = Y.argmax(axis=1).item()
pred_prob = Y.max(axis=1).item()
self.return_dict['orig_class'] = pred_class
self.return_dict['orig_proba'] = pred_prob
logger.debug('Initial prediction: %s with p=%s', pred_class, pred_prob)
# define the class-specific prediction function
self.predict_class_fn, t_class = _define_func(self.predict_fn, pred_class, self.target_class)
# initialize with an instance
X_init = self._initialize(X)
# minimize loss iteratively
self._minimize_loss(X, X_init, Y)
return_dict = self.return_dict.copy()
self.instance_dict = dict.fromkeys(['X', 'distance', 'lambda', 'index', 'class', 'proba', 'loss'])
self.return_dict = {'cf': None, 'all': {i: [] for i in range(self.max_lam_steps)}, 'orig_class': None,
'orig_proba': None}
# create explanation object
explanation = Explanation(meta=copy.deepcopy(self.meta), data=return_dict)
return explanation
def _prob_condition(self, X_current):
return np.abs(self.predict_class_fn(X_current) - self.target_proba_arr) <= self.tol
def _update_exp(self, i, l_step, lam, cf_found, X_current):
cf_found[0][l_step] += 1 # TODO: batch support
dist =
# populate the return dict
self.instance_dict['X'] = X_current
self.instance_dict['distance'] = dist
self.instance_dict['lambda'] = lam[0]
self.instance_dict['index'] = l_step * self.max_iter + i
preds = self.predict_fn(X_current)
pred_class = preds.argmax()
proba = preds.max()
self.instance_dict['class'] = pred_class
self.instance_dict['proba'] = preds
self.instance_dict['loss'] = (proba - self.target_proba_arr[0]) ** 2 + lam[0] * dist
# update best CF if it has a smaller distance
if self.return_dict['cf'] is None:
self.return_dict['cf'] = self.instance_dict.copy()
elif dist < self.return_dict['cf']['distance']:
self.return_dict['cf'] = self.instance_dict.copy()
logger.debug('CF found at step %s', l_step * self.max_iter + i)
def _write_tb(self, lam, lam_lb, lam_ub, cf_found, X_current, **kwargs):
if self.model:
scalars_tf = [self.global_step, self.learning_rate, self.dist[0],
self.loss_pred[0], self.loss_opt[0], self.pred_proba_class[0]]
gs, lr, dist, loss_pred, loss_opt, pred =, feed_dict={self.lam: lam})
scalars_tf = [self.global_step, self.learning_rate, self.dist[0],
gs, lr, dist, loss_opt =, feed_dict={self.lam: lam})
loss_pred = kwargs['loss_pred']
pred = kwargs['pred']
found = kwargs['found']
not_found = kwargs['not_found']
except KeyError:
found = 0
not_found = 0
summary = tf.Summary()
summary.value.add(tag='lr/global_step', simple_value=gs)
summary.value.add(tag='lr/lr', simple_value=lr)
summary.value.add(tag='lambda/lambda', simple_value=lam[0])
summary.value.add(tag='lambda/l_bound', simple_value=lam_lb[0])
summary.value.add(tag='lambda/u_bound', simple_value=lam_ub[0])
summary.value.add(tag='losses/dist', simple_value=dist)
summary.value.add(tag='losses/loss_pred', simple_value=loss_pred)
summary.value.add(tag='losses/loss_opt', simple_value=loss_opt)
summary.value.add(tag='losses/pred_div_dist', simple_value=loss_pred / (lam[0] * dist))
summary.value.add(tag='Y/pred_proba_class', simple_value=pred)
summary.value.add(tag='Y/pred_class_fn(X_current)', simple_value=self.predict_class_fn(X_current))
summary.value.add(tag='Y/n_cf_found', simple_value=cf_found[0].sum())
summary.value.add(tag='Y/found', simple_value=found)
summary.value.add(tag='Y/not_found', simple_value=not_found)
def _bisect_lambda(self, cf_found, l_step, lam, lam_lb, lam_ub):
for batch_idx in range(self.batch_size): # TODO: batch not supported
if cf_found[batch_idx][l_step] >= 5: # minimum number of CF instances to warrant increasing lambda
# want to improve the solution by putting more weight on the distance term TODO: hyperparameter?
# by increasing lambda
lam_lb[batch_idx] = max(lam[batch_idx], lam_lb[batch_idx])
logger.debug('Lambda bounds: (%s, %s)', lam_lb[batch_idx], lam_ub[batch_idx])
if lam_ub[batch_idx] < 1e9:
lam[batch_idx] = (lam_lb[batch_idx] + lam_ub[batch_idx]) / 2
lam[batch_idx] *= 10
logger.debug('Changed lambda to %s', lam[batch_idx])
elif cf_found[batch_idx][l_step] < 5:
# if not enough solutions found so far, decrease lambda by a factor of 10,
# otherwise bisect up to the last known successful lambda
lam_ub[batch_idx] = min(lam_ub[batch_idx], lam[batch_idx])
logger.debug('Lambda bounds: (%s, %s)', lam_lb[batch_idx], lam_ub[batch_idx])
if lam_lb[batch_idx] > 0:
lam[batch_idx] = (lam_lb[batch_idx] + lam_ub[batch_idx]) / 2
logger.debug('Changed lambda to %s', lam[batch_idx])
lam[batch_idx] /= 10
return lam, lam_lb, lam_ub
def _minimize_loss(self,
X: np.ndarray,
X_init: np.ndarray,
Y: np.ndarray) -> None:
# keep track of the number of CFs found for each lambda in outer loop
cf_found = np.zeros((self.batch_size, self.max_lam_steps))
# set the lower and upper bound for lamda to scale the distance loss term
lam_lb = np.zeros(self.batch_size)
lam_ub = np.ones(self.batch_size) * 1e10
# make a one-hot vector of targets
Y_ohe = np.zeros(Y.shape)
np.put(Y_ohe, np.argmax(Y, axis=1), 1)
# on first run estimate lambda bounds
n_orders = 10
n_steps = self.max_iter // n_orders
lams = np.array([self.lam_init / 10 ** i for i in range(n_orders)]) # exponential decay
cf_count = np.zeros_like(lams)
logger.debug('Initial lambda sweep: %s', lams)
X_current = X_init
# TODO this whole initial loop should be optional?
for ix, l_step in enumerate(lams):
lam = np.ones(self.batch_size) * l_step, {self.assign_orig: X,
self.assign_cf: X_current,
self.assign_target: Y_ohe})
for i in range(n_steps):
# numerical gradients
grads_num = np.zeros(self.data_shape)
if not self.model:
pred = self.predict_class_fn(X_current)
prediction_grad = num_grad_batch(self.predict_class_fn, X_current, eps=self.eps)
# squared difference prediction loss
loss_pred = (pred - self.target_proba.eval(session=self.sess)) ** 2
grads_num = 2 * (pred - self.target_proba.eval(session=self.sess)) * prediction_grad
grads_num = grads_num.reshape(self.data_shape) # TODO? correct?
# add values to tensorboard (1st item in batch only) every n steps
if self.debug and not i % 50:
if not self.model:
self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current, loss_pred=loss_pred, pred=pred)
self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current)
# compute graph gradients
grads_vars_graph =, feed_dict={self.lam: lam})
grads_graph = [g for g, _ in grads_vars_graph][0]
# apply gradients
gradients = grads_graph + grads_num, feed_dict={self.grad_ph: gradients, self.lam: lam})
# does the counterfactual condition hold?
X_current =
cond = self._prob_condition(X_current).squeeze()
if cond:
cf_count[ix] += 1
# find the lower bound
logger.debug('cf_count: %s', cf_count)
lb_ix = np.where(cf_count > 0)[0][1] # take the second order of magnitude with some CFs as lower-bound
# TODO robust?
except IndexError:
logger.error('No appropriate lambda range found, try decreasing lam_init')
lam_lb = np.ones(self.batch_size) * lams[lb_ix]
# find the upper bound
ub_ix = np.where(cf_count == 0)[0][-1] # TODO is 0 robust?
except IndexError:
ub_ix = 0
logger.debug('Could not find upper bound for lambda where no solutions found, setting upper bound to '
'lam_init=%s', lams[ub_ix])
lam_ub = np.ones(self.batch_size) * lams[ub_ix]
# start the search in the middle
lam = (lam_lb + lam_ub) / 2
logger.debug('Found upper and lower bounds: %s, %s', lam_lb[0], lam_ub[0])
# on subsequent runs bisect lambda within the bounds found initially
X_current = X_init
for l_step in range(self.max_lam_steps):
# assign variables for the current iteration, {self.assign_orig: X,
self.assign_cf: X_current,
self.assign_target: Y_ohe})
found, not_found = 0, 0
# number of gradient descent steps in each inner loop
for i in range(self.max_iter):
# numerical gradients
grads_num = np.zeros(self.data_shape)
if not self.model:
pred = self.predict_class_fn(X_current)
prediction_grad = num_grad_batch(self.predict_class_fn, X_current, eps=self.eps)
# squared difference prediction loss
loss_pred = (pred - self.target_proba.eval(session=self.sess)) ** 2
grads_num = 2 * (pred - self.target_proba.eval(session=self.sess)) * prediction_grad
grads_num = grads_num.reshape(self.data_shape)
# add values to tensorboard (1st item in batch only) every n steps
if self.debug and not i % 50:
if not self.model:
self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current, found=found, not_found=not_found,
loss_pred=loss_pred, pred=pred)
self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current, found=found, not_found=not_found)
# compute graph gradients
grads_vars_graph =, feed_dict={self.lam: lam})
grads_graph = [g for g, _ in grads_vars_graph][0]
# apply gradients
gradients = grads_graph + grads_num, feed_dict={self.grad_ph: gradients, self.lam: lam})
# does the counterfactual condition hold?
X_current =
cond = self._prob_condition(X_current)
if cond:
self._update_exp(i, l_step, lam, cf_found, X_current)
found += 1
not_found = 0
found = 0
not_found += 1
# early stopping criterion - if no solutions or enough solutions found, change lambda
if found >= self.early_stop or not_found >= self.early_stop:
# adjust the lambda constant via bisection at the end of the outer loop
self._bisect_lambda(cf_found, l_step, lam, lam_lb, lam_ub)
self.return_dict['success'] = True
def reset_predictor(self, predictor: Union[Callable, tf.keras.Model]) -> None:
Resets the predictor function/model.
New predictor function/model.
raise NotImplementedError('Resetting a predictor is currently not supported')