Source code for alibi.models.pytorch.metrics

"""
This module contains a loss wrapper and a definition of various monitoring metrics used during training. The model
to be trained inherits form :py:class:`alibi.explainers.models.pytorch.model.Model` and represents a simplified
version of the `tensorflow.keras` API for training and monitoring the model. Currently it is used internally to test
the functionalities for the Pytorch backend. To be discussed if the module will be exposed to the user in future
versions.
"""

import torch
import numpy as np
from enum import Enum
from abc import ABC, abstractmethod
from typing import Dict, Union, Callable


[docs] class Reduction(Enum): """ Reduction operation supported by the monitoring metrics. """ SUM: str = 'sum' MEAN: str = 'mean'
[docs] class LossContainer: """ Loss wrapped to monitor the average loss throughout training. """
[docs] def __init__(self, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], name: str): """ Constructor. Parameters ---------- loss Loss function. name Name of the loss function """ self.name = name self.loss = loss self.total, self.count = 0., 0
[docs] def __call__(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Computes and accumulates the loss given the prediction labels and the true labels. Parameters ---------- y_pred Prediction labels. y_true True labels. Returns ------- Loss value. """ # compute loss loss = self.loss(y_pred, y_true) # add loss to the total self.total += loss.item() self.count += 1 return loss
[docs] def result(self) -> Dict[str, float]: """ Computes the average loss obtain by dividing the cumulated loss by the number of steps Returns ------- Average loss. """ return {self.name: self.total / self.count}
[docs] def reset(self): """ Resets the loss. """ self.total = 0 self.count = 0
[docs] class Metric(ABC): """ Monitoring metric object. Supports two types of reduction: mean and sum. """
[docs] def __init__(self, reduction: Reduction = Reduction.MEAN, name: str = "unknown"): """ Constructor. Parameters ---------- reduction Metric's reduction type. Possible values `mean`|`sum`. By default `mean`. name Name of the metric. """ self.reduction = reduction self.name = name self.total, self.count = 0, 0
[docs] @abstractmethod def compute_metric(self, y_pred: Union[torch.Tensor, np.ndarray], y_true: Union[torch.Tensor, np.ndarray]): pass
[docs] def update_state(self, values: np.ndarray): """ Update the state of the metric by summing up the metric values and updating the counts by adding the number of instances for which the metric was computed (first dimension). """ self.total += values.sum() self.count += values.shape[0]
[docs] def result(self) -> Dict[str, float]: """ Computes the result according to the reduction procedure. Returns ------- Monitoring metric. """ if self.reduction == Reduction.SUM: return {self.name: self.total} if self.reduction == Reduction.MEAN: return {self.name: np.nan_to_num(self.total / self.count)} raise NotImplementedError(f"Reduction {self.reduction} not implemented")
[docs] def reset(self): """ Resets the monitoring metric. """ self.total = 0 self.count = 0
[docs] class AccuracyMetric(Metric): """ Accuracy monitoring metric. """ def __init__(self, name: str = "accuracy"): super().__init__(reduction=Reduction.MEAN, name=name)
[docs] def compute_metric(self, y_pred: Union[torch.Tensor, np.ndarray], y_true: Union[torch.Tensor, np.ndarray]) -> None: """ Computes accuracy metric given the predicted label and the true label. Parameters ---------- y_pred Predicted label. y_true True label. """ if isinstance(y_true, torch.Tensor): y_true = y_true.detach().cpu().numpy() if isinstance(y_pred, torch.Tensor): y_pred = y_pred.detach().cpu().numpy() # in case y_pred is a distribution and not a label if len(y_pred.shape) > len(y_true.shape): y_pred = np.argmax(y_pred, axis=-1) # check if the shapes match if y_pred.shape != y_true.shape: raise ValueError("The shape of the prediction and labels do not match") matches = np.array(y_pred == y_true) super().update_state(matches)