import copy
import logging
from typing import Callable, Optional, Tuple, Union
import numpy as np
import tensorflow.compat.v1 as tf
from alibi.api.defaults import DEFAULT_DATA_CF, DEFAULT_META_CF
from alibi.api.interfaces import Explainer, Explanation
from alibi.utils.gradients import num_grad_batch
logger = logging.getLogger(__name__)
def _define_func(predict_fn: Callable,
pred_class: int,
target_class: Union[str, int] = 'same') -> Tuple[Callable, Union[str, int]]:
# TODO: convert to batchwise function
"""
Define the class-specific prediction function to be used in the optimization.
Parameters
----------
predict_fn
Classifier prediction function.
pred_class
Predicted class of the instance to be explained.
target_class
Target class of the explanation, one of ``'same'``, ``'other'`` or an integer class.
Returns
-------
Class-specific prediction function and the target class used.
"""
if target_class == 'other':
# TODO: need to optimize this
def func(X):
probas = predict_fn(X)
sorted = np.argsort(-probas) # class indices in decreasing order of probability
# take highest probability class different from class predicted for X
if sorted[0, 0] == pred_class:
target_class = sorted[0, 1]
# logger.debug('Target class equals predicted class')
else:
target_class = sorted[0, 0]
# logger.debug('Current best target class: %s', target_class)
return (predict_fn(X)[:, target_class]).reshape(-1, 1)
return func, target_class
elif target_class == 'same':
target_class = pred_class
def func(X): # type: ignore
return (predict_fn(X)[:, target_class]).reshape(-1, 1)
return func, target_class
[docs]
def CounterFactual(*args, **kwargs):
"""
The class name `CounterFactual` is deprecated, please use `Counterfactual`.
"""
# TODO: remove this function in an upcoming release
warning_msg = 'The class name `CounterFactual` is deprecated, please use `Counterfactual`.'
import warnings
warnings.warn(warning_msg, FutureWarning)
return Counterfactual(*args, **kwargs)
[docs]
class Counterfactual(Explainer):
[docs]
def __init__(self,
predict_fn: Union[Callable[[np.ndarray], np.ndarray], tf.keras.Model],
shape: Tuple[int, ...],
distance_fn: str = 'l1',
target_proba: float = 1.0,
target_class: Union[str, int] = 'other',
max_iter: int = 1000,
early_stop: int = 50,
lam_init: float = 1e-1,
max_lam_steps: int = 10,
tol: float = 0.05,
learning_rate_init=0.1,
feature_range: Union[Tuple, str] = (-1e10, 1e10),
eps: Union[float, np.ndarray] = 0.01, # feature-wise epsilons
init: str = 'identity',
decay: bool = True,
write_dir: Optional[str] = None,
debug: bool = False,
sess: Optional[tf.Session] = None) -> None:
"""
Initialize counterfactual explanation method based on Wachter et al. (2017)
Parameters
----------
predict_fn
`tensorflow` model or any other model's prediction function returning class probabilities.
shape
Shape of input data starting with batch size.
distance_fn
Distance function to use in the loss term.
target_proba
Target probability for the counterfactual to reach.
target_class
Target class for the counterfactual to reach, one of ``'other'``, ``'same'`` or an integer denoting
desired class membership for the counterfactual instance.
max_iter
Maximum number of iterations to run the gradient descent for (inner loop).
early_stop
Number of steps after which to terminate gradient descent if all or none of found instances are solutions.
lam_init
Initial regularization constant for the prediction part of the Wachter loss.
max_lam_steps
Maximum number of times to adjust the regularization constant (outer loop) before terminating the search.
tol
Tolerance for the counterfactual target probability.
learning_rate_init
Initial learning rate for each outer loop of `lambda`.
feature_range
Tuple with `min` and `max` ranges to allow for perturbed instances. `Min` and `max` ranges can be `float`
or `numpy` arrays with dimension (1 x nb of features) for feature-wise ranges.
eps
Gradient step sizes used in calculating numerical gradients, defaults to a single value for all
features, but can be passed an array for feature-wise step sizes.
init
Initialization method for the search of counterfactuals, currently must be ``'identity'``.
decay
Flag to decay learning rate to zero for each outer loop over lambda.
write_dir
Directory to write `tensorboard` files to.
debug
Flag to write `tensorboard` summaries for debugging.
sess
Optional `tensorflow` session that will be used if passed instead of creating or inferring one internally.
"""
super().__init__(meta=copy.deepcopy(DEFAULT_META_CF))
# get params for storage in meta
params = locals()
remove = ['self', 'predict_fn', 'sess', '__class__']
for key in remove:
params.pop(key)
self.meta['params'].update(params)
self.data_shape = shape
self.batch_size = shape[0]
self.target_class = target_class
# options for the optimizer
self.max_iter = max_iter
self.lam_init = lam_init
self.tol = tol
self.max_lam_steps = max_lam_steps
self.early_stop = early_stop
self.eps = eps
self.init = init
self.feature_range = feature_range
self.target_proba_arr = target_proba * np.ones(self.batch_size)
self.debug = debug
# check if the passed object is a model and get session
is_model = isinstance(predict_fn, tf.keras.Model)
model_sess = tf.compat.v1.keras.backend.get_session()
self.meta['params'].update(is_model=is_model)
# if session provided, use it
if isinstance(sess, tf.Session):
self.sess = sess
else:
self.sess = model_sess
if is_model: # Keras or TF model
self.model = True
self.predict_fn = predict_fn.predict # type: ignore # array function
self.predict_tn = predict_fn # tensor function
else: # black-box model
self.predict_fn = predict_fn
self.predict_tn = None
self.model = False
self.n_classes = self.predict_fn(np.zeros(shape)).shape[1]
# flag to keep track if explainer is fit or not
self.fitted = False
# set up graph session for optimization (counterfactual search)
with tf.variable_scope('cf_search', reuse=tf.AUTO_REUSE):
# define variables for original and candidate counterfactual instances, target labels and lambda
self.orig = tf.get_variable('original', shape=shape, dtype=tf.float32)
self.cf = tf.get_variable('counterfactual', shape=shape,
dtype=tf.float32,
constraint=lambda x: tf.clip_by_value(x, feature_range[0], feature_range[1]))
# the following will be a 1-hot encoding of the target class (as predicted by the model)
self.target = tf.get_variable('target', shape=(self.batch_size, self.n_classes), dtype=tf.float32)
# constant target probability and global step variable
self.target_proba = tf.constant(target_proba * np.ones(self.batch_size), dtype=tf.float32,
name='target_proba')
self.global_step = tf.Variable(0.0, trainable=False, name='global_step')
# lambda hyperparameter - placeholder instead of variable as annealed in first epoch
self.lam = tf.placeholder(tf.float32, shape=(self.batch_size), name='lam')
# define placeholders that will be assigned to relevant variables
self.assign_orig = tf.placeholder(tf.float32, shape, name='assing_orig')
self.assign_cf = tf.placeholder(tf.float32, shape, name='assign_cf')
self.assign_target = tf.placeholder(tf.float32, shape=(self.batch_size, self.n_classes),
name='assign_target')
# L1 distance and MAD constants
# TODO: MADs?
ax_sum = list(np.arange(1, len(self.data_shape)))
if distance_fn == 'l1':
self.dist = tf.reduce_sum(tf.abs(self.cf - self.orig), axis=ax_sum, name='l1')
else:
logger.exception('Distance metric %s not supported', distance_fn)
raise ValueError
# distance loss
self.loss_dist = self.lam * self.dist
# prediction loss
if not self.model:
# will need to calculate gradients numerically
self.loss_opt = self.loss_dist
else:
# autograd gradients throughout
self.pred_proba = self.predict_tn(self.cf)
# 3 cases for target_class
if target_class == 'same':
self.pred_proba_class = tf.reduce_max(self.target * self.pred_proba, 1)
elif target_class == 'other':
self.pred_proba_class = tf.reduce_max((1 - self.target) * self.pred_proba, 1)
elif target_class in range(self.n_classes):
# if class is specified, this is known in advance
self.pred_proba_class = tf.reduce_max(tf.one_hot(target_class, self.n_classes, dtype=tf.float32)
* self.pred_proba, 1)
else:
logger.exception('Target class %s unknown', target_class)
raise ValueError
self.loss_pred = tf.square(self.pred_proba_class - self.target_proba)
self.loss_opt = self.loss_pred + self.loss_dist
# optimizer
if decay:
self.learning_rate = tf.train.polynomial_decay(learning_rate_init, self.global_step,
self.max_iter, 0.0, power=1.0)
else:
self.learning_rate = tf.convert_to_tensor(learning_rate_init)
# TODO optional argument to change type, learning rate scheduler
opt = tf.train.AdamOptimizer(self.learning_rate)
# first compute gradients, then apply them
self.compute_grads = opt.compute_gradients(self.loss_opt, var_list=[self.cf])
self.grad_ph = tf.placeholder(shape=shape, dtype=tf.float32, name='grad_cf')
grad_and_var = [(self.grad_ph, self.cf)]
self.apply_grads = opt.apply_gradients(grad_and_var, global_step=self.global_step)
# variables to initialize
self.setup: list = []
self.setup.append(self.orig.assign(self.assign_orig))
self.setup.append(self.cf.assign(self.assign_cf))
self.setup.append(self.target.assign(self.assign_target))
self.tf_init = tf.variables_initializer(var_list=tf.global_variables(scope='cf_search'))
# tensorboard
if write_dir is not None:
self.writer = tf.summary.FileWriter(write_dir, tf.get_default_graph())
self.writer.add_graph(tf.get_default_graph())
# return templates
self.instance_dict = dict.fromkeys(['X', 'distance', 'lambda', 'index', 'class', 'proba', 'loss'])
self.return_dict = copy.deepcopy(DEFAULT_DATA_CF)
self.return_dict['all'] = {i: [] for i in range(self.max_lam_steps)}
def _initialize(self, X: np.ndarray) -> np.ndarray:
# TODO initialization strategies ("same", "random", "from_train")
if self.init == 'identity':
X_init = X
logger.debug('Initializing search at the test point X')
else:
raise ValueError('Initialization method should be "identity"')
return X_init
[docs]
def fit(self,
X: np.ndarray,
y: Optional[np.ndarray]) -> "Counterfactual":
"""
Fit method - currently unused as the counterfactual search is fully unsupervised.
Parameters
----------
X
Not used. Included for consistency.
y
Not used. Included for consistency.
Returns
-------
self
Explainer itself.
"""
# TODO feature ranges, epsilons and MADs
self.fitted = True
return self
[docs]
def explain(self, X: np.ndarray) -> Explanation:
"""
Explain an instance and return the counterfactual with metadata.
Parameters
----------
X
Instance to be explained.
Returns
-------
explanation
`Explanation` object containing the counterfactual with additional metadata as attributes.
See usage at `Counterfactual examples`_ for details.
.. _Counterfactual examples:
https://docs.seldon.io/projects/alibi/en/stable/methods/CF.html
"""
# TODO change init parameters on the fly
if X.shape[0] != 1:
logger.warning('Currently only single instance explanations supported (first dim = 1), '
'but first dim = %s', X.shape[0])
# make a prediction
Y = self.predict_fn(X)
pred_class = Y.argmax(axis=1).item()
pred_prob = Y.max(axis=1).item()
self.return_dict['orig_class'] = pred_class
self.return_dict['orig_proba'] = pred_prob
logger.debug('Initial prediction: %s with p=%s', pred_class, pred_prob)
# define the class-specific prediction function
self.predict_class_fn, t_class = _define_func(self.predict_fn, pred_class, self.target_class)
# initialize with an instance
X_init = self._initialize(X)
# minimize loss iteratively
self._minimize_loss(X, X_init, Y)
return_dict = self.return_dict.copy()
self.instance_dict = dict.fromkeys(['X', 'distance', 'lambda', 'index', 'class', 'proba', 'loss'])
self.return_dict = {'cf': None, 'all': {i: [] for i in range(self.max_lam_steps)}, 'orig_class': None,
'orig_proba': None}
# create explanation object
explanation = Explanation(meta=copy.deepcopy(self.meta), data=return_dict)
return explanation
def _prob_condition(self, X_current):
return np.abs(self.predict_class_fn(X_current) - self.target_proba_arr) <= self.tol
def _update_exp(self, i, l_step, lam, cf_found, X_current):
cf_found[0][l_step] += 1 # TODO: batch support
dist = self.sess.run(self.dist).item()
# populate the return dict
self.instance_dict['X'] = X_current
self.instance_dict['distance'] = dist
self.instance_dict['lambda'] = lam[0]
self.instance_dict['index'] = l_step * self.max_iter + i
preds = self.predict_fn(X_current)
pred_class = preds.argmax()
proba = preds.max()
self.instance_dict['class'] = pred_class
self.instance_dict['proba'] = preds
self.instance_dict['loss'] = (proba - self.target_proba_arr[0]) ** 2 + lam[0] * dist
self.return_dict['all'][l_step].append(self.instance_dict.copy())
# update best CF if it has a smaller distance
if self.return_dict['cf'] is None:
self.return_dict['cf'] = self.instance_dict.copy()
elif dist < self.return_dict['cf']['distance']:
self.return_dict['cf'] = self.instance_dict.copy()
logger.debug('CF found at step %s', l_step * self.max_iter + i)
def _write_tb(self, lam, lam_lb, lam_ub, cf_found, X_current, **kwargs):
if self.model:
scalars_tf = [self.global_step, self.learning_rate, self.dist[0],
self.loss_pred[0], self.loss_opt[0], self.pred_proba_class[0]]
gs, lr, dist, loss_pred, loss_opt, pred = self.sess.run(scalars_tf, feed_dict={self.lam: lam})
else:
scalars_tf = [self.global_step, self.learning_rate, self.dist[0],
self.loss_opt[0]]
gs, lr, dist, loss_opt = self.sess.run(scalars_tf, feed_dict={self.lam: lam})
loss_pred = kwargs['loss_pred']
pred = kwargs['pred']
try:
found = kwargs['found']
not_found = kwargs['not_found']
except KeyError:
found = 0
not_found = 0
summary = tf.Summary()
summary.value.add(tag='lr/global_step', simple_value=gs)
summary.value.add(tag='lr/lr', simple_value=lr)
summary.value.add(tag='lambda/lambda', simple_value=lam[0])
summary.value.add(tag='lambda/l_bound', simple_value=lam_lb[0])
summary.value.add(tag='lambda/u_bound', simple_value=lam_ub[0])
summary.value.add(tag='losses/dist', simple_value=dist)
summary.value.add(tag='losses/loss_pred', simple_value=loss_pred)
summary.value.add(tag='losses/loss_opt', simple_value=loss_opt)
summary.value.add(tag='losses/pred_div_dist', simple_value=loss_pred / (lam[0] * dist))
summary.value.add(tag='Y/pred_proba_class', simple_value=pred)
summary.value.add(tag='Y/pred_class_fn(X_current)', simple_value=self.predict_class_fn(X_current))
summary.value.add(tag='Y/n_cf_found', simple_value=cf_found[0].sum())
summary.value.add(tag='Y/found', simple_value=found)
summary.value.add(tag='Y/not_found', simple_value=not_found)
self.writer.add_summary(summary)
self.writer.flush()
def _bisect_lambda(self, cf_found, l_step, lam, lam_lb, lam_ub):
for batch_idx in range(self.batch_size): # TODO: batch not supported
if cf_found[batch_idx][l_step] >= 5: # minimum number of CF instances to warrant increasing lambda
# want to improve the solution by putting more weight on the distance term TODO: hyperparameter?
# by increasing lambda
lam_lb[batch_idx] = max(lam[batch_idx], lam_lb[batch_idx])
logger.debug('Lambda bounds: (%s, %s)', lam_lb[batch_idx], lam_ub[batch_idx])
if lam_ub[batch_idx] < 1e9:
lam[batch_idx] = (lam_lb[batch_idx] + lam_ub[batch_idx]) / 2
else:
lam[batch_idx] *= 10
logger.debug('Changed lambda to %s', lam[batch_idx])
elif cf_found[batch_idx][l_step] < 5:
# if not enough solutions found so far, decrease lambda by a factor of 10,
# otherwise bisect up to the last known successful lambda
lam_ub[batch_idx] = min(lam_ub[batch_idx], lam[batch_idx])
logger.debug('Lambda bounds: (%s, %s)', lam_lb[batch_idx], lam_ub[batch_idx])
if lam_lb[batch_idx] > 0:
lam[batch_idx] = (lam_lb[batch_idx] + lam_ub[batch_idx]) / 2
logger.debug('Changed lambda to %s', lam[batch_idx])
else:
lam[batch_idx] /= 10
return lam, lam_lb, lam_ub
def _minimize_loss(self,
X: np.ndarray,
X_init: np.ndarray,
Y: np.ndarray) -> None:
# keep track of the number of CFs found for each lambda in outer loop
cf_found = np.zeros((self.batch_size, self.max_lam_steps))
# set the lower and upper bound for lamda to scale the distance loss term
lam_lb = np.zeros(self.batch_size)
lam_ub = np.ones(self.batch_size) * 1e10
# make a one-hot vector of targets
Y_ohe = np.zeros(Y.shape)
np.put(Y_ohe, np.argmax(Y, axis=1), 1)
# on first run estimate lambda bounds
n_orders = 10
n_steps = self.max_iter // n_orders
lams = np.array([self.lam_init / 10 ** i for i in range(n_orders)]) # exponential decay
cf_count = np.zeros_like(lams)
logger.debug('Initial lambda sweep: %s', lams)
X_current = X_init
# TODO this whole initial loop should be optional?
for ix, l_step in enumerate(lams):
lam = np.ones(self.batch_size) * l_step
self.sess.run(self.tf_init)
self.sess.run(self.setup, {self.assign_orig: X,
self.assign_cf: X_current,
self.assign_target: Y_ohe})
for i in range(n_steps):
# numerical gradients
grads_num = np.zeros(self.data_shape)
if not self.model:
pred = self.predict_class_fn(X_current)
prediction_grad = num_grad_batch(self.predict_class_fn, X_current, eps=self.eps)
# squared difference prediction loss
loss_pred = (pred - self.target_proba.eval(session=self.sess)) ** 2
grads_num = 2 * (pred - self.target_proba.eval(session=self.sess)) * prediction_grad
grads_num = grads_num.reshape(self.data_shape) # TODO? correct?
# add values to tensorboard (1st item in batch only) every n steps
if self.debug and not i % 50:
if not self.model:
self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current, loss_pred=loss_pred, pred=pred)
else:
self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current)
# compute graph gradients
grads_vars_graph = self.sess.run(self.compute_grads, feed_dict={self.lam: lam})
grads_graph = [g for g, _ in grads_vars_graph][0]
# apply gradients
gradients = grads_graph + grads_num
self.sess.run(self.apply_grads, feed_dict={self.grad_ph: gradients, self.lam: lam})
# does the counterfactual condition hold?
X_current = self.sess.run(self.cf)
cond = self._prob_condition(X_current).squeeze()
if cond:
cf_count[ix] += 1
# find the lower bound
logger.debug('cf_count: %s', cf_count)
try:
lb_ix = np.where(cf_count > 0)[0][1] # take the second order of magnitude with some CFs as lower-bound
# TODO robust?
except IndexError:
logger.error('No appropriate lambda range found, try decreasing lam_init')
return
lam_lb = np.ones(self.batch_size) * lams[lb_ix]
# find the upper bound
try:
ub_ix = np.where(cf_count == 0)[0][-1] # TODO is 0 robust?
except IndexError:
ub_ix = 0
logger.debug('Could not find upper bound for lambda where no solutions found, setting upper bound to '
'lam_init=%s', lams[ub_ix])
lam_ub = np.ones(self.batch_size) * lams[ub_ix]
# start the search in the middle
lam = (lam_lb + lam_ub) / 2
logger.debug('Found upper and lower bounds: %s, %s', lam_lb[0], lam_ub[0])
# on subsequent runs bisect lambda within the bounds found initially
X_current = X_init
for l_step in range(self.max_lam_steps):
self.sess.run(self.tf_init)
# assign variables for the current iteration
self.sess.run(self.setup, {self.assign_orig: X,
self.assign_cf: X_current,
self.assign_target: Y_ohe})
found, not_found = 0, 0
# number of gradient descent steps in each inner loop
for i in range(self.max_iter):
# numerical gradients
grads_num = np.zeros(self.data_shape)
if not self.model:
pred = self.predict_class_fn(X_current)
prediction_grad = num_grad_batch(self.predict_class_fn, X_current, eps=self.eps)
# squared difference prediction loss
loss_pred = (pred - self.target_proba.eval(session=self.sess)) ** 2
grads_num = 2 * (pred - self.target_proba.eval(session=self.sess)) * prediction_grad
grads_num = grads_num.reshape(self.data_shape)
# add values to tensorboard (1st item in batch only) every n steps
if self.debug and not i % 50:
if not self.model:
self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current, found=found, not_found=not_found,
loss_pred=loss_pred, pred=pred)
else:
self._write_tb(lam, lam_lb, lam_ub, cf_found, X_current, found=found, not_found=not_found)
# compute graph gradients
grads_vars_graph = self.sess.run(self.compute_grads, feed_dict={self.lam: lam})
grads_graph = [g for g, _ in grads_vars_graph][0]
# apply gradients
gradients = grads_graph + grads_num
self.sess.run(self.apply_grads, feed_dict={self.grad_ph: gradients, self.lam: lam})
# does the counterfactual condition hold?
X_current = self.sess.run(self.cf)
cond = self._prob_condition(X_current)
if cond:
self._update_exp(i, l_step, lam, cf_found, X_current)
found += 1
not_found = 0
else:
found = 0
not_found += 1
# early stopping criterion - if no solutions or enough solutions found, change lambda
if found >= self.early_stop or not_found >= self.early_stop:
break
# adjust the lambda constant via bisection at the end of the outer loop
self._bisect_lambda(cf_found, l_step, lam, lam_lb, lam_ub)
self.return_dict['success'] = True
[docs]
def reset_predictor(self, predictor: Union[Callable, tf.keras.Model]) -> None:
"""
Resets the predictor function/model.
Parameters
----------
predictor
New predictor function/model.
"""
raise NotImplementedError('Resetting a predictor is currently not supported')