This page was generated from examples/cfproto_cat_adult_ord.ipynb.

Counterfactual explanations with ordinally encoded categorical variables

This example notebook illustrates how to obtain counterfactual explanations for instances with a mixture of ordinally encoded categorical and numerical variables. A more elaborate notebook highlighting additional functionality can be found here. We generate counterfactuals for instances in the adult dataset where we predict whether a person’s income is above or below $50k.

Note

To enable support for CounterfactualProto, you may need to run

pip install alibi[tensorflow]
[1]:
import tensorflow as tf
tf.get_logger().setLevel(40) # suppress deprecation messages
tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs
from tensorflow.keras.layers import Dense, Input, Embedding, Concatenate, Reshape, Dropout, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
from sklearn.preprocessing import OneHotEncoder
from time import time
from alibi.datasets import fetch_adult
from alibi.explainers import CounterfactualProto

print('TF version: ', tf.__version__)
print('Eager execution enabled: ', tf.executing_eagerly()) # False
TF version:  2.2.0
Eager execution enabled:  False

Load adult dataset

The fetch_adult function returns a Bunch object containing the features, the targets, the feature names and a mapping of the categories in each categorical variable.

[2]:
adult = fetch_adult()
data = adult.data
target = adult.target
feature_names = adult.feature_names
category_map_tmp = adult.category_map
target_names = adult.target_names

Define shuffled training and test set:

[3]:
def set_seed(s=0):
    np.random.seed(s)
    tf.random.set_seed(s)
[4]:
set_seed()
data_perm = np.random.permutation(np.c_[data, target])
X = data_perm[:,:-1]
y = data_perm[:,-1]
[5]:
idx = 30000
y_train, y_test = y[:idx], y[idx+1:]

Reorganize data so categorical features come first:

[6]:
X = np.c_[X[:, 1:8], X[:, 11], X[:, 0], X[:, 8:11]]

Adjust feature_names and category_map as well:

[7]:
feature_names = feature_names[1:8] + feature_names[11:12] + feature_names[0:1] + feature_names[8:11]
print(feature_names)
['Workclass', 'Education', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Country', 'Age', 'Capital Gain', 'Capital Loss', 'Hours per week']
[8]:
category_map = {}
for i, (_, v) in enumerate(category_map_tmp.items()):
    category_map[i] = v

Create a dictionary with as keys the categorical columns and values the number of categories for each variable in the dataset. This dictionary will later be used in the counterfactual explanation.

[9]:
cat_vars_ord = {}
n_categories = len(list(category_map.keys()))
for i in range(n_categories):
    cat_vars_ord[i] = len(np.unique(X[:, i]))
print(cat_vars_ord)
{0: 9, 1: 7, 2: 4, 3: 9, 4: 6, 5: 5, 6: 2, 7: 11}

Preprocess data

Scale numerical features between -1 and 1:

[10]:
X_num = X[:, -4:].astype(np.float32, copy=False)
xmin, xmax = X_num.min(axis=0), X_num.max(axis=0)
rng = (-1., 1.)
X_num_scaled = (X_num - xmin) / (xmax - xmin) * (rng[1] - rng[0]) + rng[0]
X_num_scaled_train = X_num_scaled[:idx, :]
X_num_scaled_test = X_num_scaled[idx+1:, :]

Combine numerical and categorical data:

[11]:
X = np.c_[X[:, :-4], X_num_scaled].astype(np.float32, copy=False)
X_train, X_test = X[:idx, :], X[idx+1:, :]
print(X_train.shape, X_test.shape)
(30000, 12) (2560, 12)

Train a neural net

The neural net will use entity embeddings for the categorical variables.

[12]:
def nn_ord():

    x_in = Input(shape=(12,))
    layers_in = []

    # embedding layers
    for i, (_, v) in enumerate(cat_vars_ord.items()):
        emb_in = Lambda(lambda x: x[:, i:i+1])(x_in)
        emb_dim = int(max(min(np.ceil(.5 * v), 50), 2))
        emb_layer = Embedding(input_dim=v+1, output_dim=emb_dim, input_length=1)(emb_in)
        emb_layer = Reshape(target_shape=(emb_dim,))(emb_layer)
        layers_in.append(emb_layer)

    # numerical layers
    num_in = Lambda(lambda x: x[:, -4:])(x_in)
    num_layer = Dense(16)(num_in)
    layers_in.append(num_layer)

    # combine
    x = Concatenate()(layers_in)
    x = Dense(60, activation='relu')(x)
    x = Dropout(.2)(x)
    x = Dense(60, activation='relu')(x)
    x = Dropout(.2)(x)
    x = Dense(60, activation='relu')(x)
    x = Dropout(.2)(x)
    x_out = Dense(2, activation='softmax')(x)

    nn = Model(inputs=x_in, outputs=x_out)
    nn.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    return nn
[13]:
set_seed()
nn = nn_ord()
nn.summary()
nn.fit(X_train, to_categorical(y_train), batch_size=128, epochs=30, verbose=0)
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 12)]         0
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_4 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_5 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_6 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
lambda_7 (Lambda)               (None, 1)            0           input_1[0][0]
__________________________________________________________________________________________________
embedding (Embedding)           (None, 1, 5)         50          lambda[0][0]
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 1, 4)         32          lambda_1[0][0]
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, 1, 2)         10          lambda_2[0][0]
__________________________________________________________________________________________________
embedding_3 (Embedding)         (None, 1, 5)         50          lambda_3[0][0]
__________________________________________________________________________________________________
embedding_4 (Embedding)         (None, 1, 3)         21          lambda_4[0][0]
__________________________________________________________________________________________________
embedding_5 (Embedding)         (None, 1, 3)         18          lambda_5[0][0]
__________________________________________________________________________________________________
embedding_6 (Embedding)         (None, 1, 2)         6           lambda_6[0][0]
__________________________________________________________________________________________________
embedding_7 (Embedding)         (None, 1, 6)         72          lambda_7[0][0]
__________________________________________________________________________________________________
lambda_8 (Lambda)               (None, 4)            0           input_1[0][0]
__________________________________________________________________________________________________
reshape (Reshape)               (None, 5)            0           embedding[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 4)            0           embedding_1[0][0]
__________________________________________________________________________________________________
reshape_2 (Reshape)             (None, 2)            0           embedding_2[0][0]
__________________________________________________________________________________________________
reshape_3 (Reshape)             (None, 5)            0           embedding_3[0][0]
__________________________________________________________________________________________________
reshape_4 (Reshape)             (None, 3)            0           embedding_4[0][0]
__________________________________________________________________________________________________
reshape_5 (Reshape)             (None, 3)            0           embedding_5[0][0]
__________________________________________________________________________________________________
reshape_6 (Reshape)             (None, 2)            0           embedding_6[0][0]
__________________________________________________________________________________________________
reshape_7 (Reshape)             (None, 6)            0           embedding_7[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, 16)           80          lambda_8[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 46)           0           reshape[0][0]
                                                                 reshape_1[0][0]
                                                                 reshape_2[0][0]
                                                                 reshape_3[0][0]
                                                                 reshape_4[0][0]
                                                                 reshape_5[0][0]
                                                                 reshape_6[0][0]
                                                                 reshape_7[0][0]
                                                                 dense[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 60)           2820        concatenate[0][0]
__________________________________________________________________________________________________
dropout (Dropout)               (None, 60)           0           dense_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 60)           3660        dropout[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 60)           0           dense_2[0][0]
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 60)           3660        dropout_1[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 60)           0           dense_3[0][0]
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 2)            122         dropout_2[0][0]
==================================================================================================
Total params: 10,601
Trainable params: 10,601
Non-trainable params: 0
__________________________________________________________________________________________________
[13]:
<tensorflow.python.keras.callbacks.History at 0x7f482905f8d0>

Generate counterfactual

Original instance:

[14]:
X = X_test[0].reshape((1,) + X_test[0].shape)

Initialize counterfactual parameters:

[15]:
shape = X.shape
beta = .01
c_init = 1.
c_steps = 5
max_iterations = 500
rng = (-1., 1.)  # scale features between -1 and 1
rng_shape = (1,) + data.shape[1:]
feature_range = ((np.ones(rng_shape) * rng[0]).astype(np.float32),
                 (np.ones(rng_shape) * rng[1]).astype(np.float32))

Initialize explainer. Since the Embedding layers in tf.keras do not let gradients propagate through, we will only make use of the model’s predict function, treat it as a black box and perform numerical gradient calculations.

[16]:
set_seed()

# define predict function
predict_fn = lambda x: nn.predict(x)

cf = CounterfactualProto(predict_fn,
                         shape,
                         beta=beta,
                         cat_vars=cat_vars_ord,
                         max_iterations=max_iterations,
                         feature_range=feature_range,
                         c_init=c_init,
                         c_steps=c_steps,
                         eps=(.01, .01)  # perturbation size for numerical gradients
                        )

Fit explainer. Please check the documentation for more info about the optional arguments.

[17]:
cf.fit(X_train, d_type='abdm', disc_perc=[25, 50, 75]);

Explain instance:

[18]:
set_seed()
explanation = cf.explain(X)

Helper function to more clearly describe explanations:

[19]:
def describe_instance(X, explanation, eps=1e-2):
    print('Original instance: {}  -- proba: {}'.format(target_names[explanation.orig_class],
                                                       explanation.orig_proba[0]))
    print('Counterfactual instance: {}  -- proba: {}'.format(target_names[explanation.cf['class']],
                                                             explanation.cf['proba'][0]))
    print('\nCounterfactual perturbations...')
    print('\nCategorical:')
    X_orig_ord = X
    X_cf_ord = explanation.cf['X']
    delta_cat = {}
    for i, (_, v) in enumerate(category_map.items()):
        cat_orig = v[int(X_orig_ord[0, i])]
        cat_cf = v[int(X_cf_ord[0, i])]
        if cat_orig != cat_cf:
            delta_cat[feature_names[i]] = [cat_orig, cat_cf]
    if delta_cat:
        for k, v in delta_cat.items():
            print('{}: {}  -->   {}'.format(k, v[0], v[1]))
    print('\nNumerical:')
    delta_num = X_cf_ord[0, -4:] - X_orig_ord[0, -4:]
    n_keys = len(list(cat_vars_ord.keys()))
    for i in range(delta_num.shape[0]):
        if np.abs(delta_num[i]) > eps:
            print('{}: {:.2f}  -->   {:.2f}'.format(feature_names[i+n_keys],
                                            X_orig_ord[0,i+n_keys],
                                            X_cf_ord[0,i+n_keys]))
[20]:
describe_instance(X, explanation)
Original instance: <=50K  -- proba: [0.6976237  0.30237624]
Counterfactual instance: >50K  -- proba: [0.49604183 0.5039582 ]

Counterfactual perturbations...

Categorical:

Numerical:
Capital Gain: -1.00  -->   -0.88

The person’s incomce is predicted to be above $50k by increasing his or her capital gain.