alibi.explainers.backends.tensorflow.cfrl_tabular module
This module contains utility functions for the Counterfactual with Reinforcement Learning tabular class (cfrl_tabular) for the Tensorflow backend.
- alibi.explainers.backends.tensorflow.cfrl_tabular.consistency_loss(Z_cf_pred, Z_cf_tgt, **kwargs)[source]
Computes heterogeneous consistency loss.
- Parameters:
Z_cf_pred (
Tensor
) – Counterfactual embedding prediction.Z_cf_tgt (
Union
[ndarray
,Tensor
]) – Counterfactual embedding target.
- Returns:
Heterogeneous consistency loss.
- alibi.explainers.backends.tensorflow.cfrl_tabular.l0_ohe(input, target, reduction='none')[source]
Computes the L0 loss for a one-hot encoding representation.
- Parameters:
input (
Tensor
) – Input tensor.target (
Tensor
) – Target tensorreduction (
str
) – Specifies the reduction to apply to the output:'none'
|'mean'
|'sum'
.
- Return type:
Tensor
- Returns:
L0 loss.
- alibi.explainers.backends.tensorflow.cfrl_tabular.l1_loss(input, target=tensorflow.Tensor, reduction='none')[source]
Computes the L1 loss.
- Parameters:
input (
Tensor
) – Input tensor.target – Target tensor
reduction (
str
) – Specifies the reduction to apply to the output:'none'
|'mean'
|'sum'
.
- Return type:
Tensor
- Returns:
L1 loss.
- alibi.explainers.backends.tensorflow.cfrl_tabular.sample_differentiable(X_hat_split, category_map)[source]
Samples differentiable reconstruction.
- Parameters:
- Return type:
List
[Tensor
]- Returns:
Differentiable reconstruction.
- alibi.explainers.backends.tensorflow.cfrl_tabular.sparsity_loss(X_hat_split, X_ohe, category_map, weight_num=1.0, weight_cat=1.0)[source]
Computes heterogeneous sparsity loss.
- Parameters:
X_hat_split (
List
[Tensor
]) – List of reconstructed columns form the auto-encoder.X_ohe (
Tensor
) – One-hot encoded representation of the input.category_map (
Dict
[int
,List
[str
]]) – Dictionary of category mapping. The keys are column indexes and the values are lists containing the possible values for an attribute.weight_num (
float
) – Numerical loss weight.weight_cat (
float
) – Categorical loss weight.
- Returns:
Heterogeneous sparsity loss.