Source code for alibi.explainers.similarity.base

from abc import ABC
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union

import numpy as np
from alibi.api.interfaces import Explainer
from alibi.explainers.similarity.backends import _select_backend
from alibi.utils.frameworks import Framework, has_pytorch, has_tensorflow
from alibi.utils.missing_optional_dependency import import_optional
from tqdm import tqdm
from typing_extensions import Literal

_TfTensor = import_optional('tensorflow', ['Tensor'])
_PtTensor = import_optional('torch', ['Tensor'])

if TYPE_CHECKING:
    import tensorflow
    import torch


[docs] class BaseSimilarityExplainer(Explainer, ABC): """Base class for similarity explainers."""
[docs] def __init__(self, predictor: 'Union[tensorflow.keras.Model, torch.nn.Module]', loss_fn: '''Union[Callable[[tensorflow.Tensor, tensorflow.Tensor], tensorflow.Tensor], Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]''', sim_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], precompute_grads: bool = False, backend: Framework = Framework.TENSORFLOW, device: 'Union[int, str, torch.device, None]' = None, meta: Optional[dict] = None, verbose: bool = False, ): """Constructor Parameters ---------- predictor Model to be explained. loss_fn Loss function. sim_fn Similarity function. Takes two inputs and returns a similarity value. precompute_grads Whether to precompute and store the gradients when fitting. backend Deep learning backend. device Device to be used. Will default to the same device the backend defaults to. meta Metadata specific to explainers that inherit from this class. Should be initialized in the child class and passed in here. Is used in the `__init__` of the base Explainer class. """ # Select backend. self.backend = _select_backend(backend) self.backend.set_device(device) # type: ignore self.predictor = predictor self.loss_fn = loss_fn self.sim_fn = sim_fn self.precompute_grads = precompute_grads self.verbose = verbose meta = {} if meta is None else meta super().__init__(meta=meta)
[docs] def fit(self, X_train: Union[np.ndarray, List[Any]], Y_train: np.ndarray) -> "Explainer": """Fit the explainer. If ``self.precompute_grads == True`` then the gradients are precomputed and stored. Parameters ---------- X_train Training data. Y_train Training labels. Returns ------- self Returns self. """ self.X_train = X_train self.Y_train = Y_train self.X_dims = self.X_train.shape[1:] if isinstance(self.X_train, np.ndarray) else None self.Y_dims = self.Y_train.shape[1:] self.grad_X_train = np.array([]) # compute and store gradients if self.precompute_grads: grads = [] X: Union[np.ndarray, List[Any]] for X, Y in tqdm(zip(self.X_train, self.Y_train), disable=not self.verbose): grad_X_train = self._compute_grad(self._format(X), Y[None]) grads.append(grad_X_train[None]) self.grad_X_train = np.concatenate(grads, axis=0) return self
@staticmethod def _is_tensor(x: Any) -> bool: """Checks if an obejct is a tensor.""" if has_tensorflow and isinstance(x, _TfTensor): return True if has_pytorch and isinstance(x, _PtTensor): return True if isinstance(x, np.ndarray): return True return False @staticmethod def _format(x: 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, Any]' ) -> 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, List[Any]]': """Adds batch dimension.""" if BaseSimilarityExplainer._is_tensor(x): return x[None] return [x] def _verify_fit(self) -> None: """Verify that the explainer has been fitted. Raises ------ ValueError If the explainer has not been fitted. """ if not hasattr(self, 'X_train') or not hasattr(self, 'Y_train'): raise ValueError('Training data not set. Call `fit` and pass training data first.') def _match_shape_to_data(self, data: 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, Any, List[Any]]', target_type: Literal['X', 'Y'] ) -> 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, List[Any]]': """ Verify the shape of `data` against the shape of the training data. Used to ensure input is correct shape for gradient methods implemented in the backends. `data` will be the features or label of the instance being explained. If the `data` is not a batch, reshape to be a single batch element. i.e. if training data shape is `(3, 28, 28)` and data shape is `(3, 28, 28)` we reshape data to `(1, 3, 28, 28)`. Parameters ---------- data Data to be matched shape-wise against the training data. target_type Type of data: ``'X'`` | ``'Y'``. Used to determine if data should take the shape of predictor input or predictor output. ``'X'`` will utilize the `X_dims` attribute which stores the shape of the training data. ``'Y'`` will match the shape of `Y_dims` which is the shape of the target data. Raises ------ ValueError If the shape of `data` does not match the shape of the training data, or fit has not been called prior to calling this method. """ if self._is_tensor(data): return self._match_shape_to_data_tensor(data, target_type) return self._match_shape_to_data_any(data) def _match_shape_to_data_tensor(self, data: 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]', target_type: Literal['X', 'Y'] ) -> 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]': """ Verify the shape of `data` against the shape of the training data for tensor like data.""" target_shape = getattr(self, f'{target_type}_dims') if data.shape == target_shape: data = data[None] if data.shape[1:] != target_shape: raise ValueError((f'Input `{target_type}` has shape {data.shape[1:]}' f' but training data has shape {target_shape}')) return data @staticmethod def _match_shape_to_data_any(data: Union[Any, List[Any]]) -> list: """ Ensures that any other data type is a list.""" if isinstance(data, list): return data return [data] def _compute_adhoc_similarity(self, grad_X: np.ndarray) -> np.ndarray: """ Computes the similarity between the gradients of the test instances and all the training instances. The method performs the computation of the gradients of the training instance on the fly without storing them in memory. parameters ---------- grad_X Gradients of the test instances. """ scores = np.zeros((len(grad_X), len(self.X_train))) X: Union[np.ndarray, List[Any]] for i, (X, Y) in tqdm(enumerate(zip(self.X_train, self.Y_train)), disable=not self.verbose): grad_X_train = self._compute_grad(self._format(X), Y[None]) scores[:, i] = self.sim_fn(grad_X, grad_X_train[None])[:, 0] return scores def _compute_grad(self, X: 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, List[Any]]', Y: 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]') \ -> np.ndarray: """Computes predictor parameter gradients and returns a flattened `numpy` array.""" X = self.backend.to_tensor(X) if isinstance(X, np.ndarray) else X Y = self.backend.to_tensor(Y) if isinstance(Y, np.ndarray) else Y return self.backend.get_grads(self.predictor, X, Y, self.loss_fn)
[docs] def reset_predictor(self, predictor: 'Union[tensorflow.keras.Model, torch.nn.Module]') -> None: """Resets the predictor to the given predictor. Parameters ---------- predictor The new predictor to use. """ self.predictor = predictor