Bijections#
For an introduction to using bijections in FlowJAX, see Getting started.
Bijections from flowjax.bijections
.
- class AbstractBijection[source]#
Bases:
Module
Bijection abstract class.
Similar to
AbstractDistribution
, bijections have ashape
and acond_shape
attribute. To allow easy composing of bijections, all bijections support passing of conditioning variables (even if ignored).Bijections are registered as Jax PyTrees (as they are equinox modules), so are compatible with normal JAX operations. The methods of bijections do not support passing of additional batch dimensions, however,
jax.vmap
oreqx.filter_vmap
can be used to vmap specific methods if desired, and a bijection can be explicitly vectorised using theVmap
bijection.Implementing a bijection:
Inherit from
AbstractBijection
.Define the attributes
shape
andcond_shape
. Acond_shape
ofNone
is used to represent unconditional bijections.Implement the abstract methods,
transform_and_log_det
, andinverse_and_log_det
. These should act on inputs compatible with the shapesshape
forx
, andcond_shape
forcondition
.
- inverse(y, condition=None)[source]#
Compute the inverse transformation.
- Parameters:
y (
ArrayLike
) – Input array with shape matching bijection.shapecondition (
ArrayLike
|None
) – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- Return type:
Array
- abstract inverse_and_log_det(y, condition=None)[source]#
Inverse transformation and corresponding log absolute jacobian determinant.
- transform(x, condition=None)[source]#
Apply the forward transformation.
- Parameters:
x (
ArrayLike
) – Input with shape matchingbijections.shape
.condition (
ArrayLike
|None
) – Condition, with shape matchingbijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- Return type:
Array
- class AdditiveCondition(module, shape, cond_shape)[source]#
Bases:
AbstractBijection
Given a callable
f
, carries out the transformationy = x + f(condition)
.If used to transform a distribution, this allows the “location” to be changed as a function of the conditioning variables. Note that the callable can be a callable module with trainable parameters.
- Parameters:
Example
Conditioning using a linear transformation
>>> from flowjax.bijections import AdditiveCondition >>> from equinox.nn import Linear >>> import jax.numpy as jnp >>> import jax.random as jr >>> bijection = AdditiveCondition( ... Linear(2, 3, key=jr.key(0)), shape=(3,), cond_shape=(2,) ... ) >>> bijection.transform(jnp.ones(3), condition=jnp.ones(2)) Array([1.9670618, 0.8156546, 1.7763454], dtype=float32)
- class Affine(loc=0, scale=1)[source]#
Bases:
AbstractBijection
Elementwise affine transformation \(y = a \cdot x + b\).
loc
andscale
should broadcast to the desired shape of the bijection. By default, we constrain the scale parameter to be postive usingsoftplus
, but other parameterizations can be achieved by replacing the scale parameter after construction e.g. usingeqx.tree_at
.- Parameters:
loc (
ArrayLike
) – Location parameter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- class BlockAutoregressiveNetwork(
- key,
- *,
- dim,
- cond_dim=None,
- depth,
- block_dim,
- activation=None,
Bases:
AbstractBijection
Block Autoregressive Network (https://arxiv.org/abs/1904.04676).
Note that in contrast to the original paper which uses tanh activations, by default we use \(\tanh(x) + 0.01 * x\). This ensures the image of the activation is the set of real values, which ensures the transform maps real -> real, avoiding various issues (see danielward27/flowjax#102).
- Parameters:
key (
PRNGKeyArray
) – Jax keydim (
int
) – Dimension of the distribution.cond_dim (
int
|None
) – Dimension of conditioning variables. Defaults to None.depth (
int
) – Number of hidden layers in the network.block_dim (
int
) – Block dimension (hidden layer size is dim*block_dim).activation (
AbstractBijection
|Callable
|None
) – Activation function, either a scalar bijection or a callable that computes the activation for a scalar value. Note that the activation should be bijective to ensure invertibility of the network and in general should map real -> real to ensure that when transforming a distribution (either with the forward or inverse), the map is defined across the support of the base distribution. Defaults to \(\tanh(x) + 0.01 * x\).
- class Chain(bijections)[source]#
Bases:
AbstractBijection
Compose arbitrary bijections to form another bijection.
If the layers you are chaining have consistent structure, consider using
Scan
, which will avoid seperately compiling each layer.- Parameters:
bijections (
Sequence
[Union
[AbstractBijection
,AbstractUnwrappable
[AbstractBijection
]]]) – Sequence of bijections. The bijection shapes must match, and any none None condition shapes must match.
- class Concatenate(bijections, axis=0)[source]#
Bases:
AbstractBijection
Concatenate bijections along an existing axis, similar to
jnp.concatenate
.See also
Stack
.Example
>>> from flowjax.bijections import Concatenate, Affine, Exp >>> concat = Concatenate([Affine(jnp.ones((2, 3))), Exp((2,3))]) >>> concat.shape (4, 3)
- Parameters:
bijections (
Sequence
[AbstractBijection
]) – Bijections, to stack into a single bijection.axis (
int
) – Axis along which to stack. Defaults to 0.
- class Coupling(
- key,
- *,
- transformer,
- untransformed_dim,
- dim,
- cond_dim=None,
- nn_width,
- nn_depth,
- nn_activation=<jax._src.custom_derivatives.custom_jvp object>,
Bases:
AbstractBijection
Coupling layer implementation (https://arxiv.org/abs/1605.08803).
- Parameters:
key (
PRNGKeyArray
) – Jax keytransformer (
AbstractBijection
) – Unconditional bijection with shape () to be parameterised by the conditioner neural netork. Parameters wrapped withNonTrainable
are excluded from being parameterized.untransformed_dim (
int
) – Number of untransformed conditioning variables (e.g. dim//2).dim (
int
) – Total dimension.cond_dim (
int
|None
) – Dimension of additional conditioning variables. Defaults to None.nn_width (
int
) – Neural network hidden layer width.nn_depth (
int
) – Neural network hidden layer size.nn_activation (
Callable
) – Neural network activation function. Defaults to jnn.relu.
- class EmbedCondition(bijection, embedding_net, raw_cond_shape)[source]#
Bases:
AbstractBijection
Wrap a bijection to include an embedding network.
Generally this is used to reduce the dimensionality of the conditioning variable. The returned bijection has cond_dim equal to the raw condition size.
- Parameters:
bijection (
AbstractBijection
) – Bijection withbijection.cond_dim
equal to the embedded size.embedding_net (
Callable
) – A callable (e.g. equinox module) that embeds a conditioning variable to sizebijection.cond_dim
.raw_cond_shape (
tuple
[int
,...
]) – The dimension of the raw conditioning variable.
- class Exp(shape=())[source]#
Bases:
AbstractBijection
Elementwise exponential transform (forward) and log transform (inverse).
- class Flip(shape=())[source]#
Bases:
AbstractBijection
Flip the input array. Condition argument is ignored.
- class Identity(shape=())[source]#
Bases:
AbstractBijection
The identity bijection.
- class Indexed(bijection, idxs, shape)[source]#
Bases:
AbstractBijection
Applies bijection to specific indices of an input.
- Parameters:
- class Invert(bijection)[source]#
Bases:
AbstractBijection
Invert a bijection.
This wraps a bijection, such that the transform methods become the inverse methods and vice versa. Note that in general, we define bijections such that the forward methods are preffered, i.e. faster/actually implemented. For training flows, we generally want the inverse method (used in density evaluation), to be faster. Hence it is often useful to use this class to achieve this aim.
- Parameters:
bijection (
AbstractBijection
) – Bijection to invert.
- class LeakyTanh(max_val, shape=())[source]#
Bases:
AbstractBijection
Tanh bijection, with a linear transformation beyond +/- max_val.
The value and gradient of the linear segments are set to match tanh at +/- max_val. This bijection can be useful to encourage values to be within an interval, whilst avoiding numerical precision issues, or in cases we require a real -> real mapping so Tanh is not appropriate.
- class Loc(loc)[source]#
Bases:
AbstractBijection
Location transformation \(y = a \cdot x + b\).
- Parameters:
loc (
ArrayLike
) – Scale parameter. Defaults to 1.
- class MaskedAutoregressive(
- key,
- *,
- transformer,
- dim,
- cond_dim=None,
- nn_width,
- nn_depth,
- nn_activation=<jax._src.custom_derivatives.custom_jvp object>,
Bases:
AbstractBijection
Masked autoregressive bijection.
The transformer is parameterised by a neural network, with weights masked to ensure an autoregressive structure.
- Parameters:
key (
PRNGKeyArray
) – Jax keytransformer (
AbstractBijection
) – Bijection with shape () to be parameterised by the autoregressive network. Parameters wrapped withNonTrainable
are exluded.dim (
int
) – Dimension.cond_dim (
int
|None
) – Dimension of any conditioning variables. Defaults to None.nn_width (
int
) – Neural network width.nn_depth (
int
) – Neural network depth.nn_activation (
Callable
) – Neural network activation. Defaults to jnn.relu.
- class NumericalInverse(bijection, inverter)[source]#
Bases:
AbstractBijection
Bijection wrapper to provide inverse methods using e.g. root finding.
- Parameters:
bijection (
AbstractBijection
) – The bijection to add an inverse to.inverter (
Callable
[[AbstractBijection
,Array
,Array
|None
],Array
]) – Callable implementing the numerical inversion method. Should accept the bijection, y and condition as arguments, and return the inverse.
- class Permute(permutation)[source]#
Bases:
AbstractBijection
Permutation transformation.
- Parameters:
permutation (
Union
[Int[Array, '...']
,Int[ndarray, '...']
]) – An array with shape matching the array to transform, with elements 0-(array.size-1) representing the new order based on the flattened array (uses, C-like ordering).
- class Planar(key, *, dim, cond_dim=None, negative_slope=None, **mlp_kwargs)[source]#
Bases:
AbstractBijection
Planar bijection as used by https://arxiv.org/pdf/1505.05770.pdf.
Uses the transformation
\[\boldsymbol{y}=\boldsymbol{x} + \boldsymbol{u} \cdot \text{tanh}(\boldsymbol{w}^T \boldsymbol{x} + b)\]where \(\boldsymbol{u} \in \mathbb{R}^D, \ \boldsymbol{w} \in \mathbb{R}^D\) and \(b \in \mathbb{R}\). In the unconditional case, the (unbounded) parameters are learned directly. In the unconditional case they are parameterised by an MLP.
- Parameters:
key (
PRNGKeyArray
) – Jax random key.dim (
int
) – Dimension of the bijection.cond_dim (
int
|None
) – Dimension of extra conditioning variables. Defaults to None.negative_slope (
float
|None
) – A positive float. If provided, then a leaky relu activation (with the corresponding negative slope) is used instead of tanh. This also provides the advantage that the bijection can be inverted analytically.**mlp_kwargs – Keyword arguments (excluding in_size and out_size) passed to the MLP (
equinox.nn.MLP
). Ignored whencond_dim
is None.
- class Power(exponent, shape=())[source]#
Bases:
AbstractBijection
Power transform \(y = x^p\).
Supports postive values, over which this is a bijection.
- class RationalQuadraticSpline(*, knots, interval, min_derivative=0.001, softmax_adjust=0.01)[source]#
Bases:
AbstractBijection
Scalar RationalQuadraticSpline transformation (https://arxiv.org/abs/1906.04032).
- Parameters:
knots (
int
) – Number of knots.interval (
float
|int
|tuple
[int
|float
,int
|float
]) – Interval to transform, if a scalar value, uses [-interval, interval], if a tuple, uses [interval[0], interval[1]]min_derivative (
float
) – Minimum dervivative. Defaults to 1e-3.softmax_adjust (
float
|int
) – Controls minimum bin width and height by rescaling softmax output, e.g. 0=no adjustment, 1=average softmax output with evenly spaced widths, >1 promotes more evenly spaced widths. Seereal_to_increasing_on_interval
. Defaults to 1e-2.
- class Reshape(bijection, shape=None, cond_shape=None)[source]#
Bases:
AbstractBijection
Wraps bijection methods with reshaping operations.
One use case for this is for bijections that do not directly support a scalar shape, but this allows construction with shape (1, ) and reshaping to ().
- Parameters:
Example
>>> import jax.numpy as jnp >>> from flowjax.bijections import Affine, Reshape >>> affine = Affine(loc=jnp.arange(4)) >>> affine.shape (4,) >>> affine = Reshape(affine, (2,2)) >>> affine.shape (2, 2) >>> affine.transform(jnp.zeros((2,2))) Array([[0., 1.], [2., 3.]], dtype=float32)
- class Scale(scale)[source]#
Bases:
AbstractBijection
Scale transformation \(y = a \cdot x\).
- Parameters:
scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- class Scan(bijection)[source]#
Bases:
AbstractBijection
Repeatedly apply the same bijection with different parameter values.
Internally, uses jax.lax.scan to reduce compilation time. Often it is convenient to construct these using
equinox.filter_vmap
.- Parameters:
bijection (
AbstractBijection
) – A bijection, in which the arrays leaves have an additional leading axis to scan over. It is often can convenient to create compatible bijections withequinox.filter_vmap
.
Example
Below is equivilent to
Chain([Affine(p) for p in params])
.>>> from flowjax.bijections import Scan, Affine >>> import jax.numpy as jnp >>> import equinox as eqx >>> params = jnp.ones((3, 2)) >>> affine = eqx.filter_vmap(Affine)(params) >>> affine = Scan(affine)
- class Sigmoid(shape=())[source]#
Bases:
AbstractBijection
Sigmoid bijection \(y = \sigma(x) = \frac{1}{1 + \exp(-x)}\).
- class SoftPlus(shape=())[source]#
Bases:
AbstractBijection
Transforms to positive domain using softplus \(y = \log(1 + \exp(x))\).
- class Stack(bijections, axis=0)[source]#
Bases:
AbstractBijection
Stack bijections along a new axis (analagous to
jnp.stack
).See also
Concatenate
.Example
>>> from flowjax.bijections import Stack, Affine, Exp >>> concat = Stack([Affine(jnp.ones(3)), Exp((3,))]) >>> concat.shape (2, 3)
- Parameters:
bijections (
list
[AbstractBijection
]) – Bijections.axis (
int
) – Axis along which to stack. Defaults to 0.
- class Tanh(shape=())[source]#
Bases:
AbstractBijection
Tanh bijection \(y=\tanh(x)\).
- class TriangularAffine(loc, arr, *, lower=True)[source]#
Bases:
AbstractBijection
A triangular affine transformation.
Transformation has the form \(Ax + b\), where \(A\) is a lower or upper triangular matrix, and \(b\) is the bias vector. We assume the diagonal entries are positive, and constrain the values using softplus. Other parameterizations can be achieved by e.g. replacing
self.triangular
after construction.- Parameters:
loc (
Shaped[ArrayLike, '#dim']
) – Location parameter. If this is scalar, it is broadcast to the dimension inferred from arr.arr (
Shaped[Array, 'dim dim']
) – Triangular matrix.lower (
bool
) – Whether the mask should select the lower or upper triangular matrix (other elements ignored). Defaults to True (lower).
- class Vmap(bijection, *, in_axes=None, axis_size=None, in_axes_condition=None)[source]#
Bases:
AbstractBijection
Applies vmap to bijection methods to add a batch dimension to the bijection.
- Parameters:
bijection (
AbstractBijection
) – The bijection to vectorize.in_axes (
PyTree
|None
|int
|Callable
) – Specify which axes of the bijection parameters to vectorise over. It should be a PyTree ofNone
,int
with the tree structure being a prefix of the bijection, or a callable mappingLeaf -> Union[None, int]
. Note, if the bijection contains unwrappables, then in_axes should be specified for the unwrapped structure of the bijection. Defaults to None.axis_size (
int
|None
) – The size of the new axis. This should be left unspecified if in_axes is provided, as the size can be inferred from the bijection parameters. Defaults to None.in_axes_condition (
int
|None
) – Optionally define an axis of the conditioning variable to vectorize over. Defaults to None.
Example
>>> # Add a bijection batch dimension, mapping over bijection parameters >>> import jax.numpy as jnp >>> import equinox as eqx >>> from flowjax.bijections import Vmap, RationalQuadraticSpline, Affine >>> bijection = eqx.filter_vmap( ... lambda: RationalQuadraticSpline(knots=5, interval=2), ... axis_size=10 ... )() >>> bijection = Vmap(bijection, in_axes=eqx.if_array(0)) >>> bijection.shape (10,) >>> # Add a bijection batch dimension, broadcasting bijection parameters: >>> bijection = RationalQuadraticSpline(knots=5, interval=2) >>> bijection = Vmap(bijection, axis_size=10) >>> bijection.shape (10,)
A more advanced use case is to create bijections with more fine grained control over parameter broadcasting. For example, the
Affine
constructor broadcasts the location and scale parameters during initialization. What if we want anAffine
bijection, with a global scale parameter, but an elementwise location parameter? We could achieve this as follows.>>> from jax.tree_util import tree_map >>> import paramax >>> bijection = Affine(jnp.zeros(()), jnp.ones(())) >>> bijection = eqx.tree_at(lambda bij: bij.loc, bijection, jnp.arange(3)) >>> in_axes = tree_map(lambda _: None, paramax.unwrap(bijection)) >>> in_axes = eqx.tree_at( ... lambda bij: bij.loc, in_axes, 0, is_leaf=lambda x: x is None ... ) >>> bijection = Vmap(bijection, in_axes=in_axes) >>> bijection.shape (3,) >>> bijection.bijection.loc.shape (3,) >>> paramax.unwrap(bijection.bijection.scale).shape () >>> x = jnp.ones(3) >>> bijection.transform(x) Array([1., 2., 3.], dtype=float32)