import logging
from copy import deepcopy
from typing import Callable, Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from skimage.transform import resize
from sklearn.model_selection import KFold
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
from alibi.api.defaults import (DEFAULT_DATA_PROTOSELECT,
DEFAULT_META_PROTOSELECT)
from alibi.api.interfaces import Explanation, FitMixin, Summariser
from alibi.utils.distance import batch_compute_kernel_matrix
from alibi.utils.kernel import EuclideanDistance
logger = logging.getLogger(__name__)
[docs]class ProtoSelect(Summariser, FitMixin):
[docs] def __init__(self,
kernel_distance: Callable[[np.ndarray, np.ndarray], np.ndarray],
eps: float,
lambda_penalty: Optional[float] = None,
batch_size: int = int(1e10),
preprocess_fn: Optional[Callable[[Union[list, np.ndarray]], np.ndarray]] = None,
verbose: bool = False):
"""
Prototype selection for dataset distillation and interpretable classification proposed by
Bien and Tibshirani (2012): https://arxiv.org/abs/1202.5933
Parameters
----------
kernel_distance
Kernel distance to be used. Expected to support computation in batches. Given an input `x` of
size `Nx x f1 x f2 x ...` and an input `y` of size `Ny x f1 x f2 x ...`, the kernel distance
should return a kernel matrix of size `Nx x Ny`.
eps
Epsilon ball size.
lambda_penalty
Penalty for each prototype. Encourages a lower number of prototypes to be selected. Corresponds to
:math:`\\lambda` in the paper notation. If not specified, the default value is set to `1 / N` where
`N` is the size of the dataset to choose the prototype instances from, passed to the
:py:meth:`alibi.prototypes.protoselect.ProtoSelect.fit` method.
batch_size
Batch size to be used for kernel matrix computation.
preprocess_fn
Preprocessing function used for kernel matrix computation. The preprocessing function takes the input as
a `list` or a `numpy` array and transforms it into a `numpy` array which is then fed to the
`kernel_distance` function. The use of `preprocess_fn` allows the method to be applied to any data modality.
verbose
Whether to display progression bar while computing prototype points.
"""
super().__init__(meta=deepcopy(DEFAULT_META_PROTOSELECT))
self.kernel_distance = kernel_distance
self.eps = eps
self.lambda_penalty = lambda_penalty
self.batch_size = batch_size
self.preprocess_fn = preprocess_fn
self.verbose = verbose
# get kernel tag
if hasattr(self.kernel_distance, '__name__'):
kernel_distance_tag = self.kernel_distance.__name__
elif hasattr(self.kernel_distance, '__class__'):
kernel_distance_tag = self.kernel_distance.__class__.__name__
else:
kernel_distance_tag = 'unknown kernel distance'
# update metadata
self.meta['params'].update({
'kernel_distance': kernel_distance_tag,
'eps': eps,
'lambda_penalty': lambda_penalty,
'batch_size': batch_size,
'verbose': verbose
})
[docs] def fit(self, # type: ignore[override]
X: Union[list, np.ndarray],
y: Optional[np.ndarray] = None,
Z: Optional[Union[list, np.ndarray]] = None) -> 'ProtoSelect':
"""
Fit the summariser. This step forms the kernel matrix in memory which has a shape of `NX x NX`,
where `NX` is the number of instances in `X`, if the optional dataset `Z` is not provided. Otherwise, if
the optional dataset `Z` is provided, the kernel matrix has a shape of `NZ x NX`, where `NZ` is the
number of instances in `Z`.
Parameters
---------
X
Dataset to be summarised.
y
Labels of the dataset `X` to be summarised. The labels are expected to be represented as integers
`[0, 1, ..., L-1]`, where `L` is the number of classes in the dataset `X`.
Z
Optional dataset to choose the prototypes from. If ``Z=None``, the prototypes will be selected from the
dataset `X`. Otherwise, if `Z` is provided, the dataset to be summarised is still `X`, but
it is summarised by prototypes belonging to the dataset `Z`.
Returns
-------
self
Reference to itself.
"""
if y is not None:
y = y.flatten()
if len(X) != len(y):
raise ValueError('The number of data instances does not match the number of labels. '
f'Got len(X)={len(X)} and len(y)={len(y)}.')
self.X = X
# if the y is not provided, then consider that all elements belong to the same class. This means
# that loss term which tries to avoid including in an epsilon ball elements belonging to other classes
# will always be 0. Still the first term of the loss tries to cover as many examples as possible with
# minimal overlap between the epsilon balls corresponding to the other prototypes.
self.y = y.astype(np.int32) if (y is not None) else np.zeros((len(X),), dtype=np.int32)
# redefine the labels, so they are in the interval [0, len(np.unique(y)) - 1].
# For example, if the labels provided were [40, 51], internally, we relabel them as [0, 1].
# This approach can reduce computation and memory allocation, as without the intermediate mapping we would
# have to allocate memory corresponding to 52 labels, [0, ..., 51], for some internal matrices.
self.label_mapping = {l: i for i, l in enumerate(np.unique(self.y))}
self.label_inv_mapping = {v: k for k, v in self.label_mapping.items()}
idx = np.nonzero(np.asarray(list(self.label_mapping.keys())) == self.y[:, None])[1]
self.y = np.asarray(list(self.label_mapping.values()))[idx]
# if the set of prototypes is not provided, then find the prototypes belonging to the X dataset.
self.Z = Z if (Z is not None) else self.X
# initialize penalty for adding a prototype
if self.lambda_penalty is None:
self.lambda_penalty = 1. / len(self.Z)
self.meta['params'].update({'lambda_penalty': self.lambda_penalty})
self.kmatrix = batch_compute_kernel_matrix(x=self.Z,
y=self.X,
kernel=self.kernel_distance,
batch_size=self.batch_size,
preprocess_fn=self.preprocess_fn)
return self
[docs] def summarise(self, num_prototypes: int = 1) -> Explanation:
"""
Searches for the requested number of prototypes. Note that the algorithm can return a lower number of
prototypes than the requested one. To increase the number of prototypes, reduce the epsilon-ball radius
(`eps`), and the penalty for adding a prototype (`lambda_penalty`).
Parameters
----------
num_prototypes
Maximum number of prototypes to be selected.
Returns
-------
An `Explanation` object containing the prototypes, prototype indices and prototype labels with additional \
metadata as attributes.
"""
if num_prototypes > len(self.Z):
num_prototypes = len(self.Z)
logger.warning('The number of prototypes requested is larger than the number of elements from '
f'the prototypes selection set. Automatically setting `num_prototypes={num_prototypes}`.')
# dictionary of prototypes indices for each class
protos: Dict[int, List[int]] = {l: [] for l in range(len(self.label_mapping))} # noqa: E741
# set of available prototypes indices. Note that initially we start with the entire set of Z,
# but as the algorithm progresses, we remove the indices of the prototypes that we already selected.
available_indices = set(range(len(self.Z)))
# matrix of size [NZ, NX], where NZ = len(Z) and NX = len(X)
# represents a mask which indicates for each element z in Z what are the elements of X that are in an
# epsilon ball centered in z.
B = (self.kmatrix <= self.eps).astype(np.int32)
# matrix of size [L, NX], where L is the number of labels
# each row l indicates the elements from X that are covered by prototypes belonging to class l.
B_P = np.zeros((len(self.label_mapping), len(self.X)), dtype=np.int32)
# matrix of size [L, NX]. Each row l indicates which elements form X are labeled as l.
Xl = np.concatenate([(self.y == l).reshape(1, -1)
for l in range(len(self.label_mapping))], axis=0).astype(np.int32) # noqa: E741
# vectorized implementation of the prototypes scores.
# See paper (pag 8): https://arxiv.org/pdf/1202.5933.pdf for more details
B_diff = B[:, np.newaxis, :] - B_P[np.newaxis, :, :] # [NZ, 1, NX] - [1, L, NX] -> [NZ, L, NX]
# [NZ, L, NX] + [1, L, NX] -> [NZ, L, NX]
delta_xi_all = B_diff + Xl[np.newaxis, ...] >= 2
# [NZ, L]. For every row z and every column l, we compute how many new instances belonging to class
# l will be covered if we add the prototype z.
delta_xi_summed = np.sum(delta_xi_all, axis=-1)
# [NZ, 1, NX] + [1, L, NX] -> [NZ, L, NX]
delta_nu_all = B[:, np.newaxis, :] + (1 - Xl[np.newaxis, ...]) >= 2
# [NZ, L]. For every row z and every column l, we compute how many new instances belonging to all the
# other classes different from l will be covered if we add the prototype z.
delta_nu_summed = np.sum(delta_nu_all, axis=-1)
# compute the tradeoff score - each prototype tries to cover as many new elements as possible
# belonging to the same class, while trying to avoid covering elements belonging to another class
scores_all = delta_xi_summed - delta_nu_summed - self.lambda_penalty
for _ in tqdm(range(num_prototypes), disable=(not self.verbose)):
j = np.array(list(available_indices)).astype(np.int32)
scores = scores_all[j]
# stopping criterion. The number of the returned prototypes might be lower than
# the number of requested prototypes.
if np.all(scores < 0):
break
# find the index i of the best prototype and the class l that it covers.
row, col = np.unravel_index(np.argmax(scores), scores.shape)
i, l = j[row.item()], col.item() # noqa: E741
# update the score.
covered = np.sum(delta_xi_all[:, l, B[i].astype(bool)], axis=-1)
delta_xi_all[:, l, B[i].astype(bool)] = 0
delta_xi_summed[:, l] -= covered
scores_all[:, l] -= covered
# add prototype to the corresponding list according to the class label l that it covers
# and remove the index i from list of available indices.
protos[l].append(i)
available_indices.remove(i)
return self._build_summary(protos)
def _build_summary(self, protos: Dict[int, List[int]]) -> Explanation:
"""
Helper method to build the summary as an `Explanation` object.
"""
data = deepcopy(DEFAULT_DATA_PROTOSELECT)
data['prototype_indices'] = np.concatenate(list(protos.values())).astype(np.int32)
data['prototype_labels'] = np.concatenate([[self.label_inv_mapping[l]] * len(protos[l])
for l in protos]).astype(np.int32) # noqa: E741
data['prototypes'] = self.Z[data['prototype_indices']]
return Explanation(meta=self.meta, data=data)
def _helper_protoselect_euclidean_1knn(summariser: ProtoSelect,
num_prototypes: int,
eps: float,
knn_kw: dict) -> Optional[KNeighborsClassifier]:
"""
Helper function to fit a 1-KNN classifier on the prototypes returned by the `summariser`.
Sets the epsilon radius to be used.
Parameters
----------
summariser
Fitted `ProtoSelect` summariser.
num_prototypes
Number of requested prototypes.
eps
Epsilon radius to be set and used for the computation of prototypes.
knn_kw
Keyword arguments passed to `sklearn.neighbors.KNeighborsClassifier`. See parameters description:
https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html
Returns
-------
Fitted 1-KNN classifier with Euclidean distance metric.
"""
# update summariser eps and get the summary
summariser.eps = eps
summary = summariser.summarise(num_prototypes=num_prototypes)
# train 1-knn classifier
X_protos, y_protos = summary.data['prototypes'], summary.data['prototype_labels']
if len(X_protos) == 0:
return None
# note that the knn_kw are updated in `cv_protoselect_euclidean` to define a 1-KNN with Euclidean distance
knn = KNeighborsClassifier(**knn_kw)
return knn.fit(X=X_protos, y=y_protos)
def _get_splits(trainset: Tuple[np.ndarray, np.ndarray],
valset: Tuple[Optional[np.ndarray], Optional[np.ndarray]],
kfold_kw: dict) -> Tuple[
Tuple[np.ndarray, np.ndarray],
Tuple[np.ndarray, np.ndarray],
List[Tuple[np.ndarray, np.ndarray]]
]:
"""
Helper function to obtain appropriate train-validation splits.
If the validation dataset is not provided, then the method returns the appropriate datasets and indices
to perform k-fold validation. Otherwise, if the validation dataset is provided, then use it instead of performing
k-fold validation.
Parameters
----------
trainset
Tuple `(X_train, y_train)` consisting of the training data instances with the corresponding labels.
valset
Optional tuple, `(X_val, y_val)`, consisting of validation data instances with the corresponding labels.
kfold_kw
Keyword arguments passed to `sklearn.model_selection.KFold`. See parameters description:
https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html
Returns
-------
Tuple consisting of training dataset, validation dataset (can overlap with training if validation is not
provided), and a list of splits containing indices from the training and validation datasets.
"""
X_train, y_train = trainset
X_val, y_val = valset
if X_val is None:
kfold = KFold(**kfold_kw)
splits = kfold.split(X=X_train, y=y_train)
return trainset, trainset, list(splits)
splits = [(np.arange(len(X_train)), np.arange(len(X_val)))]
return trainset, valset, splits # type: ignore
[docs]def cv_protoselect_euclidean(trainset: Tuple[np.ndarray, np.ndarray],
protoset: Optional[Tuple[np.ndarray, ]] = None,
valset: Optional[Tuple[np.ndarray, np.ndarray]] = None,
num_prototypes: int = 1,
eps_grid: Optional[np.ndarray] = None,
quantiles: Optional[Tuple[float, float]] = None,
grid_size: int = 25,
n_splits: int = 2,
batch_size: int = int(1e10),
preprocess_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
protoselect_kw: Optional[dict] = None,
knn_kw: Optional[dict] = None,
kfold_kw: Optional[dict] = None) -> dict:
"""
Cross-validation parameter selection for `ProtoSelect` with Euclidean distance. The method computes
the best epsilon radius.
Parameters
----------
trainset
Tuple, `(X_train, y_train)`, consisting of the training data instances with the corresponding labels.
protoset
Tuple, `(Z, )`, consisting of the dataset to choose the prototypes from. If `Z` is not provided
(i.e., ``protoset=None``), the prototypes will be selected from the training dataset `X`. Otherwise, if `Z`
is provided, the dataset to be summarised is still `X`, but it is summarised by prototypes belonging to
the dataset `Z`. Note that the argument is passed as a tuple with a single element for consistency reasons.
valset
Optional tuple `(X_val, y_val)` consisting of validation data instances with the corresponding
validation labels. 1-KNN classifier is evaluated on the validation dataset to obtain the best epsilon radius.
In case ``valset=None``, then `n-splits` cross-validation is performed on the `trainset`.
num_prototypes
The number of prototypes to be selected.
eps_grid
Optional grid of values to select the epsilon radius from. If not specified, the search grid is
automatically proposed based on the inter-distances between `X` and `Z`. The distances are filtered
by considering only values in between the `quantiles` values. The minimum and maximum distance values are
used to define the range of values to search the epsilon radius. The interval is discretized in `grid_size`
equidistant bins.
quantiles
Quantiles, `(q_min, q_max)`, to be used to filter the range of values of the epsilon radius. The expected
quantile values are in `[0, 1]` and clipped to `[0, 1]` if outside the range. See `eps_grid` for usage.
If not specified, no filtering is applied. Only used if ``eps_grid=None``.
grid_size
The number of equidistant bins to be used to discretize the `eps_grid` automatically proposed interval.
Only used if ``eps_grid=None``.
batch_size
Batch size to be used for kernel matrix computation.
preprocess_fn
Preprocessing function to be applied to the data instance before applying the kernel.
protoselect_kw
Keyword arguments passed to :py:meth:`alibi.prototypes.protoselect.ProtoSelect.__init__`.
knn_kw
Keyword arguments passed to `sklearn.neighbors.KNeighborsClassifier`. The `n_neighbors` will be
set automatically to 1 and the `metric` will be set to ``'euclidean``. See parameters description:
https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html
kfold_kw
Keyword arguments passed to `sklearn.model_selection.KFold`. See parameters description:
https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html
Returns
-------
Dictionary containing
- ``'best_eps'``: ``float`` - the best epsilon radius according to the accuracy of a 1-KNN classifier.
- ``'meta'``: ``dict`` - dictionary containing argument and data gather throughout cross-validation.
"""
if protoselect_kw is None:
protoselect_kw = {}
if kfold_kw is None:
kfold_kw = {}
if knn_kw is None:
knn_kw = {}
# ensure that we are training a 1-KNN classifier with Euclidean distance metric
knn_kw.update({'n_neighbors': 1, 'metric': 'euclidean'})
X_train, y_train = trainset
Z = protoset[0] if (protoset is not None) else X_train
X_val, y_val = valset if (valset is not None) else (None, None)
if preprocess_fn is not None:
X_train = _batch_preprocessing(X=X_train, preprocess_fn=preprocess_fn, batch_size=batch_size)
Z = _batch_preprocessing(X=Z, preprocess_fn=preprocess_fn, batch_size=batch_size)
if X_val is not None:
X_val = _batch_preprocessing(X_val, preprocess_fn=preprocess_fn, batch_size=batch_size)
# propose eps_grid if not specified
if eps_grid is None:
dist = batch_compute_kernel_matrix(x=X_train, y=Z, kernel=EuclideanDistance()).reshape(-1)
if quantiles is not None:
if quantiles[0] > quantiles[1]:
raise ValueError('The quantile lower-bound is greater then the quantile upper-bound.')
quantiles = np.clip(quantiles, a_min=0, a_max=1) # type: ignore[assignment]
min_dist, max_dist = np.quantile(a=dist, q=np.array(quantiles))
else:
min_dist, max_dist = np.min(dist), np.max(dist)
# define list of values for eps
eps_grid = np.linspace(min_dist, max_dist, num=grid_size)
(X_train, y_train), (X_val, y_val), splits = _get_splits(trainset=(X_train, y_train),
valset=(X_val, y_val),
kfold_kw=kfold_kw)
scores = np.zeros((len(eps_grid), len(splits)))
for i, (train_index, val_index) in enumerate(splits):
X_train_i, y_train_i = X_train[train_index], y_train[train_index]
X_val_i, y_val_i = X_val[val_index], y_val[val_index]
# define and fit the summariser here, so we don't repeat the kernel matrix computation in the next for loop
summariser = ProtoSelect(kernel_distance=EuclideanDistance(), eps=0, **protoselect_kw)
summariser = summariser.fit(X=X_train_i, y=y_train_i, Z=Z)
for j in range(len(eps_grid)):
knn = _helper_protoselect_euclidean_1knn(summariser=summariser,
num_prototypes=num_prototypes,
eps=eps_grid[j],
knn_kw=knn_kw)
if knn is None:
continue
scores[j][i] = knn.score(X_val_i, y_val_i)
return {
'best_eps': eps_grid[np.argmax(np.mean(scores, axis=-1))],
'meta': {
'num_prototypes': num_prototypes,
'eps_grid': eps_grid,
'quantiles': quantiles,
'grid_size': grid_size,
'n_splits': n_splits,
'batch_size': batch_size,
'scores': scores,
}
}
def _batch_preprocessing(X: np.ndarray,
preprocess_fn: Callable[[np.ndarray], np.ndarray],
batch_size: int = 32) -> np.ndarray:
"""
Preprocess a dataset `X` in batches by applying the preprocessor function.
Parameters
----------
X
Dataset to be preprocessed.
preprocess_fn
Preprocessor function.
batch_size
Batch size to be used for each call to `preprocess_fn`.
Returns
-------
Preprocessed dataset.
"""
X_ft = []
num_iter = int(np.ceil(len(X) / batch_size))
for i in range(num_iter):
istart, iend = batch_size * i, min(batch_size * (i + 1), len(X))
X_ft.append(preprocess_fn(X[istart:iend]))
return np.concatenate(X_ft, axis=0)
def _imscatterplot(x: np.ndarray,
y: np.ndarray,
images: np.ndarray,
ax: Optional[plt.Axes] = None,
fig_kw: Optional[dict] = None,
image_size: Tuple[int, int] = (28, 28),
zoom: Optional[np.ndarray] = None,
zoom_lb: float = 1.0,
zoom_ub=2.0,
sort_by_zoom: bool = True) -> plt.Axes:
"""
2D image scatter plot.
Parameters
----------
x
Images x-coordinates.
y
Images y-coordinates.
images
Array of images to be placed at coordinates `(x, y)`.
ax
A `matplotlib` axes object to plot on.
fig_kw
Keyword arguments passed to the `fig.set` function.
image_size
Size of the generated output image as `(rows, cols)`.
zoom
Images zoom to be used.
zoom_lb
Zoom lower bound. The zoom values will be scaled linearly between `[zoom_lb, zoom_up]`.
zoom_ub
Zoom upper bound. The zoom values will be scaled linearly between `[zoom_lb, zoom_up]`.
"""
if fig_kw is None:
fig_kw = {}
if zoom is None:
zoom = np.ones(len(images))
else:
zoom_min, zoom_max = np.min(zoom), np.max(zoom)
zoom = (zoom - zoom_min) / (zoom_max - zoom_min) * (zoom_ub - zoom_lb) + zoom_lb
if sort_by_zoom:
idx = np.argsort(zoom)[::-1] # type: ignore
zoom = zoom[idx] # type: ignore
x, y, images = x[idx], y[idx], images[idx]
if ax is None:
fig, ax = plt.subplots()
ax.set_xticks([])
ax.set_yticks([])
else:
fig = ax.figure
resized_imgs = [resize(images[i], image_size) for i in range(len(images))]
imgs = [OffsetImage(img, zoom=zoom[i], cmap='gray') for i, img in enumerate(resized_imgs)] # type: ignore
artists = []
for i in range(len(imgs)):
x0, y0, im = x[i], y[i], imgs[i]
ab = AnnotationBbox(im, (x0, y0), xycoords='data', frameon=False)
artists.append(ax.add_artist(ab))
ax.update_datalim(np.column_stack([x, y]))
ax.autoscale()
fig.set(**fig_kw)
return ax
[docs]def compute_prototype_importances(summary: 'Explanation',
trainset: Tuple[np.ndarray, np.ndarray],
preprocess_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
knn_kw: Optional[dict] = None) -> Dict[str, Optional[np.ndarray]]:
"""
Computes the importance of each prototype. The importance of a prototype is the number of assigned
training instances correctly classified according to the 1-KNN classifier
(Bien and Tibshirani (2012): https://arxiv.org/abs/1202.5933).
Parameters
----------
summary
An `Explanation` object produced by a call to the
:py:meth:`alibi.prototypes.protoselect.ProtoSelect.summarise` method.
trainset
Tuple, `(X_train, y_train)`, consisting of the training data instances with the corresponding labels.
preprocess_fn
Optional preprocessor function. If ``preprocess_fn=None``, no preprocessing is applied.
knn_kw
Keyword arguments passed to `sklearn.neighbors.KNeighborsClassifier`. The `n_neighbors` will be
set automatically to 1, but the `metric` has to be specified according to the kernel distance used.
If the `metric` is not specified, it will be set by default to ``'euclidean'``.
See parameters description:
https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html
Returns
-------
A dictionary containing:
- ``'prototype_indices'`` - an array of the prototype indices.
- ``'prototype_importances'`` - an array of prototype importances.
- ``'X_protos'`` - an array of raw prototypes.
- ``'X_protos_ft'`` - an optional array of preprocessed prototypes. If the ``preprocess_fn=None``, \
no preprocessing is applied and ``None`` is returned instead.
"""
if knn_kw is None:
knn_kw = {}
if knn_kw.get('metric') is None:
knn_kw.update({'metric': 'euclidean'})
logger.warning("KNN metric was not specified. Automatically setting `metric='euclidean'`.")
X_train, y_train = trainset
X_protos = summary.data['prototypes']
y_protos = summary.data['prototype_labels']
# preprocess the dataset
X_train_ft = _batch_preprocessing(X=X_train, preprocess_fn=preprocess_fn) \
if (preprocess_fn is not None) else X_train
X_protos_ft = _batch_preprocessing(X=X_protos, preprocess_fn=preprocess_fn) \
if (preprocess_fn is not None) else X_protos
# train knn classifier
knn = KNeighborsClassifier(n_neighbors=1, **knn_kw)
knn = knn.fit(X=X_protos_ft, y=y_protos)
# get neighbors indices for each training instance
neigh_idx = knn.kneighbors(X=X_train_ft, n_neighbors=1, return_distance=False).reshape(-1)
# compute how many correct labeled instances each prototype covers
idx, counts = np.unique(neigh_idx[y_protos[neigh_idx] == y_train], return_counts=True)
return {
'prototype_indices': idx,
'prototype_importances': counts,
'X_protos': X_protos[idx],
'X_protos_ft': None if (preprocess_fn is None) else X_protos_ft[idx]
}
[docs]def visualize_image_prototypes(summary: 'Explanation',
trainset: Tuple[np.ndarray, np.ndarray],
reducer: Callable[[np.ndarray], np.ndarray],
preprocess_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
knn_kw: Optional[dict] = None,
ax: Optional[plt.Axes] = None,
fig_kw: Optional[dict] = None,
image_size: Tuple[int, int] = (28, 28),
zoom_lb: float = 1.0,
zoom_ub: float = 3.0) -> plt.Axes:
"""
Plot the images of the prototypes at the location given by the `reducer` representation.
The size of each prototype is proportional to the logarithm of the number of assigned training instances correctly
classified according to the 1-KNN classifier (Bien and Tibshirani (2012): https://arxiv.org/abs/1202.5933).
Parameters
----------
summary
An `Explanation` object produced by a call to the
:py:meth:`alibi.prototypes.protoselect.ProtoSelect.summarise` method.
trainset
Tuple, `(X_train, y_train)`, consisting of the training data instances with the corresponding labels.
reducer
2D reducer. Reduces the input feature representation to 2D. Note that the reducer operates directly on the
input instances if ``preprocess_fn=None``. If the `preprocess_fn` is specified, the reducer will be called
on the feature representation obtained after passing the input instances through the `preprocess_fn`.
preprocess_fn
Optional preprocessor function. If ``preprocess_fn=None``, no preprocessing is applied.
knn_kw
Keyword arguments passed to `sklearn.neighbors.KNeighborsClassifier`. The `n_neighbors` will be
set automatically to 1, but the `metric` has to be specified according to the kernel distance used.
If the `metric` is not specified, it will be set by default to ``'euclidean'``.
See parameters description:
https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html
ax
A `matplotlib` axes object to plot on.
fig_kw
Keyword arguments passed to the `fig.set` function.
image_size
Shape to which the prototype images will be resized. A zoom of 1 will display the image having the shape
`image_size`.
zoom_lb
Zoom lower bound. The zoom will be scaled linearly between `[zoom_lb, zoom_ub]`.
zoom_ub
Zoom upper bound. The zoom will be scaled linearly between `[zoom_lb, zoom_ub]`.
"""
# compute how many correct labeled instances each prototype covers
protos_importance = compute_prototype_importances(summary=summary,
trainset=trainset,
preprocess_fn=preprocess_fn,
knn_kw=knn_kw)
# unpack values
counts = protos_importance['prototype_importances']
X_protos = protos_importance['X_protos']
X_protos_ft = protos_importance['X_protos_ft'] if (protos_importance['X_protos_ft'] is not None) else X_protos
# compute image zoom
zoom = np.log(counts) # type: ignore[arg-type]
# compute 2D embedding
protos_2d = reducer(X_protos_ft) # type: ignore[arg-type]
x, y = protos_2d[:, 0], protos_2d[:, 1]
# plot images
return _imscatterplot(x=x, y=y,
images=X_protos, # type: ignore[arg-type]
ax=ax, fig_kw=fig_kw,
image_size=image_size,
zoom=zoom, zoom_lb=zoom_lb, zoom_ub=zoom_ub)