import logging
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type, Union)
import numpy as np
import ray
from alibi.api.interfaces import Explanation
from alibi.utils.discretizer import Discretizer
from alibi.utils.mapping import ohe_to_ord
from .anchor_base import AnchorBaseBeam
from .anchor_tabular import AnchorTabular, TabularSampler
from alibi.utils.distributed import ActorPool
from functools import partial
[docs]
class DistributedAnchorBaseBeam(AnchorBaseBeam):
def __init__(self, samplers: List[Callable], **kwargs) -> None:
super().__init__(samplers)
self.chunksize = kwargs.get('chunksize', 1)
self.sample_fcn = lambda actor, anchor, n_samples, compute_labels=True: \
actor.__call__.remote(anchor,
n_samples,
compute_labels=compute_labels)
self.pool = ActorPool(samplers)
self.samplers = samplers
def _get_coverage_samples(self, coverage_samples: int, # type: ignore[override]
samplers: List[Callable]) -> np.ndarray:
"""
Sends a request for a coverage set to process running sampling tasks.
Parameters
----------
coverage_samples, samplers
See :py:meth:`alibi.explainers.anchors.anchor_base.AnchorBaseBeam._get_coverage_samples` implementation.
Returns
-------
See :py:meth:`alibi.explainers.anchors.anchor_base.AnchorBaseBeam._get_coverage_samples` implementation.
"""
[coverage_data] = ray.get(
self.sample_fcn(samplers[0], (0, ()), coverage_samples, compute_labels=False)
)
return coverage_data
[docs]
def draw_samples(self, anchors: list, batch_size: int) -> Tuple[np.ndarray, np.ndarray]: # type: ignore[override]
"""
Distributes sampling requests among processes running sampling tasks.
Parameters
----------
anchors, batch_size
See :py:meth:`alibi.explainers.anchors.anchor_base.AnchorBaseBeam.draw_samples` implementation.
Returns
-------
See :py:meth:`alibi.explainers.anchors.anchor_base.AnchorBaseBeam.draw_samples` implementation.
"""
# partial anchors not generated by propose_anchors are not in the order dictionary
for anchor in anchors:
if anchor not in self.state['t_order']:
self.state['t_order'][anchor] = list(anchor)
pos, total = np.zeros((len(anchors),)), np.zeros((len(anchors),))
order_map = [(i, tuple(self.state['t_order'][anchor])) for i, anchor in enumerate(anchors)]
samples_iter = self.pool.map_unordered(
partial(self.sample_fcn, n_samples=batch_size),
order_map,
self.chunksize,
)
for samples_batch in samples_iter:
for samples in samples_batch:
covered_true, covered_false, labels, *additionals, anchor_idx = samples
positives, n_samples = self.update_state(
covered_true,
covered_false,
labels,
additionals,
anchors[anchor_idx],
)
# return statistics in the same order as the requests
pos[anchor_idx], total[anchor_idx] = positives, n_samples
return pos, total
[docs]
class RemoteSampler:
""" A wrapper that facilitates the use of `TabularSampler` for distributed sampling."""
def __init__(self, *args):
self.train_id, self.d_train_id, self.sampler = args
self.sampler = self.sampler.deferred_init(self.train_id, self.d_train_id)
[docs]
def __call__(self, anchors_batch: Union[Tuple[int, tuple], List[Tuple[int, tuple]]], num_samples: int,
compute_labels: bool = True) -> List:
"""
Wrapper around :py:meth:`alibi.explainers.anchors.anchor_tabular.TabularSampler.__call__`. It allows sampling a
batch of anchors in the same process, which can improve performance.
Parameters
----------
anchors_batch, num_samples, compute_labels
A list of result tuples. See :py:meth:`alibi.explainers.anchors.anchor_tabular.TabularSampler.__call__`
for details.
"""
if isinstance(anchors_batch, tuple): # DistributedAnchorBaseBeam._get_samples_coverage call
return self.sampler(anchors_batch, num_samples, compute_labels=compute_labels)
elif len(anchors_batch) == 1: # batch size = 1
return [self.sampler(*anchors_batch, num_samples, compute_labels=compute_labels)]
else: # batch size > 1
batch_result = []
for anchor in anchors_batch:
batch_result.append(self.sampler(anchor, num_samples, compute_labels=compute_labels))
return batch_result
[docs]
def set_instance_label(self, X: np.ndarray) -> int:
"""
Sets the remote sampler instance label.
Parameters
----------
X
The instance to be explained.
Returns
-------
label
The label of the instance to be explained.
"""
self.sampler.set_instance_label(X)
label = self.sampler.instance_label
return label
[docs]
def set_n_covered(self, n_covered: int) -> None:
"""
Sets the remote sampler number of examples to save for inspection.
Parameters
----------
n_covered
Number of examples where the result (and partial anchors) apply.
"""
self.sampler.set_n_covered(n_covered)
def _get_sampler(self) -> TabularSampler:
"""
A getter that returns the underlying tabular object.
Returns
-------
The tabular sampler object that is used in the process.
"""
return self.sampler
[docs]
def build_lookups(self, X: np.ndarray):
"""
Wrapper around :py:meth:`alibi.explainers.anchors.anchor_tabular.TabularSampler.build_lookups`.
Parameters
--------
X
See :py:meth:`alibi.explainers.anchors.anchor_tabular.TabularSampler.build_lookups`.
Returns
-------
See :py:meth:`alibi.explainers.anchors.anchor_tabular.TabularSampler.build_lookups`.
"""
cat_lookup_id, ord_lookup_id, enc2feat_idx_id = self.sampler.build_lookups(X)
return [cat_lookup_id, ord_lookup_id, enc2feat_idx_id]
[docs]
class DistributedAnchorTabular(AnchorTabular):
def __init__(self,
predictor: Callable,
feature_names: List[str],
categorical_names: Optional[Dict[int, List[str]]] = None,
dtype: Type[np.generic] = np.float32,
ohe: bool = False,
seed: Optional[int] = None) -> None:
super().__init__(predictor, feature_names, categorical_names, dtype, ohe, seed)
if not ray.is_initialized():
ray.init()
[docs]
def fit(self, # type: ignore[override]
train_data: np.ndarray,
disc_perc: tuple = (25, 50, 75),
**kwargs) -> "AnchorTabular":
"""
Creates a list of handles to parallel processes handles that are used for submitting sampling
tasks.
Parameters
----------
train_data, disc_perc, **kwargs
See :py:meth:`alibi.explainers.anchors.anchor_tabular.AnchorTabular.fit` superclass.
"""
try:
ncpu = kwargs['ncpu']
except KeyError:
logging.warning('DistributedAnchorTabular object has been initalised but kwargs did not contain '
'expected argument, ncpu. Defaulting to ncpu=2!')
ncpu = 2
# transform one-hot encodings to labels if ohe == True
train_data = ohe_to_ord(X_ohe=train_data, cat_vars_ohe=self.cat_vars_ohe)[0] if self.ohe else train_data
disc = Discretizer(train_data, self.numerical_features, self.feature_names, percentiles=disc_perc)
d_train_data = disc.discretize(train_data)
self.feature_values.update(disc.feature_intervals)
sampler_args = (
self._predictor,
disc_perc,
self.numerical_features,
self.categorical_features,
self.feature_names,
self.feature_values,
)
train_data_id = ray.put(train_data)
d_train_data_id = ray.put(d_train_data)
samplers = [TabularSampler(*sampler_args, seed=self.seed) for _ in range(ncpu)] # type: ignore[arg-type]
d_samplers = []
for sampler in samplers:
d_samplers.append(
ray.remote(RemoteSampler).remote( # type: ignore[call-arg]
*(train_data_id, d_train_data_id, sampler)
)
)
self.samplers = d_samplers
# update metadata
self.meta['params'].update(disc_perc=disc_perc)
return self
def _build_sampling_lookups(self, X: np.ndarray) -> None:
"""
See :py:meth:`alibi.explainers.anchors.anchor_tabular.AnchorTabular._build_sampling_lookups` documentation.
Parameters
----------
X
See :py:meth:`alibi.explainers.anchors.anchor_tabular.AnchorTabular._build_sampling_lookups` documentation.
"""
lookups = [sampler.build_lookups.remote(X) for sampler in self.samplers][0]
self.cat_lookup, self.ord_lookup, self.enc2feat_idx = ray.get(lookups)
[docs]
def explain(self,
X: np.ndarray,
threshold: float = 0.95,
delta: float = 0.1,
tau: float = 0.15,
batch_size: int = 100,
coverage_samples: int = 10000,
beam_size: int = 1,
stop_on_first: bool = False,
max_anchor_size: Optional[int] = None,
min_samples_start: int = 1,
n_covered_ex: int = 10,
binary_cache_size: int = 10000,
cache_margin: int = 1000,
verbose: bool = False,
verbose_every: int = 1,
**kwargs: Any) -> Explanation:
"""
Explains the prediction made by a classifier on instance `X`. Sampling is done in parallel over a number of
cores specified in `kwargs['ncpu']`.
Parameters
----------
X, threshold, delta, tau, batch_size, coverage_samples, beam_size, stop_on_first, max_anchor_size, \
min_samples_start, n_covered_ex, binary_cache_size, cache_margin, verbose, verbose_every, **kwargs
See :py:meth:`alibi.explainers.anchors.anchor_tabular.AnchorTabular.explain`.
Returns
-------
See :py:meth:`alibi.explainers.anchors.anchor_tabular.AnchorTabular.explain` superclass.
"""
# transform one-hot encodings to labels if ohe == True
X = ohe_to_ord(X_ohe=X.reshape(1, -1), cat_vars_ohe=self.cat_vars_ohe)[0].reshape(-1) if self.ohe else X
# get params for storage in meta
params = locals()
remove = ['X', 'self']
for key in remove:
params.pop(key)
for sampler in self.samplers:
label = sampler.set_instance_label.remote(X)
sampler.set_n_covered.remote(n_covered_ex)
self.instance_label = ray.get(label)
# build feature encoding and mappings from the instance values to database rows where similar records are found
# get anchors and add metadata
self._build_sampling_lookups(X)
mab = DistributedAnchorBaseBeam(
samplers=self.samplers,
sample_cache_size=binary_cache_size,
cache_margin=cache_margin,
**kwargs,
)
result: Any = mab.anchor_beam(
delta=delta, epsilon=tau,
desired_confidence=threshold,
beam_size=beam_size,
min_samples_start=min_samples_start,
max_anchor_size=max_anchor_size,
batch_size=batch_size,
coverage_samples=coverage_samples,
verbose=verbose,
verbose_every=verbose_every,
)
self.mab = mab
return self._build_explanation(X, result, self.instance_label, params)
[docs]
def reset_predictor(self, predictor: Callable) -> None:
"""
Resets the predictor function.
Parameters
----------
predictor
New model prediction function.
"""
raise NotImplementedError("Resetting predictor is currently not supported for distributed explainers.")
# TODO: to support resetting a predictor we would need to re-run most of the code in `fit` instantiating the
# instances of RemoteSampler anew