Source code for alibi.explainers.anchors.text_samplers

import logging
from abc import abstractmethod

from typing import (TYPE_CHECKING, Dict, List, Optional, Tuple, Union)

import numpy as np
import spacy

if TYPE_CHECKING:
    import spacy  # noqa: F811

logger = logging.getLogger(__name__)


[docs] class Neighbors:
[docs] def __init__(self, nlp_obj: 'spacy.language.Language', n_similar: int = 500, w_prob: float = -15.) -> None: """ Initialize class identifying neighbouring words from the embedding for a given word. Parameters ---------- nlp_obj `spaCy` model. n_similar Number of similar words to return. w_prob Smoothed log probability estimate of token's type. """ self.nlp = nlp_obj self.w_prob = w_prob # list with spaCy lexemes in vocabulary # first if statement is a workaround due to some missing keys in models: # https://github.com/SeldonIO/alibi/issues/275#issuecomment-665017691 self.to_check = [self.nlp.vocab[w] for w in self.nlp.vocab.vectors if int(w) in self.nlp.vocab.strings and # type: ignore[operator] self.nlp.vocab[w].prob >= self.w_prob] self.n_similar = n_similar
[docs] def neighbors(self, word: str, tag: str, top_n: int) -> dict: """ Find similar words for a certain word in the vocabulary. Parameters ---------- word Word for which we need to find similar words. tag Part of speech tag for the words. top_n Return only `top_n` neighbors. Returns ------- A dict with two fields. The ``'words'`` field contains a `numpy` array of the `top_n` most similar words, \ whereas the fields ``'similarities'`` is a `numpy` array with corresponding word similarities. """ # the word itself is excluded so we add one to return the expected number of words top_n += 1 texts: List = [] similarities: List = [] if word in self.nlp.vocab: word_vocab = self.nlp.vocab[word] queries = [w for w in self.to_check if w.is_lower == word_vocab.is_lower] if word_vocab.prob < self.w_prob: queries += [word_vocab] by_similarity = sorted(queries, key=lambda w: word_vocab.similarity(w), reverse=True)[:self.n_similar] # Find similar words with the same part of speech for lexeme in by_similarity: # because we don't add the word itself anymore if len(texts) == top_n - 1: break token = self.nlp(lexeme.orth_)[0] if token.tag_ != tag or token.text == word: continue texts.append(token.text) similarities.append(word_vocab.similarity(lexeme)) words = np.array(texts) if texts else np.array(texts, dtype='<U') return {'words': words, 'similarities': np.array(similarities)}
[docs] def load_spacy_lexeme_prob(nlp: 'spacy.language.Language') -> 'spacy.language.Language': """ This utility function loads the `lexeme_prob` table for a spacy model if it is not present. This is required to enable support for different spacy versions. """ import spacy SPACY_VERSION = spacy.__version__.split('.') MAJOR, MINOR = int(SPACY_VERSION[0]), int(SPACY_VERSION[1]) if MAJOR == 2: if MINOR < 3: return nlp elif MINOR == 3: # spacy 2.3.0 moved lexeme_prob into a different package `spacy_lookups_data` # https://github.com/explosion/spaCy/issues/5638 try: table = nlp.vocab.lookups_extra.get_table('lexeme_prob') # type: ignore[attr-defined] # remove the default empty table if table == dict(): nlp.vocab.lookups_extra.remove_table('lexeme_prob') # type: ignore[attr-defined] except KeyError: pass finally: # access the `prob` of any word to load the full table assert nlp.vocab["a"].prob != -20.0, f"Failed to load the `lexeme_prob` table for model {nlp}" elif MAJOR >= 3: # in spacy 3.x we need to manually add the tables # https://github.com/explosion/spaCy/discussions/6388#discussioncomment-331096 if 'lexeme_prob' not in nlp.vocab.lookups.tables: from spacy.lookups import load_lookups lookups = load_lookups(nlp.lang, ['lexeme_prob']) # type: ignore[arg-type] nlp.vocab.lookups.add_table('lexeme_prob', lookups.get_table('lexeme_prob')) return nlp
[docs] class AnchorTextSampler:
[docs] @abstractmethod def set_text(self, text: str) -> None: pass
@abstractmethod def __call__(self, anchor: tuple, num_samples: int) -> Tuple[np.ndarray, np.ndarray]: pass def _joiner(self, arr: np.ndarray, dtype: Optional[str] = None) -> np.ndarray: """ Function to concatenate a `numpy` array of strings along a specified axis. Parameters ---------- arr 1D `numpy` array of strings. dtype Array type, used to avoid truncation of strings when concatenating along axis. Returns ------- Array with one element, the concatenation of the strings in the input array. """ if not dtype: return np.array(' '.join(arr)) return np.array(' '.join(arr)).astype(dtype)
[docs] class UnknownSampler(AnchorTextSampler): UNK: str = "UNK" #: Unknown token to be used.
[docs] def __init__(self, nlp: 'spacy.language.Language', perturb_opts: Dict): """ Initialize unknown sampler. This sampler replaces word with the `UNK` token. Parameters ---------- nlp `spaCy` object. perturb_opts Perturbation options. """ super().__init__() # set nlp and perturbation options self.nlp = load_spacy_lexeme_prob(nlp) self.perturb_opts: Union[Dict, None] = perturb_opts # define buffer for word, punctuation and position self.words: List = [] self.punctuation: List = [] self.positions: List = []
[docs] def set_text(self, text: str) -> None: """ Sets the text to be processed. Parameters ---------- text Text to be processed. """ # process text processed = self.nlp(text) # spaCy tokens for text self.words = [x.text for x in processed] # list with words in text self.positions = [x.idx for x in processed] # positions of words in text self.punctuation = [x for x in processed if x.is_punct] # list with punctuation in text # set dtype self.set_data_type()
[docs] def __call__(self, anchor: tuple, num_samples: int) -> Tuple[np.ndarray, np.ndarray]: """ The function returns a `numpy` array of `num_samples` where randomly chosen features, except those in anchor, are replaced by ``'UNK'`` token. Parameters ---------- anchor: Indices represent the positions of the words to be kept unchanged. num_samples: Number of perturbed sentences to be returned. Returns ------- raw Array containing num_samples elements. Each element is a perturbed sentence. data A `(num_samples, m)`-dimensional boolean array, where `m` is the number of tokens in the instance to be explained. """ assert self.perturb_opts, "Perturbation options are not set." # allocate memory for the binary mask and the perturbed instances data = np.ones((num_samples, len(self.words))) raw = np.zeros((num_samples, len(self.words)), self.dtype) # fill each row of the raw data matrix with the text instance to be explained raw[:] = self.words for i, t in enumerate(self.words): # do not perturb words that are in anchor if i in anchor: continue # sample the words in the text outside of the anchor that are replaced with UNKs n_changed = np.random.binomial(num_samples, self.perturb_opts['sample_proba']) changed = np.random.choice(num_samples, n_changed, replace=False) raw[changed, i] = UnknownSampler.UNK data[changed, i] = 0 # join the words raw = np.apply_along_axis(self._joiner, axis=1, arr=raw, dtype=self.dtype) return raw, data
[docs] def set_data_type(self) -> None: """ Working with `numpy` arrays of strings requires setting the data type to avoid truncating examples. This function estimates the longest sentence expected during the sampling process, which is used to set the number of characters for the samples and examples arrays. This depends on the perturbation method used for sampling. """ max_len = max(len(self.UNK), len(max(self.words, key=len))) max_sent_len = len(self.words) * max_len + len(self.UNK) * len(self.punctuation) + 1 self.dtype = '<U' + str(max_sent_len)
[docs] class SimilaritySampler(AnchorTextSampler):
[docs] def __init__(self, nlp: 'spacy.language.Language', perturb_opts: Dict): """ Initialize similarity sampler. This sampler replaces words with similar words. Parameters ---------- nlp `spaCy` object. perturb_opts Perturbation options. """ super().__init__() # set nlp and perturbation options self.nlp = load_spacy_lexeme_prob(nlp) self.perturb_opts = perturb_opts # define synonym generator self._synonyms_generator = Neighbors(self.nlp) # dict containing an np.array of similar words with same part of speech and an np.array of similarities self.synonyms: Dict[str, Dict[str, np.ndarray]] = {} self.tokens: 'spacy.tokens.Doc' self.words: List[str] = [] self.positions: List[int] = [] self.punctuation: List['spacy.tokens.Token'] = []
[docs] def set_text(self, text: str) -> None: """ Sets the text to be processed Parameters ---------- text Text to be processed. """ processed = self.nlp(text) # spaCy tokens for text self.words = [x.text for x in processed] # list with words in text self.positions = [x.idx for x in processed] # positions of words in text self.punctuation = [x for x in processed if x.is_punct] # punctuation in text self.tokens = processed # find similar words self.find_similar_words() # set dtype self.set_data_type()
[docs] def find_similar_words(self) -> None: """ This function queries a `spaCy` nlp model to find `n` similar words with the same part of speech for each word in the instance to be explained. For each word the search procedure returns a dictionary containing a `numpy` array of words (``'words'``) and a `numpy` array of word similarities (``'similarities'``). """ for word, token in zip(self.words, self.tokens): if word not in self.synonyms: self.synonyms[word] = self._synonyms_generator.neighbors(word, token.tag_, self.perturb_opts['top_n'])
[docs] def __call__(self, anchor: tuple, num_samples: int) -> Tuple[np.ndarray, np.ndarray]: """ The function returns a `numpy` array of `num_samples` where randomly chosen features, except those in anchor, are replaced by similar words with the same part of speech of tag. See :py:meth:`alibi.explainers.anchors.text_samplers.SimilaritySampler.perturb_sentence_similarity` for details of how the replacement works. Parameters ---------- anchor: Indices represent the positions of the words to be kept unchanged. num_samples: Number of perturbed sentences to be returned. Returns ------- See :py:meth:`alibi.explainers.anchors.text_samplers.SimilaritySampler.perturb_sentence_similarity`. """ assert self.perturb_opts, "Perturbation options are not set." return self.perturb_sentence_similarity(anchor, num_samples, **self.perturb_opts)
[docs] def perturb_sentence_similarity(self, present: tuple, n: int, sample_proba: float = 0.5, forbidden: frozenset = frozenset(), forbidden_tags: frozenset = frozenset(['PRP$']), forbidden_words: frozenset = frozenset(['be']), temperature: float = 1., pos: frozenset = frozenset(['NOUN', 'VERB', 'ADJ', 'ADV', 'ADP', 'DET']), use_proba: bool = False, **kwargs) -> Tuple[np.ndarray, np.ndarray]: """ Perturb the text instance to be explained. Parameters ---------- present Word index in the text for the words in the proposed anchor. n Number of samples used when sampling from the corpus. sample_proba Sample probability for a word if `use_proba=False`. forbidden Forbidden lemmas. forbidden_tags Forbidden POS tags. forbidden_words Forbidden words. pos POS that can be changed during perturbation. use_proba Bool whether to sample according to a similarity score with the corpus embeddings. temperature Sample weight hyper-parameter if ``use_proba=True``. **kwargs Other arguments. Not used. Returns ------- raw Array of perturbed text instances. data Matrix with 1s and 0s indicating whether a word in the text has not been perturbed for each sample. """ # allocate memory for the binary mask and the perturbed instances raw = np.zeros((n, len(self.tokens)), self.dtype) data = np.ones((n, len(self.tokens))) # fill each row of the raw data matrix with the text to be explained raw[:] = [x.text for x in self.tokens] for i, t in enumerate(self.tokens): # apply sampling to each token # if the word is part of the anchor, move on to next token if i in present: continue # check that token does not fall in any forbidden category if (t.text not in forbidden_words and t.pos_ in pos and t.lemma_ not in forbidden and t.tag_ not in forbidden_tags): t_neighbors = self.synonyms[t.text]['words'] # no neighbours with the same tag or word not in spaCy vocabulary if t_neighbors.size == 0: continue n_changed = np.random.binomial(n, sample_proba) changed = np.random.choice(n, n_changed, replace=False) if use_proba: # use similarity scores to sample changed tokens weights = self.synonyms[t.text]['similarities'] weights = np.exp(weights / temperature) # weighting by temperature (check previous implementation) weights = weights / sum(weights) else: weights = np.ones((t_neighbors.shape[0],)) weights /= t_neighbors.shape[0] raw[changed, i] = np.random.choice(t_neighbors, n_changed, p=weights, replace=True) data[changed, i] = 0 raw = np.apply_along_axis(self._joiner, axis=1, arr=raw, dtype=self.dtype) return raw, data
[docs] def set_data_type(self) -> None: """ Working with `numpy` arrays of strings requires setting the data type to avoid truncating examples. This function estimates the longest sentence expected during the sampling process, which is used to set the number of characters for the samples and examples arrays. This depends on the perturbation method used for sampling. """ max_len = 0 max_sent_len = 0 for word in self.words: similar_words = self.synonyms[word]['words'] max_len = max(max_len, int(similar_words.dtype.itemsize / np.dtype(similar_words.dtype.char + '1').itemsize)) max_sent_len += max_len self.dtype = '<U' + str(max_sent_len)