Source code for alibi_detect.utils.state.state

import os
from pathlib import Path
import logging
from abc import ABC
from typing import Union, Tuple
import numpy as np
from alibi_detect.utils.frameworks import Framework
from alibi_detect.utils.state._pytorch import save_state_dict as _save_state_dict_pt, \
    load_state_dict as _load_state_dict_pt

logger = logging.getLogger(__name__)


[docs] class StateMixin(ABC): """ Utility class that provides methods to save and load stateful attributes to disk. """ t: int online_state_keys: Tuple[str, ...] def _set_state_dir(self, dirpath: Union[str, os.PathLike]): """ Set the directory path to store state in, and create an empty directory if it doesn't already exist. Parameters ---------- dirpath The directory to save state file inside. """ self.state_dir = Path(dirpath) self.state_dir.mkdir(parents=True, exist_ok=True)
[docs] def save_state(self, filepath: Union[str, os.PathLike]): """ Save a detector's state to disk in order to generate a checkpoint. Parameters ---------- filepath The directory to save state to. """ self._set_state_dir(filepath) suffix = '.pt' if hasattr(self, 'backend') and self.backend == Framework.PYTORCH else '.npz' _save_state_dict(self, self.online_state_keys, self.state_dir.joinpath('state' + suffix)) logger.info('Saved state for t={} to {}'.format(self.t, self.state_dir))
[docs] def load_state(self, filepath: Union[str, os.PathLike]): """ Load the detector's state from disk, in order to restart from a checkpoint previously generated with `save_state`. Parameters ---------- filepath The directory to load state from. """ self._set_state_dir(filepath) suffix = '.pt' if hasattr(self, 'backend') and self.backend == Framework.PYTORCH else '.npz' _load_state_dict(self, self.state_dir.joinpath('state' + suffix), raise_error=True) logger.info('State loaded for t={} from {}'.format(self.t, self.state_dir))
def _save_state_dict(detector: StateMixin, keys: tuple, filepath: Path): """ Utility function to save a detector's state dictionary to a filepath. Parameters ---------- detector The detector to extract state attributes from. keys Tuple of state dict keys to populate dictionary with. filepath The file to save state dictionary to. """ # Construct state dictionary state_dict = {key: getattr(detector, key, None) for key in keys} # Save to disk if filepath.suffix == '.pt': _save_state_dict_pt(state_dict, filepath) else: np.savez(filepath, **state_dict) def _load_state_dict(detector: StateMixin, filepath: Path, raise_error: bool = True): """ Utility function to load a detector's state dictionary from a filepath, and update the detectors attributes with the values in the state dictionary. Parameters ---------- detector The detector to update. filepath File to load state dictionary from. raise_error Whether to raise an error if a file is not found at `filepath`. Otherwise, raise a warning and skip loading. Returns ------- None. The detector is updated inplace. """ if filepath.is_file(): if filepath.suffix == '.pt': state_dict = _load_state_dict_pt(filepath) else: state_dict = np.load(str(filepath)) for key, value in state_dict.items(): setattr(detector, key, value) else: if raise_error: raise FileNotFoundError('State file not found at {}.'.format(filepath)) else: logger.warning('State file not found at {}. Skipping loading of state.'.format(filepath))