alibi.explainers.anchors.language_model_text_sampler module

class alibi.explainers.anchors.language_model_text_sampler.LanguageModelSampler(model, perturb_opts)[source]

Bases: AnchorTextSampler

FILLING_AUTOREGRESSIVE = 'autoregressive'

Autoregressive filling procedure. Considerably slow.

FILLING_PARALLEL: str = 'parallel'

Parallel filling procedure.

__call__(anchor, num_samples)[source]

The function returns a numpy array of num_samples where randomly chosen features, except those in anchor, are replaced by words sampled according to the language model’s predictions.

Parameters:
  • anchor (tuple) – Indices represent the positions of the words to be kept unchanged.

  • num_samples (int) – Number of perturbed sentences to be returned.

Return type:

Tuple[ndarray, ndarray]

Returns:

See alibi.explainers.anchors.language_model_text_sampler.LanguageModelSampler.perturb_sentence().

__init__(model, perturb_opts)[source]

Initialize language model sampler. This sampler replaces words with the ones sampled according to the output distribution of the language model. There are two modes to use the sampler: 'parallel' and 'autoregressive'. In the 'parallel' mode, all words are replaced simultaneously. In the 'autoregressive' model, the words are replaced one by one, starting from left to right. Thus the following words are conditioned on the previous predicted words.

Parameters:
  • model (LanguageModel) – Transformers masked language model.

  • perturb_opts (dict) – Perturbation options.

create_mask(anchor, num_samples, sample_proba=1.0, filling='parallel', frac_mask_templates=0.1, **kwargs)[source]

Create mask for words to be perturbed.

Parameters:
  • anchor (tuple) – Indices represent the positions of the words to be kept unchanged.

  • num_samples (int) – Number of perturbed sentences to be returned.

  • sample_proba (float) – Probability of a word being replaced.

  • filling (str) – Method to fill masked words. Either 'parallel' or 'autoregressive'.

  • frac_mask_templates (float) – Fraction of mask templates from the number of requested samples.

  • **kwargs – Other arguments to be passed to other methods.

Return type:

Tuple[ndarray, ndarray]

Returns:

  • raw – Array with masked instances.

  • data – A (num_samples, m)-dimensional boolean array, where m is the number of tokens in the instance to be explained.

fill_mask(raw, data, num_samples, top_n=100, batch_size_lm=32, filling='parallel', **kwargs)[source]

Fill in the masked tokens with language model.

Parameters:
  • raw (ndarray) – Array of mask templates.

  • data (ndarray) – Binary mask having 0 where the word was masked.

  • num_samples (int) – Number of samples to be drawn.

  • top_n (int) – Use the top n words when sampling.

  • batch_size_lm (int) – Batch size used for language model.

  • filling (str) – Method to fill masked words. Either 'parallel' or 'autoregressive'.

  • **kwargs – Other paremeters to be passed to other methods.

Return type:

Tuple[ndarray, ndarray]

Returns:

raw – Array containing num_samples elements. Each element is a perturbed sentence.

get_sample_ids(punctuation='!"#$%&\\'()*+, -./:;<=>?@[\\\\]^_`{|}~', stopwords=None, **kwargs)[source]

Find indices in words which can be perturbed.

Parameters:
  • punctuation (str) – String of punctuation characters.

  • stopwords (Optional[List[str]]) – List of stopwords.

  • **kwargs – Other arguments. Not used.

Return type:

None

perturb_sentence(anchor, num_samples, sample_proba=0.5, top_n=100, batch_size_lm=32, filling='parallel', **kwargs)[source]

The function returns an numpy array of num_samples where randomly chosen features, except those in anchor, are replaced by words sampled according to the language model’s predictions.

Parameters:
  • anchor (tuple) – Indices represent the positions of the words to be kept unchanged.

  • num_samples (int) – Number of perturbed sentences to be returned.

  • sample_proba (float) – Probability of a token being replaced by a similar token.

  • top_n (int) – Used for top n sampling.

  • batch_size_lm (int) – Batch size used for language model.

  • filling (str) – Method to fill masked words. Either 'parallel' or 'autoregressive'.

  • **kwargs – Other arguments to be passed to other methods.

Return type:

Tuple[ndarray, ndarray]

Returns:

  • raw – Array containing num_samples elements. Each element is a perturbed sentence.

  • data – A (num_samples, m)-dimensional boolean array, where m is the number of tokens in the instance to be explained.

seed(seed)[source]
Return type:

None

set_data_type()[source]

Working with numpy arrays of strings requires setting the data type to avoid truncating examples. This function estimates the longest sentence expected during the sampling process, which is used to set the number of characters for the samples and examples arrays. This depends on the perturbation method used for sampling.

Return type:

None

set_text(text)[source]

Sets the text to be processed

Parameters:

text (str) – Text to be processed.

Return type:

None