Source code for alibi_detect.utils.frameworks

from .missing_optional_dependency import ERROR_TYPES
from typing import Optional, List, Dict, Iterable
from enum import Enum

[docs]class Framework(str, Enum): PYTORCH = 'pytorch' TENSORFLOW = 'tensorflow' KEOPS = 'keops' SKLEARN = 'sklearn'
try: import tensorflow as tf # noqa import tensorflow_probability as tfp # noqa has_tensorflow = True except ImportError: has_tensorflow = False try: import torch # noqa has_pytorch = True except ImportError: has_pytorch = False try: import pykeops # noqa import torch # noqa has_keops = True except ImportError: has_keops = False # Map from backend name to boolean value indicating its presence HAS_BACKEND = { 'tensorflow': has_tensorflow, 'pytorch': has_pytorch, 'sklearn': True, 'keops': has_keops, } def _iter_to_str(iterable: Iterable[str]) -> str: """ Correctly format iterable of items to comma seperated sentence string.""" items = [f'`{option}`' for option in iterable] last_item_str = f'{items[-1]}' if not items[:-1] else f' and {items[-1]}' return ', '.join(items[:-1]) + last_item_str
[docs]class BackendValidator:
[docs] def __init__(self, backend_options: Dict[Optional[str], List[str]], construct_name: str): """Checks for required sets of backend options. Takes a dictionary of backends plus extra dependencies and generates correct error messages if they are unmet. Parameters ---------- backend_options Dictionary from backend to list of dependencies that must be satisfied. The keys are the available options for the user and the values should be a list of dependencies that are checked via the `HAS_BACKEND` map defined in this module. An example of `backend_options` would be `{'tensorflow': ['tensorflow'], 'pytorch': ['pytorch'], None: []}`.This would mean `'tensorflow'`, `'pytorch'` or `None` are available backend options. If the user passes a different backend they will receive and error listing the correct backends. In addition, if one of the dependencies in the `backend_option` values is missing for the specified backend the validator will issue an error message telling the user what dependency bucket to install. construct_name Name of the object that has a set of backends we need to verify. """ self.backend_options = backend_options self.construct_name = construct_name
[docs] def verify_backend(self, backend: str): """Verifies backend choice. Verifies backend is implemented and that the correct dependencies are installed for the requested backend. If the backend is not implemented or a dependency is missing then an error is issued. Parameters ---------- backend Choice of backend the user wishes to initialize the alibi-detect construct with. Must be one of the keys in the `self.backend_options` dictionary. Raises ------ NotImplementedError If backend is not a member of `self.backend_options.keys()` a `NotImplementedError` is raised. Note `None` is a valid choice of backend if it is set as a key on `self.backend_options.keys()`. If a backend is not implemented for an alibi-detect object then it should not have a key on `self.backend_options`. ImportError If one of the dependencies in `self.backend_options[backend]` is missing then an ImportError will be thrown including a message informing the user how to install. """ if backend not in self.backend_options: self._raise_implementation_error(backend) dependencies = self.backend_options[backend] missing_deps = [] for dependency in dependencies: if not HAS_BACKEND[dependency]: missing_deps.append(dependency) if missing_deps: self._raise_import_error(missing_deps, backend)
def _raise_import_error(self, missing_deps: List[str], backend: str): """Raises import error if backend choice has missing dependency.""" optional_dependencies = list(ERROR_TYPES[missing_dep] for missing_dep in missing_deps) optional_dependencies.sort() missing_deps_str = _iter_to_str(missing_deps) error_msg = (f'{missing_deps_str} not installed. Cannot initialize and run {self.construct_name} ' f'with {backend} backend.') pip_msg = '' if not optional_dependencies else \ (f'The necessary missing dependencies can be installed using ' f'`pip install alibi-detect[{" ".join(optional_dependencies)}]`.') raise ImportError(f'{error_msg} {pip_msg}') def _raise_implementation_error(self, backend: str): """Raises NotImplementedError error if backend choice is not implemented.""" backend_list = _iter_to_str(self.backend_options.keys()) raise NotImplementedError(f"{backend} backend not implemented. Use one of {backend_list} instead.")