Flows
Premade versions of common flow architetctures from flowjax.flows
.
All these functions return a Transformed
distribution.
- block_neural_autoregressive_flow(key, *, base_dist, cond_dim=None, nn_depth=1, nn_block_dim=8, flow_layers=1, invert=True, activation=None, inverter=None)[source]
Block neural autoregressive flow (BNAF) (https://arxiv.org/abs/1904.04676).
Each flow layer contains a
BlockAutoregressiveNetwork
bijection. The bijection does not have an analytic inverse, so must be inverted using numerical methods (by default a bisection search). Note that this means that only one oflog_prob
orsample{_and_log_prob}
can be efficient, controlled by theinvert
argument.- Parameters:
key (
PRNGKeyArray
) – Jax PRNGKey.base_dist (
AbstractDistribution
) – Base distribution, withbase_dist.ndim==1
.cond_dim (
int
|None
) – Dimension of conditional variables. Defaults to None.nn_depth (
int
) – Number of hidden layers within the networks. Defaults to 1.nn_block_dim (
int
) – Block size. Hidden layer width is dim*nn_block_dim. Defaults to 8.flow_layers (
int
) – Number of BNAF layers. Defaults to 1.invert (
bool
) – Use True for efficientlog_prob
(e.g. when fitting by maximum likelihood), and False for efficientsample
andsample_and_log_prob
methods (e.g. for fitting variationally).activation (
AbstractBijection
|Callable
|None
) – Activation function used within block neural autoregressive networks. Note this should be bijective and in some use cases should map real -> real. For more information, seeBlockAutoregressiveNetwork
. Defaults toLeakyTanh
.inverter (
Callable
|None
) – Callable that implements the required numerical method to invert theBlockAutoregressiveNetwork
bijection. Must have the signatureinverter(bijection, y, condition=None)
. Defaults to using a bisection search viaAutoregressiveBisectionInverter
.
- Return type:
- coupling_flow(key, *, base_dist, transformer=None, cond_dim=None, flow_layers=8, nn_width=50, nn_depth=1, nn_activation=<jax._src.custom_derivatives.custom_jvp object>, invert=True)[source]
Create a coupling flow (https://arxiv.org/abs/1605.08803).
- Parameters:
key (
PRNGKeyArray
) – Jax random number generator key.base_dist (
AbstractDistribution
) – Base distribution, withbase_dist.ndim==1
.transformer (
AbstractBijection
|None
) – Bijection to be parameterised by conditioner. Defaults to affine.cond_dim (
int
|None
) – Dimension of conditioning variables. Defaults to None.flow_layers (
int
) – Number of coupling layers. Defaults to 8.nn_width (
int
) – Conditioner hidden layer size. Defaults to 50.nn_depth (
int
) – Conditioner depth. Defaults to 1.nn_activation (
Callable
) – Conditioner activation function. Defaults to jnn.relu.invert (
bool
) – Whether to invert the bijection. Broadly, True will prioritise a faster inverse methods, leading to faster log_prob, False will prioritise faster transform methods, leading to faster sample. Defaults to True.
- Return type:
- masked_autoregressive_flow(key, *, base_dist, transformer=None, cond_dim=None, flow_layers=8, nn_width=50, nn_depth=1, nn_activation=<jax._src.custom_derivatives.custom_jvp object>, invert=True)[source]
Masked autoregressive flow.
Parameterises a transformer bijection with an autoregressive neural network. Refs: https://arxiv.org/abs/1606.04934; https://arxiv.org/abs/1705.07057v4.
- Parameters:
key (
PRNGKeyArray
) – Random seed.base_dist (
AbstractDistribution
) – Base distribution, withbase_dist.ndim==1
.transformer (
AbstractBijection
|None
) – Bijection parameterised by autoregressive network. Defaults to affine.cond_dim (
int
|None
) – Dimension of the conditioning variable. Defaults to None.flow_layers (
int
) – Number of flow layers. Defaults to 8.nn_width (
int
) – Number of hidden layers in neural network. Defaults to 50.nn_depth (
int
) – Depth of neural network. Defaults to 1.nn_activation (
Callable
) – _description_. Defaults to jnn.relu.invert (
bool
) – Whether to invert the bijection. Broadly, True will prioritise a faster inverse, leading to faster log_prob, False will prioritise faster forward, leading to faster sample. Defaults to True.
- Return type:
- planar_flow(key, *, base_dist, cond_dim=None, flow_layers=8, invert=True, **mlp_kwargs)[source]
Planar flow as introduced in https://arxiv.org/pdf/1505.05770.pdf.
This alternates between
Planar
layers and permutations. Note the definition here is inverted compared to the original paper.- Parameters:
key (
PRNGKeyArray
) – Jax PRNGKey.base_dist (
AbstractDistribution
) – Base distribution, withbase_dist.ndim==1
.cond_dim (
int
|None
) – Dimension of conditioning variables. Defaults to None.flow_layers (
int
) – Number of flow layers. Defaults to 8.invert (
bool
) – Whether to invert the bijection. Broadly, True will prioritise a faster inverse methods, leading to faster log_prob, False will prioritise faster transform methods, leading to faster sample. Defaults to True.**mlp_kwargs – Keyword arguments (excluding in_size and out_size) passed to the MLP (equinox.nn.MLP). Ignored when cond_dim is None.
- Return type:
- triangular_spline_flow(key, *, base_dist, cond_dim=None, flow_layers=8, knots=8, tanh_max_val=3.0, invert=True, init=None)[source]
Triangular spline flow.
A single layer consists where each layer consists of a triangular affine transformation with weight normalisation, and an elementwise rational quadratic spline. Tanh is used to constrain to the input to [-1, 1] before spline transformations.
- Parameters:
key (
PRNGKeyArray
) – Jax random seed.base_dist (
AbstractDistribution
) – Base distribution, withbase_dist.ndim==1
.cond_dim (
int
|None
) – The number of conditioning features. Defaults to None.flow_layers (
int
) – Number of flow layers. Defaults to 8.knots (
int
) – Number of knots in the splines. Defaults to 8.tanh_max_val (
float
|int
) – Maximum absolute value beyond which we use linear “tails” in the tanh function. Defaults to 3.0.invert (
bool
) – Whether to invert the bijection before transforming the base distribution. Defaults to True.init (
Callable
|None
) – Initialisation method for the lower triangular weights. Defaults to glorot_uniform().
- Return type: