This page was generated from doc/source/methods/CFProto.ipynb.


Counterfactuals Guided by Prototypes


According to Molnar’s Interpretable Machine Learning book, a counterfactual explanation of a prediction describes the smallest change to the feature values that changes the prediction. One use case for counterfactuals would be loan approvals. Ann might be interested why her application for a loan was rejected, and what would need to change so the application is approved.

Counterfactuals are generated from the original instance by applying perturbations. The counterfactual then needs to satisfy certain constraints related to the model prediction or sparsity of the perturbed instance.

Some issues arise during the perturbation process:

  • the training data manifold needs to be respected
  • finding a satisfactory counterfactual can take time, especially for high dimensional data
  • there is often a trade off between sparsity and interpretability of the counterfactual

We can address these issues by incorporating additional loss terms in the objective function that is optimized using gradient descent. A basic loss function for a counterfactual can look like this:

\(Loss\) = \(cL_{pred}\) + \(\beta\)\(L_{1}\) + \(L_{2}\)

The first loss term, \(cL_{pred}\), encourages the perturbed instance to predict another class than the original instance. The \(\beta\)\(L_{1}\) + \(L_{2}\) acts as a regularizer and introduces sparsity by penalizing the size of the difference between the counterfactual and the perturbed instance. While we can obtain sparse counterfactuals using this objective function, these are often not very interpretable because the training data manifold is not taken into account, and the perturbations are not necessarily meaningful.

The Contrastive Explanation Method (CEM) uses an auto-encoder which is trained to reconstruct instances of the training set. We can then add the \(L_{2}\) reconstruction error of the perturbed instance as loss term to keep the counterfactual close to the training data manifold. The loss function becomes:

\(Loss\) = \(cL_{pred}\) + \(\beta\)\(L_{1}\) + \(L_{2}\) + \(\gamma\)\(L_{AE}\)

The \(L_{AE}\) does however not necessarily lead to interpretable solutions or speed up the counterfactual search. That’s where the prototype loss term \(L_{proto}\) comes in. To define the prototype for each prediction class, we can use the encoder part of the previously mentioned auto-encoder. We also need the training data or at least a representative sample. We use the model to make predictions on this data set. For each predicted class, we encode the instances belonging to that class. The class prototype is simply the average encoding for that class. When we want to generate a counterfactual, we first find the nearest prototype other than the one for the predicted class on the original instance. The \(L_{proto}\) loss term tries to minimize the \(L_{2}\) distance between the counterfactual and the nearest prototype. As a result, the perturbations are guided to the closest prototype, speeding up the counterfactual search and making the perturbations more meaningful as they move towards a typical in-distribution instance. If we do not have a trained encoder available, we can build class representations using k-d trees for each class. The prototype is then the nearest instance from a k-d tree other than the tree which represents the predicted class on the original instance. The loss function now looks as follows:

\(Loss\) = \(cL_{pred}\) + \(\beta\)\(L_{1}\) + \(L_{2}\) + \(\gamma\)\(L_{AE}\) + \(\theta\)\(L_{proto}\)

The method allows us to select specific prototype classes to guide the counterfactual. For example, in MNIST the closest prototype to 9 is 4. However, we can specify that we want to move towards the 7 prototype and avoid 4.

In order to help interpretability, we can also add a trust score constraint on the proposed counterfactual. The trust score is defined as the ratio of the distance between the encoded counterfactual and the prototype of the class predicted on the original instance, and the distance between the encoded counterfactual and the prototype of the class predicted for the counterfactual instance. Intuitively, a high trust score implies that the counterfactual is far from the originally predicted class compared to the counterfactual class. For more info on trust scores, please check out the documentation.

Because of the \(L_{proto}\) term, we can actually remove the prediction loss term and still obtain an interpretable counterfactual. This is especially relevant for fully black box models. When we provide the counterfactual search method with a Keras or TensorFlow model, it is incorporated in the TensorFlow graph and evaluated using automatic differentiation. However, if we only have access to the model’s predict function, the gradient updates are numerical and typically require a large number of prediction calls because of \(L_{pred}\). These prediction calls can slow the search down significantly and become a bottleneck. We can represent the gradient of the loss term as follows:

\begin{equation*} \frac{\partial L_{pred}}{\partial x} = \frac{\partial L_{pred}}{\partial p} \frac{\partial p}{\partial x} \end{equation*}

where \(p\) is the predict function and \(x\) the input features to optimize. For a 28 by 28 MNIST image, the \(^{\delta p}/_{\delta x}\) term alone would require a prediction call with batch size 28x28x2 = 1568. By using the prototypes to guide the search however, we can remove the prediction loss term and only make a single prediction at the end of each gradient update to check whether the predicted class on the proposed counterfactual is different from the original class.

The different use cases are highlighted in the example notebooks linked at the bottom of the page.

More details will be revealed in a forthcoming paper.



The counterfactuals guided by prototypes method works on fully black-box models, meaning they can work with arbitrary functions that take arrays and return arrays. However, if the user has access to a full TensorFlow (TF) or Keras model, this can be passed in as well to take advantage of the automatic differentiation in TF to speed up the search. This section describes the initialization for a TF/Keras model. Please see the numerical gradients section for black box models.

We first load our MNIST classifier and the (optional) auto-encoder and encoder:

cnn = load_model('mnist_cnn.h5')
ae = load_model('mnist_ae.h5')
enc = load_model('mnist_enc.h5')

Because the optimizer is defined in TensorFlow, we need to run the explainer within a TensorFlow session. The example below uses the session set by the loaded Keras models:

sess = K.get_session()

We can now initialize the counterfactual:

shape = (1,) + x_train.shape[1:]
cf = CounterFactualProto(sess, cnn, shape, kappa=0., beta=.1, gamma=100., theta=100.,
                         ae_model=ae, enc_model=enc, max_iterations=500,
                         feature_range=(-.5, .5), c_init=1., c_steps=5,
                         learning_rate_init=1e-2, clip=(-1000., 1000.), write_dir='./cf')

Besides passing the previously defined session as well as the predictive, and (optional) auto-encoder and models, we set a number of hyperparameters


  • shape: shape of the instance to be explained, starting with batch dimension. Currently only single explanations are supported, so the batch dimension should be equal to 1.
  • feature_range: global or feature-wise min and max values for the perturbed instance.
  • write_dir: write directory for Tensorboard logging of the loss terms. It can be helpful when tuning the hyperparameters for your use case. It makes it easy to verify that e.g. not 1 loss term dominates the optimization, that the number of iterations is OK etc. You can access Tensorboard by running tensorboard --logdir {write_dir} in the terminal. The figure below for example shows the loss to be optimized over different \(c\) iterations. It is clear that within each iteration, the number of max_iterations steps is too high and we can speed up the search.


… related to the optimizer:

  • max_iterations: number of loss optimization steps for each value of c; the multiplier of the first loss term.
  • learning_rate_init: initial learning rate, follows polynomial decay.
  • clip: min and max gradient values.

… related to the objective function:

  • c_init and c_steps: the multiplier \(c\) of the first loss term is updated for c_steps iterations, starting at c_init. The first loss term encourages the perturbed instance to be predicted as a different class than the original instance. If we find a candidate counterfactual for the current value of \(c\), we reduce the value of \(c\) for the next optimization cycle to put more emphasis on the other loss terms and improve the solution. If we cannot find a solution, \(c\) is increased to put more weight on the prediction class restrictions of the counterfactual.
  • kappa: the first term in the loss function is defined by a difference between the predicted probabilities for the perturbed instance of the original class and the max of the other classes. \(\kappa \geq 0\) defines a cap for this difference, limiting its impact on the overall loss to be optimized. Similar to CEM, we set \(\kappa\) to 0 in the examples.
  • beta: \(\beta\) is the \(L_{1}\) loss term multiplier. A higher value for \(\beta\) means more weight on the sparsity restrictions of the perturbations. \(\beta\) equal to 0.1 works well for the example datasets.
  • gamma: multiplier for the optional \(L_{2}\) reconstruction error. A higher value for \(\gamma\) means more emphasis on the reconstruction error penalty defined by the auto-encoder. A value of 100 is reasonable for the examples.
  • theta: multiplier for the \(L_{proto}\) loss term. A higher \(\theta\) means more emphasis on the gradients guiding the counterfactual towards the nearest class prototype. A value of 100 worked well for the examples.

While the default values for the loss term coefficients worked well for the simple examples provided in the notebooks, it is recommended to test their robustness for your own applications.


If we use an encoder to find the class prototypes, we need an additional fit step on the training data:


We can now explain the instance and close the TensorFlow session when we are done:

explanation = cf.explain(X, Y=None, target_class=None, threshold=0.,
                         verbose=True, print_every=100, log_every=100)
  • X: original instance
  • Y: one-hot-encoding of class label for X, inferred from the prediction on X if None.
  • target_class: classes considered for the nearest class prototype. Either a list with class indices or None.
  • threshold: threshold level for the ratio between the distance of the counterfactual to the prototype of the predicted class for the original instance over the distance to the prototype of the predicted class for the counterfactual. If the trust score is below the threshold, the proposed counterfactual does not meet the requirements and is rejected.
  • verbose: if True, print progress of counterfactual search every print_every steps.
  • log_every: if write_dir for Tensorboard is specified, then log losses every log_every steps.

The explain method returns a dictionary with the following key: value pairs:

  • cf: a dictionary with the overall best counterfactual found. explanation[‘cf’] has the following key: value pairs:
    • X: the counterfactual instance
    • class: predicted class for the counterfactual
    • proba: predicted class probabilities for the counterfactual
    • grads_graph: gradient values computed from the TF graph with respect to the input features at the counterfactual
    • grads_num: numerical gradient values with respect to the input features at the counterfactual
  • orig_class: predicted class for original instance
  • orig_proba: predicted class probabilities for original instance
  • all: a dictionary with the iterations as keys and for each iteration a list with counterfactuals found in that iteration as values. So for instance, during the first iteration, explanation[‘all’][0], initially we typically find fairly noisy counterfactuals that improve over the course of the iteration. The counterfactuals for the subsequent iterations then need to be better (sparser) than the previous best counterfactual. So over the next few iterations, we probably find less but better solutions.

Numerical Gradients

So far, the whole optimization problem could be defined within the TF graph, making automatic differentiation possible. It is however possible that we do not have access to the model architecture and weights, and are only provided with a predict function returning probabilities for each class. The counterfactual can then be initialized in the TF session as follows:

# define model
cnn = load_model('mnist_cnn.h5')
predict_fn = lambda x: cnn.predict(x)
ae = load_model('mnist_ae.h5')
enc = load_model('mnist_enc.h5')

sess = K.get_session()

# initialize explainer
shape = (1,) + x_train.shape[1:]
cf = CounterFactualProto(sess, predict_fn, shape, gamma=100., theta=100.,
                         ae_model=ae, enc_model=enc, max_iterations=500,
                         feature_range=(-.5, .5), c_init=1., c_steps=4,
                         eps=(1e-2, 1e-2), update_num_grad=100)

In this case, we need to evaluate the gradients of the loss function with respect to the input features numerically:

\begin{equation*} \frac{\partial L_{pred}}{\partial x} = \frac{\partial L_{pred}}{\partial p} \frac{\partial p}{\partial x} \end{equation*}

where \(L_{pred}\) is the predict function loss term, \(p\) the predict function and \(x\) the input features to optimize. There are now 2 additional hyperparameters to consider:

  • eps: a tuple to define the perturbation size used to compute the numerical gradients. eps[0] and eps[1] are used respectively for \(^{\delta L_{pred}}/_{\delta p}\) and \(^{\delta p}/_{\delta x}\). eps[0] and eps[1] can be a combination of float values or numpy arrays. For eps[0], the array dimension should be (1 x nb of prediction categories) and for eps[1] it should be (1 x nb of features). For the Iris dataset, eps could look as follows:
eps0 = np.array([[1e-2, 1e-2, 1e-2]])  # 3 prediction categories, equivalent to 1e-2
eps1 = np.array([[1e-2, 1e-2, 1e-2, 1e-2]])  # 4 features, also equivalent to 1e-2
eps = (eps0, eps1)
  • update_num_grad: for complex models with a high number of parameters and a high dimensional feature space (e.g. Inception on ImageNet), evaluating numerical gradients can be expensive as they involve prediction calls for each perturbed instance. The update_num_grad parameter allows you to set a batch size on which to evaluate the numerical gradients, reducing the number of prediction calls required.

We can also remove the prediction loss term by setting c_init to 0 and only run 1 c_steps, and still obtain an interpretable counterfactual. This dramatically speeds up the counterfactual search (e.g. by 100x in the MNIST example notebook):

cf = CounterFactualProto(sess, predict_fn, shape, gamma=100., theta=100.,
                         ae_model=ae, enc_model=enc, max_iterations=500,
                         feature_range=(-.5, .5), c_init=0., c_steps=1)

k-d trees

So far, we assumed that we have a trained encoder available to find the nearest class prototype. This is however not a hard requirement. As mentioned in the Overview section, we can use k-d trees to build class representations, find prototypes by querying the trees for each class and return the closest class instance as the nearest prototype. We can run the counterfactual as follows:

cf = CounterFactualProto(sess, cnn, shape, use_kdtree=True, theta=10., feature_range=(-.5, .5)), trustscore_kwargs=None)
explanation = cf.explain(X)
  • trustscore_kwargs: keyword arguments for the trust score object used to define the k-d trees for each class. Please check the trust scores documentation for more info.