Source code for alibi.explainers.anchors.language_model_text_sampler

import string
from functools import partial

from typing import (Dict, List, Optional, Tuple)

import numpy as np
import tensorflow as tf

from alibi.utils.lang_model import LanguageModel
from alibi.explainers.anchors.text_samplers import AnchorTextSampler


[docs] class LanguageModelSampler(AnchorTextSampler): # filling procedures FILLING_PARALLEL: str = 'parallel' #: Parallel filling procedure. FILLING_AUTOREGRESSIVE = 'autoregressive' #: Autoregressive filling procedure. Considerably slow.
[docs] def __init__(self, model: LanguageModel, perturb_opts: dict, ): """ Initialize language model sampler. This sampler replaces words with the ones sampled according to the output distribution of the language model. There are two modes to use the sampler: ``'parallel'`` and ``'autoregressive'``. In the ``'parallel'`` mode, all words are replaced simultaneously. In the ``'autoregressive'`` model, the words are replaced one by one, starting from left to right. Thus the following words are conditioned on the previous predicted words. Parameters ---------- model Transformers masked language model. perturb_opts Perturbation options. """ super().__init__() # set language model and perturbation options self.model = model self.perturb_opts = perturb_opts # Define language model's vocab vocab: Dict[str, int] = self.model.tokenizer.get_vocab() # Define masking sampling tensor. This tensor is used to avoid sampling # certain tokens from the vocabulary such as: subwords, punctuation, etc. self.subwords_mask = np.zeros(len(vocab.keys()), dtype=np.bool_) for token in vocab: # Add subwords in the sampling mask. This means that subwords # will not be considered when sampling for the masked words. if self.model.is_subword_prefix(token): self.subwords_mask[vocab[token]] = True continue # Add punctuation in the sampling mask. This means that the # punctuation will not be considered when sampling for the masked words. sample_punctuation: bool = perturb_opts.get('sample_punctuation', False) punctuation: str = perturb_opts.get('punctuation', string.punctuation) if (not sample_punctuation) and self.model.is_punctuation(token, punctuation): self.subwords_mask[vocab[token]] = True # define head, tail part of the text self.head: str = '' self.tail: str = '' self.head_tokens: List[str] = [] self.tail_tokens: List[str] = []
[docs] def get_sample_ids(self, punctuation: str = string.punctuation, stopwords: Optional[List[str]] = None, **kwargs) -> None: """ Find indices in words which can be perturbed. Parameters ---------- punctuation String of punctuation characters. stopwords List of stopwords. **kwargs Other arguments. Not used. """ # transform stopwords to lowercase if stopwords: stopwords = [w.lower().strip() for w in stopwords] # Initialize list of indices allowed to be perturbed ids_sample = list(np.arange(len(self.head_tokens))) # Define partial function for stopwords checking is_stop_word = partial( self.model.is_stop_word, tokenized_text=self.head_tokens, punctuation=punctuation, stopwords=stopwords ) # lambda expressions to check for a subword subword_cond = lambda token, idx: self.model.is_subword_prefix(token) # noqa: E731 # lambda experssion to check for a stopword stopwords_cond = lambda token, idx: is_stop_word(start_idx=idx) # noqa: E731 # lambda expression to check for punctuation punctuation_cond = lambda token, idx: self.model.is_punctuation(token, punctuation) # noqa: E731 # Gather all in a list of conditions conds = [punctuation_cond, stopwords_cond, subword_cond] # Remove indices of the tokens that are not allowed to be masked for i, token in enumerate(self.head_tokens): if any([cond(token, i) for cond in conds]): ids_sample.remove(i) # Save the indices allowed to be masked and the corresponding mapping. # The anchor base algorithm alters indices one by one. By saving the mapping # and sending only the initial token of a word, we avoid unnecessary sampling. # E.g. word = token1 ##token2. Instead of trying two anchors (1 0), (1, 1) - which are # equivalent because we take the full word, just try one (1) self.ids_sample = np.array(ids_sample) self.ids_mapping = {i: id for i, id in enumerate(self.ids_sample)}
[docs] def set_text(self, text: str) -> None: """ Sets the text to be processed Parameters ---------- text Text to be processed. """ # Some language models can only work with a limited number of tokens. Thus the text needs # to be split in head_text and tail_text. We will only alter the head_tokens. self.head, self.tail, self.head_tokens, self.tail_tokens = self.model.head_tail_split(text) # define indices of the words which can be perturbed self.get_sample_ids(**self.perturb_opts) # Set dtypes self.set_data_type()
[docs] def __call__(self, anchor: tuple, num_samples: int) -> Tuple[np.ndarray, np.ndarray]: """ The function returns a `numpy` array of `num_samples` where randomly chosen features, except those in anchor, are replaced by words sampled according to the language model's predictions. Parameters ---------- anchor: Indices represent the positions of the words to be kept unchanged. num_samples: Number of perturbed sentences to be returned. Returns ------- See :py:meth:`alibi.explainers.anchors.language_model_text_sampler.LanguageModelSampler.perturb_sentence`. """ assert self.perturb_opts, "Perturbation options are not set." return self.perturb_sentence(anchor, num_samples, **self.perturb_opts)
[docs] def perturb_sentence(self, anchor: tuple, num_samples: int, sample_proba: float = .5, top_n: int = 100, batch_size_lm: int = 32, filling: str = "parallel", **kwargs) -> Tuple[np.ndarray, np.ndarray]: """ The function returns an `numpy` array of `num_samples` where randomly chosen features, except those in anchor, are replaced by words sampled according to the language model's predictions. Parameters ---------- anchor: Indices represent the positions of the words to be kept unchanged. num_samples: Number of perturbed sentences to be returned. sample_proba: Probability of a token being replaced by a similar token. top_n: Used for top n sampling. batch_size_lm: Batch size used for language model. filling: Method to fill masked words. Either ``'parallel'`` or ``'autoregressive'``. **kwargs Other arguments to be passed to other methods. Returns ------- raw Array containing `num_samples` elements. Each element is a perturbed sentence. data A `(num_samples, m)`-dimensional boolean array, where `m` is the number of tokens in the instance to be explained. """ # Create the mask raw, data = self.create_mask( anchor=anchor, num_samples=num_samples, sample_proba=sample_proba, filling=filling, **kwargs ) # If the anchor does not cover the entire sentence, # then fill in mask with language model if len(anchor) != len(self.ids_sample): raw, data = self.fill_mask( raw=raw, data=data, num_samples=num_samples, top_n=top_n, batch_size_lm=batch_size_lm, filling=filling, **kwargs ) # append tail if it exits raw = self._append_tail(raw) if self.tail else raw return raw, data
[docs] def create_mask(self, anchor: tuple, num_samples: int, sample_proba: float = 1.0, filling: str = 'parallel', frac_mask_templates: float = 0.1, **kwargs) -> Tuple[np.ndarray, np.ndarray]: """ Create mask for words to be perturbed. Parameters ---------- anchor Indices represent the positions of the words to be kept unchanged. num_samples Number of perturbed sentences to be returned. sample_proba Probability of a word being replaced. filling: Method to fill masked words. Either ``'parallel'`` or ``'autoregressive'``. frac_mask_templates Fraction of mask templates from the number of requested samples. **kwargs Other arguments to be passed to other methods. Returns ------- raw Array with masked instances. data A `(num_samples, m)`-dimensional boolean array, where `m` is the number of tokens in the instance to be explained. """ # make sure that frac_mask_templates is in [0, 1] frac_mask_templates = np.clip(frac_mask_templates, 0, 1).item() # compute indices allowed be masked all_indices = range(len(self.ids_sample)) allowed_indices = list(set(all_indices) - set(anchor)) if len(allowed_indices) == 0 or filling == self.FILLING_AUTOREGRESSIVE: # If the anchor covers all the words that can be perturbed (it can happen) # then the number of mask_templates should be equal to the number of sampled requested. # If the filling is autoregressive, just generate from the start a `num_sample` # masks, cause the computation performance is pretty similar. mask_templates = num_samples else: # If the probability of sampling a word is 1, then all words will be masked. # Thus there is no point in generating more than one mask. # Otherwise compute the number of masking templates according to the fraction # passed as argument and make sure that at least one mask template is generated mask_templates = 1 if np.isclose(sample_proba, 1) else max(1, int(num_samples * frac_mask_templates)) # allocate memory data = np.ones((mask_templates, len(self.ids_sample)), dtype=np.int32) raw = np.zeros((mask_templates, len(self.head_tokens)), dtype=self.dtype_token) # fill each row of the raw data matrix with the text instance to be explained raw[:] = self.head_tokens # create mask if len(allowed_indices): for i in range(mask_templates): # Here the sampling of the indices of the word to be masked is done by rows # and not by columns as in the other sampling methods. The reason is that # is much easier to ensure that at least one word in the sentence is masked. # If the sampling is performed over the columns it might be the case # that no word in a sentence will be masked. n_changed = max(1, np.random.binomial(len(allowed_indices), sample_proba)) changed = np.random.choice(allowed_indices, n_changed, replace=False) # mark the entrance as maks data[i, changed] = 0 # Mask the corresponding words. This requires a mapping from indices # to the actual position of the words in the text changed_mapping = [self.ids_mapping[j] for j in changed] raw[i, changed_mapping] = self.model.mask # Have to remove the subwords of the masked word, which has to be done iteratively for j in changed_mapping: self._remove_subwords(raw=raw, row=i, col=j, **kwargs) # join words raw = np.apply_along_axis(self._joiner, axis=1, arr=raw, dtype=self.dtype_sent) return raw, data
def _append_tail(self, raw: np.ndarray) -> np.ndarray: """ Appends the tail part of the text to the new sampled head. Parameters ---------- raw New sampled heads. Returns ------- full_raw Concatenation of the new sampled head with the original tail. """ full_raw = [] for i in range(raw.shape[0]): new_head_tokens = self.model.tokenizer.tokenize(raw[i]) new_tokens = new_head_tokens + self.tail_tokens full_raw.append(self.model.tokenizer.convert_tokens_to_string(new_tokens)) # convert to array and return return np.array(full_raw, dtype=self.dtype_sent) def _joiner(self, arr: np.ndarray, dtype: Optional[str] = None) -> np.ndarray: """ Function to concatenate an `numpy` array of strings along a specified axis. Parameters ---------- arr 1D `numpy` array of strings. dtype Array type, used to avoid truncation of strings when concatenating along axis. Returns ------- Array with one element, the concatenation of the strings in the input array. """ filtered_arr = list(filter(lambda x: len(x) > 0, arr)) str_arr = self.model.tokenizer.convert_tokens_to_string(filtered_arr) if not dtype: return np.array(str_arr) return np.array(str_arr).astype(dtype)
[docs] def fill_mask(self, raw: np.ndarray, data: np.ndarray, num_samples: int, top_n: int = 100, batch_size_lm: int = 32, filling: str = "parallel", **kwargs) -> Tuple[np.ndarray, np.ndarray]: """ Fill in the masked tokens with language model. Parameters ---------- raw Array of mask templates. data Binary mask having 0 where the word was masked. num_samples Number of samples to be drawn. top_n: Use the top n words when sampling. batch_size_lm: Batch size used for language model. filling Method to fill masked words. Either ``'parallel'`` or ``'autoregressive'``. **kwargs Other paremeters to be passed to other methods. Returns ------- raw Array containing `num_samples` elements. Each element is a perturbed sentence. """ # chose the perturbation function perturb_func = self._perturb_instances_parallel if filling == self.FILLING_PARALLEL \ else self._perturb_instance_ar # perturb instances tokens, data = perturb_func(raw=raw, data=data, num_samples=num_samples, batch_size_lm=batch_size_lm, top_n=top_n, **kwargs) # decode the tokens and remove special characters as <pad>, <cls> etc. raw = self.model.tokenizer.batch_decode(tokens, skip_special_tokens=True) return np.array(raw), data
def _remove_subwords(self, raw: np.ndarray, row: int, col: int, punctuation: str = '', **kwargs) -> np.ndarray: """ Deletes the subwords that follow a given token identified by the `(row, col)` pair in the `raw` matrix. A token is considered to be part of a word if is not a punctuation and if has the subword prefix specific to the used language model. The subwords are not actually deleted in, but they are replace by the empty string ``''``. Parameters ---------- raw Array of tokens. row Row coordinate of the word to be removed. col Column coordinate of the word to be removed. punctuation String containing the punctuation to be considered. Returns ------- raw Array of tokens where deleted subwords are replaced by the empty string. """ for next_col in range(col + 1, len(self.head_tokens)): # if encounter a punctuation, just stop if self.model.is_punctuation(raw[row, next_col], punctuation): break # if it is a subword prefix, then replace it by empty string if self.model.is_subword_prefix(raw[row, next_col]): raw[row, next_col] = '' else: break return raw def _perturb_instances_parallel(self, num_samples: int, raw: np.ndarray, data: np.ndarray, top_n: int = 100, batch_size_lm: int = 32, temperature: float = 1.0, use_proba: bool = False, **kwargs) -> Tuple[np.ndarray, np.ndarray]: """ Perturb the instances in a single forward pass (parallel). Parameters ---------- num_samples Number of samples to be generated raw Array of mask templates. Has `mask_templates` rows. data Binary array having 0 where the tokens are masked. Has `mask_templates` rows. top_n: Use the top n words when sampling. batch_size_lm: Batch size used for language model. temperature Sample weight hyper-parameter. use_proba Bool whether to sample according to the predicted words distribution **kwargs Other arguments. Not used. Returns ------- sampled_tokens Array containing the ids of the sampled tokens. Has `num_samples` rows. sampled_data Binary array having 0 where the tokens were masked. Has `num_samples` rows. """ # tokenize instances tokens_plus = self.model.tokenizer.batch_encode_plus(list(raw), padding=True, return_tensors='tf') # number of samples to generate per mask template remainder = num_samples % len(raw) mult_factor = num_samples // len(raw) # fill in masks with language model # (mask_template x max_length_sentence x num_tokens) logits = self.model.predict_batch_lm(x=tokens_plus, vocab_size=self.model.tokenizer.vocab_size, batch_size=batch_size_lm) # select rows and cols where the input the tokens are masked tokens = tokens_plus['input_ids'] # (mask_template x max_length_sentence) mask_pos = tf.where(tokens == self.model.mask_id) mask_row, mask_col = mask_pos[:, 0], mask_pos[:, 1] # buffer containing sampled tokens sampled_tokens = np.zeros((num_samples, tokens.shape[1]), dtype=np.int32) sampled_data = np.zeros((num_samples, data.shape[1])) for i in range(logits.shape[0]): # select indices corresponding to the current row `i` idx = tf.reshape(tf.where(mask_row == i), shape=-1) # select columns corresponding to the current row `i` cols = tf.gather(mask_col, idx) # select the logits of the masked input logits_mask = logits[i, cols, :] # mask out tokens according to the subword_mask logits_mask[:, self.subwords_mask] = -np.inf # select top n tokens from each distribution top_k = tf.math.top_k(logits_mask, top_n) top_k_logits, top_k_tokens = top_k.values, top_k.indices top_k_logits = (top_k_logits / temperature) if use_proba else (top_k_logits * 0) # sample `num_samples` instance for the current mask template for j in range(mult_factor + int(i < remainder)): # Compute the buffer index idx = i * mult_factor + j + min(i, remainder) # Sample indices ids_k = tf.reshape(tf.random.categorical(top_k_logits, 1), shape=-1) # Set the unmasked tokens and for the masked one and replace them with the samples drawn sampled_tokens[idx] = tokens[i] sampled_tokens[idx, cols] = tf.gather(top_k_tokens, ids_k, batch_dims=1) # Add the original binary mask which marks the beginning of a masked # word, as is needed for the anchor algorithm (backend stuff) idx, offset = i * mult_factor, min(i, remainder) sampled_data[idx + offset:idx + mult_factor + offset + (i < remainder)] = data[i] # Check that there are not masked tokens left assert np.all(sampled_tokens != self.model.mask_id) assert np.all(np.any(sampled_tokens != 0, axis=1)) return sampled_tokens, sampled_data def _perturb_instance_ar(self, num_samples: int, raw: np.ndarray, data: np.ndarray, top_n: int = 100, batch_size: int = 32, temperature: float = 1.0, use_proba: bool = False, **kwargs) -> Tuple[np.ndarray, np.ndarray]: """ Perturb the instances in an autoregressive fashion (sequential). Parameters ---------- num_samples Number of samples to be generated. raw Array of mask templates. Has `mask_templates` rows. data Binary array having 0 where the tokens are masked. Has `mask_templates` rows. top_n: Use the top n words when sampling. batch_size_lm: Batch size used for language model. temperature Sample weight hyper-parameter. use_proba Bool whether to sample according to the predicted words distribution. **kwargs Other arguments. Not used. Returns ------- sampled_tokens Array containing the ids of the sampled tokens. Has `num_samples` rows. sampled_data Binary array having 0 where the tokens were masked. Has `num_samples` rows. """ # number of samples to generate per mask template assert num_samples == raw.shape[0] # tokenize instances tokens_plus = self.model.tokenizer.batch_encode_plus(list(raw), padding=True, return_tensors='tf') tokens = tokens_plus['input_ids'].numpy() # (mask_template x max_length_sentence) # store the column indices for each row where a token is a mask masked_idx = [] max_len_idx = -1 mask_pos = tf.where(tokens == self.model.mask_id) mask_row, mask_col = mask_pos[:, 0], mask_pos[:, 1] for i in range(tokens.shape[0]): # get the columns indexes and store them in the buffer idx = tf.reshape(tf.where(mask_row == i), shape=-1) cols = tf.gather(mask_col, idx) masked_idx.append(cols) # update maximum length max_len_idx = max(max_len_idx, len(cols)) # iterate through all possible columns indexes for i in range(max_len_idx): masked_rows, masked_cols = [], [] # iterate through all possible examples for row in range(tokens.shape[0]): # this means that the row does not have any more masked columns if len(masked_idx[row]) <= i: continue masked_rows.append(row) masked_cols.append(masked_idx[row][i]) # compute logits logits = self.model.predict_batch_lm(x=tokens_plus, vocab_size=self.model.tokenizer.vocab_size, batch_size=batch_size) # select only the logits of the first masked word in each row logits_mask = logits[masked_rows, masked_cols, :] # mask out words according to the subword_mask logits_mask[:, self.subwords_mask] = -np.inf # select top n tokens from each distribution top_k = tf.math.top_k(logits_mask, top_n) top_k_logits, top_k_tokens = top_k.values, top_k.indices top_k_logits = (top_k_logits / temperature) if use_proba else (top_k_logits * 0) # Sample indices ids_k = tf.reshape(tf.random.categorical(top_k_logits, 1), shape=-1) # replace masked tokens with the sampled one tokens[masked_rows, masked_cols] = tf.gather(top_k_tokens, ids_k, batch_dims=1) tokens_plus['input_ids'] = tf.convert_to_tensor(tokens) return tokens, data
[docs] def set_data_type(self) -> None: """ Working with `numpy` arrays of strings requires setting the data type to avoid truncating examples. This function estimates the longest sentence expected during the sampling process, which is used to set the number of characters for the samples and examples arrays. This depends on the perturbation method used for sampling. """ # get the vocabulary vocab = self.model.tokenizer.get_vocab() max_len = 0 # go through the vocabulary and compute the maximum length of a token for token in vocab.keys(): max_len = len(token) if len(token) > max_len else max_len # length of the maximum word. the prefix it is just a precaution. # for example <mask> -> _<mask> which is not in the vocabulary. max_len += len(self.model.SUBWORD_PREFIX) # length of the maximum text max_sent_len = (len(self.head_tokens) + len(self.tail_tokens)) * max_len # define the types to be used self.dtype_token = '<U' + str(max_len) self.dtype_sent = '<U' + str(max_sent_len)
[docs] def seed(self, seed: int) -> None: tf.random.set_seed(seed)