This page was generated from examples/cfproto_housing.ipynb.

Counterfactuals guided by prototypes on Boston housing dataset

This notebook goes through an example of prototypical counterfactuals using k-d trees to build the prototypes. Please check out this notebook for a more in-depth application of the method on MNIST using (auto-)encoders and trust scores.

In this example, we will train a simple neural net to predict whether house prices in the Boston area are above the median value or not. We can then find a counterfactual to see which variables need to be changed to increase or decrease a house price above or below the median value.

Note

To enable support for CounterfactualProto, you may need to run

pip install alibi[tensorflow]
[1]:
import tensorflow as tf
tf.get_logger().setLevel(40) # suppress deprecation messages
tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.utils import to_categorical
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
from sklearn.datasets import load_boston
from alibi.explainers import CounterfactualProto

print('TF version: ', tf.__version__)
print('Eager execution enabled: ', tf.executing_eagerly()) # False
TF version:  2.2.0
Eager execution enabled:  False

Load and prepare Boston housing dataset

[2]:
boston = load_boston()
data = boston.data
target = boston.target
feature_names = boston.feature_names

Transform into classification task: target becomes whether house price is above the overall median or not

[3]:
y = np.zeros((target.shape[0],))
y[np.where(target > np.median(target))[0]] = 1

Remove categorical feature

[4]:
data = np.delete(data, 3, 1)
feature_names = np.delete(feature_names, 3)

Explanation of remaining features:

  • CRIM: per capita crime rate by town

  • ZN: proportion of residential land zoned for lots over 25,000 sq.ft.

  • INDUS: proportion of non-retail business acres per town

  • RM: average number of rooms per dwelling

  • AGE: proportion of owner-occupied units built prior to 1940

  • DIS: weighted distances to five Boston employment centres

  • RAD: index of accessibility to radial highways

  • TAX: full-value property-tax rate per USD10,000

  • PTRATIO: pupil-teacher ratio by town

  • B: 1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town

  • LSTAT: % lower status of the population

Standardize data

[5]:
mu = data.mean(axis=0)
sigma = data.std(axis=0)
data = (data - mu) / sigma

Define train and test set

[6]:
idx = 475
x_train,y_train = data[:idx,:], y[:idx]
x_test, y_test = data[idx:,:], y[idx:]
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

Train model

[7]:
np.random.seed(42)
tf.random.set_seed(42)
[8]:
def nn_model():
    x_in = Input(shape=(12,))
    x = Dense(40, activation='relu')(x_in)
    x = Dense(40, activation='relu')(x)
    x_out = Dense(2, activation='softmax')(x)
    nn = Model(inputs=x_in, outputs=x_out)
    nn.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
    return nn
[9]:
nn = nn_model()
nn.summary()
nn.fit(x_train, y_train, batch_size=64, epochs=500, verbose=0)
nn.save('nn_boston.h5', save_format='h5')
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 12)]              0
_________________________________________________________________
dense (Dense)                (None, 40)                520
_________________________________________________________________
dense_1 (Dense)              (None, 40)                1640
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 82
=================================================================
Total params: 2,242
Trainable params: 2,242
Non-trainable params: 0
_________________________________________________________________
[10]:
nn = load_model('nn_boston.h5')
score = nn.evaluate(x_test, y_test, verbose=0)
print('Test accuracy: ', score[1])
Test accuracy:  0.83870965

Generate counterfactual guided by the nearest class prototype

Original instance:

[11]:
X = x_test[1].reshape((1,) + x_test[1].shape)
shape = X.shape

Run counterfactual:

[12]:
# define model
nn = load_model('nn_boston.h5')

# initialize explainer, fit and generate counterfactual
cf = CounterfactualProto(nn, shape, use_kdtree=True, theta=10., max_iterations=1000,
                         feature_range=(x_train.min(axis=0), x_train.max(axis=0)),
                         c_init=1., c_steps=10)

cf.fit(x_train)
explanation = cf.explain(X)

The prediction flipped from 0 (value below the median) to 1 (above the median):

[13]:
print(f'Original prediction: {explanation.orig_class}')
print('Counterfactual prediction: {}'.format(explanation.cf['class']))
Original prediction: 0
Counterfactual prediction: 1

Let’s take a look at the counterfactual. To make the results more interpretable, we will first undo the pre-processing step and then check where the counterfactual differs from the original instance:

[14]:
orig = X * sigma + mu
counterfactual = explanation.cf['X'] * sigma + mu
delta = counterfactual - orig
for i, f in enumerate(feature_names):
    if np.abs(delta[0][i]) > 1e-4:
        print('{}: {}'.format(f, delta[0][i]))
AGE: -6.526239960695747
LSTAT: -4.799340540220259

So in order to increase the house price, the proportion of owner-occupied units built prior to 1940 should decrease by ~11-12%. This is not surprising since the proportion for the observation is very high at 93.6%. Furthermore, the % of the population with “lower status” should decrease by ~5%.

[15]:
print('% owner-occupied units built prior to 1940: {}'.format(orig[0][5]))
print('% lower status of the population: {}'.format(orig[0][11]))
% owner-occupied units built prior to 1940: 93.6
% lower status of the population: 18.68

Clean up:

[16]:
os.remove('nn_boston.h5')