alibi.explainers.similarity.grad module

Gradient-based explainer.

This module implements the gradient-based explainers grad-dot and grad-cos.

class alibi.explainers.similarity.grad.GradientSimilarity(predictor, loss_fn, sim_fn='grad_dot', task='classification', precompute_grads=False, backend='tensorflow', device=None, verbose=False)[source]

Bases: BaseSimilarityExplainer

__init__(predictor, loss_fn, sim_fn='grad_dot', task='classification', precompute_grads=False, backend='tensorflow', device=None, verbose=False)[source]

GradientSimilarity explainer.

The gradient similarity explainer is used to find examples in the training data that the predictor considers similar to test instances the user wants to explain. It uses the gradients of the loss between the model output and the training data labels. These are compared using the similarity function specified by sim_fn. The GradientSimilarity explainer can be applied to models trained for both classification and regression tasks.

Parameters:
  • predictor (Union[Model, Module]) – Model to explain.

  • loss_fn (Union[Callable[[Tensor, Tensor], Tensor], Callable[[Tensor, Tensor], Tensor]]) – Loss function used. The gradient of the loss function is used to compute the similarity between the test instances and the training set.

  • sim_fn (Literal['grad_dot', 'grad_cos', 'grad_asym_dot']) – Similarity function to use. The 'grad_dot' similarity function computes the dot product of the gradients, see alibi.explainers.similarity.metrics.dot(). The 'grad_cos' similarity function computes the cosine similarity between the gradients, see alibi.explainers.similarity.metrics.cos(). The 'grad_asym_dot' similarity function is similar to 'grad_dot' but is asymmetric, see alibi.explainers.similarity.metrics.asym_dot().

  • task (Literal['classification', 'regression']) – Type of task performed by the model. If the task is 'classification', the target value passed to the explain method of the test instance can be specified either directly or left as None, if left None we use the model’s maximum prediction. If the task is 'regression', the target value of the test instance must be specified directly.

  • precompute_grads (bool) – Whether to precompute the gradients. If False, gradients are computed on the fly otherwise we precompute them which can be faster when it comes to computing explanations. Note this option may be memory intensive if the model is large.

  • backend (Literal['tensorflow', 'pytorch']) – Backend to use.

  • device (Union[int, str, device, None]) – Device to use. If None, the default device for the backend is used. If using pytorch backend see pytorch device docs for correct options. Note that in the pytorch backend case this parameter can be a torch.device. If using tensorflow backend see tensorflow docs for correct options.

  • verbose (bool) – Whether to print the progress of the explainer.

Raises:
  • ValueError – If the task is not 'classification' or 'regression'.

  • ValueError – If the sim_fn is not 'grad_dot', 'grad_cos' or 'grad_asym_dot'.

  • ValueError – If the backend is not 'tensorflow' or 'pytorch'.

  • TypeError – If the device is not an int, str, torch.device or None for the torch backend option or if the device is not str or None for the tensorflow backend option.

explain(X, Y=None)[source]

Explain the predictor’s predictions for a given input.

Computes the similarity score between the inputs and the training set. Returns an explainer object containing the scores, the indices of the training set instances sorted by descending similarity and the most similar and least similar instances of the data set for the input. Note that the input may be a single instance or a batch of instances.

Parameters:
  • X (Union[ndarray, Tensor, Tensor, Any, List[Any]]) – X can be a numpy array, tensorflow tensor, pytorch tensor of the same shape as the training data or a list of objects, with or without a leading batch dimension. If the batch dimension is missing it’s added.

  • Y (Union[ndarray, Tensor, Tensor, None]) – Y can be a numpy array, tensorflow tensor or a pytorch tensor. In the case of a regression task, the Y argument must be present. If the task is classification then Y defaults to the model prediction.

Return type:

Explanation

Returns:

Explanation object containing the ordered similarity scores for the test instance(s) with additional metadata as attributes. Contains the following data-related attributes –

  • scores: np.ndarray - similarity scores for each pair of instances in the training and test set sorted in descending order.

  • ordered_indices: np.ndarray - indices of the paired training and test set instances sorted by the similarity score in descending order.

  • most_similar: np.ndarray - 5 most similar instances in the training set for each test instance The first element is the most similar instance.

  • least_similar: np.ndarray - 5 least similar instances in the training set for each test instance. The first element is the least similar instance.

Raises:
  • ValueError – If Y is None and the task is 'regression'.

  • ValueError – If the shape of X or Y does not match the shape of the training or target data.

  • ValueError – If the fit method has not been called prior to calling this method.

fit(X_train, Y_train)[source]

Fit the explainer.

The GradientSimilarity explainer requires the model gradients over the training data. In the explain method it compares them to the model gradients for the test instance(s). If precompute_grads=True on initialization then the gradients are precomputed here and stored. This will speed up the explain method call but storing the gradients may not be feasible for large models.

Parameters:
  • X_train (Union[ndarray, List[Any]]) – Training data.

  • Y_train (ndarray) – Training labels.

Return type:

Explainer

Returns:

self – Returns self.

class alibi.explainers.similarity.grad.Task(value)[source]

Bases: str, Enum

Enum of supported tasks.

CLASSIFICATION = 'classification'
REGRESSION = 'regression'
__format__(format_spec)

Returns format using actual value type unless __str__ has been overridden.