Flows#

Normalizing flows define flexible distributions by transforming a simple base distribution through a flexible bijection (more precisely, a diffeomorphism). For a detailed introduction to normalizing flows, we recommend the following review paper: Papamakarios et al., 2021.

In FlowJAX, all the normalizing flows are convient constructors of Transformed distributions. Generally, the overall transform is built using multiple layers, by composing many individual transforms. In FlowJAX, there are two ways to compose bijections:

  • Chain Allows chaining arbitary and heterogeneous bijections, but compiles each layer seperately.

  • Scan Requires the layers to share the same structure, but avoids compiling each layer seperately.

All FlowJAX flows use Scan to reduce compilation times.

Note

Bijections in normalizing flows typically have asymmetric computational efficiency. Generally:

  • Bijections are implemented to favor efficiency in the forward transformation.

  • The forward transformation is used for sampling (or sample_and_log_prob), while the inverse is used for density evaluation.

  • By default, flows invert the bijection with Invert before transforming the base distribution with Transformed. This prioritizes a faster log_prob method at the cost of a slower sample and sample_and_log_prob method.

  • If faster sample and sample_and_log_prob methods are needed (e.g., for certain variational objectives), setting invert=False is recommended.

Premade versions of common flow architetctures from flowjax.flows.

All these functions return a Transformed distribution.

block_neural_autoregressive_flow(
key,
*,
base_dist,
cond_dim=None,
nn_depth=1,
nn_block_dim=8,
flow_layers=1,
invert=True,
activation=None,
inverter=None,
)[source]#

Block neural autoregressive flow (BNAF) (https://arxiv.org/abs/1904.04676).

Each flow layer contains a BlockAutoregressiveNetwork bijection. The bijection does not have an analytic inverse, so must be inverted using numerical methods (by default a bisection search). Note that this means that only one of log_prob or sample{_and_log_prob} can be efficient, controlled by the invert argument. Note, ensuring reasonably scaled base and target distributions will be beneficial for the efficiency of the numerical inverse.

Parameters:
  • key (PRNGKeyArray) – Jax key.

  • base_dist (AbstractDistribution) – Base distribution, with base_dist.ndim==1.

  • cond_dim (int | None) – Dimension of conditional variables. Defaults to None.

  • nn_depth (int) – Number of hidden layers within the networks. Defaults to 1.

  • nn_block_dim (int) – Block size. Hidden layer width is dim*nn_block_dim. Defaults to 8.

  • flow_layers (int) – Number of BNAF layers. Defaults to 1.

  • invert (bool) – Use True for efficient log_prob (e.g. when fitting by maximum likelihood), and False for efficient sample and sample_and_log_prob methods (e.g. for fitting variationally).

  • activation (AbstractBijection | Callable | None) – Activation function used within block neural autoregressive networks. Note this should be bijective and in some use cases should map real -> real. For more information, see BlockAutoregressiveNetwork. Defaults to LeakyTanh.

  • inverter (Callable[[AbstractBijection, Array, Array | None], Array] | None) – Callable that implements the required numerical method to invert the BlockAutoregressiveNetwork bijection. Passed to NumericalInverse. Defaults to using elementwise_autoregressive_bisection.

Return type:

Transformed

coupling_flow(
key,
*,
base_dist,
transformer=None,
cond_dim=None,
flow_layers=8,
nn_width=50,
nn_depth=1,
nn_activation=<jax._src.custom_derivatives.custom_jvp object>,
invert=True,
)[source]#

Create a coupling flow (https://arxiv.org/abs/1605.08803).

Parameters:
  • key (PRNGKeyArray) – Jax random key.

  • base_dist (AbstractDistribution) – Base distribution, with base_dist.ndim==1.

  • transformer (AbstractBijection | None) – Bijection to be parameterised by conditioner. Defaults to affine.

  • cond_dim (int | None) – Dimension of conditioning variables. Defaults to None.

  • flow_layers (int) – Number of coupling layers. Defaults to 8.

  • nn_width (int) – Conditioner hidden layer size. Defaults to 50.

  • nn_depth (int) – Conditioner depth. Defaults to 1.

  • nn_activation (Callable) – Conditioner activation function. Defaults to jnn.relu.

  • invert (bool) – Whether to invert the bijection. Broadly, True will prioritise a faster inverse methods, leading to faster log_prob, False will prioritise faster transform methods, leading to faster sample. Defaults to True.

Return type:

Transformed

masked_autoregressive_flow(
key,
*,
base_dist,
transformer=None,
cond_dim=None,
flow_layers=8,
nn_width=50,
nn_depth=1,
nn_activation=<jax._src.custom_derivatives.custom_jvp object>,
invert=True,
)[source]#

Masked autoregressive flow.

Parameterises a transformer bijection with an autoregressive neural network. Refs: https://arxiv.org/abs/1606.04934; https://arxiv.org/abs/1705.07057v4.

Parameters:
  • key (PRNGKeyArray) – Random seed.

  • base_dist (AbstractDistribution) – Base distribution, with base_dist.ndim==1.

  • transformer (AbstractBijection | None) – Bijection parameterised by autoregressive network. Defaults to affine.

  • cond_dim (int | None) – Dimension of the conditioning variable. Defaults to None.

  • flow_layers (int) – Number of flow layers. Defaults to 8.

  • nn_width (int) – Number of hidden layers in neural network. Defaults to 50.

  • nn_depth (int) – Depth of neural network. Defaults to 1.

  • nn_activation (Callable) – _description_. Defaults to jnn.relu.

  • invert (bool) – Whether to invert the bijection. Broadly, True will prioritise a faster inverse, leading to faster log_prob, False will prioritise faster forward, leading to faster sample. Defaults to True.

Return type:

Transformed

planar_flow(
key,
*,
base_dist,
cond_dim=None,
flow_layers=8,
invert=True,
negative_slope=None,
**mlp_kwargs,
)[source]#

Planar flow as introduced in https://arxiv.org/pdf/1505.05770.pdf.

This alternates between Planar layers and permutations. Note the definition here is inverted compared to the original paper.

Parameters:
  • key (PRNGKeyArray) – Jax key.

  • base_dist (AbstractDistribution) – Base distribution, with base_dist.ndim==1.

  • cond_dim (int | None) – Dimension of conditioning variables. Defaults to None.

  • flow_layers (int) – Number of flow layers. Defaults to 8.

  • invert (bool) – Whether to invert the bijection. Broadly, True will prioritise a faster inverse methods, leading to faster log_prob, False will prioritise faster transform methods, leading to faster sample. Defaults to True.

  • negative_slope (float | None) – A positive float. If provided, then a leaky relu activation (with the corresponding negative slope) is used instead of tanh. This also provides the advantage that the bijection can be inverted analytically.

  • **mlp_kwargs – Keyword arguments (excluding in_size and out_size) passed to the MLP (equinox.nn.MLP). Ignored when cond_dim is None.

Return type:

Transformed

triangular_spline_flow(
key,
*,
base_dist,
cond_dim=None,
flow_layers=8,
knots=8,
tanh_max_val=3.0,
invert=True,
init=None,
)[source]#

Triangular spline flow.

A single layer consists where each layer consists of a triangular affine transformation with weight normalisation, and an elementwise rational quadratic spline. Tanh is used to constrain to the input to [-1, 1] before spline transformations.

Parameters:
  • key (PRNGKeyArray) – Jax random key.

  • base_dist (AbstractDistribution) – Base distribution, with base_dist.ndim==1.

  • cond_dim (int | None) – The number of conditioning features. Defaults to None.

  • flow_layers (int) – Number of flow layers. Defaults to 8.

  • knots (int) – Number of knots in the splines. Defaults to 8.

  • tanh_max_val (float | int) – Maximum absolute value beyond which we use linear “tails” in the tanh function. Defaults to 3.0.

  • invert (bool) – Whether to invert the bijection before transforming the base distribution. Defaults to True.

  • init (Callable | None) – Initialisation method for the lower triangular weights. Defaults to glorot_uniform().

Return type:

Transformed