import copy
import inspect
import logging
import math
import numbers
import sys
from collections import defaultdict
from copy import deepcopy
from enum import Enum
from typing import (Any, Callable, Dict, List, Optional, Tuple, Union,
no_type_check)
import matplotlib.pyplot as plt
import numpy as np
import sklearn.metrics
from tqdm import tqdm
from alibi.api.defaults import (DEFAULT_DATA_PERMUTATION_IMPORTANCE,
DEFAULT_META_PERMUTATION_IMPORTANCE)
from alibi.api.interfaces import Explainer, Explanation
if sys.version_info >= (3, 8):
from typing import Literal, get_args
else:
from typing_extensions import Literal, get_args
logger = logging.getLogger(__name__)
[docs]
class Method(str, Enum):
""" Enumeration of supported method. """
EXACT = 'exact'
ESTIMATE = 'estimate'
[docs]
class Kind(str, Enum):
""" Enumeration of supported kind. """
DIFFERENCE = 'difference'
RATIO = 'ratio'
LOSS_FNS = {
# regression
"mean_absolute_error": sklearn.metrics.mean_absolute_error,
"mean_squared_error": sklearn.metrics.mean_squared_error,
"mean_squared_log_error": sklearn.metrics.mean_squared_log_error,
"mean_absolute_percentage_error": sklearn.metrics.mean_absolute_percentage_error,
# classification
"log_loss": sklearn.metrics.log_loss,
}
"""
Dictionary of supported string specified loss functions
- ``'mean_absolute_error'`` - Mean absolute error regression loss. See `sklearn.metrics.mean_absolute_error`_ \
for documentation.
- ``'mean_squared_error'`` - Mean squared error regression loss. See `sklearn.metrics.mean_squared_error`_ \
for documentation.
- ``'mean_squared_log_error'`` - Mean squared logarithmic error regression loss. \
See `sklearn.metrics.mean_squared_log_error`_ for documentation.
- ``'mean_absolute_percentage_error'`` - Mean absolute percentage error (MAPE) regression loss. \
See `sklearn.metrics.mean_absolute_percentage_error`_ for documentation.
- ``'log_loss'`` - Log loss, aka logistic loss or cross-entropy loss. \
See `sklearn.metrics.log_loss`_ for documentation.
.. _sklearn.metrics.mean_absolute_error:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html#sklearn.metrics.mean_absolute_error
.. _sklearn.metrics.mean_squared_error:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html#sklearn.metrics.mean_squared_error
.. _sklearn.metrics.mean_squared_log_error:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_log_error.html#sklearn.metrics.mean_squared_log_error
.. _sklearn.metrics.mean_absolute_percentage_error:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_percentage_error.html#sklearn.metrics.mean_absolute_percentage_error
.. _sklearn.metrics.log_loss:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html#sklearn.metrics.log_loss
"""
LossFnName = Literal[
# regression
"mean_absolute_error",
"mean_squared_error",
"mean_squared_log_error",
"mean_absolute_percentage_error",
# classification
"log_loss"
]
SCORE_FNS = {
# classification
"accuracy": sklearn.metrics.accuracy_score,
"precision": sklearn.metrics.precision_score,
"recall": sklearn.metrics.recall_score,
"f1": sklearn.metrics.f1_score,
"roc_auc": sklearn.metrics.roc_auc_score,
# regression
"r2": sklearn.metrics.r2_score
}
"""
Dictionary of supported string specified score functions
- ``'accuracy'`` - Accuracy classification score. See `sklearn.metrics.accuracy_score`_ for documentation.
- ``'precision'`` - Precision score. See `sklearn.metrics.precision_score`_ for documentation.
- ``'recall'`` - Recall score. See `sklearn.metrics.recall_score`_ for documentation.
- ``'f1_score'`` - F1 score. See `sklearn.metrics.f1_score`_ for documentation.
- ``'roc_auc_score'`` - Area Under the Receiver Operating Characteristic Curve (ROC AUC) score. \
See `sklearn.metrics.roc_auc_score`_ for documentation.
- ``'r2_score'`` - :math:`R^2` (coefficient of determination) regression score. \
See `sklearn.metrics.r2_score`_ for documentation.
.. _sklearn.metrics.accuracy_score:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html#sklearn.metrics.accuracy_score
.. _sklearn.metrics.precision_score:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn.metrics.precision_score
.. _sklearn.metrics.recall_score:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html#sklearn.metrics.recall_score
.. _sklearn.metrics.f1_score:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score
.. _sklearn.metrics.roc_auc_score:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
.. _sklearn.metrics.r2_score:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html
"""
ScoreFnName = Literal[
# classification
"accuracy",
"precision",
"recall",
"f1",
"roc_auc",
# regression
"r2"
]
assert set(get_args(LossFnName)) == set(LOSS_FNS.keys())
assert set(get_args(ScoreFnName)) == set(SCORE_FNS.keys())
[docs]
class PermutationImportance(Explainer):
""" Implementation of the permutation feature importance for tabular datasets. The method measure the importance
of a feature as the relative increase/decrease in the loss/score function when the feature values are permuted.
Supports black-box models.
For details of the method see the papers:
- https://link.springer.com/article/10.1023/A:1010933404324
- https://arxiv.org/abs/1801.01489
"""
[docs]
def __init__(self,
predictor: Callable[[np.ndarray], np.ndarray],
loss_fns: Optional[
Union[
LossFnName,
List[LossFnName],
Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], float],
Dict[str, Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], float]]
]
] = None,
score_fns: Optional[
Union[
ScoreFnName,
List[ScoreFnName],
Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], float],
Dict[str, Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], float]]
]
] = None,
feature_names: Optional[List[str]] = None,
verbose: bool = False):
"""
Initialize the permutation feature importance.
Parameters
----------
predictor
A prediction function which receives as input a `numpy` array of size `N x F`, and outputs a
`numpy` array of size `N` (i.e. `(N, )`) or `N x T`, where `N` is the number of input instances,
`F` is the number of features, and `T` is the number of targets. Note that the output shape must be
compatible with the loss and score functions provided in `loss_fns` and `score_fns`.
loss_fns
A literal, or a list of literals, or a loss function, or a dictionary of loss functions having as keys the
names of the loss functions and as values the loss functions (i.e., lower values are better). The available
literal values are described in :py:data:`alibi.explainers.permutation_importance.LOSS_FNS`. Note that the
`predictor` output must be compatible with every loss function. Every loss function is expected to receive
the following arguments:
- `y_true` : ``np.ndarray`` - a `numpy` array of ground-truth labels.
- `y_pred` | `y_score` : ``np.ndarray`` - a `numpy` array of model predictions. This corresponds to \
the output of the model.
- `sample_weight`: ``Optional[np.ndarray]`` - a `numpy` array of sample weights.
score_fns
A literal, or a list or literals, or a score function, or a dictionary of score functions having as keys the
names of the score functions and as values the score functions (i.e, higher values are better). The
available literal values are described in :py:data:`alibi.explainers.permutation_importance.SCORE_FNS`.
As with the `loss_fns`, the `predictor` output must be compatible with every score function and the score
function must have the same signature presented in the `loss_fns` parameter description.
feature_names
A list of feature names used for displaying results.
verbose
Whether to print the progress of the explainer.
"""
super().__init__(meta=copy.deepcopy(DEFAULT_META_PERMUTATION_IMPORTANCE))
self.predictor = predictor
self.feature_names = feature_names
self.verbose = verbose
if (loss_fns is None) and (score_fns is None):
raise ValueError('At least one loss function or a score function must be provided.')
# initialize loss and score functions
self.loss_fns = PermutationImportance._init_metrics(metric_fns=loss_fns, metric_type='loss') # type: ignore[arg-type] # noqa
self.score_fns = PermutationImportance._init_metrics(metric_fns=score_fns, metric_type='score') # type: ignore[arg-type] # noqa
[docs]
def explain(self, # type: ignore[override]
X: np.ndarray,
y: np.ndarray,
features: Optional[List[Union[int, Tuple[int, ...]]]] = None,
method: Literal["estimate", "exact"] = "estimate",
kind: Literal["ratio", "difference"] = "ratio",
n_repeats: int = 50,
sample_weight: Optional[np.ndarray] = None) -> Explanation:
"""
Computes the permutation feature importance for each feature with respect to the given loss or score
functions and the dataset `(X, y)`.
Parameters
----------
X
A `N x F` input feature dataset used to calculate the permutation feature importance. This is typically the
test dataset.
y
Ground-truth labels array of size `N` (i.e. `(N, )`) corresponding the input feature `X`.
features
An optional list of features or tuples of features for which to compute the permutation feature
importance. If not provided, the permutation feature importance will be computed for every single features
in the dataset. Some example of `features` would be: ``[0, 2]``, ``[0, 2, (0, 2)]``, ``[(0, 2)]``,
where ``0`` and ``2`` correspond to column 0 and 2 in `X`, respectively.
method
The method to be used to compute the feature importance. If set to ``'exact'``, a "switch" operation is
performed across all observed pairs, by excluding pairings that are actually observed in the original
dataset. This operation is quadratic in the number of samples (`N x (N - 1)` samples) and thus can be
computationally intensive. If set to ``'estimate'``, the dataset will be divided in half. The values of
the first half containing the ground-truth labels the rest of the features (i.e. features that are left
intact) is matched with the values of the second half of the permuted features, and the other way around.
This method is computationally lighter and provides estimate error bars given by the standard deviation.
Note that for some specific loss and score functions, the estimate does not converge to the exact metric
value.
kind
Whether to report the importance as the loss/score ratio or the loss/score difference.
Available values are: ``'ratio'`` | ``'difference'``.
n_repeats
Number of times to permute the feature values. Considered only when ``method='estimate'``.
sample_weight
Optional weight for each sample instance.
Returns
-------
explanation
An `Explanation` object containing the data and the metadata of the permutation feature importance.
See usage at `Permutation feature importance examples`_ for details
.. _Permutation feature importance examples:
https://docs.seldon.io/projects/alibi/en/stable/methods/PermutationImportance.html
"""
n_features = X.shape[1]
# set the `features_names` when the user did not provide the feature names
if self.feature_names is None:
self.feature_names = [f'f_{i}' for i in range(n_features)]
# construct `feature_names` based on the `features`. If `features` is ``None``, then initialize
# `features` with all single feature available in the dataset.
if features:
feature_names = [tuple([self.feature_names[f] for f in features])
if isinstance(features, tuple) else self.feature_names[features]
for features in features]
else:
feature_names = self.feature_names # type: ignore[assignment]
features = list(range(n_features))
# unaltered model predictions
y_hat = self.predictor(X)
# compute original loss
loss_orig = PermutationImportance._compute_metrics(metric_fns=self.loss_fns,
y=y,
y_hat=y_hat,
sample_weight=sample_weight)
# compute original score
score_orig = PermutationImportance._compute_metrics(metric_fns=self.score_fns,
y=y,
y_hat=y_hat,
sample_weight=sample_weight)
# compute permutation feature importance for every feature
# TODO: implement parallel version - future work as it can be done for ALE too
individual_feature_importance = []
for ifeatures in tqdm(features, disable=not self.verbose):
individual_feature_importance.append(
self._compute_permutation_importance(
X=X,
y=y,
method=method,
kind=kind,
n_repeats=n_repeats,
sample_weight=sample_weight,
features=ifeatures,
loss_orig=loss_orig,
score_orig=score_orig
)
)
# update meta data params
self.meta['params'].update(feature_names=feature_names,
method=method,
kind=kind,
n_repeats=n_repeats,
sample_weight=sample_weight)
# build and return the explanation object
return self._build_explanation(feature_names=feature_names, # type: ignore[arg-type]
individual_feature_importance=individual_feature_importance)
@staticmethod
def _init_metrics(metric_fns: Optional[
Union[
str,
List[str],
Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], float],
Dict[str, Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], float]]
]
],
metric_type: Literal['loss', 'score']
) -> Dict[str, Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], float]]:
"""
Helper function to initialize the loss and score functions.
Parameters
----------
metric_fns
See `loss_fns` or `score_fns` as defined in
:py:meth:`alibi.explainers.permutation_importance.PermutationImportance.explain`.
metric_type
Metric function type. Supported types: ``'loss'`` | ``'score'``.
Returns
-------
Initialized loss and score functions.
"""
if metric_fns is None:
return {}
if callable(metric_fns):
return {metric_type: metric_fns}
if isinstance(metric_fns, str):
metric_fns = [metric_fns]
if isinstance(metric_fns, list):
dict_metric_fns = {}
METRIC_FNS = LOSS_FNS if metric_type == 'loss' else SCORE_FNS
for metric_fn in metric_fns:
if not isinstance(metric_fn, str):
raise ValueError(f'The {metric_type} inside {metric_type}_fns must be of type `str`.')
if metric_fn not in METRIC_FNS:
raise ValueError(f'Unknown {metric_type} name. Received {metric_fn}. '
f'Supported values are: {list(METRIC_FNS.keys())}')
dict_metric_fns[metric_fn] = METRIC_FNS[metric_fn]
return dict_metric_fns
return metric_fns
@staticmethod
def _compute_metrics(metric_fns: Dict[str, Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], float]],
y: np.ndarray,
y_hat: np.ndarray,
sample_weight: Optional[np.ndarray] = None) -> Dict[str, List[float]]:
"""
Helper function to compute multiple metrics.
Parameters
----------
metric_fns
A dictionary of metric functions having as keys the names of the metric functions and as
values the metric functions.
y
Ground truth targets.
y_hat
Predicted outcome as returned by the classifier.
sample_weight
Weight of each sample instance.
Returns
-------
Dictionary having as keys the metric names and as values the evaluation of the metrics.
"""
metrics = defaultdict(list)
for metric_name, metric_fn in metric_fns.items():
metrics[metric_name].append(
PermutationImportance._compute_metric(
metric_fn=metric_fn,
y=y,
y_hat=y_hat,
sample_weight=sample_weight
)
)
return metrics
@staticmethod
def _compute_metric(metric_fn: Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], float],
y: np.ndarray,
y_hat: np.ndarray,
sample_weight: Optional[np.ndarray] = None) -> float:
"""
Helper function to compute a metric. It also checks if the metric function contains in its signature the
arguments `y_true`, `y_pred` or `y_score`, and optionally `sample_weight`.
Parameters
----------
metric_fn
Metric function to be used. Note that the loss/score function must be compatible with the
`y_true`, `y_pred`, and optionally with `sample_weight`.
y, y_hat, sample_weight
See :py:meth:`alibi.explainers.permutation_importance.PermutationImportance._compute_metrics`.
Returns
-------
Evaluation of the metric.
"""
str_args = inspect.signature(metric_fn).parameters.keys()
if 'y_true' not in str_args:
raise ValueError('The `scoring` function must have the argument `y_true` in its definition.')
if ('y_pred' not in str_args) and ('y_score' not in str_args):
raise ValueError('The `scoring` function must have the argument `y_pred` or `y_score` in its definition.')
kwargs: Dict[str, Optional[np.ndarray]] = {
'y_true': y,
'y_pred' if 'y_pred' in str_args else 'y_score': y_hat
}
if 'sample_weight' not in str_args:
# some metrics might not support `sample_weight` such as:
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.max_error.html#sklearn.metrics.max_error
if sample_weight is not None:
logger.warning(f"The loss function '{metric_fn.__name__}' does not support argument `sample_weight`. "
f"Calling the method without `sample_weight`.")
else:
# include `sample_weight` int the `kwargs` if the metric supports it
kwargs['sample_weight'] = sample_weight
return metric_fn(**kwargs) # type: ignore [call-arg]
def _compute_permutation_importance(self,
X: np.ndarray,
y: np.ndarray,
method: Literal["estimate", "exact"],
kind: Literal["difference", "ratio"],
n_repeats: int,
sample_weight: Optional[np.ndarray],
features: Union[int, Tuple[int, ...]],
loss_orig: Dict[str, List[float]],
score_orig: Dict[str, List[float]]) -> Dict[str, Any]:
"""
Helper function to compute the permutation feature importance for a given feature or tuple of features.
Parameters
----------
X, y, method, kind, n_repeats, sample_weight
See :py:meth:`alibi.explainers.permutation_importance.PermutationImportance.explain`.#
features
The feature or the tuple of features to compute the permutation feature importance for.
loss_orig
Original loss value when the features are left intact. The loss is computed on the original datasets.
score_orig
Original score value when the feature are left intact. The score is computed on the original dataset.
Returns
--------
A dictionary having as keys the metric names and as values the permutation feature importance associated
with the corresponding metrics.
"""
if method == Method.EXACT:
# computation of the exact statistic which is quadratic in the number of samples
return self._compute_exact(X=X,
y=y,
kind=kind,
sample_weight=sample_weight,
features=features,
loss_orig=loss_orig,
score_orig=score_orig)
# sample approximation
return self._compute_estimate(X=X,
y=y,
kind=kind,
n_repeats=n_repeats,
sample_weight=sample_weight,
features=features,
loss_orig=loss_orig,
score_orig=score_orig)
def _compute_exact(self,
X: np.ndarray,
y: np.ndarray,
kind: str,
sample_weight: Optional[np.ndarray],
features: Union[int, Tuple[int, ...]],
loss_orig: Dict[str, List[float]],
score_orig: Dict[str, List[float]]) -> Dict[str, Any]:
"""
Helper function to compute the "exact" value of the permutation feature importance.
Parameters
----------
X, y, kind, sample_weight, features, loss_orig, score_orig
See :py:meth:`alibi.explainers.permutation_importance.PermutationImportance._compute_permutation_importance`. # noqa
Returns
-------
A dictionary having as keys the metric names and as values the permutation feature importance associated
with the corresponding metrics.
"""
y_perm, y_perm_hat = [], []
weights: Optional[List[np.ndarray]] = [] if sample_weight else None
for i in range(len(X)):
# create input features dataset: (1, F1, F2, ...) -> (N - 1, F1, F2, ... )
# where N is the number of instances in the dataset and Fi is the dimension in the axis i.
X_tmp = np.tile(X[i:i+1], reps=(len(X) - 1, ) + (1, ) * (len(X.shape) - 1))
X_tmp[:, features] = np.delete(arr=X[:, features], obj=i, axis=0)
# create ground-truth labels: (1, C1, C2, ...) -> (N - 1, C1, C2, ... )
# where N is the number of instances in the dataset and Ci is the dimension in the axis i.
y_tmp = np.tile(y[i:i+1], reps=(len(y) - 1, ) + (1, ) * (len(y.shape) - 1))
# compute predictions
y_perm_hat.append(self.predictor(X_tmp))
y_perm.append(y_tmp)
# create sample weights if necessary
if sample_weight is not None:
weights.append(np.full(shape=(len(X_tmp),), fill_value=sample_weight[i])) # type: ignore[union-attr]
# concatenate all predictions and construct ground-truth array. At this point, the `y_hat` vector
# should contain `N x (N - 1)` predictions, where `N` is the number of samples in `X`.
y_perm_hat = np.concatenate(y_perm_hat, axis=0)
y_perm = np.concatenate(y_perm, axis=0)
if weights is not None:
weights = np.concatenate(weights, axis=0)
# compute loss values for the altered dataset
loss_permuted = PermutationImportance._compute_metrics(metric_fns=self.loss_fns,
y=y_perm, # type: ignore[arg-type]
y_hat=y_perm_hat, # type: ignore[arg-type]
sample_weight=weights) # type: ignore[arg-type]
# compute score values for the altered dataset
score_permuted = PermutationImportance._compute_metrics(metric_fns=self.score_fns,
y=y_perm, # type: ignore[arg-type]
y_hat=y_perm_hat, # type: ignore[arg-type]
sample_weight=weights) # type: ignore[arg-type]
# compute feature importance for the loss functions
loss_feature_importance = PermutationImportance._compute_importances(metric_orig=loss_orig,
metric_permuted=loss_permuted,
kind=kind,
lower_is_better=True)
# compute feature importance for the score functions
score_feature_importance = PermutationImportance._compute_importances(metric_orig=score_orig,
metric_permuted=score_permuted,
kind=kind,
lower_is_better=False)
return {**loss_feature_importance, **score_feature_importance}
def _compute_estimate(self,
X: np.ndarray,
y: np.ndarray,
kind: str,
n_repeats: int,
sample_weight: Optional[np.ndarray],
features: Union[int, Tuple[int, ...]],
loss_orig: Dict[str, List[float]],
score_orig: Dict[str, List[float]]) -> Dict[str, Any]:
"""
Helper function to compute the "estimate" mean, standard deviation and sample values of the permutation
feature importance.
Parameters
----------
X, y, kind, sample_weight, features, loss_orig, score_orig
See :py:meth:`alibi.explainers.permutation_importance.PermutationImportance._compute_permutation_importance`. # noqa
Returns
-------
A dictionary having as keys the metric names and as values the permutation feature importance associated
with the corresponding metrics.
"""
N = len(X)
start, middle, end = 0, N // 2, N if N % 2 == 0 else N - 1
fh, sh = np.s_[start:middle], np.s_[middle:end]
loss_permuted: Dict[str, List[float]] = defaultdict(list)
score_permuted: Dict[str, List[float]] = defaultdict(list)
for i in range(n_repeats):
# get random permutation. Note that this includes also the last element
shuffled_indices = np.random.permutation(len(X))
# shuffle the dataset
X_tmp, y_tmp = X[shuffled_indices], y[shuffled_indices]
sample_weight_tmp = None if (sample_weight is None) else sample_weight[shuffled_indices]
# permute values from the first half into the second half and the other way around
fvals_tmp = X_tmp[fh, features].copy()
X_tmp[fh, features] = X_tmp[sh, features]
X_tmp[sh, features] = fvals_tmp
# compute scores
y_tmp_hat = self.predictor(X_tmp[:end])
y_tmp = y_tmp[:end]
weights = None if (sample_weight_tmp is None) else sample_weight_tmp[:end]
# compute loss values for the altered dataset
tmp_loss_permuted = PermutationImportance._compute_metrics(metric_fns=self.loss_fns,
y=y_tmp,
y_hat=y_tmp_hat,
sample_weight=weights)
for loss_name in tmp_loss_permuted:
loss_permuted[loss_name] += tmp_loss_permuted[loss_name]
# compute score values for the altered dataset
tmp_score_permuted = PermutationImportance._compute_metrics(metric_fns=self.score_fns,
y=y_tmp,
y_hat=y_tmp_hat,
sample_weight=weights)
for score_name in tmp_score_permuted:
score_permuted[score_name] += tmp_score_permuted[score_name]
# compute feature importance for the loss functions
loss_feature_importance = PermutationImportance._compute_importances(metric_orig=loss_orig,
metric_permuted=loss_permuted,
kind=kind,
lower_is_better=True)
# compute feature importance for the score functions
score_feature_importance = PermutationImportance._compute_importances(metric_orig=score_orig,
metric_permuted=score_permuted,
kind=kind,
lower_is_better=False)
return {**loss_feature_importance, **score_feature_importance}
@staticmethod
def _compute_importances(metric_orig: Dict[str, List[float]],
metric_permuted: Dict[str, List[float]],
kind: str,
lower_is_better: bool) -> Dict[str, Any]:
"""
Helper function to compute the feature importance as the metric ration or the metric difference
based on the `kind` parameter and the `lower_is_better` flag for multiple metric functions.
Parameters
----------
metric_orig
A dictionary having as keys the names of the metric functions and as values the metric evaluations when
the feature values are left intact. The values are lists with a single element.
metric_permuted
A dictionary having as keys the names of the metric functions and as values a list of metric evaluations
when the feature values are permuted.
kind
See :py:meth:`alibi.explainers.permutation_importance.PermutationImportance.explain`.
lower_is_better
Whether lower metric value is better.
Returns
-------
A dictionary having as keys the names of the metric functions and as values the feature importance or
a dictionary containing the mean and the standard deviation of the feature importance, and the samples used to
compute the two statistics for the corresponding metrics.
"""
feature_importance = {}
for metric_name in metric_orig:
importance_values = [
PermutationImportance._compute_importance(
metric_orig=metric_orig[metric_name][0], # a list with just one element
metric_permuted=metric_permuted_value,
kind=kind,
lower_is_better=lower_is_better
) for metric_permuted_value in metric_permuted[metric_name]
]
if len(importance_values) > 1:
feature_importance[metric_name] = {
"mean": np.mean(importance_values),
"std": np.std(importance_values),
"samples": np.array(importance_values),
}
else:
feature_importance[metric_name] = importance_values[0] # type: ignore
return feature_importance
@staticmethod
def _compute_importance(metric_orig: float, metric_permuted: float, kind: str, lower_is_better: bool) -> float:
"""
Helper function to compute the feature importance as the metric ratio or the metric difference
based on the `kind` parameter and `lower_is_better` flag.
Parameters
----------
metric_orig
Metric value when the feature values are left intact.
metric_permuted
Metric value when the feature value are permuted.
kind
See :py:meth:`alibi.explainers.permutation_importance.PermutationImportance.explain`.
lower_is_better
See :py:meth:`alibi.explainers.permutation_importance.PermutationImportance._compute_importances.
Returns
-------
Importance score.
"""
if lower_is_better:
return metric_permuted / metric_orig if kind == Kind.RATIO else metric_permuted - metric_orig
return metric_orig / metric_permuted if kind == Kind.RATIO else metric_orig - metric_permuted
def _build_explanation(self,
feature_names: List[Union[str, Tuple[str, ...]]],
individual_feature_importance: List[
Union[
Dict[str, float],
Dict[str, Dict[str, float]]
]
]) -> Explanation:
"""
Helper method to build `Explanation` object.
Parameters
----------
feature_names
List of names of the explained features.
individual_feature_importance
List of dictionary having as keys the names of the metric functions and as values the feature
importance when ``kind='exact'`` or a dictionary containing the mean and the standard deviation of the
feature importance, and the samples used to compute the two statistics when``kind='estimate'`` for
the corresponding metrics.
Returns
-------
`Explanation` object.
"""
# list of metrics names
metric_names = list(individual_feature_importance[0].keys())
# list of lists of features importance, one list per loss function
feature_importance: List[List[Union[float, Dict[str, float]]]] = []
for metric_name in metric_names:
feature_importance.append([])
for i in range(len(feature_names)):
feature_importance[-1].append(individual_feature_importance[i][metric_name])
data = copy.deepcopy(DEFAULT_DATA_PERMUTATION_IMPORTANCE)
data.update(
feature_names=feature_names,
metric_names=metric_names,
feature_importance=feature_importance,
)
return Explanation(meta=copy.deepcopy(self.meta), data=data)
[docs]
def reset_predictor(self, predictor: Callable) -> None:
"""
Resets the predictor function.
Parameters
----------
predictor
New predictor function.
"""
self.predictor = predictor
# No type check due to the generic explanation object
[docs]
@no_type_check
def plot_permutation_importance(exp: Explanation,
features: Union[List[int], Literal['all']] = 'all',
metric_names: Union[List[Union[str, int]], Literal['all']] = 'all',
n_cols: int = 3,
sort: bool = True,
top_k: Optional[int] = None,
ax: Optional[Union['plt.Axes', np.ndarray]] = None,
bar_kw: Optional[dict] = None,
fig_kw: Optional[dict] = None) -> 'plt.Axes':
"""
Plot permutation feature importance on `matplotlib` axes.
Parameters
----------
exp
An `Explanation` object produced by a call to the
:py:meth:`alibi.explainers.permutation_importance.PermutationImportance.explain` method.
features
A list of feature entries provided in `feature_names` argument to the
:py:meth:`alibi.explainers.permutation_importance.PermutationImportance.explain` method, or
``'all'`` to plot all the explained features. For example, consider that the
``feature_names = ['temp', 'hum', 'windspeed', 'season']``. If we set `features=None` in the `explain` method,
meaning that all the feature were explained, and we want to plot only the values for the ``'temp'`` and
``'windspeed'``, then we would set ``features=[0, 2]``. Otherwise, if we set `features=[1, 2, 3]` in the
explain method, meaning that we explained ``['hum', 'windspeed', 'season']``, and we want to plot the values
only for ``['windspeed', 'season']``, then we would set ``features=[1, 2]`` (i.e., their index in the
`features` list passed to the `explain` method). Defaults to ``'all'``.
metric_names
A list of metric entries in the `exp.data['metrics']` to plot the permutation feature importance for,
or ``'all'`` to plot the permutation feature importance for all metrics (i.e., loss and score functions).
The ordering is given by the concatenation of the loss metrics followed by the score metrics.
n_cols
Number of columns to organize the resulting plot into.
sort
Boolean flag whether to sort the values in descending order.
top_k
Number of top k values to be displayed if the ``sort=True``. If not provided, then all values will be displayed.
ax
A `matplotlib` axes object or a `numpy` array of `matplotlib` axes to plot on.
bar_kw
Keyword arguments passed to the `matplotlib.pyplot.barh`_ function.
fig_kw
Keyword arguments passed to the `matplotlib.figure.set`_ function.
.. _matplotlib.pyplot.barh:
https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.barh.html
.. _matplotlib.figure.set:
https://matplotlib.org/stable/api/figure_api.html
Returns
--------
`plt.Axes` with the feature importance plot.
"""
from matplotlib.gridspec import GridSpec
# define figure arguments
default_fig_kw = {'tight_layout': 'tight'}
if fig_kw is None:
fig_kw = {}
fig_kw = {**default_fig_kw, **fig_kw}
if features == 'all':
features = list(range(len(exp.data['feature_names'])))
metric_names = deepcopy(exp.data['metric_names'] if metric_names == 'all' else metric_names)
# `features` sanity checks
for ifeature in features:
if ifeature >= len(exp.data['feature_names']):
raise IndexError(f"The `features` indices must be less than the "
f"``len(feature_names) = {len(exp.data['feature_names'])}``. Received {ifeature}.")
# construct vector of feature names to display importance for
feature_names = [exp.data['feature_names'][i] for i in features]
# `metric_names` sanity checks
for i, imetric_name in enumerate(metric_names):
if isinstance(imetric_name, str) and (imetric_name not in exp.data['metric_names']):
raise ValueError(f"Unknown metric name. Received {imetric_name}. "
f"Available values are: {exp.data['metric_names']}.")
if isinstance(imetric_name, numbers.Integral):
if imetric_name >= len(exp.data['metric_names']):
raise IndexError(f"Metric name index out of range. Received {imetric_name}. "
f"The number of `metric_names` is {len(exp.data['metric_names'])}")
# convert index to string
metric_names[i] = exp.data['metric_names'][imetric_name]
if ax is None:
fix, ax = plt.subplots()
# number of metrics will correspond to the number of axis
n_metric_names = len(metric_names)
if isinstance(ax, plt.Axes) and n_metric_names != 1:
ax.set_axis_off() # treat passed axis as a canvas for subplots
fig = ax.figure
n_cols = min(n_cols, n_metric_names)
n_rows = math.ceil(n_metric_names / n_cols)
axes = np.empty((n_rows, n_cols), dtype=object)
axes_ravel = axes.ravel()
gs = GridSpec(n_rows, n_cols)
for i, spec in zip(range(n_metric_names), gs):
axes_ravel[i] = fig.add_subplot(spec)
else: # array-like
if isinstance(ax, plt.Axes):
ax = np.array(ax)
if ax.size < n_metric_names:
raise ValueError(f"Expected ax to have {n_metric_names} axes, got {ax.size}")
axes = np.atleast_2d(ax)
axes_ravel = axes.ravel()
fig = axes_ravel[0].figure
for i in range(n_metric_names):
ax = axes_ravel[i]
metric_idx = exp.data['metric_names'].index(metric_names[i])
# define bar plot data
y_labels = feature_names
y_labels = ['(' + ', '.join(y_label) + ')' if isinstance(y_label, tuple) else y_label for y_label in y_labels]
if exp.meta['params']['method'] == Method.EXACT:
width = [exp.data['feature_importance'][metric_idx][j] for j in features]
xerr = None
else:
width = [exp.data['feature_importance'][metric_idx][j]['mean'] for j in features]
xerr = [exp.data['feature_importance'][metric_idx][j]['std'] for j in features]
if sort:
sorted_indices = np.argsort(width)[::-1][:top_k]
width = [width[j] for j in sorted_indices]
y_labels = [y_labels[j] for j in sorted_indices]
if exp.meta['params']['method'] == Method.ESTIMATE:
xerr = [xerr[j] for j in sorted_indices]
y = np.arange(len(width))
default_bar_kw = {'align': 'center'}
bar_kw = default_bar_kw if bar_kw is None else {**default_bar_kw, **bar_kw}
ax.barh(y=y, width=width, xerr=xerr, **bar_kw)
ax.set_yticks(y)
ax.set_yticklabels(y_labels)
ax.invert_yaxis() # labels read top-to-bottom
ax.set_xlabel('Permutation feature importance')
ax.set_title(metric_names[i])
fig.set(**fig_kw)
return axes