alibi.explainers.backends.pytorch.cfrl_tabular module

This module contains utility functions for the Counterfactual with Reinforcement Learning tabular class, alibi.explainers.cfrl_tabular, for the Pytorch backend.

alibi.explainers.backends.pytorch.cfrl_tabular.consistency_loss(Z_cf_pred, Z_cf_tgt, **kwargs)[source]

Computes heterogeneous consistency loss.

Parameters
  • Z_cf_pred (Tensor) – Predicted counterfactual embedding.

  • Z_cf_tgt (Tensor) – Counterfactual embedding target.

Returns

Heterogeneous consistency loss.

alibi.explainers.backends.pytorch.cfrl_tabular.l0_ohe(input, target, reduction='none')[source]

Computes the L0 loss for a one-hot encoding representation.

Parameters
  • input (Tensor) – Input tensor.

  • target (Tensor) – Target tensor

  • reduction (str) – Specifies the reduction to apply to the output: none | mean | sum.

Return type

Tensor

Returns

L0 loss.

alibi.explainers.backends.pytorch.cfrl_tabular.l1_loss(input, target, reduction='none')[source]

Computes L1 loss.

Parameters
  • input (Tensor) – Input tensor.

  • target (Tensor) – Target tensor.

  • reduction (str) – Specifies the reduction to apply to the output: none | mean | sum.

Return type

Tensor

Returns

L1 loss.

alibi.explainers.backends.pytorch.cfrl_tabular.sample_differentiable(X_hat_split, category_map)[source]

Samples differentiable reconstruction.

Parameters
  • X_hat_split (List[Tensor]) – List of reconstructed columns form the auto-encoder.

  • category_map (Dict[int, List[str]]) – Dictionary of category mapping. The keys are column indexes and the values are lists containing the possible values for an attribute.

Return type

List[Tensor]

Returns

Differentiable reconstruction.

alibi.explainers.backends.pytorch.cfrl_tabular.sparsity_loss(X_hat_split, X_ohe, category_map, weight_num=1.0, weight_cat=1.0)[source]

Computes heterogeneous sparsity loss.

Parameters
  • X_hat_split (List[Tensor]) – List of one-hot encoded reconstructed columns form the auto-encoder.

  • X_ohe (Tensor) – One-hot encoded representation of the input.

  • category_map (Dict[int, List[str]]) – Dictionary of category mapping. The keys are column indexes and the values are lists containing the possible values for an attribute.

  • weight_num (float) – Numerical loss weight.

  • weight_cat (float) – Categorical loss weight.

Returns

Heterogeneous sparsity loss.