from typing import Union, Tuple, Callable
import numpy as np
[docs]
def perturb(X: np.ndarray,
eps: Union[float, np.ndarray] = 1e-08,
proba: bool = False) -> Tuple[np.ndarray, np.ndarray]:
"""
Apply perturbation to instance or prediction probabilities. Used for numerical calculation of gradients.
Parameters
----------
X
Array to be perturbed.
eps
Size of perturbation.
proba
If ``True``, the net effect of the perturbation needs to be 0 to keep the sum of the probabilities equal to 1.
Returns
-------
Instances where a positive and negative perturbation is applied.
"""
# N = batch size; F = nb of features in X
shape = X.shape
X = np.reshape(X, (shape[0], -1)) # NxF
dim = X.shape[1] # F
pert = np.tile(np.eye(dim) * eps, (shape[0], 1)) # (N*F)xF
if proba:
eps_n = eps / (dim - 1)
pert += np.tile((np.eye(dim) - np.ones((dim, dim))) * eps_n, (shape[0], 1)) # (N*F)xF
X_rep = np.repeat(X, dim, axis=0) # (N*F)xF
X_pert_pos, X_pert_neg = X_rep + pert, X_rep - pert
shape = (dim * shape[0],) + shape[1:]
X_pert_pos = np.reshape(X_pert_pos, shape) # (N*F)x(shape of X[0])
X_pert_neg = np.reshape(X_pert_neg, shape) # (N*F)x(shape of X[0])
return X_pert_pos, X_pert_neg
[docs]
def num_grad_batch(func: Callable,
X: np.ndarray,
args: Tuple = (),
eps: Union[float, np.ndarray] = 1e-08) -> np.ndarray:
"""
Calculate the numerical gradients of a vector-valued function (typically a prediction function in classification)
with respect to a batch of arrays `X`.
Parameters
----------
func
Function to be differentiated.
X
A batch of vectors at which to evaluate the gradient of the function.
args
Any additional arguments to pass to the function.
eps
Gradient step to use in the numerical calculation, can be a single `float` or one for each feature.
Returns
-------
An array of gradients at each point in the batch `X`.
"""
# N = gradient batch size; F = nb of features in X, P = nb of prediction classes, B = instance batch size
batch_size = X.shape[0]
data_shape = X[0].shape
preds = func(X, *args)
X_pert_pos, X_pert_neg = perturb(X, eps) # (N*F)x(shape of X[0])
X_pert = np.concatenate([X_pert_pos, X_pert_neg], axis=0)
preds_concat = func(X_pert, *args) # make predictions
n_pert = X_pert_pos.shape[0]
grad_numerator = preds_concat[:n_pert] - preds_concat[n_pert:] # (N*F)*P
grad_numerator = np.reshape(np.reshape(grad_numerator, (batch_size, -1)),
(batch_size, preds.shape[1], -1), order='F') # NxPxF
grad = grad_numerator / (2 * eps) # NxPxF
grad = grad.reshape(preds.shape + data_shape) # BxPx(shape of X[0])
return grad