Source code for alibi_detect.utils.tensorflow.misc

import tensorflow as tf


[docs] def zero_diag(mat: tf.Tensor) -> tf.Tensor: """ Set the diagonal of a matrix to 0 Parameters ---------- mat A 2D square matrix Returns ------- A 2D square matrix with zeros along the diagonal """ return mat - tf.linalg.diag(tf.linalg.diag_part(mat))
[docs] def quantile(sample: tf.Tensor, p: float, type: int = 7, sorted: bool = False) -> float: """ Estimate a desired quantile of a univariate distribution from a vector of samples Parameters ---------- sample A 1D vector of values p The desired quantile in (0,1) type The method for computing the quantile. See https://wikipedia.org/wiki/Quantile#Estimating_quantiles_from_a_sample sorted Whether or not the vector is already sorted into ascending order Returns ------- An estimate of the quantile """ N = len(sample) if len(sample.shape) != 1: raise ValueError("Quantile estimation only supports vectors of univariate samples.") if not 1/N <= p <= (N-1)/N: raise ValueError(f"The {p}-quantile should not be estimated using only {N} samples.") sorted_sample = sample if sorted else tf.sort(sample) if type == 6: h = (N+1)*p elif type == 7: h = (N-1)*p + 1 elif type == 8: h = (N+1/3)*p + 1/3 h_floor = int(h) quantile = sorted_sample[h_floor-1] if h_floor != h: quantile += (h - h_floor)*(sorted_sample[h_floor]-sorted_sample[h_floor-1]) return float(quantile)
[docs] def subset_matrix(mat: tf.Tensor, inds_0: tf.Tensor, inds_1: tf.Tensor) -> tf.Tensor: """ Take a matrix and return the submatrix correspond to provided row and column indices Parameters ---------- mat A 2D matrix inds_0 A vector of row indices inds_1 A vector of column indices Returns ------- A submatrix of shape (len(inds_0), len(inds_1)) """ if len(mat.shape) != 2: raise ValueError("Subsetting only supported for matrices (2D)") subbed_rows = tf.gather(mat, inds_0, axis=0) subbed_rows_cols = tf.gather(subbed_rows, inds_1, axis=1) return subbed_rows_cols
[docs] def clone_model(model: tf.keras.Model) -> tf.keras.Model: """ Clone a sequential, functional or subclassed tf.keras.Model. """ try: # sequential or functional model return tf.keras.models.clone_model(model) except ValueError: # subclassed model try: config = model.get_config() except NotImplementedError: config = {} return model.__class__.from_config(config)