
For an introduction to using bijections in FlowJAX, see Getting started.

Bijections from flowjax.bijections.

class AbstractBijection[source]#

Bijection abstract class.

Similar to AbstractDistribution, bijections have a shape and a cond_shape attribute. To allow easy composing of bijections, all bijections support passing of conditioning variables (even if ignored).

Bijections are registered as Jax PyTrees (as they are equinox modules), so are compatible with normal JAX operations. The methods of bijections do not support passing of additional batch dimensions, however, jax.vmap or eqx.filter_vmap can be used to vmap specific methods if desired, and a bijection can be explicitly vectorised using the Vmap bijection.

Implementing a bijection:

  • Inherit from AbstractBijection.

  • Define the attributes shape and cond_shape. A cond_shape of None is used to represent unconditional bijections.

  • Implement the abstract methods, transform_and_log_det, and inverse_and_log_det. These should act on inputs compatible with the shapes shape for x, and cond_shape for condition.

inverse(y, condition=None)[source]#

Compute the inverse transformation.

  • y (ArrayLike) – Input array with shape matching bijection.shape

  • condition (ArrayLike | None) – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

abstract inverse_and_log_det(y, condition=None)[source]#

Inverse transformation and corresponding log absolute jacobian determinant.

  • y (ArrayLike) – Input array with shape matching bijection.shape.

  • condition (ArrayLike | None) – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

tuple[Array, Array]

transform(x, condition=None)[source]#

Apply the forward transformation.

  • x (ArrayLike) – Input with shape matching bijections.shape.

  • condition (ArrayLike | None) – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

abstract transform_and_log_det(x, condition=None)[source]#

Apply transformation and compute the log absolute Jacobian determinant.

  • x (ArrayLike) – Input with shape matching the bijections shape

  • condition (ArrayLike | None) – . Defaults to None.

tuple[Array, Array]

class AdditiveCondition(module, shape, cond_shape)[source]#

Bases: AbstractBijection

Given a callable f, carries out the transformation y = x + f(condition).

If used to transform a distribution, this allows the “location” to be changed as a function of the conditioning variables. Note that the callable can be a callable module with trainable parameters.

  • module (Callable[[ArrayLike], ArrayLike]) – A callable (e.g. a function or callable module) that maps array with shape cond_shape, to a shape that is broadcastable with the shape of the bijection.

  • shape (tuple[int, ...]) – The shape of the bijection.

  • cond_shape (tuple[int, ...]) – The condition shape of the bijection.


Conditioning using a linear transformation

>>> from flowjax.bijections import AdditiveCondition
>>> from equinox.nn import Linear
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> bijection = AdditiveCondition(
...     Linear(2, 3, key=jr.key(0)), shape=(3,), cond_shape=(2,)
...     )
>>> y = bijection.transform(jnp.ones(3), condition=jnp.ones(2))
class Affine(loc=0, scale=1)[source]#

Bases: AbstractBijection

Elementwise affine transformation y=ax+b.

loc and scale should broadcast to the desired shape of the bijection. By default, we constrain the scale parameter to be postive using softplus, but other parameterizations can be achieved by replacing the scale parameter after construction e.g. using eqx.tree_at.

  • loc (ArrayLike) – Location parameter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

class BlockAutoregressiveNetwork(

Bases: AbstractBijection

Block Autoregressive Network (

Note that in contrast to the original paper which uses tanh activations, by default we use tanh(x)+0.01x. This ensures the image of the activation is the set of real values, which ensures the transform maps real -> real, avoiding various issues (see danielward27/flowjax#102).

  • key (PRNGKeyArray) – Jax key

  • dim (int) – Dimension of the distribution.

  • cond_dim (int | None) – Dimension of conditioning variables. Defaults to None.

  • depth (int) – Number of hidden layers in the network.

  • block_dim (int) – Block dimension (hidden layer size is dim*block_dim).

  • activation (AbstractBijection | Callable | None) – Activation function, either a scalar bijection or a callable that computes the activation for a scalar value. Note that the activation should be bijective to ensure invertibility of the network and in general should map real -> real to ensure that when transforming a distribution (either with the forward or inverse), the map is defined across the support of the base distribution. Defaults to tanh(x)+0.01x.

class Chain(bijections)[source]#

Bases: AbstractBijection

Compose arbitrary bijections to form another bijection.

If the layers you are chaining have consistent structure, consider using Scan, which will avoid seperately compiling each layer.


bijections (Sequence[Union[AbstractBijection, AbstractUnwrappable[AbstractBijection]]]) – Sequence of bijections. The bijection shapes must match, and any none None condition shapes must match.


Returns an equivilent Chain object, in which nested chains are flattened.

class Concatenate(bijections, axis=0)[source]#

Bases: AbstractBijection

Concatenate bijections along an existing axis, similar to jnp.concatenate.

See also Stack.


>>> from flowjax.bijections import Concatenate, Affine, Exp
>>> concat = Concatenate([Affine(jnp.ones((2, 3))), Exp((2,3))])
>>> concat.shape
(4, 3)
  • bijections (Sequence[AbstractBijection]) – Bijections, to stack into a single bijection.

  • axis (int) – Axis along which to stack. Defaults to 0.

class Coupling(
nn_activation=<jax._src.custom_derivatives.custom_jvp object>,

Bases: AbstractBijection

Coupling layer implementation (

  • key (PRNGKeyArray) – Jax key

  • transformer (AbstractBijection) – Unconditional bijection with shape () to be parameterised by the conditioner neural netork. Parameters wrapped with NonTrainable are excluded from being parameterized.

  • untransformed_dim (int) – Number of untransformed conditioning variables (e.g. dim//2).

  • dim (int) – Total dimension.

  • cond_dim (int | None) – Dimension of additional conditioning variables. Defaults to None.

  • nn_width (int) – Neural network hidden layer width.

  • nn_depth (int) – Neural network hidden layer size.

  • nn_activation (Callable) – Neural network activation function. Defaults to jnn.relu.

class DiscreteCosine(shape, *, axis=-1)[source]#

Bases: AbstractBijection

Discrete Cosine Transform (DCT) bijection.

This bijection applies the DCT or its inverse along a specified axis.

  • shape – Shape of the input/output arrays.

  • axis (int) – Axis along which to apply the DCT.

class EmbedCondition(bijection, embedding_net, raw_cond_shape)[source]#

Bases: AbstractBijection

Wrap a bijection to include an embedding network.

Generally this is used to reduce the dimensionality of the conditioning variable. The returned bijection has cond_dim equal to the raw condition size.

  • bijection (AbstractBijection) – Bijection with bijection.cond_dim equal to the embedded size.

  • embedding_net (Callable) – A callable (e.g. equinox module) that embeds a conditioning variable to size bijection.cond_dim.

  • raw_cond_shape (tuple[int, ...]) – The dimension of the raw conditioning variable.

class Exp(shape=())[source]#

Bases: AbstractBijection

Elementwise exponential transform (forward) and log transform (inverse).


shape (tuple[int, ...]) – Shape of the bijection. Defaults to ().

class Flip(shape=())[source]#

Bases: AbstractBijection

Flip the input array. Condition argument is ignored.


shape (tuple[int, ...]) – The shape of the bijection. Defaults to None.

class Householder(params)[source]#

Bases: AbstractBijection

A Householder reflection.

A linear transformation reflecting vectors across a hyperplane defined by a normal vector (params). The transformation is its own inverse and volume-preserving (determinant = -1). Given a unit vector v, the transformation is y=x2(xTv)v.

It is often desirable to stack multiple such transforms (e.g. up to the dimensionality of the data):

>>> from flowjax.bijections import Householder, Scan
>>> import jax.random as jr
>>> import equinox as eqx
>>> import jax.numpy as jnp

>>> dim = 5
>>> keys = jr.split(jr.key(0), dim)
>>> householder_stack = Scan(
...    eqx.filter_vmap(lambda key: Householder(jr.normal(key, dim)))(keys)
... )

params (ArrayLike) – Normal vector defining the reflection hyperplane. The vector is normalized in the transformation, so scaling params will have no effect on the bijection.

class Identity(shape=())[source]#

Bases: AbstractBijection

The identity bijection.


shape (tuple[int, ...]) – The shape of the bijection. Defaults to ().

class Indexed(bijection, idxs, shape)[source]#

Bases: AbstractBijection

Applies bijection to specific indices of an input.

  • bijection (AbstractBijection) – Bijection that is compatible with the subset of x indexed by idxs.

  • idxs (int | slice | Array | tuple) – Indices (Integer, a slice, or an ndarray with integer/bool dtype) of the transformed portion.

  • shape (tuple[int, ...]) – Shape of the bijection. Defaults to None.

class Invert(bijection)[source]#

Bases: AbstractBijection

Invert a bijection.

This wraps a bijection, such that the transform methods become the inverse methods and vice versa. Note that in general, we define bijections such that the forward methods are preffered, i.e. faster/actually implemented. For training flows, we generally want the inverse method (used in density evaluation), to be faster. Hence it is often useful to use this class to achieve this aim.


bijection (AbstractBijection) – Bijection to invert.

class LeakyTanh(max_val, shape=())[source]#

Bases: AbstractBijection

Tanh bijection, with a linear transformation beyond +/- max_val.

The value and gradient of the linear segments are set to match tanh at +/- max_val. This bijection can be useful to encourage values to be within an interval, whilst avoiding numerical precision issues, or in cases we require a real -> real mapping so Tanh is not appropriate.

  • max_val (float | int) – Value above or below which the function becomes linear.

  • shape (tuple[int, ...]) – The shape of the bijection. Defaults to ().

class Loc(loc)[source]#

Bases: AbstractBijection

Location transformation y=ax+b.


loc (ArrayLike) – Scale parameter. Defaults to 1.

class MaskedAutoregressive(
nn_activation=<jax._src.custom_derivatives.custom_jvp object>,

Bases: AbstractBijection

Masked autoregressive bijection.

The transformer is parameterised by a neural network, with weights masked to ensure an autoregressive structure.

  • key (PRNGKeyArray) – Jax key

  • transformer (AbstractBijection) – Bijection with shape () to be parameterised by the autoregressive network. Parameters wrapped with NonTrainable are exluded.

  • dim (int) – Dimension.

  • cond_dim (int | None) – Dimension of any conditioning variables. Defaults to None.

  • nn_width (int) – Neural network width.

  • nn_depth (int) – Neural network depth.

  • nn_activation (Callable) – Neural network activation. Defaults to jnn.relu.

inv_scan_fn(init, _, condition)[source]#

One ‘step’ in computing the inverse.

class NumericalInverse(bijection, inverter)[source]#

Bases: AbstractBijection

Bijection wrapper to provide inverse methods using e.g. root finding.

  • bijection (AbstractBijection) – The bijection to add an inverse to.

  • inverter (Callable[[AbstractBijection, Array, Array | None], Array]) – Callable implementing the numerical inversion method. Should accept the bijection, y and condition as arguments, and return the inverse.

class Permute(permutation)[source]#

Bases: AbstractBijection

Permutation transformation.


permutation (Union[Int[Array, '...'], Int[ndarray, '...']]) – An array with shape matching the array to transform, with elements 0-(array.size-1) representing the new order based on the flattened array (uses, C-like ordering).

class Planar(key, *, dim, cond_dim=None, negative_slope=None, **mlp_kwargs)[source]#

Bases: AbstractBijection

Planar bijection as used by

Uses the transformation


where uRD, wRD and bR. In the unconditional case, the (unbounded) parameters are learned directly. In the unconditional case they are parameterised by an MLP.

  • key (PRNGKeyArray) – Jax random key.

  • dim (int) – Dimension of the bijection.

  • cond_dim (int | None) – Dimension of extra conditioning variables. Defaults to None.

  • negative_slope (float | None) – A positive float. If provided, then a leaky relu activation (with the corresponding negative slope) is used instead of tanh. This also provides the advantage that the bijection can be inverted analytically.

  • **mlp_kwargs – Keyword arguments (excluding in_size and out_size) passed to the MLP (equinox.nn.MLP). Ignored when cond_dim is None.


Get the planar bijection with the conditioning applied if conditional.

class Power(exponent, shape=())[source]#

Bases: AbstractBijection

Power transform y=xp.

Supports postive values, over which this is a bijection.

  • exponent (int | float) – The exponent.

  • shape (tuple[int, ...]) – The shape of the bijection.

class RationalQuadraticSpline(*, knots, interval, min_derivative=0.001, softmax_adjust=0.01)[source]#

Bases: AbstractBijection

Scalar RationalQuadraticSpline transformation (

  • knots (int) – Number of knots.

  • interval (float | int | tuple[int | float, int | float]) – Interval to transform, if a scalar value, uses [-interval, interval], if a tuple, uses [interval[0], interval[1]]

  • min_derivative (float) – Minimum dervivative. Defaults to 1e-3.

  • softmax_adjust (float | int) – Controls minimum bin width and height by rescaling softmax output, e.g. 0=no adjustment, 1=average softmax output with evenly spaced widths, >1 promotes more evenly spaced widths. See real_to_increasing_on_interval. Defaults to 1e-2.


The derivative dy/dx of the forward transformation.

Return type:


class Reshape(bijection, shape=None, cond_shape=None)[source]#

Bases: AbstractBijection

Wraps bijection methods with reshaping operations.

One use case for this is for bijections that do not directly support a scalar shape, but this allows construction with shape (1, ) and reshaping to ().

  • bijection (AbstractBijection) – The bijection to wrap.

  • shape (tuple[int, ...] | None) – The new input and output shape of the bijection. Defaults to unchanged.

  • cond_shape (tuple[int, ...] | None) – The new cond_shape of the bijection. Defaults to unchanged.


>>> import jax.numpy as jnp
>>> from flowjax.bijections import Affine, Reshape
>>> affine = Affine(loc=jnp.arange(4))
>>> affine.shape
>>> affine = Reshape(affine, (2,2))
>>> affine.shape
(2, 2)
>>> affine.transform(jnp.zeros((2,2)))
Array([[0., 1.],
       [2., 3.]], dtype=float32)
class Sandwich(inner, outer)[source]#

Bases: AbstractBijection

Composes bijections in a nested structure: g1fg.

Creates a new transformation by “sandwiching” one bijection between the forward and inverse applications of another. Given bijections f and g, it computes

  • Forward: y=g1(f(g(x)))

  • Inverse: x=g1(f1(g(y)))

This can be used for e.g. creating symmetries in the transformation or to apply a transformation in a different coordinate system.

class Scale(scale)[source]#

Bases: AbstractBijection

Scale transformation y=ax.


scale (ArrayLike) – Scale parameter. Defaults to 1.

class Scan(bijection)[source]#

Bases: AbstractBijection

Repeatedly apply the same bijection with different parameter values.

Internally, uses jax.lax.scan to reduce compilation time. Often it is convenient to construct these using equinox.filter_vmap.


bijection (AbstractBijection) – A bijection, in which the arrays leaves have an additional leading axis to scan over. It is often can convenient to create compatible bijections with equinox.filter_vmap.


Below is equivilent to Chain([Affine(p) for p in params]).

>>> from flowjax.bijections import Scan, Affine
>>> import jax.numpy as jnp
>>> import equinox as eqx
>>> params = jnp.ones((3, 2))
>>> affine = eqx.filter_vmap(Affine)(params)
>>> affine = Scan(affine)
class Sigmoid(shape=())[source]#

Bases: AbstractBijection

Sigmoid bijection y=σ(x)=11+exp(x).


shape (tuple[int, ...]) – The shape of the transform.

class SoftPlus(shape=())[source]#

Bases: AbstractBijection

Transforms to positive domain using softplus y=log(1+exp(x)).

class Stack(bijections, axis=0)[source]#

Bases: AbstractBijection

Stack bijections along a new axis (analagous to jnp.stack).

See also Concatenate.


>>> from flowjax.bijections import Stack, Affine, Exp
>>> concat = Stack([Affine(jnp.ones(3)), Exp((3,))])
>>> concat.shape
(2, 3)
class Tanh(shape=())[source]#

Bases: AbstractBijection

Tanh bijection y=tanh(x).

class TriangularAffine(loc, arr, *, lower=True)[source]#

Bases: AbstractBijection

A triangular affine transformation.

Transformation has the form Ax+b, where A is a lower or upper triangular matrix, and b is the bias vector. We assume the diagonal entries are positive, and constrain the values using softplus. Other parameterizations can be achieved by e.g. replacing self.triangular after construction.

  • loc (Shaped[ArrayLike, '#dim']) – Location parameter. If this is scalar, it is broadcast to the dimension inferred from arr.

  • arr (Shaped[Array, 'dim dim']) – Triangular matrix.

  • lower (bool) – Whether the mask should select the lower or upper triangular matrix (other elements ignored). Defaults to True (lower).

class Vmap(bijection, *, in_axes=None, axis_size=None, in_axes_condition=None)[source]#

Bases: AbstractBijection

Applies vmap to bijection methods to add a batch dimension to the bijection.

  • bijection (AbstractBijection) – The bijection to vectorize.

  • in_axes (PyTree | None | int | Callable) – Specify which axes of the bijection parameters to vectorise over. It should be a PyTree of None, int with the tree structure being a prefix of the bijection, or a callable mapping Leaf -> Union[None, int]. Note, if the bijection contains unwrappables, then in_axes should be specified for the unwrapped structure of the bijection. Defaults to None.

  • axis_size (int | None) – The size of the new axis. This should be left unspecified if in_axes is provided, as the size can be inferred from the bijection parameters. Defaults to None.

  • in_axes_condition (int | None) – Optionally define an axis of the conditioning variable to vectorize over. Defaults to None.


>>> # Add a bijection batch dimension, mapping over bijection parameters
>>> import jax.numpy as jnp
>>> import equinox as eqx
>>> from flowjax.bijections import Vmap, RationalQuadraticSpline, Affine
>>> bijection = eqx.filter_vmap(
...    lambda: RationalQuadraticSpline(knots=5, interval=2),
...    axis_size=10
... )()
>>> bijection = Vmap(bijection, in_axes=eqx.if_array(0))
>>> bijection.shape
>>> # Add a bijection batch dimension, broadcasting bijection parameters:
>>> bijection = RationalQuadraticSpline(knots=5, interval=2)
>>> bijection = Vmap(bijection, axis_size=10)
>>> bijection.shape

A more advanced use case is to create bijections with more fine grained control over parameter broadcasting. For example, the Affine constructor broadcasts the location and scale parameters during initialization. What if we want an Affine bijection, with a global scale parameter, but an elementwise location parameter? We could achieve this as follows.

>>> from jax.tree_util import tree_map
>>> import paramax
>>> bijection = Affine(jnp.zeros(()), jnp.ones(()))
>>> bijection = eqx.tree_at(lambda bij: bij.loc, bijection, jnp.arange(3))
>>> in_axes = tree_map(lambda _: None, paramax.unwrap(bijection))
>>> in_axes = eqx.tree_at(
...     lambda bij: bij.loc, in_axes, 0, is_leaf=lambda x: x is None
...     )
>>> bijection = Vmap(bijection, in_axes=in_axes)
>>> bijection.shape
>>> bijection.bijection.loc.shape
>>> paramax.unwrap(bijection.bijection.scale).shape
>>> x = jnp.ones(3)
>>> bijection.transform(x)
Array([1., 2., 3.], dtype=float32)