Source code for alibi_detect.saving.loading

import logging
import os
from functools import partial
from importlib import import_module
from pathlib import Path
from typing import Any, Callable, Optional, Union, Type, TYPE_CHECKING

import dill
import numpy as np
import toml
from transformers import AutoTokenizer

from alibi_detect.saving.registry import registry
from alibi_detect.saving._tensorflow import load_detector_legacy, load_embedding_tf, load_kernel_config_tf, \
    load_model_tf, load_optimizer_tf, prep_model_and_emb_tf, get_tf_dtype
from alibi_detect.saving._pytorch import load_embedding_pt, load_kernel_config_pt, load_model_pt, \
    load_optimizer_pt, prep_model_and_emb_pt, get_pt_dtype
from alibi_detect.saving._keops import load_kernel_config_ke
from alibi_detect.saving._sklearn import load_model_sk
from alibi_detect.saving.validate import validate_config
from alibi_detect.base import Detector, ConfigurableDetector, StatefulDetectorOnline
from alibi_detect.utils.frameworks import has_tensorflow, has_pytorch, Framework
from alibi_detect.saving.schemas import supported_models_tf, supported_models_torch
from alibi_detect.utils.missing_optional_dependency import import_optional
get_device = import_optional('alibi_detect.utils.pytorch.misc', names=['get_device'])

    import tensorflow as tf
    import torch

STATE_PATH = 'state/'  # directory (relative to detector directory) where state is saved (and loaded from)

logger = logging.getLogger(__name__)

# Fields to resolve in resolve_config ("resolve" meaning either load local artefact or resolve @registry, conversion to
# tuple, np.ndarray and np.dtype are dealt with separately).
# Note: For fields consisting of nested dicts, they must be listed in order from deepest to shallowest, so that the
# deepest fields are resolved first. e.g. 'preprocess_fn.src' must be resolved before 'preprocess_fn'.
    ['preprocess_fn', 'src'],
    ['preprocess_fn', 'model'],
    ['preprocess_fn', 'embedding'],
    ['preprocess_fn', 'tokenizer'],
    ['preprocess_fn', 'preprocess_batch_fn'],
    ['kernel', 'src'],
    ['kernel', 'proj'],
    ['kernel', 'init_sigma_fn'],
    ['kernel', 'kernel_a', 'src'],
    ['kernel', 'kernel_a', 'init_sigma_fn'],
    ['kernel', 'kernel_b', 'src'],
    ['kernel', 'kernel_b', 'init_sigma_fn'],
    ['x_kernel', 'src'],
    ['x_kernel', 'init_sigma_fn'],
    ['c_kernel', 'src'],
    ['c_kernel', 'init_sigma_fn'],

# Fields to convert from str to dtype
    ['preprocess_fn', 'dtype']

[docs] def load_detector(filepath: Union[str, os.PathLike], enable_unsafe_loading: bool = False, **kwargs) -> Union[Detector, ConfigurableDetector]: """ Load outlier, drift or adversarial detector. Parameters ---------- filepath Load directory. enable_unsafe_loading Sets allow_pickle=True when a np.ndarray is loaded from a .npy file referenced in the detector config. Needed if you have to load objects. Only applied if the filepath is config.toml or a directory containing a config.toml. It has security implications: Returns ------- Loaded outlier or adversarial detector object. """ filepath = Path(filepath) # If reference is a 'config.toml' itself, pass to new load function if == 'config.toml': return _load_detector_config(filepath, enable_unsafe_loading=enable_unsafe_loading) # Otherwise, if a directory, look for meta.dill, meta.pickle or config.toml inside it elif filepath.is_dir(): files = [str( for f in filepath.iterdir() if f.is_file()] if 'config.toml' in files: return _load_detector_config(filepath.joinpath('config.toml'), enable_unsafe_loading=enable_unsafe_loading) elif 'meta.dill' in files: return load_detector_legacy(filepath, '.dill', **kwargs) elif 'meta.pickle' in files: return load_detector_legacy(filepath, '.pickle', **kwargs) else: raise ValueError(f'Neither meta.dill, meta.pickle or config.toml exist in {filepath}.') # No other file types are accepted, so if not dir raise error else: raise ValueError("load_detector accepts only a filepath to a directory, or a config.toml file.")
# TODO - will eventually become load_detector def _load_detector_config(filepath: Union[str, os.PathLike], enable_unsafe_loading: bool = False) \ -> ConfigurableDetector: """ Loads a drift detector specified in a detector config dict. Validation is performed with pydantic. Parameters ---------- filepath Filepath to the `config.toml` file. enable_unsafe_loading Sets allow_pickle=True when a np.ndarray is loaded from a .npy file (happens if the .toml references one). Needed if you have to load objects. It has security implications: Returns ------- The instantiated detector. """ # Load toml if needed if isinstance(filepath, (str, os.PathLike)): config_file = Path(filepath) config_dir = config_file.parent cfg = read_config(config_file) else: raise ValueError("`filepath` should point to a directory containing a 'config.toml' file.") # Resolve and validate config cfg = validate_config(cfg)'Validated unresolved config.') cfg = resolve_config(cfg, config_dir=config_dir, enable_unsafe_loading=enable_unsafe_loading) cfg = validate_config(cfg, resolved=True)'Validated resolved config.') # Init detector from config'Instantiating detector.') detector = _init_detector(cfg) # Load state if it exists (and detector supports it) # TODO - this will be removed in follow-up offline state PR, as loading to be moved to __init__ (w/ state_dir kwarg) if isinstance(detector, StatefulDetectorOnline): state_dir = config_dir.joinpath(STATE_PATH) if state_dir.is_dir(): detector.load_state(state_dir)'Finished loading detector.') return detector def _init_detector(cfg: dict) -> ConfigurableDetector: """ Instantiates a detector from a fully resolved config dictionary. Parameters ---------- cfg The detector's resolved config dictionary. Returns ------- The instantiated detector. """ detector_name = cfg.pop('name') # Instantiate the detector klass = getattr(import_module(''), detector_name) detector = klass.from_config(cfg)'Instantiated drift detector {}'.format(detector_name)) return detector def _load_kernel_config(cfg: dict, backend: str = Framework.TENSORFLOW) -> Callable: """ Loads a kernel from a kernel config dict. Parameters ---------- cfg A kernel config dict. (see pydantic schema's). backend The backend. Returns ------- The kernel. """ if backend == Framework.TENSORFLOW: kernel = load_kernel_config_tf(cfg) elif backend == Framework.PYTORCH: kernel = load_kernel_config_pt(cfg) else: # backend=='keops' kernel = load_kernel_config_ke(cfg) return kernel def _load_preprocess_config(cfg: dict) -> Optional[Callable]: """ This function builds a preprocess_fn from the preprocess dict in a detector config dict. The dict format is expected to match that generated by serialize_preprocess in alibi_detect.utils.saving (also see pydantic schema). The model, tokenizer and preprocess_batch_fn are expected to be already resolved. Parameters ---------- cfg A preprocess_fn config dict. (see pydantic schemas). Returns ------- The preprocess_fn function. """ preprocess_fn = cfg.pop('src') if callable(preprocess_fn): if preprocess_fn.__name__ == 'preprocess_drift': # If preprocess_drift function, kwargs is preprocess cfg minus 'src' and 'kwargs' cfg.pop('kwargs') kwargs = cfg.copy() # Final processing of model (and/or embedding) model = kwargs['model'] emb = kwargs.pop('embedding') # embedding passed to preprocess_drift as `model` therefore remove # Backend specifics if has_tensorflow and isinstance(model, supported_models_tf): model = prep_model_and_emb_tf(model, emb) elif has_pytorch and isinstance(model, supported_models_torch): model = prep_model_and_emb_pt(model, emb) elif model is None: model = emb if model is None: raise ValueError("A 'model' and/or `embedding` must be specified when " "preprocess_fn='preprocess_drift'") kwargs.update({'model': model}) # Set`device` if a PyTorch model, otherwise remove from kwargs if isinstance(model, supported_models_torch): device = get_device(cfg['device']) model = kwargs.update({'device': device}) kwargs.update({'model': model}) else: kwargs.pop('device') else: kwargs = cfg['kwargs'] # If generic callable, kwargs is cfg['kwargs'] else: logger.warning('Unable to process preprocess_fn. No preprocessing function is defined.') return None if kwargs == {}: return preprocess_fn else: return partial(preprocess_fn, **kwargs) def _load_model_config(cfg: dict) -> Callable: """ Loads supported models from a model config dict. Parameters ---------- cfg Model config dict. (see pydantic model schemas). Returns ------- The loaded model. """ # Load model flavour = cfg['flavour'] src = cfg['src'] custom_obj = cfg['custom_objects'] layer = cfg['layer'] src = Path(src) if not src.is_dir(): raise FileNotFoundError("The `src` field is not a recognised directory. It should be a directory containing " "a compatible model.") if flavour == Framework.TENSORFLOW: model = load_model_tf(src, custom_objects=custom_obj, layer=layer) elif flavour == Framework.PYTORCH: model = load_model_pt(src, layer=layer) elif flavour == Framework.SKLEARN: model = load_model_sk(src) return model def _load_embedding_config(cfg: dict) -> Callable: # TODO: Could type return more tightly """ Load a pre-trained text embedding from an embedding config dict. Parameters ---------- cfg An embedding config dict. (see the pydantic schemas). Returns ------- The loaded embedding. """ src = cfg['src'] layers = cfg['layers'] typ = cfg['type'] flavour = cfg['flavour'] if flavour == Framework.TENSORFLOW: emb = load_embedding_tf(src, embedding_type=typ, layers=layers) else: emb = load_embedding_pt(src, embedding_type=typ, layers=layers) return emb def _load_tokenizer_config(cfg: dict) -> AutoTokenizer: """ Loads a text tokenizer from a tokenizer config dict. Parameters ---------- cfg A tokenizer config dict. (see the pydantic schemas). Returns ------- The loaded tokenizer. """ src = cfg['src'] kwargs = cfg['kwargs'] src = Path(src) tokenizer = AutoTokenizer.from_pretrained(src, **kwargs) return tokenizer def _load_optimizer_config(cfg: dict, backend: str) \ -> Union['tf.keras.optimizers.Optimizer', Type['tf.keras.optimizers.Optimizer'], Type['torch.optim.Optimizer']]: """ Loads an optimzier from an optimizer config dict. Parameters ---------- cfg The optimizer config dict. backend The backend. Returns ------- The loaded optimizer. """ if backend == Framework.TENSORFLOW: return load_optimizer_tf(cfg) else: return load_optimizer_pt(cfg) def _get_nested_value(dic: dict, keys: list) -> Any: """ Get a value from a nested dictionary. Parameters ---------- dic The dictionary. keys List of keys to "walk" to nested value. For example, to extract the value `dic['key1']['key2']['key3']`, set `keys = ['key1', 'key2', 'key3']`. Returns ------- The nested value specified by `keys`. """ for key in keys: try: dic = dic[key] except (TypeError, KeyError): return None return dic def _set_nested_value(dic: dict, keys: list, value: Any): """ Set a value in a nested dictionary. Parameters ---------- dic The dictionary. keys List of keys to "walk" to nested value. For example, to set the value `dic['key1']['key2']['key3']`, set `keys = ['key1', 'key2', 'key3']`. value The value to set. """ for key in keys[:-1]: dic = dic.setdefault(key, {}) dic[keys[-1]] = value def _set_dtypes(cfg: dict): """ Converts str's in the config dictionary to dtypes e.g. 'np.float32' is converted to np.float32. Parameters ---------- cfg The config dictionary. """ # TODO - we could explore a custom pydantic generic type for this (similar to how we handle NDArray) for key in FIELDS_TO_DTYPE: val = _get_nested_value(cfg, key) if val is not None: lib, dtype, *_ = val.split('.') # val[0] = np if val[0] == 'np' else tf if val[0] == 'tf' else torch if val[0] == 'torch' else None # TODO - add above back in once optional deps are handled properly if lib is None: raise ValueError("`dtype` must be in format np.<dtype>, tf.<dtype> or torch.<dtype>.") { 'tf': lambda: _set_nested_value(cfg, key, get_tf_dtype(dtype)), 'torch': lambda: _set_nested_value(cfg, key, get_pt_dtype(dtype)), 'np': lambda: _set_nested_value(cfg, key, getattr(np, dtype)), }[lib]()
[docs] def read_config(filepath: Union[os.PathLike, str]) -> dict: """ This function reads a detector toml config file and returns a dict specifying the detector. Parameters ---------- filepath The filepath to the config.toml file. Returns ------- Parsed toml dictionary. """ filepath = Path(filepath) cfg = dict(toml.load(filepath)) # toml.load types return as MutableMapping, force to dict'Loaded config file from {}'.format(str(filepath))) # This is necessary as no None/null in toml spec., and missing values are set to defaults set in pydantic models. # But we sometimes need to explicitly spec as None. cfg = _replace(cfg, "None", None) return cfg
[docs] def resolve_config(cfg: dict, config_dir: Optional[Path], enable_unsafe_loading: bool = False) -> dict: """ Resolves artefacts in a config dict. For example x_ref='x_ref.npy' is resolved by loading the np.ndarray from the .npy file. For a list of fields that are resolved, see Parameters ---------- cfg The unresolved config dict. config_dir Filepath to directory the `config.toml` is located in. Only required if different from the runtime directory, and artefacts are specified with filepaths relative to the config.toml file. enable_unsafe_loading If set to true, allow_pickle=True is set in np.load(). Needed if you have to load objects. It has security implications: Returns ------- The resolved config dict. """ # Convert selected str's to required dtype's (all other type coercion is performed by pydantic) _set_dtypes(cfg) # Before main resolution, update filepaths relative to config file if config_dir is not None: _prepend_cfg_filepaths(cfg, config_dir) # Resolve filepaths (load files) and resolve function/object registries for key in FIELDS_TO_RESOLVE:'Resolving config field: {}.'.format(key)) src = _get_nested_value(cfg, key) obj = None # Resolve string references to registered objects and filepaths if isinstance(src, str): # Resolve registry references if src.startswith('@'): src = src[1:] if src in registry.get_all(): obj = registry.get(src) else: raise ValueError( f"Can't find {src} in the custom function registry, It may be misspelled or missing " "if you have incorrect optional dependencies installed. Make sure the loading environment" " is the same as the saving environment. For more information, check the Installation " "documentation at " "" )'Successfully resolved registry entry {}'.format(src)) # Resolve dill or numpy file references elif Path(src).is_file(): if Path(src).suffix == '.dill': obj = dill.load(open(src, 'rb')) if Path(src).suffix == '.npy': obj = np.load(src, allow_pickle=enable_unsafe_loading) # Resolve artefact dicts elif isinstance(src, dict): backend = cfg.get('backend', Framework.TENSORFLOW) if key[-1] in ('model', 'proj'): obj = _load_model_config(src) elif key[-1] == 'embedding': obj = _load_embedding_config(src) elif key[-1] == 'tokenizer': obj = _load_tokenizer_config(src) elif key[-1] == 'optimizer': obj = _load_optimizer_config(src, backend) elif key[-1] == 'preprocess_fn': obj = _load_preprocess_config(src) elif key[-1] in ('kernel', 'x_kernel', 'c_kernel'): obj = _load_kernel_config(src, backend) # Put the resolved function into the cfg dict if obj is not None: _set_nested_value(cfg, key, obj) return cfg
def _replace(cfg: dict, orig: Optional[str], new: Optional[str]) -> dict: """ Recursively traverse a nested dictionary and replace values. Parameters ---------- cfg The dictionary. orig Original value to search. new Value to replace original with. Returns ------- The updated dictionary. """ for k, v in cfg.items(): if isinstance(v == orig, bool) and v == orig: cfg[k] = new elif isinstance(v, dict): _replace(v, orig, new) return cfg def _prepend_cfg_filepaths(cfg: dict, prepend_dir: Path): """ Recursively traverse through a nested dictionary and prepend a directory to any filepaths. Parameters ---------- cfg The dictionary. prepend_dir The filepath to prepend to any filepaths in the dictionary. Returns ------- The updated config dictionary. """ for k, v in cfg.items(): if isinstance(v, str): v = prepend_dir.joinpath(Path(v)) if v.is_file() or v.is_dir(): # Update if prepending config_dir made config value a real filepath cfg[k] = str(v) elif isinstance(v, dict): _prepend_cfg_filepaths(v, prepend_dir)