Flows#
Normalizing flows define flexible distributions by transforming a simple base distribution through a flexible bijection (more precisely, a diffeomorphism). For a detailed introduction to normalizing flows, we recommend the following review paper: Papamakarios et al., 2021.
In FlowJAX, all the normalizing flows are convient constructors of
Transformed
distributions. Generally, the overall
transform is built using multiple layers, by composing many individual transforms.
In FlowJAX, there are two ways to compose bijections:
Chain
Allows chaining arbitary and heterogeneous bijections, but compiles each layer seperately.Scan
Requires the layers to share the same structure, but avoids compiling each layer seperately.
All FlowJAX flows use Scan
to reduce compilation
times.
Note
Bijections in normalizing flows typically have asymmetric computational efficiency. Generally:
Bijections are implemented to favor efficiency in the forward transformation.
The forward transformation is used for sampling (or
sample_and_log_prob
), while the inverse is used for density evaluation.By default, flows invert the bijection with
Invert
before transforming the base distribution withTransformed
. This prioritizes a fasterlog_prob
method at the cost of a slowersample
andsample_and_log_prob
method.If faster
sample
andsample_and_log_prob
methods are needed (e.g., for certain variational objectives), settinginvert=False
is recommended.
Premade versions of common flow architetctures from flowjax.flows
.
All these functions return a Transformed
distribution.
- block_neural_autoregressive_flow(
- key,
- *,
- base_dist,
- cond_dim=None,
- nn_depth=1,
- nn_block_dim=8,
- flow_layers=1,
- invert=True,
- activation=None,
- inverter=None,
Block neural autoregressive flow (BNAF) (https://arxiv.org/abs/1904.04676).
Each flow layer contains a
BlockAutoregressiveNetwork
bijection. The bijection does not have an analytic inverse, so must be inverted using numerical methods (by default a bisection search). Note that this means that only one oflog_prob
orsample{_and_log_prob}
can be efficient, controlled by theinvert
argument.- Parameters:
key (
PRNGKeyArray
) – Jax key.base_dist (
AbstractDistribution
) – Base distribution, withbase_dist.ndim==1
.cond_dim (
int
|None
) – Dimension of conditional variables. Defaults to None.nn_depth (
int
) – Number of hidden layers within the networks. Defaults to 1.nn_block_dim (
int
) – Block size. Hidden layer width is dim*nn_block_dim. Defaults to 8.flow_layers (
int
) – Number of BNAF layers. Defaults to 1.invert (
bool
) – Use True for efficientlog_prob
(e.g. when fitting by maximum likelihood), and False for efficientsample
andsample_and_log_prob
methods (e.g. for fitting variationally).activation (
AbstractBijection
|Callable
|None
) – Activation function used within block neural autoregressive networks. Note this should be bijective and in some use cases should map real -> real. For more information, seeBlockAutoregressiveNetwork
. Defaults toLeakyTanh
.inverter (
Callable
|None
) – Callable that implements the required numerical method to invert theBlockAutoregressiveNetwork
bijection. Must have the signatureinverter(bijection, y, condition=None)
. Defaults to using a bisection search viaAutoregressiveBisectionInverter
.
- Return type:
- coupling_flow(
- key,
- *,
- base_dist,
- transformer=None,
- cond_dim=None,
- flow_layers=8,
- nn_width=50,
- nn_depth=1,
- nn_activation=<jax._src.custom_derivatives.custom_jvp object>,
- invert=True,
Create a coupling flow (https://arxiv.org/abs/1605.08803).
- Parameters:
key (
PRNGKeyArray
) – Jax random key.base_dist (
AbstractDistribution
) – Base distribution, withbase_dist.ndim==1
.transformer (
AbstractBijection
|None
) – Bijection to be parameterised by conditioner. Defaults to affine.cond_dim (
int
|None
) – Dimension of conditioning variables. Defaults to None.flow_layers (
int
) – Number of coupling layers. Defaults to 8.nn_width (
int
) – Conditioner hidden layer size. Defaults to 50.nn_depth (
int
) – Conditioner depth. Defaults to 1.nn_activation (
Callable
) – Conditioner activation function. Defaults to jnn.relu.invert (
bool
) – Whether to invert the bijection. Broadly, True will prioritise a faster inverse methods, leading to faster log_prob, False will prioritise faster transform methods, leading to faster sample. Defaults to True.
- Return type:
- masked_autoregressive_flow(
- key,
- *,
- base_dist,
- transformer=None,
- cond_dim=None,
- flow_layers=8,
- nn_width=50,
- nn_depth=1,
- nn_activation=<jax._src.custom_derivatives.custom_jvp object>,
- invert=True,
Masked autoregressive flow.
Parameterises a transformer bijection with an autoregressive neural network. Refs: https://arxiv.org/abs/1606.04934; https://arxiv.org/abs/1705.07057v4.
- Parameters:
key (
PRNGKeyArray
) – Random seed.base_dist (
AbstractDistribution
) – Base distribution, withbase_dist.ndim==1
.transformer (
AbstractBijection
|None
) – Bijection parameterised by autoregressive network. Defaults to affine.cond_dim (
int
|None
) – Dimension of the conditioning variable. Defaults to None.flow_layers (
int
) – Number of flow layers. Defaults to 8.nn_width (
int
) – Number of hidden layers in neural network. Defaults to 50.nn_depth (
int
) – Depth of neural network. Defaults to 1.nn_activation (
Callable
) – _description_. Defaults to jnn.relu.invert (
bool
) – Whether to invert the bijection. Broadly, True will prioritise a faster inverse, leading to faster log_prob, False will prioritise faster forward, leading to faster sample. Defaults to True.
- Return type:
- planar_flow(
- key,
- *,
- base_dist,
- cond_dim=None,
- flow_layers=8,
- invert=True,
- negative_slope=None,
- **mlp_kwargs,
Planar flow as introduced in https://arxiv.org/pdf/1505.05770.pdf.
This alternates between
Planar
layers and permutations. Note the definition here is inverted compared to the original paper.- Parameters:
key (
PRNGKeyArray
) – Jax key.base_dist (
AbstractDistribution
) – Base distribution, withbase_dist.ndim==1
.cond_dim (
int
|None
) – Dimension of conditioning variables. Defaults to None.flow_layers (
int
) – Number of flow layers. Defaults to 8.invert (
bool
) – Whether to invert the bijection. Broadly, True will prioritise a faster inverse methods, leading to faster log_prob, False will prioritise faster transform methods, leading to faster sample. Defaults to True.negative_slope (
float
|None
) – A positive float. If provided, then a leaky relu activation (with the corresponding negative slope) is used instead of tanh. This also provides the advantage that the bijection can be inverted analytically.**mlp_kwargs – Keyword arguments (excluding in_size and out_size) passed to the MLP (equinox.nn.MLP). Ignored when cond_dim is None.
- Return type:
- triangular_spline_flow(
- key,
- *,
- base_dist,
- cond_dim=None,
- flow_layers=8,
- knots=8,
- tanh_max_val=3.0,
- invert=True,
- init=None,
Triangular spline flow.
A single layer consists where each layer consists of a triangular affine transformation with weight normalisation, and an elementwise rational quadratic spline. Tanh is used to constrain to the input to [-1, 1] before spline transformations.
- Parameters:
key (
PRNGKeyArray
) – Jax random key.base_dist (
AbstractDistribution
) – Base distribution, withbase_dist.ndim==1
.cond_dim (
int
|None
) – The number of conditioning features. Defaults to None.flow_layers (
int
) – Number of flow layers. Defaults to 8.knots (
int
) – Number of knots in the splines. Defaults to 8.tanh_max_val (
float
|int
) – Maximum absolute value beyond which we use linear “tails” in the tanh function. Defaults to 3.0.invert (
bool
) – Whether to invert the bijection before transforming the base distribution. Defaults to True.init (
Callable
|None
) – Initialisation method for the lower triangular weights. Defaults to glorot_uniform().
- Return type: