Source code for alibi.tests.utils

from contextlib import contextmanager
from typing import Optional

import numpy as np

OUT_TYPES = ['proba', 'class', 'raw', 'probability', 'probability_doubled', 'log_loss', 'continuous']


[docs] class MockPredictor: """ A class the mimicks the output of a classifier or regressor to allow testing of functionality that depends on it without inference overhead. """
[docs] def __init__(self, out_dim: int, out_type: str = 'proba', model_type: Optional[str] = None, seed: Optional[int] = None, ) -> None: """ Parameters ---------- out_dim The number of output classes. out_type Indicates if probabilities, class predictions or continuous outputs are generated. """ np.random.seed(seed) self.out_dim = out_dim self.num_outputs = out_dim self.out_type = out_type self.model_type = model_type # Required to emulate `shap` model wrapper arguments, see test_shap_wrappers.py if out_type not in OUT_TYPES: raise ValueError("Unknown output type. Accepted values are {}".format(OUT_TYPES))
def __call__(self, *args, **kwargs): # can specify size s.t. multiple predictions/batches of predictions are returned if hasattr(args[0], 'shape'): sz = args[0].shape[:-1] else: raise ValueError("Predictor expects the input to have attribute .shape!") if self.out_type == 'proba' or self.out_type == 'probability': return self._generate_probas(sz, *args, **kwargs) elif self.out_type == 'class': return self._generate_labels(sz, *args, **kwargs) elif self.out_type == 'raw' or self.out_type == 'log_loss' or self.out_type == 'continuous': return self._generate_logits(sz, *args, **kwargs) def _generate_probas(self, sz: Optional[tuple] = None, *args, **kwargs) -> np.ndarray: """ Generates probability vectors by sampling from a Dirichlet distribution. User can specify the Dirichlet distribution parameters via the 'alpha' kwargs. See documentation for np.random.dirichlet to see how to set this parameter. Parameters ---------- sz Output dimension: [N, B] where N is number of batches and B is batch size. """ if self.out_dim == 1: return np.random.uniform(size=sz) # set distribution parameters alpha = kwargs.get('alpha', np.ones(self.out_dim)) if isinstance(alpha, np.ndarray): (dim,) = alpha.squeeze().shape elif isinstance(alpha, list): dim = len(alpha) else: raise TypeError("Expected Dirichlet parameters to be of type list or np.ndarray!") if dim != self.out_dim: raise ValueError("The dimension of the Dirichlet distribution parameters" "must match output dimension. Got alpha dim={} and " "out_dim={} ".format(dim, self.out_dim)) return np.random.dirichlet(alpha, size=sz) def _generate_labels(self, sz: Optional[tuple] = None, *args, **kwargs) -> np.ndarray: """ Generates labels by sampling random integers in range(0, n_classes+1). """ if sz: sz += (self.out_dim,) return np.random.randint(0, self.out_dim + 1, size=sz) def _generate_logits(self, sz: Optional[tuple] = None, *args, **kwargs) -> np.ndarray: """ Generates fake logit values by sampling from the standard normal """ if sz: sz += (self.out_dim,) return np.random.normal(size=sz)
[docs] def predict(self, *args, **kwargs): return self.__call__(*args, **kwargs)
[docs] def issorted(arr, reverse=False): """ Checks if a numpy array is sorted. """ if reverse: return np.all(arr[::-1][:-1] <= arr[::-1][1:]) return np.all(arr[:-1] <= arr[1:])
[docs] @contextmanager def not_raises(ExpectedException): """ A context manager used to check that `ExpectedException` does not occur during testing. """ try: yield except ExpectedException as error: raise AssertionError("Raised exception {} when it should not!".format(error)) except Exception as error: raise AssertionError("An unexpected exception {} raised.".format(error))
[docs] def assert_message_in_logs(msg, records): """ Helper function to check if a msg is present in any of the records (an iterable of strings). """ count = 0 for record in records: if msg in record.msg: count += 1 assert count > 0