Source code for alibi.explainers.anchors.anchor_text

import copy
import logging
import string
from copy import deepcopy
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union)

import numpy as np
import spacy

from alibi.utils.missing_optional_dependency import import_optional
from alibi.api.defaults import DEFAULT_DATA_ANCHOR, DEFAULT_META_ANCHOR
from alibi.api.interfaces import Explainer, Explanation
from alibi.exceptions import (PredictorCallError,
                              PredictorReturnTypeError)

from alibi.utils.wrappers import ArgmaxTransformer
from .anchor_base import AnchorBaseBeam
from .anchor_explanation import AnchorExplanation

from .text_samplers import UnknownSampler, SimilaritySampler, load_spacy_lexeme_prob

LanguageModelSampler = import_optional(
    'alibi.explainers.anchors.language_model_text_sampler',
    names=['LanguageModelSampler'])

if TYPE_CHECKING:
    import spacy  # noqa: F811
    from alibi.utils.lang_model import LanguageModel
else:
    from alibi.utils import LanguageModel

logger = logging.getLogger(__name__)

DEFAULT_SAMPLING_UNKNOWN = {
    "sample_proba": 0.5
}
"""
Default perturbation options for ``'unknown'`` sampling

    - ``'sample_proba'`` : ``float`` - probability of a word to be masked.
"""

DEFAULT_SAMPLING_SIMILARITY = {
    "sample_proba": 0.5,
    "top_n": 100,
    "temperature": 1.0,
    "use_proba": False
}
"""
Default perturbation options for ``'similarity'`` sampling

    - ``'sample_proba'`` : ``float`` - probability of a word to be masked.

    - ``'top_n'`` : ``int`` - number of similar words to sample for perturbations.

    - ``'temperature'`` : ``float`` - sample weight hyper-parameter if `use_proba=True`.

    - ``'use_proba'`` : ``bool`` - whether to sample according to the words similarity.
"""

DEFAULT_SAMPLING_LANGUAGE_MODEL = {
    "filling": "parallel",
    "sample_proba": 0.5,
    "top_n": 100,
    "temperature": 1.0,
    "use_proba": False,
    "frac_mask_templates": 0.1,
    "batch_size_lm": 32,
    "punctuation": string.punctuation,
    "stopwords": [],
    "sample_punctuation": False,
}
"""
Default perturbation options for ``'language_model'`` sampling

    - ``'filling'`` : ``str`` - filling method for language models. Allowed values: ``'parallel'``, \
    ``'autoregressive'``. ``'parallel'`` method corresponds to a single forward pass through the language model. The \
    masked words are sampled independently, according to the selected probability distribution (see `top_n`, \
    `temperature`, `use_proba`). `autoregressive` method fills the words one at the time. This corresponds to \
    multiple forward passes through  the language model which is computationally expensive.

    - ``'sample_proba'`` : ``float`` - probability of a word to be masked.

    - ``'top_n'`` : ``int`` - number of similar words to sample for perturbations.

    - ``'temperature'`` : ``float`` - sample weight hyper-parameter if use_proba equals ``True``.

    - ``'use_proba'`` : ``bool`` - whether to sample according to the predicted words distribution. If set to \
    ``False``, the `top_n` words are sampled uniformly at random.

    - ``'frac_mask_template'`` : ``float`` - fraction from the number of samples of mask templates to be generated. \
    In each sampling call, will generate `int(frac_mask_templates * num_samples)` masking templates. \
    Lower fraction corresponds to lower computation time since the batch fed to the language model is smaller. \
    After the words' distributions is predicted for each mask, a total of `num_samples` will be generated by sampling \
    evenly from each template. Note that lower fraction might correspond to less diverse sample. A `sample_proba=1` \
    corresponds to masking each word. For this case only one masking template will be constructed. \
    A `filling='autoregressive'` will generate `num_samples` masking templates regardless of the value \
    of `frac_mask_templates`.

    - ``'batch_size_lm'`` : ``int`` - batch size used for the language model forward pass.

    - ``'punctuation'`` : ``str`` - string of punctuation not to be masked.

    - ``'stopwords'`` : ``List[str]`` - list of words not to be masked.

    - ``'sample_punctuation'`` : ``bool`` - whether to sample punctuation to fill the masked words. If ``False``, the \
    punctuation defined in `punctuation` will not be sampled.
"""


[docs] class AnchorText(Explainer): # sampling methods SAMPLING_UNKNOWN = 'unknown' #: Unknown sampling strategy. SAMPLING_SIMILARITY = 'similarity' #: Similarity sampling strategy. SAMPLING_LANGUAGE_MODEL = 'language_model' #: Language model sampling strategy. # default params DEFAULTS: Dict[str, Dict] = { SAMPLING_UNKNOWN: DEFAULT_SAMPLING_UNKNOWN, SAMPLING_SIMILARITY: DEFAULT_SAMPLING_SIMILARITY, SAMPLING_LANGUAGE_MODEL: DEFAULT_SAMPLING_LANGUAGE_MODEL, } # class of samplers CLASS_SAMPLER = { SAMPLING_UNKNOWN: UnknownSampler, SAMPLING_SIMILARITY: SimilaritySampler, SAMPLING_LANGUAGE_MODEL: LanguageModelSampler }
[docs] def __init__(self, predictor: Callable[[List[str]], np.ndarray], sampling_strategy: str = 'unknown', nlp: Optional['spacy.language.Language'] = None, language_model: Union['LanguageModel', None] = None, seed: int = 0, **kwargs: Any) -> None: """ Initialize anchor text explainer. Parameters ---------- predictor A callable that takes a list of text strings representing `N` data points as inputs and returns `N` outputs. sampling_strategy Perturbation distribution method: - ``'unknown'`` - replaces words with UNKs. - ``'similarity'`` - samples according to a similarity score with the corpus embeddings. - ``'language_model'`` - samples according the language model's output distributions. nlp `spaCy` object when sampling method is ``'unknown'`` or ``'similarity'``. language_model Transformers masked language model. This is a model that it adheres to the `LanguageModel` interface we define in :py:class:`alibi.utils.lang_model.LanguageModel`. seed If set, ensure identical random streams. kwargs Sampling arguments can be passed as `kwargs` depending on the `sampling_strategy`. Check default arguments defined in: - :py:data:`alibi.explainers.anchor_text.DEFAULT_SAMPLING_UNKNOWN` - :py:data:`alibi.explainers.anchor_text.DEFAULT_SAMPLING_SIMILARITY` - :py:data:`alibi.explainers.anchor_text.DEFAULT_SAMPLING_LANGUAGE_MODEL` Raises ------ :py:class:`alibi.exceptions.PredictorCallError` If calling `predictor` fails at runtime. :py:class:`alibi.exceptions.PredictorReturnTypeError` If the return type of `predictor` is not `np.ndarray`. """ super().__init__(meta=copy.deepcopy(DEFAULT_META_ANCHOR)) self._seed(seed) # set the predictor self.predictor = self._transform_predictor(predictor) # define model which can be either spacy object or LanguageModel # the initialization of the model happens in _validate_kwargs self.model: Union['spacy.language.Language', LanguageModel] #: Language model to be used. # validate kwargs self.perturb_opts, all_opts = self._validate_kwargs(sampling_strategy=sampling_strategy, nlp=nlp, language_model=language_model, **kwargs) # set perturbation self.perturbation: Any = \ self.CLASS_SAMPLER[self.sampling_strategy](self.model, self.perturb_opts) #: Perturbation method. # update metadata self.meta['params'].update(seed=seed) self.meta['params'].update(**all_opts)
def _validate_kwargs(self, sampling_strategy: str, nlp: Optional['spacy.language.Language'] = None, language_model: Optional['LanguageModel'] = None, **kwargs: Any) -> Tuple[dict, dict]: # set sampling method sampling_strategy = sampling_strategy.strip().lower() sampling_strategies = [ self.SAMPLING_UNKNOWN, self.SAMPLING_SIMILARITY, self.SAMPLING_LANGUAGE_MODEL ] # validate sampling method if sampling_strategy not in sampling_strategies: sampling_strategy = self.SAMPLING_UNKNOWN logger.warning(f"Sampling method {sampling_strategy} if not valid. " f"Using the default value `{self.SAMPLING_UNKNOWN}`") if sampling_strategy in [self.SAMPLING_UNKNOWN, self.SAMPLING_SIMILARITY]: if nlp is None: raise ValueError("spaCy model can not be `None` when " f"`sampling_strategy` set to `{sampling_strategy}`.") # set nlp object self.model = load_spacy_lexeme_prob(nlp) else: if language_model is None: raise ValueError("Language model can not be `None` when " f"`sampling_strategy` set to `{sampling_strategy}`") # set language model object self.model = language_model self.model_class = type(language_model).__name__ # set sampling method self.sampling_strategy = sampling_strategy # get default args default_args: dict = self.DEFAULTS[self.sampling_strategy] perturb_opts: dict = deepcopy(default_args) # contains only the perturbation params all_opts = deepcopy(default_args) # contains params + some potential incorrect params # compute common keys allowed_keys = set(perturb_opts.keys()) provided_keys = set(kwargs.keys()) common_keys = allowed_keys & provided_keys # incorrect keys if len(common_keys) < len(provided_keys): incorrect_keys = ", ".join(provided_keys - common_keys) logger.warning("The following keys are incorrect: " + incorrect_keys) # update defaults args and all params perturb_opts.update({key: kwargs[key] for key in common_keys}) all_opts.update(kwargs) return perturb_opts, all_opts
[docs] def sampler(self, anchor: Tuple[int, tuple], num_samples: int, compute_labels: bool = True) -> \ Union[List[Union[np.ndarray, np.ndarray, np.ndarray, np.ndarray, float, int]], List[np.ndarray]]: """ Generate perturbed samples while maintaining features in positions specified in anchor unchanged. Parameters ---------- anchor - ``int`` - the position of the anchor in the input batch. - ``tuple`` - the anchor itself, a list of words to be kept unchanged. num_samples Number of generated perturbed samples. compute_labels If ``True``, an array of comparisons between predictions on perturbed samples and instance to be explained is returned. Returns ------- If ``compute_labels=True``, a list containing the following is returned - `covered_true` - perturbed examples where the anchor applies and the model prediction \ on perturbation is the same as the instance prediction. - `covered_false` - perturbed examples where the anchor applies and the model prediction \ is NOT the same as the instance prediction. - `labels` - num_samples ints indicating whether the prediction on the perturbed sample \ matches (1) the label of the instance to be explained or not (0). - `data` - Matrix with 1s and 0s indicating whether a word in the text has been perturbed for each sample. - `-1.0` - indicates exact coverage is not computed for this algorithm. - `anchor[0]` - position of anchor in the batch request. Otherwise, a list containing the data matrix only is returned. """ raw_data, data = self.perturbation(anchor[1], num_samples) # create labels using model predictions as true labels if compute_labels: labels = self.compare_labels(raw_data) covered_true = raw_data[labels][:self.n_covered_ex] covered_false = raw_data[np.logical_not(labels)][:self.n_covered_ex] # coverage set to -1.0 as we can't compute 'true' coverage for this model return [covered_true, covered_false, labels.astype(int), data, -1.0, anchor[0]] else: return [data]
[docs] def compare_labels(self, samples: np.ndarray) -> np.ndarray: """ Compute the agreement between a classifier prediction on an instance to be explained and the prediction on a set of samples which have a subset of features fixed to a given value (aka compute the precision of anchors). Parameters ---------- samples Samples whose labels are to be compared with the instance label. Returns ------- A `numpy` boolean array indicating whether the prediction was the same as the instance label. """ return self.predictor(samples.tolist()) == self.instance_label
[docs] def explain(self, # type: ignore[override] text: str, threshold: float = 0.95, delta: float = 0.1, tau: float = 0.15, batch_size: int = 100, coverage_samples: int = 10000, beam_size: int = 1, stop_on_first: bool = True, max_anchor_size: Optional[int] = None, min_samples_start: int = 100, n_covered_ex: int = 10, binary_cache_size: int = 10000, cache_margin: int = 1000, verbose: bool = False, verbose_every: int = 1, **kwargs: Any) -> Explanation: """ Explain instance and return anchor with metadata. Parameters ---------- text Text instance to be explained. threshold Minimum anchor precision threshold. The algorithm tries to find an anchor that maximizes the coverage under precision constraint. The precision constraint is formally defined as :math:`P(prec(A) \\ge t) \\ge 1 - \\delta`, where :math:`A` is an anchor, :math:`t` is the `threshold` parameter, :math:`\\delta` is the `delta` parameter, and :math:`prec(\\cdot)` denotes the precision of an anchor. In other words, we are seeking for an anchor having its precision greater or equal than the given `threshold` with a confidence of `(1 - delta)`. A higher value guarantees that the anchors are faithful to the model, but also leads to more computation time. Note that there are cases in which the precision constraint cannot be satisfied due to the quantile-based discretisation of the numerical features. If that is the case, the best (i.e. highest coverage) non-eligible anchor is returned. delta Significance threshold. `1 - delta` represents the confidence threshold for the anchor precision (see `threshold`) and the selection of the best anchor candidate in each iteration (see `tau`). tau Multi-armed bandit parameter used to select candidate anchors in each iteration. The multi-armed bandit algorithm tries to find within a tolerance `tau` the most promising (i.e. according to the precision) `beam_size` candidate anchor(s) from a list of proposed anchors. Formally, when the `beam_size=1`, the multi-armed bandit algorithm seeks to find an anchor :math:`A` such that :math:`P(prec(A) \\ge prec(A^\\star) - \\tau) \\ge 1 - \\delta`, where :math:`A^\\star` is the anchor with the highest true precision (which we don't know), :math:`\\tau` is the `tau` parameter, :math:`\\delta` is the `delta` parameter, and :math:`prec(\\cdot)` denotes the precision of an anchor. In other words, in each iteration, the algorithm returns with a probability of at least `1 - delta` an anchor :math:`A` with a precision within an error tolerance of `tau` from the precision of the highest true precision anchor :math:`A^\\star`. A bigger value for `tau` means faster convergence but also looser anchor conditions. batch_size Batch size used for sampling. The Anchor algorithm will query the black-box model in batches of size `batch_size`. A larger `batch_size` gives more confidence in the anchor, again at the expense of computation time since it involves more model prediction calls. coverage_samples Number of samples used to estimate coverage from during anchor search. beam_size Number of candidate anchors selected by the multi-armed bandit algorithm in each iteration from a list of proposed anchors. A bigger beam width can lead to a better overall anchor (i.e. prevents the algorithm of getting stuck in a local maximum) at the expense of more computation time. stop_on_first If ``True``, the beam search algorithm will return the first anchor that has satisfies the probability constraint. max_anchor_size Maximum number of features to include in an anchor. min_samples_start Number of samples used for anchor search initialisation. n_covered_ex How many examples where anchors apply to store for each anchor sampled during search (both examples where prediction on samples agrees/disagrees with predicted label are stored). binary_cache_size The anchor search pre-allocates `binary_cache_size` batches for storing the boolean arrays returned during sampling. cache_margin When only ``max(cache_margin, batch_size)`` positions in the binary cache remain empty, a new cache of the same size is pre-allocated to continue buffering samples. verbose Display updates during the anchor search iterations. verbose_every Frequency of displayed iterations during anchor search process. **kwargs Other keyword arguments passed to the anchor beam search and the text sampling and perturbation functions. Returns ------- `Explanation` object containing the anchor explaining the instance with additional metadata as attributes. \ Contains the following data-related attributes - `anchor` : ``List[str]`` - a list of words in the proposed anchor. - `precision` : ``float`` - the fraction of times the sampled instances where the anchor holds yields \ the same prediction as the original instance. The precision will always be threshold for a valid anchor. - `coverage` : ``float`` - the fraction of sampled instances the anchor applies to. """ # get params for storage in meta params = locals() remove = ['text', 'self'] for key in remove: params.pop(key) params = deepcopy(params) # Get a reference to itself if not deepcopy for LM sampler # store n_covered_ex positive/negative examples for each anchor self.n_covered_ex = n_covered_ex self.instance_label = self.predictor([text])[0] # set sampler self.perturbation.set_text(text) # get anchors and add metadata mab = AnchorBaseBeam( samplers=[self.sampler], sample_cache_size=binary_cache_size, cache_margin=cache_margin, **kwargs ) result: Any = mab.anchor_beam( delta=delta, epsilon=tau, batch_size=batch_size, desired_confidence=threshold, max_anchor_size=max_anchor_size, min_samples_start=min_samples_start, beam_size=beam_size, coverage_samples=coverage_samples, stop_on_first=stop_on_first, verbose=verbose, verbose_every=verbose_every, **kwargs, ) if self.sampling_strategy == self.SAMPLING_LANGUAGE_MODEL: # take the whole word (this points just to the first part of the word) result['positions'] = [self.perturbation.ids_mapping[i] for i in result['feature']] result['names'] = [ self.perturbation.model.select_word( self.perturbation.head_tokens, idx_feature, self.perturbation.perturb_opts['punctuation'] ) for idx_feature in result['positions'] ] else: result['names'] = [self.perturbation.words[x] for x in result['feature']] result['positions'] = [self.perturbation.positions[x] for x in result['feature']] # set mab self.mab = mab return self._build_explanation(text, result, self.instance_label, params)
def _build_explanation(self, text: str, result: dict, predicted_label: int, params: dict) -> Explanation: """ Uses the metadata returned by the anchor search algorithm together with the instance to be explained to build an explanation object. Parameters ---------- text Instance to be explained. result Dictionary containing the search result and metadata. predicted_label Label of the instance to be explained. Inferred if not received. params Arguments passed to `explain`. """ result['instance'] = text result['instances'] = [text] # TODO: should this be an array? result['prediction'] = np.array([predicted_label]) exp = AnchorExplanation('text', result) # output explanation dictionary data = copy.deepcopy(DEFAULT_DATA_ANCHOR) data.update(anchor=exp.names(), precision=exp.precision(), coverage=exp.coverage(), raw=exp.exp_map) # create explanation object explanation = Explanation(meta=copy.deepcopy(self.meta), data=data) # params passed to explain # explanation.meta['params'].update(params) return explanation def _transform_predictor(self, predictor: Callable) -> Callable: # check if predictor returns predicted class or prediction probabilities for each class # if needed adjust predictor so it returns the predicted class x = ['Hello world'] try: prediction = predictor(x) except Exception as e: msg = f"Predictor failed to be called on x={x}. " \ f"Check that `predictor` works with inputs of type List[str]." raise PredictorCallError(msg) from e if not isinstance(prediction, np.ndarray): msg = f"Excepted predictor return type to be {np.ndarray} but got {type(prediction)}." raise PredictorReturnTypeError(msg) if np.argmax(prediction.shape) == 0: return predictor else: transformer = ArgmaxTransformer(predictor) return transformer
[docs] def reset_predictor(self, predictor: Callable) -> None: """ Resets the predictor function. Parameters ---------- predictor New predictor function. """ self.predictor = self._transform_predictor(predictor)
def _seed(self, seed: int) -> None: np.random.seed(seed) # If LanguageModel is used, we need to set the seed for tf as well. if hasattr(self, 'model') and isinstance(self.model, LanguageModelSampler): self.perturbation.seed(seed)