Flows

Premade versions of common flow architetctures from flowjax.flows.

All these functions return a Transformed distribution.

block_neural_autoregressive_flow(key, *, base_dist, cond_dim=None, nn_depth=1, nn_block_dim=8, flow_layers=1, invert=True, activation=None, inverter=None)[source]

Block neural autoregressive flow (BNAF) (https://arxiv.org/abs/1904.04676).

Each flow layer contains a BlockAutoregressiveNetwork bijection. The bijection does not have an analytic inverse, so must be inverted using numerical methods (by default a bisection search). Note that this means that only one of log_prob or sample{_and_log_prob} can be efficient, controlled by the invert argument.

Parameters:
  • key (PRNGKeyArray) – Jax PRNGKey.

  • base_dist (AbstractDistribution) – Base distribution, with base_dist.ndim==1.

  • cond_dim (int | None) – Dimension of conditional variables. Defaults to None.

  • nn_depth (int) – Number of hidden layers within the networks. Defaults to 1.

  • nn_block_dim (int) – Block size. Hidden layer width is dim*nn_block_dim. Defaults to 8.

  • flow_layers (int) – Number of BNAF layers. Defaults to 1.

  • invert (bool) – Use True for efficient log_prob (e.g. when fitting by maximum likelihood), and False for efficient sample and sample_and_log_prob methods (e.g. for fitting variationally).

  • activation (AbstractBijection | Callable | None) – Activation function used within block neural autoregressive networks. Note this should be bijective and in some use cases should map real -> real. For more information, see BlockAutoregressiveNetwork. Defaults to LeakyTanh.

  • inverter (Callable | None) – Callable that implements the required numerical method to invert the BlockAutoregressiveNetwork bijection. Must have the signature inverter(bijection, y, condition=None). Defaults to using a bisection search via AutoregressiveBisectionInverter.

Return type:

Transformed

coupling_flow(key, *, base_dist, transformer=None, cond_dim=None, flow_layers=8, nn_width=50, nn_depth=1, nn_activation=<jax._src.custom_derivatives.custom_jvp object>, invert=True)[source]

Create a coupling flow (https://arxiv.org/abs/1605.08803).

Parameters:
  • key (PRNGKeyArray) – Jax random number generator key.

  • base_dist (AbstractDistribution) – Base distribution, with base_dist.ndim==1.

  • transformer (AbstractBijection | None) – Bijection to be parameterised by conditioner. Defaults to affine.

  • cond_dim (int | None) – Dimension of conditioning variables. Defaults to None.

  • flow_layers (int) – Number of coupling layers. Defaults to 8.

  • nn_width (int) – Conditioner hidden layer size. Defaults to 50.

  • nn_depth (int) – Conditioner depth. Defaults to 1.

  • nn_activation (Callable) – Conditioner activation function. Defaults to jnn.relu.

  • invert (bool) – Whether to invert the bijection. Broadly, True will prioritise a faster inverse methods, leading to faster log_prob, False will prioritise faster transform methods, leading to faster sample. Defaults to True.

Return type:

Transformed

masked_autoregressive_flow(key, *, base_dist, transformer=None, cond_dim=None, flow_layers=8, nn_width=50, nn_depth=1, nn_activation=<jax._src.custom_derivatives.custom_jvp object>, invert=True)[source]

Masked autoregressive flow.

Parameterises a transformer bijection with an autoregressive neural network. Refs: https://arxiv.org/abs/1606.04934; https://arxiv.org/abs/1705.07057v4.

Parameters:
  • key (PRNGKeyArray) – Random seed.

  • base_dist (AbstractDistribution) – Base distribution, with base_dist.ndim==1.

  • transformer (AbstractBijection | None) – Bijection parameterised by autoregressive network. Defaults to affine.

  • cond_dim (int | None) – Dimension of the conditioning variable. Defaults to None.

  • flow_layers (int) – Number of flow layers. Defaults to 8.

  • nn_width (int) – Number of hidden layers in neural network. Defaults to 50.

  • nn_depth (int) – Depth of neural network. Defaults to 1.

  • nn_activation (Callable) – _description_. Defaults to jnn.relu.

  • invert (bool) – Whether to invert the bijection. Broadly, True will prioritise a faster inverse, leading to faster log_prob, False will prioritise faster forward, leading to faster sample. Defaults to True.

Return type:

Transformed

planar_flow(key, *, base_dist, cond_dim=None, flow_layers=8, invert=True, **mlp_kwargs)[source]

Planar flow as introduced in https://arxiv.org/pdf/1505.05770.pdf.

This alternates between Planar layers and permutations. Note the definition here is inverted compared to the original paper.

Parameters:
  • key (PRNGKeyArray) – Jax PRNGKey.

  • base_dist (AbstractDistribution) – Base distribution, with base_dist.ndim==1.

  • cond_dim (int | None) – Dimension of conditioning variables. Defaults to None.

  • flow_layers (int) – Number of flow layers. Defaults to 8.

  • invert (bool) – Whether to invert the bijection. Broadly, True will prioritise a faster inverse methods, leading to faster log_prob, False will prioritise faster transform methods, leading to faster sample. Defaults to True.

  • **mlp_kwargs – Keyword arguments (excluding in_size and out_size) passed to the MLP (equinox.nn.MLP). Ignored when cond_dim is None.

Return type:

Transformed

triangular_spline_flow(key, *, base_dist, cond_dim=None, flow_layers=8, knots=8, tanh_max_val=3.0, invert=True, init=None)[source]

Triangular spline flow.

A single layer consists where each layer consists of a triangular affine transformation with weight normalisation, and an elementwise rational quadratic spline. Tanh is used to constrain to the input to [-1, 1] before spline transformations.

Parameters:
  • key (PRNGKeyArray) – Jax random seed.

  • base_dist (AbstractDistribution) – Base distribution, with base_dist.ndim==1.

  • cond_dim (int | None) – The number of conditioning features. Defaults to None.

  • flow_layers (int) – Number of flow layers. Defaults to 8.

  • knots (int) – Number of knots in the splines. Defaults to 8.

  • tanh_max_val (float | int) – Maximum absolute value beyond which we use linear “tails” in the tanh function. Defaults to 3.0.

  • invert (bool) – Whether to invert the bijection before transforming the base distribution. Defaults to True.

  • init (Callable | None) – Initialisation method for the lower triangular weights. Defaults to glorot_uniform().

Return type:

Transformed