Source code for alibi.explainers.anchors.anchor_image

import copy
import logging
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import numpy as np
from skimage.segmentation import felzenszwalb, quickshift, slic

from alibi.api.defaults import DEFAULT_DATA_ANCHOR_IMG, DEFAULT_META_ANCHOR
from alibi.api.interfaces import Explainer, Explanation
from alibi.exceptions import (PredictorCallError,
                              PredictorReturnTypeError)
from alibi.utils.wrappers import ArgmaxTransformer

from .anchor_base import AnchorBaseBeam
from .anchor_explanation import AnchorExplanation

logger = logging.getLogger(__name__)

DEFAULT_SEGMENTATION_KWARGS: Dict[str, Dict] = {
    'felzenszwalb': {},
    'quickshift': {},
    'slic': {'n_segments': 10, 'compactness': 10, 'sigma': .5, 'start_label': 0}
}


[docs] def scale_image(image: np.ndarray, scale: tuple = (0, 255)) -> np.ndarray: """ Scales an image in a specified range. Parameters ---------- image Image to be scale. scale The scaling interval. Returns ------- img_scaled Scaled image. """ img_max, img_min = image.max(), image.min() img_std = (image - img_min) / (img_max - img_min) img_scaled = img_std * (scale[1] - scale[0]) + scale[0] return img_scaled
[docs] class AnchorImageSampler:
[docs] def __init__( self, predictor: Callable, segmentation_fn: Callable, custom_segmentation: bool, image: np.ndarray, images_background: Optional[np.ndarray] = None, p_sample: float = 0.5, n_covered_ex: int = 10, ): """ Initialize anchor image sampler. Parameters ---------- predictor A callable that takes a `numpy` array of `N` data points as inputs and returns `N` outputs. segmentation_fn Function used to segment the images. The segmentation function is expected to return a segmentation mask containing all integer values from `0` to `K-1`, where `K` is the number of image segments (superpixels). image Image to be explained. images_background Images to overlay superpixels on. p_sample Probability for a pixel to be represented by the average value of its superpixel. n_covered_ex How many examples where anchors apply to store for each anchor sampled during search (both examples where prediction on samples agrees/disagrees with `desired_label` are stored). """ self.predictor = predictor self.segmentation_fn = segmentation_fn self.custom_segmentation = custom_segmentation self.image = image self.images_background = images_background self.n_covered_ex = n_covered_ex self.p_sample = p_sample self.segments = self.generate_superpixels(image) self.segment_labels = list(np.unique(self.segments)) self.instance_label = self.predictor(image[np.newaxis, ...])[0]
[docs] def __call__( self, anchor: Tuple[int, tuple], num_samples: int, compute_labels: bool = True ) -> List[Union[np.ndarray, float, int]]: """ Sample images from a perturbation distribution by masking randomly chosen superpixels from the original image and replacing them with pixel values from superimposed images if background images are provided to the explainer. Otherwise, the superpixels from the original image are replaced with their average values. Parameters ---------- anchor - ``int`` - order of anchor in the batch. - ``tuple`` - features (= superpixels) present in the proposed anchor. num_samples Number of samples used. compute_labels If ``True``, an array of comparisons between predictions on perturbed samples and instance to be explained is returned. Returns ------- If ``compute_labels=True``, a list containing the following is returned - `covered_true` - perturbed examples where the anchor applies and the model prediction on perturbed is the \ same as the instance prediction. - `covered_false` - perturbed examples where the anchor applies and the model prediction on pertrurbed sample \ is NOT the same as the instance prediction. - `labels` - `num_samples` ints indicating whether the prediction on the perturbed sample matches (1) \ the label of the instance to be explained or not (0). - `data` - Matrix with 1s and 0s indicating whether the values in a superpixel will remain unchanged (1) or \ will be perturbed (0), for each sample. - `-1.0` - indicates exact coverage is not computed for this algorithm. - `anchor[0]` - position of anchor in the batch request Otherwise, a list containing the data matrix only is returned. """ if compute_labels: raw_data, data = self.perturbation(anchor[1], num_samples) labels = self.compare_labels(raw_data) covered_true = raw_data[labels][: self.n_covered_ex] covered_true = [scale_image(img) for img in covered_true] covered_false = raw_data[np.logical_not(labels)][: self.n_covered_ex] covered_false = [scale_image(img) for img in covered_false] # coverage set to -1.0 as we can't compute 'true' coverage for this model return [covered_true, covered_false, labels.astype(int), data, -1.0, anchor[0]] else: data = self._choose_superpixels(num_samples) data[:, anchor[1]] = 1 # superpixels in candidate anchor are not perturbed return [data]
[docs] def compare_labels(self, samples: np.ndarray) -> np.ndarray: """ Compute the agreement between a classifier prediction on an instance to be explained and the prediction on a set of samples which have a subset of perturbed superpixels. Parameters ---------- samples Samples whose labels are to be compared with the instance label. Returns ------- A boolean array indicating whether the prediction was the same as the instance label. """ return self.predictor(samples) == self.instance_label
def _choose_superpixels( self, num_samples: int, p_sample: float = 0.5 ) -> np.ndarray: """ Generates a binary mask of dimension [num_samples, M] where M is the number of image superpixels (segments). Parameters ---------- num_samples Number of perturbed images to be generated p_sample: The probability that a superpixel is perturbed Returns ------- data Binary 2D mask, where each non-zero entry in a row indicates that the values of the particular image segment will not be perturbed. """ n_features = len(self.segment_labels) data = np.random.choice( [0, 1], num_samples * n_features, p=[p_sample, 1 - p_sample] ) data = data.reshape((num_samples, n_features)) return data
[docs] def perturbation( self, anchor: tuple, num_samples: int ) -> Tuple[np.ndarray, np.ndarray]: """ Perturbs an image by altering the values of selected superpixels. If a dataset of image backgrounds is provided to the explainer, then the superpixels are replaced with the equivalent superpixels from the background image. Otherwise, the superpixels are replaced by their average value. Parameters ---------- anchor: Contains the superpixels whose values are not going to be perturbed. num_samples: Number of perturbed samples to be returned. Returns ------- imgs A `[num_samples, H, W, C]` array of perturbed images. segments_mask A `[num_samples, M]` binary mask, where `M` is the number of image superpixels segments. 1 indicates the values in that particular superpixels are not perturbed. """ image = self.image segments = self.segments backgrounds: Union[np.ndarray, List[None]] # choose superpixels to be perturbed segments_mask = self._choose_superpixels(num_samples, p_sample=self.p_sample) segments_mask[:, anchor] = 1 # for each sample, need to sample one of the background images if provided if self.images_background is not None: backgrounds = np.random.choice( range(len(self.images_background)), segments_mask.shape[0], replace=True, ) else: backgrounds = [None] * segments_mask.shape[0] # create fudged image where the pixel value in each superpixel is set to the # average over the superpixel for each channel fudged_image = image.copy() n_channels = image.shape[-1] for x in np.unique(segments): fudged_image[segments == x] = [ np.mean(image[segments == x][:, i]) for i in range(n_channels) ] pert_imgs = [] for mask, background_idx in zip(segments_mask, backgrounds): temp = copy.deepcopy(image) to_perturb = np.where(mask == 0)[0] # create mask for each superpixel not present in the sample mask = np.zeros(segments.shape).astype(bool) for superpixel in to_perturb: mask[segments == superpixel] = True if background_idx is not None: # replace values with those of background image temp[mask] = self.images_background[background_idx][mask] # type: ignore[index] else: # ... or with the averaged superpixel value temp[mask] = fudged_image[mask] pert_imgs.append(temp) return np.array(pert_imgs), segments_mask
[docs] def generate_superpixels(self, image: np.ndarray) -> np.ndarray: """ Generates superpixels from (i.e., segments) an image. Parameters ---------- image A grayscale or RGB image. Returns ------- A `[H, W]` array of integers. Each integer is a segment (superpixel) label. """ image_preproc = self._preprocess_img(image) return self.segmentation_fn(image_preproc)
def _preprocess_img(self, image: np.ndarray) -> np.ndarray: """ Applies necessary transformations to the image prior to segmentation. Parameters ---------- image A grayscale or RGB image. Returns ------- A preprocessed image. """ # Grayscale images are repeated across channels if not self.custom_segmentation and image.shape[-1] == 1: image_preproc = np.repeat(image, 3, axis=2) else: image_preproc = image.copy() return image_preproc
[docs] class AnchorImage(Explainer):
[docs] def __init__(self, predictor: Callable[[np.ndarray], np.ndarray], image_shape: tuple, dtype: Type[np.generic] = np.float32, segmentation_fn: Any = 'slic', segmentation_kwargs: Optional[dict] = None, images_background: Optional[np.ndarray] = None, seed: Optional[int] = None) -> None: """ Initialize anchor image explainer. Parameters ---------- predictor A callable that takes a `numpy` array of `N` data points as inputs and returns `N` outputs. image_shape Shape of the image to be explained. The channel axis is expected to be last. dtype A `numpy` scalar type that corresponds to the type of input array expected by `predictor`. This may be used to construct arrays of the given type to be passed through the `predictor`. For most use cases this argument should have no effect, but it is exposed for use with predictors that would break when called with an array of unsupported type. segmentation_fn Any of the built in segmentation function strings: ``'felzenszwalb'``, ``'slic'`` or ``'quickshift'`` or a custom segmentation function (callable) which returns an image mask with labels for each superpixel. The segmentation function is expected to return a segmentation mask containing all integer values from `0` to `K-1`, where `K` is the number of image segments (superpixels). See http://scikit-image.org/docs/dev/api/skimage.segmentation.html for more info. segmentation_kwargs Keyword arguments for the built in segmentation functions. images_background Images to overlay superpixels on. seed If set, ensures different runs with the same input will yield same explanation. Raises ------ :py:class:`alibi.exceptions.PredictorCallError` If calling `predictor` fails at runtime. :py:class:`alibi.exceptions.PredictorReturnTypeError` If the return type of `predictor` is not `np.ndarray`. """ super().__init__(meta=copy.deepcopy(DEFAULT_META_ANCHOR)) np.random.seed(seed) # TODO: this logic needs improvement. We should check against a fixed set of strings # for built-ins instead of any `str`. if isinstance(segmentation_fn, str) and segmentation_kwargs is None: try: segmentation_kwargs = DEFAULT_SEGMENTATION_KWARGS[segmentation_fn] except KeyError: logger.warning( 'DEFAULT_SEGMENTATION_KWARGS did not contain any entry' 'for segmentation method {}. No kwargs will be passed to' 'the segmentation function!'.format(segmentation_fn) ) segmentation_kwargs = {} elif callable(segmentation_fn) and segmentation_kwargs: logger.warning( 'Specified both a segmentation function to create superpixels and ' 'keyword arguments for built-in segmentation functions. By default ' 'the specified segmentation function will be used.' ) # set the predictor self.image_shape = tuple(image_shape) # coerce lists self.dtype = dtype self.predictor = self._transform_predictor(predictor) # segmentation function is either a user-defined function or one of the values in fn_options = {'felzenszwalb': felzenszwalb, 'slic': slic, 'quickshift': quickshift} if callable(segmentation_fn): self.custom_segmentation = True self.segmentation_fn = segmentation_fn else: self.custom_segmentation = False self.segmentation_fn = partial(fn_options[segmentation_fn], **segmentation_kwargs) # type: ignore[arg-type] self.images_background = images_background # a superpixel is perturbed with prob 1 - p_sample self.p_sample: float = 0.5 # update metadata self.meta['params'].update( custom_segmentation=self.custom_segmentation, segmentation_kwargs=segmentation_kwargs, p_sample=self.p_sample, seed=seed, image_shape=self.image_shape, images_background=self.images_background ) if not self.custom_segmentation: self.meta['params'].update(segmentation_fn=segmentation_fn) else: self.meta['params'].update(segmentation_fn='custom')
[docs] def generate_superpixels(self, image: np.ndarray) -> np.ndarray: """ Generates superpixels from (i.e., segments) an image. Parameters ---------- image A grayscale or RGB image. Returns ------- A `[H, W]` array of integers. Each integer is a segment (superpixel) label. """ image_preproc = self._preprocess_img(image) return self.segmentation_fn(image_preproc)
def _preprocess_img(self, image: np.ndarray) -> np.ndarray: """ Applies necessary transformations to the image prior to segmentation. Parameters ---------- image A grayscale or RGB image. Returns ------- A preprocessed image. """ # Grayscale images are repeated across channels if not self.custom_segmentation and image.shape[-1] == 1: image_preproc = np.repeat(image, 3, axis=2) else: image_preproc = image.copy() return image_preproc
[docs] def explain(self, # type: ignore[override] image: np.ndarray, p_sample: float = 0.5, threshold: float = 0.95, delta: float = 0.1, tau: float = 0.15, batch_size: int = 100, coverage_samples: int = 10000, beam_size: int = 1, stop_on_first: bool = False, max_anchor_size: Optional[int] = None, min_samples_start: int = 100, n_covered_ex: int = 10, binary_cache_size: int = 10000, cache_margin: int = 1000, verbose: bool = False, verbose_every: int = 1, **kwargs: Any) -> Explanation: """ Explain instance and return anchor with metadata. Parameters ---------- image Image to be explained. p_sample The probability of simulating the absence of a superpixel. If the `images_background` is not provided, the absent superpixels will be replaced by the average value of their constituent pixels. Otherwise, the synthetic instances are created by fixing the present superpixels and superimposing another image from the `images_background` over the rest of the absent superpixels. threshold Minimum anchor precision threshold. The algorithm tries to find an anchor that maximizes the coverage under precision constraint. The precision constraint is formally defined as :math:`P(prec(A) \\ge t) \\ge 1 - \\delta`, where :math:`A` is an anchor, :math:`t` is the `threshold` parameter, :math:`\\delta` is the `delta` parameter, and :math:`prec(\\cdot)` denotes the precision of an anchor. In other words, we are seeking for an anchor having its precision greater or equal than the given `threshold` with a confidence of `(1 - delta)`. A higher value guarantees that the anchors are faithful to the model, but also leads to more computation time. Note that there are cases in which the precision constraint cannot be satisfied due to the quantile-based discretisation of the numerical features. If that is the case, the best (i.e. highest coverage) non-eligible anchor is returned. delta Significance threshold. `1 - delta` represents the confidence threshold for the anchor precision (see `threshold`) and the selection of the best anchor candidate in each iteration (see `tau`). tau Multi-armed bandit parameter used to select candidate anchors in each iteration. The multi-armed bandit algorithm tries to find within a tolerance `tau` the most promising (i.e. according to the precision) `beam_size` candidate anchor(s) from a list of proposed anchors. Formally, when the `beam_size=1`, the multi-armed bandit algorithm seeks to find an anchor :math:`A` such that :math:`P(prec(A) \\ge prec(A^\\star) - \\tau) \\ge 1 - \\delta`, where :math:`A^\\star` is the anchor with the highest true precision (which we don't know), :math:`\\tau` is the `tau` parameter, :math:`\\delta` is the `delta` parameter, and :math:`prec(\\cdot)` denotes the precision of an anchor. In other words, in each iteration, the algorithm returns with a probability of at least `1 - delta` an anchor :math:`A` with a precision within an error tolerance of `tau` from the precision of the highest true precision anchor :math:`A^\\star`. A bigger value for `tau` means faster convergence but also looser anchor conditions. batch_size Batch size used for sampling. The Anchor algorithm will query the black-box model in batches of size `batch_size`. A larger `batch_size` gives more confidence in the anchor, again at the expense of computation time since it involves more model prediction calls. coverage_samples Number of samples used to estimate coverage from during result search. beam_size Number of candidate anchors selected by the multi-armed bandit algorithm in each iteration from a list of proposed anchors. A bigger beam width can lead to a better overall anchor (i.e. prevents the algorithm of getting stuck in a local maximum) at the expense of more computation time. stop_on_first If ``True``, the beam search algorithm will return the first anchor that has satisfies the probability constraint. max_anchor_size Maximum number of features in result. min_samples_start Min number of initial samples. n_covered_ex How many examples where anchors apply to store for each anchor sampled during search (both examples where prediction on samples agrees/disagrees with `desired_label` are stored). binary_cache_size The result search pre-allocates `binary_cache_size` batches for storing the binary arrays returned during sampling. cache_margin When only ``max(cache_margin, batch_size)`` positions in the binary cache remain empty, a new cache of the same size is pre-allocated to continue buffering samples. verbose Display updates during the anchor search iterations. verbose_every Frequency of displayed iterations during anchor search process. Returns ------- explanation `Explanation` object containing the anchor explaining the instance with additional metadata as attributes. See usage at `AnchorImage examples`_ for details. .. _AnchorImage examples: https://docs.seldon.io/projects/alibi/en/stable/methods/Anchors.html """ # get params for storage in meta params = locals() remove = ['image', 'self'] for key in remove: params.pop(key) sampler = AnchorImageSampler( predictor=self.predictor, segmentation_fn=self.segmentation_fn, custom_segmentation=self.custom_segmentation, image=image, images_background=self.images_background, p_sample=p_sample, n_covered_ex=n_covered_ex, ) # get anchors and add metadata mab = AnchorBaseBeam( samplers=[sampler], sample_cache_size=binary_cache_size, cache_margin=cache_margin, **kwargs) result: Any = mab.anchor_beam( desired_confidence=threshold, delta=delta, epsilon=tau, batch_size=batch_size, coverage_samples=coverage_samples, beam_size=beam_size, stop_on_first=stop_on_first, max_anchor_size=max_anchor_size, min_samples_start=min_samples_start, verbose=verbose, verbose_every=verbose_every, **kwargs, ) return self._build_explanation( image, result, sampler.instance_label, params, sampler )
def _build_explanation( self, image: np.ndarray, result: dict, predicted_label: int, params: dict, sampler: AnchorImageSampler, ) -> Explanation: """ Uses the metadata returned by the anchor search algorithm together with the instance to be explained to build an explanation object. Parameters ---------- image Instance to be explained. result Dictionary containing the search anchor and metadata. predicted_label Label of the instance to be explained. params Parameters passed to `:py:meth:alibi.explainers.anchor_image.AnchorImage.explain`. """ result['instance'] = image result['instances'] = np.expand_dims(image, 0) result['prediction'] = np.array([predicted_label]) # overlay image with anchor mask anchor = self.overlay_mask(image, sampler.segments, result['feature']) exp = AnchorExplanation('image', result) # output explanation dictionary data = copy.deepcopy(DEFAULT_DATA_ANCHOR_IMG) data.update( anchor=anchor, segments=sampler.segments, precision=exp.precision(), coverage=exp.coverage(), raw=exp.exp_map ) # create explanation object explanation = Explanation(meta=copy.deepcopy(self.meta), data=data) # params passed to explain explanation.meta['params'].update(params) return explanation
[docs] def overlay_mask(self, image: np.ndarray, segments: np.ndarray, mask_features: list, scale: tuple = (0, 255)) -> np.ndarray: """ Overlay image with mask described by the mask features. Parameters ---------- image Image to be explained. segments Superpixels. mask_features List with superpixels present in mask. scale Pixel scale for masked image. Returns ------- masked_image Image overlaid with mask. """ mask = np.zeros(segments.shape) for f in mask_features: mask[segments == f] = 1 image = scale_image(image, scale=scale) masked_image = (image * np.expand_dims(mask, 2)).astype(int) return masked_image
def _transform_predictor(self, predictor: Callable) -> Callable: # check if predictor returns predicted class or prediction probabilities for each class # if needed adjust predictor so it returns the predicted class x = np.zeros((1,) + self.image_shape, dtype=self.dtype) try: prediction = predictor(x) except Exception as e: msg = f"Predictor failed to be called on {type(x)} of shape {x.shape} and dtype {x.dtype}. " \ f"Check that the parameter `image_shape` is correctly specified." raise PredictorCallError(msg) from e if not isinstance(prediction, np.ndarray): msg = f"Excepted predictor return type to be {np.ndarray} but got {type(prediction)}." raise PredictorReturnTypeError(msg) if np.argmax(prediction.shape) == 0: return predictor else: transformer = ArgmaxTransformer(predictor) return transformer
[docs] def reset_predictor(self, predictor: Callable) -> None: """ Resets the predictor function. Parameters ---------- predictor New predictor function. """ self.predictor = self._transform_predictor(predictor)