Source code for alibi_detect.models.tensorflow.pixelcnn

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import numpy as np
import warnings
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.distributions import categorical
from tensorflow_probability.python.distributions import distribution
from tensorflow_probability.python.distributions import independent
from tensorflow_probability.python.distributions import logistic
from tensorflow_probability.python.distributions import mixture_same_family
from tensorflow_probability.python.distributions import quantized_distribution
from tensorflow_probability.python.distributions import transformed_distribution
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util


__all__ = [
    'Shift',
]


class WeightNorm(tf.keras.layers.Wrapper):
    def __init__(self, layer, data_init: bool = True, **kwargs):
        """Layer wrapper to decouple magnitude and direction of the layer's weights.

        This wrapper reparameterizes a layer by decoupling the weight's
        magnitude and direction. This speeds up convergence by improving the
        conditioning of the optimization problem. It has an optional data-dependent
        initialization scheme, in which initial values of weights are set as functions
        of the first minibatch of data. Both the weight normalization and data-
        dependent initialization are described in [Salimans and Kingma (2016)][1].

        Parameters
        ----------
        layer
            A `tf.keras.layers.Layer` instance. Supported layer types are
            `Dense`, `Conv2D`, and `Conv2DTranspose`. Layers with multiple inputs
            are not supported.
        data_init
            If `True` use data dependent variable initialization.
        **kwargs
            Additional keyword args passed to `tf.keras.layers.Wrapper`.

        Raises
        ------
        ValueError
            If `layer` is not a `tf.keras.layers.Layer` instance.
        """
        if not isinstance(layer, tf.keras.layers.Layer):
            raise ValueError(
                'Please initialize `WeightNorm` layer with a `tf.keras.layers.Layer` '
                'instance. You passed: {input}'.format(input=layer)
            )

        layer_type = type(layer).__name__
        if layer_type not in ['Dense', 'Conv2D', 'Conv2DTranspose']:
            warnings.warn('`WeightNorm` is tested only for `Dense`, `Conv2D`, and '
                          '`Conv2DTranspose` layers. You passed a layer of type `{}`'
                          .format(layer_type))

        super(WeightNorm, self).__init__(layer, **kwargs)

        self.data_init = data_init
        self._track_trackable(layer, name='layer')
        self.filter_axis = -2 if layer_type == 'Conv2DTranspose' else -1

    def _compute_weights(self):
        """Generate weights with normalization."""
        # Determine the axis along which to expand `g` so that `g` broadcasts to
        # the shape of `v`.
        new_axis = -self.filter_axis - 3

        self.layer.kernel = tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * tf.expand_dims(self.g, new_axis)

    def _init_norm(self):
        """Set the norm of the weight vector."""
        kernel_norm = tf.sqrt(tf.reduce_sum(tf.square(self.v), axis=self.kernel_norm_axes))
        self.g.assign(kernel_norm)

    def _data_dep_init(self, inputs):
        """Data dependent initialization."""
        # Normalize kernel first so that calling the layer calculates
        # `tf.dot(v, x)/tf.norm(v)` as in (5) in ([Salimans and Kingma, 2016][1]).
        self._compute_weights()

        activation = self.layer.activation
        self.layer.activation = None

        use_bias = self.layer.bias is not None
        if use_bias:
            bias = self.layer.bias
            self.layer.bias = tf.zeros_like(bias)

        # Since the bias is initialized as zero, setting the activation to zero and
        # calling the initialized layer (with normalized kernel) yields the correct
        # computation ((5) in Salimans and Kingma (2016))
        x_init = self.layer(inputs)
        norm_axes_out = list(range(x_init.shape.rank - 1))
        m_init, v_init = tf.nn.moments(x_init, norm_axes_out)
        scale_init = 1. / tf.sqrt(v_init + 1e-10)

        self.g.assign(self.g * scale_init)
        if use_bias:
            self.layer.bias = bias
            self.layer.bias.assign(-m_init * scale_init)
        self.layer.activation = activation

    def build(self, input_shape=None):
        """Build `Layer`.

        Parameters
        ----------
        input_shape
            The shape of the input to `self.layer`.

        Raises
        ------
        ValueError
            If `Layer` does not contain a `kernel` of weights.
        """
        input_shape = tf.TensorShape(input_shape).as_list()
        input_shape[0] = None
        self.input_spec = tf.keras.layers.InputSpec(shape=input_shape)

        if not self.layer.built:
            self.layer.build(input_shape)

            if not hasattr(self.layer, 'kernel'):
                raise ValueError('`WeightNorm` must wrap a layer that contains a `kernel` for weights')

            self.kernel_norm_axes = list(range(self.layer.kernel.shape.ndims))
            self.kernel_norm_axes.pop(self.filter_axis)

            self.v = self.layer.kernel

            # to avoid a duplicate `kernel` variable after `build` is called
            self.layer.kernel = None
            self.g = self.add_weight(
                name='g',
                shape=(int(self.v.shape[self.filter_axis]),),
                initializer='ones',
                dtype=self.v.dtype,
                trainable=True
            )
            self.initialized = self.add_weight(
                name='initialized',
                dtype=tf.bool,
                trainable=False
            )
            self.initialized.assign(False)

        super(WeightNorm, self).build()

    @tf.function
    def call(self, inputs):
        """Call `Layer`."""
        if not self.initialized:
            if self.data_init:
                self._data_dep_init(inputs)
            else:  # initialize `g` as the norm of the initialized kernel
                self._init_norm()

            self.initialized.assign(True)

        self._compute_weights()
        output = self.layer(inputs)
        return output

    def compute_output_shape(self, input_shape):
        return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())


[docs] class Shift(bijector.Bijector):
[docs] def __init__(self, shift, validate_args=False, name='shift'): """Instantiates the `Shift` bijector which computes `Y = g(X; shift) = X + shift` where `shift` is a numeric `Tensor`. Parameters ---------- shift Floating-point `Tensor`. validate_args Python `bool` indicating whether arguments should be checked for correctness. name Python `str` name given to ops managed by this object. """ with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([shift], dtype_hint=tf.float32) self._shift = tensor_util.convert_nonref_to_tensor(shift, dtype=dtype, name='shift') super(Shift, self).__init__( forward_min_event_ndims=0, is_constant_jacobian=True, dtype=dtype, validate_args=validate_args, name=name )
@property def shift(self): """The `shift` `Tensor` in `Y = X + shift`.""" return self._shift @classmethod def _is_increasing(cls): return True def _forward(self, x): return x + self.shift def _inverse(self, y): return y - self.shift def _forward_log_det_jacobian(self, x): # is_constant_jacobian = True for this bijector, hence the # `log_det_jacobian` need only be specified for a single input, as this will # be tiled to match `event_ndims`. return tf.zeros([], dtype=dtype_util.base_dtype(x.dtype))
[docs] class PixelCNN(distribution.Distribution):
[docs] def __init__(self, image_shape: tuple, conditional_shape: tuple = None, num_resnet: int = 5, num_hierarchies: int = 3, num_filters: int = 160, num_logistic_mix: int = 10, receptive_field_dims: tuple = (3, 3), dropout_p: float = 0.5, resnet_activation: str = 'concat_elu', l2_weight: float = 0., use_weight_norm: bool = True, use_data_init: bool = True, high: int = 255, low: int = 0, dtype=tf.float32, name: str = 'PixelCNN') -> None: """ Construct Pixel CNN++ distribution. Parameters ---------- image_shape 3D `TensorShape` or tuple for the `[height, width, channels]` dimensions of the image. conditional_shape `TensorShape` or tuple for the shape of the conditional input, or `None` if there is no conditional input. num_resnet The number of layers (shown in Figure 2 of [2]) within each highest-level block of Figure 2 of [1]. num_hierarchies The number of highest-level blocks (separated by expansions/contractions of dimensions in Figure 2 of [1].) num_filters The number of convolutional filters. num_logistic_mix Number of components in the logistic mixture distribution. receptive_field_dims Height and width in pixels of the receptive field of the convolutional layers above and to the left of a given pixel. The width (second element of the tuple) should be odd. Figure 1 (middle) of [2] shows a receptive field of (3, 5) (the row containing the current pixel is included in the height). The default of (3, 3) was used to produce the results in [1]. dropout_p The dropout probability. Should be between 0 and 1. resnet_activation The type of activation to use in the resnet blocks. May be 'concat_elu', 'elu', or 'relu'. l2_weight The L2 regularization weight. use_weight_norm If `True` then use weight normalization (works only in Eager mode). use_data_init If `True` then use data-dependent initialization (has no effect if `use_weight_norm` is `False`). high The maximum value of the input data (255 for an 8-bit image). low The minimum value of the input data. dtype Data type of the `Distribution`. name The name of the `Distribution`. """ parameters = dict(locals()) with tf.name_scope(name) as name: super(PixelCNN, self).__init__( dtype=dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=False, allow_nan_stats=True, parameters=parameters, name=name ) if not tensorshape_util.is_fully_defined(image_shape): raise ValueError('`image_shape` must be fully defined.') if conditional_shape is not None and not tensorshape_util.is_fully_defined(conditional_shape): raise ValueError('`conditional_shape` must be fully defined`') if tensorshape_util.rank(image_shape) != 3: raise ValueError('`image_shape` must have length 3, representing [height, width, channels] dimensions.') self._high = tf.cast(high, self.dtype) self._low = tf.cast(low, self.dtype) self._num_logistic_mix = num_logistic_mix self.network = _PixelCNNNetwork( dropout_p=dropout_p, num_resnet=num_resnet, num_hierarchies=num_hierarchies, num_filters=num_filters, num_logistic_mix=num_logistic_mix, receptive_field_dims=receptive_field_dims, resnet_activation=resnet_activation, l2_weight=l2_weight, use_weight_norm=use_weight_norm, use_data_init=use_data_init, dtype=dtype ) image_input_shape = tensorshape_util.concatenate([None], image_shape) if conditional_shape is None: input_shape = image_input_shape else: conditional_input_shape = tensorshape_util.concatenate([None], conditional_shape) input_shape = [image_input_shape, conditional_input_shape] self.image_shape = image_shape self.conditional_shape = conditional_shape self.network.build(input_shape)
def _make_mixture_dist(self, component_logits, locs, scales, return_per_feature: bool = False): """Builds a mixture of quantized logistic distributions. Parameters ---------- component_logits 4D `Tensor` of logits for the Categorical distribution over Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix]`. locs 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. scales 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. return_per_feature If True, return per pixel level log prob. Returns ------- dist A quantized logistic mixture `tfp.distribution` over the input data. """ mixture_distribution = categorical.Categorical(logits=component_logits) # Convert distribution parameters for pixel values in # `[self._low, self._high]` for use with `QuantizedDistribution` locs = self._low + 0.5 * (self._high - self._low) * (locs + 1.) scales *= 0.5 * (self._high - self._low) logistic_dist = quantized_distribution.QuantizedDistribution( distribution=transformed_distribution.TransformedDistribution( distribution=logistic.Logistic(loc=locs, scale=scales), bijector=Shift(shift=tf.cast(-0.5, self.dtype))), low=self._low, high=self._high) # mixture with logistics for the loc and scale on each pixel for each component dist = mixture_same_family.MixtureSameFamily( mixture_distribution=mixture_distribution, components_distribution=independent.Independent(logistic_dist, reinterpreted_batch_ndims=1)) if return_per_feature: return dist else: return independent.Independent(dist, reinterpreted_batch_ndims=2) def _log_prob(self, value, conditional_input=None, training=None, return_per_feature=False): """Log probability function with optional conditional input. Calculates the log probability of a batch of data under the modeled distribution (or conditional distribution, if conditional input is provided). Parameters ---------- value `Tensor` or Numpy array of image data. May have leading batch dimension(s), which must broadcast to the leading batch dimensions of `conditional_input`. conditional_input `Tensor` on which to condition the distribution (e.g. class labels), or `None`. May have leading batch dimension(s), which must broadcast to the leading batch dimensions of `value`. training `bool` or `None`. If `bool`, it controls the dropout layer, where `True` implies dropout is active. If `None`, it defaults to `tf.keras.backend.learning_phase()`. return_per_feature `bool`. If True, return per pixel level log prob. Returns ------- log_prob_values: `Tensor`. """ # Determine the batch shape of the input images image_batch_shape = prefer_static.shape(value)[:-3] # Broadcast `value` and `conditional_input` to the same batch_shape if conditional_input is None: image_batch_and_conditional_shape = image_batch_shape else: conditional_input = tf.convert_to_tensor(conditional_input) conditional_input_shape = prefer_static.shape(conditional_input) conditional_batch_rank = (prefer_static.rank(conditional_input) - tensorshape_util.rank(self.conditional_shape)) conditional_batch_shape = conditional_input_shape[:conditional_batch_rank] image_batch_and_conditional_shape = prefer_static.broadcast_shape( image_batch_shape, conditional_batch_shape) conditional_input = tf.broadcast_to( conditional_input, prefer_static.concat([image_batch_and_conditional_shape, self.conditional_shape], axis=0)) value = tf.broadcast_to(value, prefer_static.concat( [image_batch_and_conditional_shape, self.event_shape], axis=0)) # Flatten batch dimension for input to Keras model conditional_input = tf.reshape( conditional_input, prefer_static.concat([(-1,), self.conditional_shape], axis=0)) value = tf.reshape(value, prefer_static.concat([(-1,), self.event_shape], axis=0)) transformed_value = (2. * (value - self._low) / (self._high - self._low)) - 1. inputs = transformed_value if conditional_input is None else [transformed_value, conditional_input] params = self.network(inputs, training=training) num_channels = self.event_shape[-1] if num_channels == 1: component_logits, locs, scales = params else: # If there is more than one channel, we create a linear autoregressive # dependency among the location parameters of the channels of a single # pixel (the scale parameters within a pixel are independent). For a pixel # with R/G/B channels, the `r`, `g`, and `b` saturation values are # distributed as: # # r ~ Logistic(loc_r, scale_r) # g ~ Logistic(coef_rg * r + loc_g, scale_g) # b ~ Logistic(coef_rb * r + coef_gb * g + loc_b, scale_b) # on the coefficients instead of split/multiply/concat component_logits, locs, scales, coeffs = params num_coeffs = num_channels * (num_channels - 1) // 2 loc_tensors = tf.split(locs, num_channels, axis=-1) coef_tensors = tf.split(coeffs, num_coeffs, axis=-1) channel_tensors = tf.split(value, num_channels, axis=-1) coef_count = 0 for i in range(num_channels): channel_tensors[i] = channel_tensors[i][..., tf.newaxis, :] for j in range(i): loc_tensors[i] += channel_tensors[j] * coef_tensors[coef_count] coef_count += 1 locs = tf.concat(loc_tensors, axis=-1) dist = self._make_mixture_dist(component_logits, locs, scales, return_per_feature=return_per_feature) log_px = dist.log_prob(value) if return_per_feature: return log_px else: return tf.reshape(log_px, image_batch_and_conditional_shape) def _sample_n(self, n, seed=None, conditional_input=None, training=False): """Samples from the distribution, with optional conditional input. Parameters ---------- n `int`, number of samples desired. seed `int`, seed for RNG. Setting a random seed enforces reproducibility of the samples between sessions (not within a single session). conditional_input `Tensor` on which to condition the distribution (e.g. class labels), or `None`. training `bool` or `None`. If `bool`, it controls the dropout layer, where `True` implies dropout is active. If `None`, it defers to Keras' handling of train/eval status. Returns ------- samples a `Tensor` of shape `[n, height, width, num_channels]`. """ if conditional_input is not None: conditional_input = tf.convert_to_tensor(conditional_input, dtype=self.dtype) conditional_event_rank = tensorshape_util.rank(self.conditional_shape) conditional_input_shape = prefer_static.shape(conditional_input) conditional_sample_rank = prefer_static.rank(conditional_input) - conditional_event_rank # If `conditional_input` has no sample dimensions, prepend a sample # dimension if conditional_sample_rank == 0: conditional_input = conditional_input[tf.newaxis, ...] conditional_sample_rank = 1 # Assert that the conditional event shape in the `PixelCnnNetwork` is the # same as that implied by `conditional_input`. conditional_event_shape = conditional_input_shape[conditional_sample_rank:] with tf.control_dependencies([tf.assert_equal(self.conditional_shape, conditional_event_shape)]): conditional_sample_shape = conditional_input_shape[:conditional_sample_rank] repeat = n // prefer_static.reduce_prod(conditional_sample_shape) h = tf.reshape(conditional_input, prefer_static.concat([(-1,), self.conditional_shape], axis=0)) h = tf.tile(h, prefer_static.pad([repeat], paddings=[[0, conditional_event_rank]], constant_values=1)) samples_0 = tf.random.uniform( prefer_static.concat([(n,), self.event_shape], axis=0), minval=-1., maxval=1., dtype=self.dtype, seed=seed) inputs = samples_0 if conditional_input is None else [samples_0, h] params_0 = self.network(inputs, training=training) samples_0 = self._sample_channels(*params_0, seed=seed) image_height, image_width, _ = tensorshape_util.as_list(self.event_shape) def loop_body(index, samples): """Loop for iterative pixel sampling. Parameters ---------- index 0D `Tensor` of type `int32`. Index of the current pixel. samples 4D `Tensor`. Images with pixels sampled in raster order, up to pixel `[index]`, with dimensions `[batch_size, height, width, num_channels]`. Returns ------- samples 4D `Tensor`. Images with pixels sampled in raster order, up to \ and including pixel `[index]`, with dimensions `[batch_size, height, \ width, num_channels]`. """ inputs = samples if conditional_input is None else [samples, h] params = self.network(inputs, training=training) samples_new = self._sample_channels(*params, seed=seed) # Update the current pixel samples = tf.transpose(samples, [1, 2, 3, 0]) samples_new = tf.transpose(samples_new, [1, 2, 3, 0]) row, col = index // image_width, index % image_width updates = samples_new[row, col, ...][tf.newaxis, ...] samples = tf.tensor_scatter_nd_update(samples, [[row, col]], updates) samples = tf.transpose(samples, [3, 0, 1, 2]) return index + 1, samples index0 = tf.zeros([], dtype=tf.int32) # Construct the while loop for sampling total_pixels = image_height * image_width loop_cond = lambda ind, _: tf.less(ind, total_pixels) # noqa: E731 init_vars = (index0, samples_0) _, samples = tf.while_loop(loop_cond, loop_body, init_vars, parallel_iterations=1) transformed_samples = (self._low + 0.5 * (self._high - self._low) * (samples + 1.)) return tf.round(transformed_samples) def _sample_channels(self, component_logits, locs, scales, coeffs=None, seed=None): """Sample a single pixel-iteration and apply channel conditioning. Parameters ---------- component_logits 4D `Tensor` of logits for the Categorical distribution over Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix]`. locs 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. scales 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. coeffs 4D `Tensor` of coefficients for the linear dependence among color channels, or `None` if there is only one channel. Dimensions are `[batch_size, height, width, num_logistic_mix, num_coeffs]`, where `num_coeffs = num_channels * (num_channels - 1) // 2`. seed `int`, random seed. Returns ------- samples 4D `Tensor` of sampled image data with autoregression among \ channels. Dimensions are `[batch_size, height, width, num_channels]`. """ num_channels = self.event_shape[-1] # sample mixture components once for the entire pixel component_dist = categorical.Categorical(logits=component_logits) mask = tf.one_hot(indices=component_dist.sample(seed=seed), depth=self._num_logistic_mix) mask = tf.cast(mask[..., tf.newaxis], self.dtype) # apply mixture component mask and separate out RGB parameters masked_locs = tf.reduce_sum(locs * mask, axis=-2) loc_tensors = tf.split(masked_locs, num_channels, axis=-1) masked_scales = tf.reduce_sum(scales * mask, axis=-2) scale_tensors = tf.split(masked_scales, num_channels, axis=-1) if coeffs is not None: num_coeffs = num_channels * (num_channels - 1) // 2 masked_coeffs = tf.reduce_sum(coeffs * mask, axis=-2) coef_tensors = tf.split(masked_coeffs, num_coeffs, axis=-1) channel_samples = [] coef_count = 0 for i in range(num_channels): loc = loc_tensors[i] for c in channel_samples: loc += c * coef_tensors[coef_count] coef_count += 1 logistic_samp = logistic.Logistic(loc=loc, scale=scale_tensors[i]).sample(seed=seed) logistic_samp = tf.clip_by_value(logistic_samp, -1., 1.) channel_samples.append(logistic_samp) return tf.concat(channel_samples, axis=-1) def _batch_shape(self): return tf.TensorShape([]) def _event_shape(self): return tf.TensorShape(self.image_shape)
class _PixelCNNNetwork(tf.keras.layers.Layer): """Keras `Layer` to parameterize a Pixel CNN++ distribution. This is a Keras implementation of the Pixel CNN++ network, as described in Salimans et al. (2017)[1] and van den Oord et al. (2016)[2]. (https://github.com/openai/pixel-cnn). #### References [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma. PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications. In _International Conference on Learning Representations_, 2017. https://pdfs.semanticscholar.org/9e90/6792f67cbdda7b7777b69284a81044857656.pdf Additional details at https://github.com/openai/pixel-cnn [2]: Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex Graves, and Koray Kavukcuoglu. Conditional Image Generation with PixelCNN Decoders. In _30th Conference on Neural Information Processing Systems_, 2016. https://papers.nips.cc/paper/6527-conditional-image-generation-with-pixelcnn-decoders.pdf. """ def __init__(self, dropout_p: float = 0.5, num_resnet: int = 5, num_hierarchies: int = 3, num_filters: int = 160, num_logistic_mix: int = 10, receptive_field_dims: tuple = (3, 3), resnet_activation: str = 'concat_elu', l2_weight: float = 0., use_weight_norm: bool = True, use_data_init: bool = True, dtype=tf.float32) -> None: """Initialize the neural network for the Pixel CNN++ distribution. Parameters ---------- dropout_p `float`, the dropout probability. Should be between 0 and 1. num_resnet `int`, the number of layers (shown in Figure 2 of [2]) within each highest-level block of Figure 2 of [1]. num_hierarchies `int`, the number of hightest-level blocks (separated by expansions/contractions of dimensions in Figure 2 of [1].) num_filters `int`, the number of convolutional filters. num_logistic_mix `int`, number of components in the logistic mixture distribution. receptive_field_dims `tuple`, height and width in pixels of the receptive field of the convolutional layers above and to the left of a given pixel. The width (second element of the tuple) should be odd. Figure 1 (middle) of [2] shows a receptive field of (3, 5) (the row containing the current pixel is included in the height). The default of (3, 3) was used to produce the results in [1]. resnet_activation `string`, the type of activation to use in the resnet blocks. May be 'concat_elu', 'elu', or 'relu'. l2_weight `float`, the L2 regularization weight. use_weight_norm `bool`, if `True` then use weight normalization. use_data_init `bool`, if `True` then use data-dependent initialization (has no effect if `use_weight_norm` is `False`). dtype Data type of the layer. """ super(_PixelCNNNetwork, self).__init__(dtype=dtype) self._dropout_p = dropout_p self._num_resnet = num_resnet self._num_hierarchies = num_hierarchies self._num_filters = num_filters self._num_logistic_mix = num_logistic_mix self._receptive_field_dims = receptive_field_dims # first set desired receptive field, then infer kernel self._resnet_activation = resnet_activation self._l2_weight = l2_weight if use_weight_norm: def layer_wrapper(layer): def wrapped_layer(*args, **kwargs): return WeightNorm(layer(*args, **kwargs), data_init=use_data_init) return wrapped_layer self._layer_wrapper = layer_wrapper else: self._layer_wrapper = lambda layer: layer def build(self, input_shape): dtype = self.dtype if len(input_shape) == 2: batch_image_shape, batch_conditional_shape = input_shape conditional_input = tf.keras.layers.Input(shape=batch_conditional_shape[1:], dtype=dtype) else: batch_image_shape = input_shape conditional_input = None image_shape = batch_image_shape[1:] image_input = tf.keras.layers.Input(shape=image_shape, dtype=dtype) if self._resnet_activation == 'concat_elu': activation = tf.keras.layers.Lambda(lambda x: tf.nn.elu(tf.concat([x, -x], axis=-1)), dtype=dtype) else: activation = tf.keras.activations.get(self._resnet_activation) # Define layers with default inputs and layer wrapper applied Conv2D = functools.partial( # pylint:disable=invalid-name self._layer_wrapper(tf.keras.layers.Convolution2D), filters=self._num_filters, padding='same', kernel_regularizer=tf.keras.regularizers.l2(self._l2_weight), dtype=dtype) Dense = functools.partial( # pylint:disable=invalid-name self._layer_wrapper(tf.keras.layers.Dense), kernel_regularizer=tf.keras.regularizers.l2(self._l2_weight), dtype=dtype) Conv2DTranspose = functools.partial( # pylint:disable=invalid-name self._layer_wrapper(tf.keras.layers.Conv2DTranspose), filters=self._num_filters, padding='same', strides=(2, 2), kernel_regularizer=tf.keras.regularizers.l2(self._l2_weight), dtype=dtype) rows, cols = self._receptive_field_dims # Define the dimensions of the valid (unmasked) areas of the layer kernels # for stride 1 convolutions in the internal layers. kernel_valid_dims = {'vertical': (rows - 1, cols), # vertical stack 'horizontal': (2, cols // 2 + 1)} # horizontal stack # Define the size of the kernel necessary to center the current pixel # correctly for stride 1 convolutions in the internal layers. kernel_sizes = {'vertical': (2 * rows - 3, cols), 'horizontal': (3, cols)} # Make the kernel constraint functions for stride 1 convolutions in internal # layers. kernel_constraints = { k: _make_kernel_constraint(kernel_sizes[k], (0, v[0]), (0, v[1])) for k, v in kernel_valid_dims.items()} # Build the initial vertical stack/horizontal stack convolutional layers, # as shown in Figure 1 of [2]. The receptive field of the initial vertical # stack layer is a rectangular area centered above the current pixel. vertical_stack_init = Conv2D( kernel_size=(2 * rows - 1, cols), kernel_constraint=_make_kernel_constraint((2 * rows - 1, cols), (0, rows - 1), (0, cols)))(image_input) # In Figure 1 [2], the receptive field of the horizontal stack is # illustrated as the pixels in the same row and to the left of the current # pixel. [1] increases the height of this receptive field from one pixel to # two (`horizontal_stack_left`) and additionally includes a subset of the # row of pixels centered above the current pixel (`horizontal_stack_up`). horizontal_stack_up = Conv2D( kernel_size=(3, cols), kernel_constraint=_make_kernel_constraint((3, cols), (0, 1), (0, cols)))(image_input) horizontal_stack_left = Conv2D( kernel_size=(3, cols), kernel_constraint=_make_kernel_constraint((3, cols), (0, 2), (0, cols // 2)))(image_input) horizontal_stack_init = tf.keras.layers.add([horizontal_stack_up, horizontal_stack_left], dtype=dtype) layer_stacks = { 'vertical': [vertical_stack_init], 'horizontal': [horizontal_stack_init] } # Build the downward pass of the U-net (left-hand half of Figure 2 of [1]). # Each `i` iteration builds one of the highest-level blocks (identified as # 'Sequence of 6 layers' in the figure, consisting of `num_resnet=5` stride- # 1 layers, and one stride-2 layer that contracts the height/width # dimensions). The `_` iterations build the stride 1 layers. The layers of # the downward pass are stored in lists, since we'll later need them to make # skip-connections to layers in the upward pass of the U-net (the skip- # connections are represented by curved lines in Figure 2 [1]). for i in range(self._num_hierarchies): for _ in range(self._num_resnet): # Build a layer shown in Figure 2 of [2]. The 'vertical' iteration # builds the layers in the left half of the figure, and the 'horizontal' # iteration builds the layers in the right half. for stack in ['vertical', 'horizontal']: input_x = layer_stacks[stack][-1] x = activation(input_x) x = Conv2D(kernel_size=kernel_sizes[stack], kernel_constraint=kernel_constraints[stack])(x) # Add the vertical-stack layer to the horizontal-stack layer if stack == 'horizontal': h = activation(layer_stacks['vertical'][-1]) h = Dense(self._num_filters)(h) x = tf.keras.layers.add([h, x], dtype=dtype) x = activation(x) x = tf.keras.layers.Dropout(self._dropout_p, dtype=dtype)(x) x = Conv2D(filters=2*self._num_filters, kernel_size=kernel_sizes[stack], kernel_constraint=kernel_constraints[stack])(x) if conditional_input is not None: h_projection = _build_and_apply_h_projection(conditional_input, self._num_filters, dtype=dtype) x = tf.keras.layers.add([x, h_projection], dtype=dtype) x = _apply_sigmoid_gating(x) # Add a residual connection from the layer's input. out = tf.keras.layers.add([input_x, x], dtype=dtype) layer_stacks[stack].append(out) if i < self._num_hierarchies - 1: # Build convolutional layers that contract the height/width dimensions # on the downward pass between each set of layers (e.g. contracting from # 32x32 to 16x16 in Figure 2 of [1]). for stack in ['vertical', 'horizontal']: # Define kernel dimensions/masking to maintain the autoregressive property. x = layer_stacks[stack][-1] h, w = kernel_valid_dims[stack] kernel_height = 2 * h if stack == 'vertical': kernel_width = w + 1 else: kernel_width = 2 * w kernel_size = (kernel_height, kernel_width) kernel_constraint = _make_kernel_constraint(kernel_size, (0, h), (0, w)) x = Conv2D(strides=(2, 2), kernel_size=kernel_size, kernel_constraint=kernel_constraint)(x) layer_stacks[stack].append(x) # Upward pass of the U-net (right-hand half of Figure 2 of [1]). We stored # the layers of the downward pass in a list, in order to access them to make # skip-connections to the upward pass. For the upward pass, we need to keep # track of only the current layer, so we maintain a reference to the # current layer of the horizontal/vertical stack in the `upward_pass` dict. # The upward pass begins with the last layer of the downward pass. upward_pass = {key: stack.pop() for key, stack in layer_stacks.items()} # As with the downward pass, each `i` iteration builds a highest level block # in Figure 2 [1], and the `_` iterations build individual layers within the # block. for i in range(self._num_hierarchies): num_resnet = self._num_resnet if i == 0 else self._num_resnet + 1 for _ in range(num_resnet): # Build a layer as shown in Figure 2 of [2], with a skip-connection # from the symmetric layer in the downward pass. for stack in ['vertical', 'horizontal']: input_x = upward_pass[stack] x_symmetric = layer_stacks[stack].pop() x = activation(input_x) x = Conv2D(kernel_size=kernel_sizes[stack], kernel_constraint=kernel_constraints[stack])(x) # Include the vertical-stack layer of the upward pass in the layers # to be added to the horizontal layer. if stack == 'horizontal': x_symmetric = tf.keras.layers.Concatenate(axis=-1, dtype=dtype)([upward_pass['vertical'], x_symmetric]) # Add a skip-connection from the symmetric layer in the downward # pass to the layer `x` in the upward pass. h = activation(x_symmetric) h = Dense(self._num_filters)(h) x = tf.keras.layers.add([h, x], dtype=dtype) x = activation(x) x = tf.keras.layers.Dropout(self._dropout_p, dtype=dtype)(x) x = Conv2D(filters=2*self._num_filters, kernel_size=kernel_sizes[stack], kernel_constraint=kernel_constraints[stack])(x) if conditional_input is not None: h_projection = _build_and_apply_h_projection(conditional_input, self._num_filters, dtype=dtype) x = tf.keras.layers.add([x, h_projection], dtype=dtype) x = _apply_sigmoid_gating(x) upward_pass[stack] = tf.keras.layers.add([input_x, x], dtype=dtype) # Define deconvolutional layers that expand height/width dimensions on the # upward pass (e.g. expanding from 8x8 to 16x16 in Figure 2 of [1]), with # the correct kernel dimensions/masking to maintain the autoregressive # property. if i < self._num_hierarchies - 1: for stack in ['vertical', 'horizontal']: h, w = kernel_valid_dims[stack] kernel_height = 2 * h - 2 if stack == 'vertical': kernel_width = w + 1 kernel_constraint = _make_kernel_constraint( (kernel_height, kernel_width), (h - 2, kernel_height), (0, w)) else: kernel_width = 2 * w - 2 kernel_constraint = _make_kernel_constraint( (kernel_height, kernel_width), (h - 2, kernel_height), (w - 2, kernel_width)) x = upward_pass[stack] x = Conv2DTranspose(kernel_size=(kernel_height, kernel_width), kernel_constraint=kernel_constraint)(x) upward_pass[stack] = x x_out = tf.keras.layers.ELU(dtype=dtype)(upward_pass['horizontal']) # Build final Dense/Reshape layers to output the correct number of # parameters per pixel. num_channels = tensorshape_util.as_list(image_shape)[-1] num_coeffs = num_channels * (num_channels - 1) // 2 # alpha, beta, gamma in eq.3 of paper num_out = num_channels * 2 + num_coeffs + 1 # mu, s + alpha, beta, gamma + 1 (mixture weight) num_out_total = num_out * self._num_logistic_mix params = Dense(num_out_total)(x_out) params = tf.reshape(params, prefer_static.concat( # [-1,H,W,nb mixtures, params per mixture] [[-1], image_shape[:-1], [self._num_logistic_mix, num_out]], axis=0)) # If there is one color channel, split the parameters into a list of three # output `Tensor`s: (1) component logits for the Quantized Logistic mixture # distribution, (2) location parameters for each component, and (3) scale # parameters for each component. If there is more than one color channel, # return a fourth `Tensor` for the coefficients for the linear dependence # among color channels (e.g. alpha, beta, gamma). # [logits, mu, s, linear dependence] splits = 3 if num_channels == 1 else [1, num_channels, num_channels, num_coeffs] outputs = tf.split(params, splits, axis=-1) # Squeeze singleton dimension from component logits outputs[0] = tf.squeeze(outputs[0], axis=-1) # Ensure scales are positive and do not collapse to near-zero outputs[2] = tf.nn.softplus(outputs[2]) + tf.cast(tf.exp(-7.), self.dtype) inputs = image_input if conditional_input is None else [image_input, conditional_input] self._network = tf.keras.Model(inputs=inputs, outputs=outputs) super(_PixelCNNNetwork, self).build(input_shape) def call(self, inputs, training=None): """Call the Pixel CNN network model. Parameters ---------- inputs 4D `Tensor` of image data with dimensions [batch size, height, width, channels] or a 2-element `list`. If `list`, the first element is the 4D image `Tensor` and the second element is a `Tensor` with conditional input data (e.g. VAE encodings or class labels) with the same leading batch dimension as the image `Tensor`. training `bool` or `None`. If `bool`, it controls the dropout layer, where `True` implies dropout is active. If `None`, it it defaults to `tf.keras.backend.learning_phase()` Returns ------- outputs a 3- or 4-element `list` of `Tensor`s in the following order: \ component_logits: 4D `Tensor` of logits for the Categorical distribution \ over Quantized Logistic mixture components. Dimensions are \ `[batch_size, height, width, num_logistic_mix]`. locs 4D `Tensor` of location parameters for the Quantized Logistic \ mixture components. Dimensions are `[batch_size, height, width, \ num_logistic_mix, num_channels]`. scales 4D `Tensor` of location parameters for the Quantized Logistic \ mixture components. Dimensions are `[batch_size, height, width, \ num_logistic_mix, num_channels]`. coeffs 4D `Tensor` of coefficients for the linear dependence among \ color channels, included only if the image has more than one channel. \ Dimensions are `[batch_size, height, width, num_logistic_mix, \ num_coeffs]`, where `num_coeffs = num_channels * (num_channels - 1) // 2`. """ return self._network(inputs, training=training) def _make_kernel_constraint(kernel_size, valid_rows, valid_columns): """Make the masking function for layer kernels.""" mask = np.zeros(kernel_size) lower, upper = valid_rows left, right = valid_columns mask[lower:upper, left:right] = 1. mask = mask[:, :, np.newaxis, np.newaxis] return lambda x: x * mask def _build_and_apply_h_projection(h, num_filters, dtype): """Project the conditional input.""" h = tf.keras.layers.Flatten(dtype=dtype)(h) h_projection = tf.keras.layers.Dense(2*num_filters, kernel_initializer='random_normal', dtype=dtype)(h) return h_projection[..., tf.newaxis, tf.newaxis, :] def _apply_sigmoid_gating(x): """Apply the sigmoid gating in Figure 2 of [2].""" activation_tensor, gate_tensor = tf.split(x, 2, axis=-1) sigmoid_gate = tf.sigmoid(gate_tensor) return tf.keras.layers.multiply([sigmoid_gate, activation_tensor], dtype=x.dtype)