Bijections

Bijections from flowjax.bijections.

class AbstractBijection[source]

Bases: Module

Bijection abstract class.

Similar to AbstractDistribution, bijections have a shape and a cond_shape attribute. To allow easy composing of bijections, all bijections support passing of conditioning variables (even if ignored).

The methods of bijections do not generally support passing of additional batch dimensions, however, jax.vmap or eqx.filter_vmap can be used to vmap specific methods if desired, and a bijection can be explicitly vectorised using the Vmap bijection.

Bijections are registered as Jax PyTrees (as they are equinox modules), so are compatible with normal jax operations.

Implementing a bijection

  1. Inherit from AbstractBijection.

  2. Define the attributes shape and cond_shape. A cond_shape of None is used to represent unconditional bijections.

  3. Implement the abstract methods transform, transform_and_log_det, inverse and inverse_and_log_det. These should act on inputs compatible with the shapes shape for x, and cond_shape for condition.

abstract inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y (Array) – Input array with shape matching bijection.shape

  • condition (Array | None) – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

Return type:

Array

abstract inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y (Array) – Input array with shape matching bijection.shape.

  • condition (Array | None) – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

Return type:

tuple[Array, Array]

abstract transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x (Array) – Input with shape matching bijections.shape.

  • condition (Array | None) – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

Return type:

Array

abstract transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x (Array) – Input with shape matching the bijections shape

  • condition (Array | None) – . Defaults to None.

Return type:

tuple[Array, Array]

cond_shape: AbstractVar[tuple[int, ...] | None]
shape: AbstractVar[tuple[int, ...]]
class AdditiveCondition(module, shape, cond_shape)[source]

Bases: AbstractBijection

Given a callable f, carries out the transformation y = x + f(condition).

If used to transform a distribution, this allows the “location” to be changed as a function of the conditioning variables. Note that the callable can be a callable module with trainable parameters.

Parameters:
  • module (Callable[[ArrayLike], ArrayLike]) – A callable (e.g. a function or callable module) that maps array with shape cond_shape, to a shape that is broadcastable with the shape of the bijection.

  • shape (tuple[int, ...]) – The shape of the bijection.

  • cond_shape (tuple[int, ...]) – The condition shape of the bijection.

Example

Conditioning using a linear transformation

>>> from flowjax.bijections import AdditiveCondition
>>> from equinox.nn import Linear
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> bijection = AdditiveCondition(
...     Linear(2, 3, key=jr.PRNGKey(0)), shape=(3,), cond_shape=(2,)
...     )
>>> bijection.transform(jnp.ones(3), condition=jnp.ones(2))
Array([1.9670618, 0.8156546, 1.7763454], dtype=float32)
inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: tuple[int, ...]
module: Callable[[ArrayLike], ArrayLike]
shape: tuple[int, ...]
class Affine(loc=0, scale=1)[source]

Bases: AbstractBijection

Elementwise affine transformation y = a*x + b.

loc and scale should broadcast to the desired shape of the bijection. By default, we constrain the scale parameter to be postive using SoftPlus, but other parameterizations can be achieved by replacing the scale parameter after construction e.g. using eqx.tree_at.

Parameters:
  • loc (ArrayLike) – Location parameter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: ClassVar[None] = None
loc: Array
scale: Union[Array, AbstractUnwrappable[Array]]
shape: tuple[int, ...]
class BlockAutoregressiveNetwork(key, *, dim, cond_dim=None, depth, block_dim, activation=None, inverter=None)[source]

Bases: AbstractBijection

Block Autoregressive Network (https://arxiv.org/abs/1904.04676).

Note that in contrast to the original paper which uses tanh activations, by default we use LeakyTanh. This ensures the codomain of the activation is the set of real values, which will ensure properly normalised densities (see https://github.com/danielward27/flowjax/issues/102).

Parameters:
  • key (PRNGKeyArray) – Jax PRNGKey

  • dim (int) – Dimension of the distribution.

  • cond_dim (int | None) – Dimension of conditioning variables. Defaults to None.

  • depth (int) – Number of hidden layers in the network.

  • block_dim (int) – Block dimension (hidden layer size is dim*block_dim).

  • activation (AbstractBijection | Callable | None) – Activation function, either a scalar bijection or a callable that computes the activation for a scalar value. Note that the activation should be bijective to ensure invertibility of the network and in general should map real -> real to ensure that when transforming a distribution (either with the forward or inverse), the map is defined across the support of the base distribution. Defaults to LeakyTanh(3).

  • inverter (Callable | None) – Callable that implements the required numerical method to invert the BlockAutoregressiveNetwork bijection. Must have the signature inverter(bijection, y, condition=None). Defaults to AutoregressiveBisectionInverter.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

activation: AbstractBijection
block_dim: int
cond_linear: Linear | None
cond_shape: tuple[int, ...] | None
depth: int
inverter: Callable
layers: list
shape: tuple[int, ...]
class Chain(bijections)[source]

Bases: AbstractBijection

Chain together arbitrary bijections to form another bijection.

Parameters:

bijections (Sequence[Union[AbstractBijection, AbstractUnwrappable[AbstractBijection]]]) – Sequence of bijections. The bijection shapes must match, and any none None condition shapes must match.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

merge_chains()[source]

Returns an equivilent Chain object, in which nested chains are flattened.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

bijections: tuple[Union[AbstractBijection, AbstractUnwrappable[AbstractBijection]], ...]
cond_shape: tuple[int, ...] | None
shape: tuple[int, ...]
class Concatenate(bijections, axis=0)[source]

Bases: AbstractBijection

Concatenate bijections along an existing axis, similar to jnp.concatenate.

See also Stack.

Parameters:
  • bijections (Sequence[AbstractBijection]) – Bijections, to stack into a single bijection.

  • axis (int) – Axis along which to stack. Defaults to 0.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

axis: int
bijections: Sequence[AbstractBijection]
cond_shape: tuple[int, ...] | None
shape: tuple[int, ...]
split_idxs: tuple[int, ...]
class Coupling(key, *, transformer, untransformed_dim, dim, cond_dim=None, nn_width, nn_depth, nn_activation=<jax._src.custom_derivatives.custom_jvp object>)[source]

Bases: AbstractBijection

Coupling layer implementation (https://arxiv.org/abs/1605.08803).

Parameters:
  • key (PRNGKeyArray) – Jax PRNGKey

  • transformer (AbstractBijection) – Unconditional bijection with shape () to be parameterised by the conditioner neural netork. Parameters wrapped with NonTrainable are excluded from being parameterized.

  • untransformed_dim (int) – Number of untransformed conditioning variables (e.g. dim//2).

  • dim (int) – Total dimension.

  • cond_dim (int | None) – Dimension of additional conditioning variables. Defaults to None.

  • nn_width (int) – Neural network hidden layer width.

  • nn_depth (int) – Neural network hidden layer size.

  • nn_activation (Callable) – Neural network activation function. Defaults to jnn.relu.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: tuple[int, ...] | None
conditioner: MLP
dim: int
shape: tuple[int, ...]
transformer_constructor: Callable
untransformed_dim: int
class EmbedCondition(bijection, embedding_net, raw_cond_shape)[source]

Bases: AbstractBijection

Wrap a bijection to include an embedding network.

Generally this is used to reduce the dimensionality of the conditioning variable. The returned bijection has cond_dim equal to the raw condition size.

Parameters:
  • bijection (AbstractBijection) – Bijection with bijection.cond_dim equal to the embedded size.

  • embedding_net (Callable) – A callable (e.g. equinox module) that embeds a conditioning variable to size bijection.cond_dim.

  • raw_cond_shape (tuple[int, ...]) – The dimension of the raw conditioning variable.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

bijection: AbstractBijection
cond_shape: tuple[int, ...]
embedding_net: Callable
property shape
class Exp(shape=())[source]

Bases: AbstractBijection

Elementwise exponential transform (forward) and log transform (inverse).

Parameters:

shape (tuple[int, ...]) – Shape of the bijection. Defaults to ().

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: ClassVar[None] = None
shape: tuple[int, ...] = ()
class Flip(shape=())[source]

Bases: AbstractBijection

Flip the input array. Condition argument is ignored.

Parameters:

shape (tuple[int, ...]) – The shape of the bijection. Defaults to None.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: ClassVar[None] = None
shape: tuple[int, ...] = ()
class Identity(shape=())[source]

Bases: AbstractBijection

The identity bijection.

Parameters:

shape (tuple[int, ...]) – The shape of the bijection. Defaults to ().

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: ClassVar[None] = None
shape: tuple[int, ...] = ()
class Invert(bijection)[source]

Bases: AbstractBijection

Invert a bijection.

This wraps a bijection, such that the transform methods become the inverse methods and vice versa. Note that in general, we define bijections such that the forward methods are preffered, i.e. faster/actually implemented. For training flows, we generally want the inverse method (used in density evaluation), to be faster. Hence it is often useful to use this class to achieve this aim.

Parameters:

bijection (AbstractBijection) – Bijection to invert.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

bijection: AbstractBijection
property cond_shape
property shape
class LeakyTanh(max_val, shape=())[source]

Bases: AbstractBijection

Tanh bijection, with a linear transformation beyond +/- max_val.

The value and gradient of the linear segments are set to match tanh at +/- max_val. This bijection can be useful to encourage values to be within an interval, whilst avoiding numerical precision issues, or in cases we require a real -> real mapping so Tanh is not appropriate.

Parameters:
  • max_val (float | int) – Value above or below which the function becomes linear.

  • shape (tuple[int, ...]) – The shape of the bijection. Defaults to ().

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: ClassVar[None] = None
intercept: float
linear_grad: float
max_val: float
shape: tuple[int, ...] = ()
class Loc(loc)[source]

Bases: AbstractBijection

Location transformation y = x + c.

Parameters:

loc (ArrayLike) – Scale parameter. Defaults to 1.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: ClassVar[None] = None
loc: Array
shape: tuple[int, ...]
class MaskedAutoregressive(key, *, transformer, dim, cond_dim=None, nn_width, nn_depth, nn_activation=<jax._src.custom_derivatives.custom_jvp object>)[source]

Bases: AbstractBijection

Masked autoregressive bijection.

The transformer is parameterised by a neural network, with weights masked to ensure an autoregressive structure.

Refs:
Parameters:
  • key (PRNGKeyArray) – Jax PRNGKey

  • transformer (AbstractBijection) – Bijection with shape () to be parameterised by the autoregressive network. Parameters wrapped with NonTrainable are exluded.

  • dim (int) – Dimension.

  • cond_dim (int | None) – Dimension of any conditioning variables. Defaults to None.

  • nn_width (int) – Neural network width.

  • nn_depth (int) – Neural network depth.

  • nn_activation (Callable) – Neural network activation. Defaults to jnn.relu.

inv_scan_fn(init, _, condition)[source]

One ‘step’ in computing the inverse.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: tuple[int, ...] | None
masked_autoregressive_mlp: MLP
shape: tuple[int, ...]
transformer_constructor: Callable
class Partial(bijection, idxs, shape)[source]

Bases: AbstractBijection

Applies bijection to specific indices of an input.

Parameters:
  • bijection (AbstractBijection) – Bijection that is compatible with the subset of x indexed by idxs. idxs: Indices (Integer, a slice, or an ndarray with integer/bool dtype) of the transformed portion.

  • idxs (int | slice | Array | tuple) – The indexes to transform.

  • shape (tuple[int, ...]) – Shape of the bijection. Defaults to None.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

bijection: AbstractBijection
property cond_shape
idxs: int | slice | Array | tuple
shape: tuple[int, ...]
class Permute(permutation)[source]

Bases: AbstractBijection

Permutation transformation.

Parameters:

permutation (Union[Int[Array, '...'], Int[ndarray, '...']]) – An array with shape matching the array to transform, with elements 0-(array.size-1) representing the new order based on the flattened array (uses, C-like ordering).

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: ClassVar[None] = None
inverse_permutation: tuple[Array, ...]
permutation: tuple[Array, ...]
shape: tuple[int, ...]
class Planar(key, *, dim, cond_dim=None, **mlp_kwargs)[source]

Bases: AbstractBijection

Planar bijection as used by https://arxiv.org/pdf/1505.05770.pdf.

Uses the transformation \(y + u \cdot \text{tanh}(w \cdot x + b)\), where \(u \in \mathbb{R}^D, \ w \in \mathbb{R}^D\) and \(b \in \mathbb{R}\). In the unconditional case, \(w\), \(u\) and \(b\) are learned directly. In the conditional case they are parameterised by an MLP.

Parameters:
  • key (PRNGKeyArray) – Jax random seed.

  • dim (int) – Dimension of the bijection.

  • cond_dim (int | None) – Dimension of extra conditioning variables. Defaults to None.

  • **mlp_kwargs – Keyword arguments (excluding in_size and out_size) passed to the MLP (equinox.nn.MLP). Ignored when cond_dim is None.

get_planar(condition=None)[source]

Get the planar bijection with the conditioning applied if conditional.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: tuple[int, ...] | None
conditioner: Callable | None
params: Array | None
shape: tuple[int, ...]
class RationalQuadraticSpline(*, knots, interval, min_derivative=0.001, softmax_adjust=0.01)[source]

Bases: AbstractBijection

Scalar RationalQuadraticSpline transformation (https://arxiv.org/abs/1906.04032).

Parameters:
  • knots (int) – Number of knots.

  • interval (float | int) – interval to transform, [-interval, interval].

  • min_derivative (float) – Minimum dervivative. Defaults to 1e-3.

  • softmax_adjust (float | int) – Controls minimum bin width and height by rescaling softmax output, e.g. 0=no adjustment, 1=average softmax output with evenly spaced widths, >1 promotes more evenly spaced widths. See real_to_increasing_on_interval. Defaults to 1e-2.

derivative(x)[source]

The derivative dy/dx of the forward transformation.

Return type:

Array

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: ClassVar[None] = None
derivatives: Union[Array, AbstractUnwrappable[Array]]
interval: float | int
knots: int
min_derivative: float
shape: ClassVar[tuple] = ()
softmax_adjust: float | int
x_pos: Union[Array, AbstractUnwrappable[Array]]
y_pos: Union[Array, AbstractUnwrappable[Array]]
class Reshape(bijection, shape=None, cond_shape=None)[source]

Bases: AbstractBijection

Wraps bijection methods with reshaping operations.

One use case for this is for bijections that do not directly support a scalar shape, but this allows construction with shape (1, ) and reshaping to ().

Parameters:
  • bijection (AbstractBijection) – The bijection to wrap.

  • shape (tuple[int, ...] | None) – The new input and output shape of the bijection. Defaults to unchanged.

  • cond_shape (tuple[int, ...] | None) – The new cond_shape of the bijection. Defaults to unchanged.

Example

>>> import jax.numpy as jnp
>>> from flowjax.bijections import Affine, Reshape
>>> affine = Affine(loc=jnp.arange(4))
>>> affine.shape
(4,)
>>> affine = Reshape(affine, (2,2))
>>> affine.shape
(2, 2)
>>> affine.transform(jnp.zeros((2,2)))
Array([[0., 1.],
       [2., 3.]], dtype=float32)
inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

bijection: AbstractBijection
cond_shape: tuple[int, ...] | None = None
shape: tuple[int, ...]
class Scale(scale)[source]

Bases: AbstractBijection

Scale transformation y = a*x.

Parameters:

scale (ArrayLike) – Scale parameter. Defaults to 1.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: ClassVar[None] = None
scale: Union[Array, AbstractUnwrappable[Array]]
shape: tuple[int, ...]
class Scan(bijection)[source]

Bases: AbstractBijection

Repeatedly apply the same bijection with different parameter values.

Internally, uses jax.lax.scan to reduce compilation time. Often it is convenient to construct these using equinox.filter_vmap.

Parameters:

bijection (AbstractBijection) – A bijection, in which the arrays leaves have an additional leading axis to scan over. It is often can convenient to create compatible bijections with equinox.filter_vmap.

Example

Below is equivilent to Chain([Affine(p) for p in params]).

>>> from flowjax.bijections import Scan, Affine
>>> import jax.numpy as jnp
>>> import equinox as eqx
>>> params = jnp.ones((3, 2))
>>> affine = eqx.filter_vmap(Affine)(params)
>>> affine = Scan(affine)
inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

bijection: AbstractBijection
property cond_shape
property shape
class SoftPlus(shape=())[source]

Bases: AbstractBijection

Transforms to positive domain using softplus \(y = \log(1 + \exp(x))\).

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: ClassVar[None] = None
shape: tuple[int, ...] = ()
class Stack(bijections, axis=0)[source]

Bases: AbstractBijection

Stack bijections along a new axis (analagous to jnp.stack).

See also Concatenate.

Parameters:
inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

axis: int
bijections: Sequence[AbstractBijection]
cond_shape: tuple[int, ...] | None
shape: tuple[int, ...]
class Tanh(shape=())[source]

Bases: AbstractBijection

Tanh bijection.

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: ClassVar[None] = None
shape: tuple[int, ...] = ()
class TriangularAffine(loc, arr, *, lower=True)[source]

Bases: AbstractBijection

A triangular affine transformation.

Transformation has the form \(Ax + b\), where \(A\) is a lower or upper triangular matrix, and \(b\) is the bias vector. We assume the diagonal entries are positive, and constrain the values using SoftPlus. Other parameterizations can be achieved by e.g. replacing self.triangular after construction.

Parameters:
  • loc (Shaped[ArrayLike, '#dim']) – Location parameter. If this is scalar, it is broadcast to the dimension inferred from arr.

  • arr (Shaped[Array, 'dim dim']) – Triangular matrix.

  • lower (bool) – Whether the mask should select the lower or upper triangular matrix (other elements ignored). Defaults to True (lower).

inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

cond_shape: ClassVar[None] = None
loc: Array
lower: bool
shape: tuple[int, ...]
triangular: Union[Array, AbstractUnwrappable[Array]]
class Vmap(bijection, *, in_axes=None, axis_size=None, in_axes_condition=None)[source]

Bases: AbstractBijection

Applies vmap to bijection methods to add a batch dimension to the bijection.

Parameters:
  • bijection (AbstractBijection) – The bijection to vectorize.

  • in_axes (PyTree | None | int | Callable) – Specify which axes of the bijection parameters to vectorise over. It should be a PyTree of None, int with the tree structure being a prefix of the bijection, or a callable mapping Leaf -> Union[None, int]. Defaults to None. Note, if the bijection contains unwrappables, then in_axes should be specified for the unwrapped structure of the bijection.

  • axis_size (int | None) – The size of the new axis. This should be left unspecified if in_axes is provided, as the size can be inferred from the bijection parameters. Defaults to None.

  • in_axes_condition (int | None) – Optionally define an axis of the conditioning variable to vectorize over. Defaults to None.

Example

The two most common use cases, are shown below:

Add a batch dimension to a bijection, mapping over bijection parameters:

>>> import jax.numpy as jnp
>>> import equinox as eqx
>>> from flowjax.bijections import Vmap, RationalQuadraticSpline, Affine
>>> bijection = eqx.filter_vmap(
...    lambda: RationalQuadraticSpline(knots=5, interval=2),
...    axis_size=10
... )()
>>> bijection = Vmap(bijection, in_axes=eqx.if_array(0))
>>> bijection.shape
(10,)

Add a batch dimension to a bijection, broadcasting bijection parameters:
>>> bijection = RationalQuadraticSpline(knots=5, interval=2)
>>> bijection = Vmap(bijection, axis_size=10)
>>> bijection.shape
(10,)

A more advanced use case is to create bijections with more fine grained control over parameter broadcasting. For example, the Affine constructor broadcasts the location and scale parameters during initialization. What if we want an Affine bijection, with a global scale parameter, but an elementwise location parameter? We could achieve this as follows.

>>> from jax.tree_util import tree_map
>>> from flowjax.wrappers import unwrap
>>> bijection = Affine(jnp.zeros(()), jnp.ones(()))
>>> bijection = eqx.tree_at(lambda bij: bij.loc, bijection, jnp.arange(3))
>>> in_axes = tree_map(lambda _: None, unwrap(bijection))
>>> in_axes = eqx.tree_at(
...     lambda bij: bij.loc, in_axes, 0, is_leaf=lambda x: x is None
...     )
>>> bijection = Vmap(bijection, in_axes=in_axes)
>>> bijection.shape
(3,)
>>> bijection.bijection.loc.shape
(3,)
>>> unwrap(bijection.bijection.scale).shape
()
>>> x = jnp.ones(3)
>>> bijection.transform(x)
Array([1., 2., 3.], dtype=float32)
get_cond_shape(cond_ax)[source]
inverse(y, condition=None)[source]

Compute the inverse transformation.

Parameters:
  • y – Input array with shape matching bijection.shape

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

inverse_and_log_det(y, condition=None)[source]

Inverse transformation and corresponding log absolute jacobian determinant.

Parameters:
  • y – Input array with shape matching bijection.shape.

  • condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.

transform(x, condition=None)[source]

Apply the forward transformation.

Parameters:
  • x – Input with shape matching bijections.shape.

  • condition – Condition, with shape matching bijection.cond_shape, required for conditional bijections and ignored for unconditional bijections. Defaults to None.

transform_and_log_det(x, condition=None)[source]

Apply transformation and compute the log absolute Jacobian determinant.

Parameters:
  • x – Input with shape matching the bijections shape

  • condition – . Defaults to None.

vmap(f)[source]
axis_size: int
bijection: AbstractBijection
cond_shape: tuple[int, ...] | None
in_axes: tuple
property shape