Bijections
Bijections from flowjax.bijections
.
- class AbstractBijection[source]
Bases:
Module
Bijection abstract class.
Similar to
AbstractDistribution
, bijections have ashape
and acond_shape
attribute. To allow easy composing of bijections, all bijections support passing of conditioning variables (even if ignored).The methods of bijections do not generally support passing of additional batch dimensions, however,
jax.vmap
oreqx.filter_vmap
can be used to vmap specific methods if desired, and a bijection can be explicitly vectorised using theVmap
bijection.Bijections are registered as Jax PyTrees (as they are equinox modules), so are compatible with normal jax operations.
Implementing a bijection
Inherit from
AbstractBijection
.Define the attributes
shape
andcond_shape
. Acond_shape
ofNone
is used to represent unconditional bijections.Implement the abstract methods
transform
,transform_and_log_det
,inverse
andinverse_and_log_det
. These should act on inputs compatible with the shapesshape
forx
, andcond_shape
forcondition
.
- abstract inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y (
Array
) – Input array with shape matching bijection.shapecondition (
Array
|None
) – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- Return type:
Array
- abstract inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- abstract transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x (
Array
) – Input with shape matchingbijections.shape
.condition (
Array
|None
) – Condition, with shape matchingbijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- Return type:
Array
- class AdditiveCondition(module, shape, cond_shape)[source]
Bases:
AbstractBijection
Given a callable
f
, carries out the transformationy = x + f(condition)
.If used to transform a distribution, this allows the “location” to be changed as a function of the conditioning variables. Note that the callable can be a callable module with trainable parameters.
- Parameters:
Example
Conditioning using a linear transformation
>>> from flowjax.bijections import AdditiveCondition >>> from equinox.nn import Linear >>> import jax.numpy as jnp >>> import jax.random as jr >>> bijection = AdditiveCondition( ... Linear(2, 3, key=jr.PRNGKey(0)), shape=(3,), cond_shape=(2,) ... ) >>> bijection.transform(jnp.ones(3), condition=jnp.ones(2)) Array([1.9670618, 0.8156546, 1.7763454], dtype=float32)
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- class Affine(loc=0, scale=1)[source]
Bases:
AbstractBijection
Elementwise affine transformation
y = a*x + b
.loc
andscale
should broadcast to the desired shape of the bijection. By default, we constrain the scale parameter to be postive usingSoftPlus
, but other parameterizations can be achieved by replacing the scale parameter after construction e.g. usingeqx.tree_at
.- Parameters:
loc (
ArrayLike
) – Location parameter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
loc:
Array
-
scale:
Union
[Array
,AbstractUnwrappable
[Array
]]
- class BlockAutoregressiveNetwork(key, *, dim, cond_dim=None, depth, block_dim, activation=None, inverter=None)[source]
Bases:
AbstractBijection
Block Autoregressive Network (https://arxiv.org/abs/1904.04676).
Note that in contrast to the original paper which uses tanh activations, by default we use
LeakyTanh
. This ensures the codomain of the activation is the set of real values, which will ensure properly normalised densities (see https://github.com/danielward27/flowjax/issues/102).- Parameters:
key (
PRNGKeyArray
) – Jax PRNGKeydim (
int
) – Dimension of the distribution.cond_dim (
int
|None
) – Dimension of conditioning variables. Defaults to None.depth (
int
) – Number of hidden layers in the network.block_dim (
int
) – Block dimension (hidden layer size is dim*block_dim).activation (
AbstractBijection
|Callable
|None
) – Activation function, either a scalar bijection or a callable that computes the activation for a scalar value. Note that the activation should be bijective to ensure invertibility of the network and in general should map real -> real to ensure that when transforming a distribution (either with the forward or inverse), the map is defined across the support of the base distribution. Defaults toLeakyTanh(3)
.inverter (
Callable
|None
) – Callable that implements the required numerical method to invert theBlockAutoregressiveNetwork
bijection. Must have the signatureinverter(bijection, y, condition=None)
. Defaults toAutoregressiveBisectionInverter
.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
activation:
AbstractBijection
- class Chain(bijections)[source]
Bases:
AbstractBijection
Chain together arbitrary bijections to form another bijection.
- Parameters:
bijections (
Sequence
[Union
[AbstractBijection
,AbstractUnwrappable
[AbstractBijection
]]]) – Sequence of bijections. The bijection shapes must match, and any none None condition shapes must match.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
bijections:
tuple
[Union
[AbstractBijection
,AbstractUnwrappable
[AbstractBijection
]],...
]
- class Concatenate(bijections, axis=0)[source]
Bases:
AbstractBijection
Concatenate bijections along an existing axis, similar to
jnp.concatenate
.See also
Stack
.- Parameters:
bijections (
Sequence
[AbstractBijection
]) – Bijections, to stack into a single bijection.axis (
int
) – Axis along which to stack. Defaults to 0.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
bijections:
Sequence
[AbstractBijection
]
- class Coupling(key, *, transformer, untransformed_dim, dim, cond_dim=None, nn_width, nn_depth, nn_activation=<jax._src.custom_derivatives.custom_jvp object>)[source]
Bases:
AbstractBijection
Coupling layer implementation (https://arxiv.org/abs/1605.08803).
- Parameters:
key (
PRNGKeyArray
) – Jax PRNGKeytransformer (
AbstractBijection
) – Unconditional bijection with shape () to be parameterised by the conditioner neural netork. Parameters wrapped withNonTrainable
are excluded from being parameterized.untransformed_dim (
int
) – Number of untransformed conditioning variables (e.g. dim//2).dim (
int
) – Total dimension.cond_dim (
int
|None
) – Dimension of additional conditioning variables. Defaults to None.nn_width (
int
) – Neural network hidden layer width.nn_depth (
int
) – Neural network hidden layer size.nn_activation (
Callable
) – Neural network activation function. Defaults to jnn.relu.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
conditioner:
MLP
- class EmbedCondition(bijection, embedding_net, raw_cond_shape)[source]
Bases:
AbstractBijection
Wrap a bijection to include an embedding network.
Generally this is used to reduce the dimensionality of the conditioning variable. The returned bijection has cond_dim equal to the raw condition size.
- Parameters:
bijection (
AbstractBijection
) – Bijection withbijection.cond_dim
equal to the embedded size.embedding_net (
Callable
) – A callable (e.g. equinox module) that embeds a conditioning variable to sizebijection.cond_dim
.raw_cond_shape (
tuple
[int
,...
]) – The dimension of the raw conditioning variable.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
bijection:
AbstractBijection
- property shape
- class Exp(shape=())[source]
Bases:
AbstractBijection
Elementwise exponential transform (forward) and log transform (inverse).
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- class Flip(shape=())[source]
Bases:
AbstractBijection
Flip the input array. Condition argument is ignored.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- class Identity(shape=())[source]
Bases:
AbstractBijection
The identity bijection.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- class Invert(bijection)[source]
Bases:
AbstractBijection
Invert a bijection.
This wraps a bijection, such that the transform methods become the inverse methods and vice versa. Note that in general, we define bijections such that the forward methods are preffered, i.e. faster/actually implemented. For training flows, we generally want the inverse method (used in density evaluation), to be faster. Hence it is often useful to use this class to achieve this aim.
- Parameters:
bijection (
AbstractBijection
) – Bijection to invert.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
bijection:
AbstractBijection
- property cond_shape
- property shape
- class LeakyTanh(max_val, shape=())[source]
Bases:
AbstractBijection
Tanh bijection, with a linear transformation beyond +/- max_val.
The value and gradient of the linear segments are set to match tanh at +/- max_val. This bijection can be useful to encourage values to be within an interval, whilst avoiding numerical precision issues, or in cases we require a real -> real mapping so Tanh is not appropriate.
- Parameters:
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- class Loc(loc)[source]
Bases:
AbstractBijection
Location transformation
y = x + c
.- Parameters:
loc (
ArrayLike
) – Scale parameter. Defaults to 1.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
loc:
Array
- class MaskedAutoregressive(key, *, transformer, dim, cond_dim=None, nn_width, nn_depth, nn_activation=<jax._src.custom_derivatives.custom_jvp object>)[source]
Bases:
AbstractBijection
Masked autoregressive bijection.
The transformer is parameterised by a neural network, with weights masked to ensure an autoregressive structure.
- Parameters:
key (
PRNGKeyArray
) – Jax PRNGKeytransformer (
AbstractBijection
) – Bijection with shape () to be parameterised by the autoregressive network. Parameters wrapped withNonTrainable
are exluded.dim (
int
) – Dimension.cond_dim (
int
|None
) – Dimension of any conditioning variables. Defaults to None.nn_width (
int
) – Neural network width.nn_depth (
int
) – Neural network depth.nn_activation (
Callable
) – Neural network activation. Defaults to jnn.relu.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
masked_autoregressive_mlp:
MLP
- class Partial(bijection, idxs, shape)[source]
Bases:
AbstractBijection
Applies bijection to specific indices of an input.
- Parameters:
bijection (
AbstractBijection
) – Bijection that is compatible with the subset of x indexed by idxs. idxs: Indices (Integer, a slice, or an ndarray with integer/bool dtype) of the transformed portion.idxs (
int
|slice
|Array
|tuple
) – The indexes to transform.shape (
tuple
[int
,...
]) – Shape of the bijection. Defaults to None.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
bijection:
AbstractBijection
- property cond_shape
- class Permute(permutation)[source]
Bases:
AbstractBijection
Permutation transformation.
- Parameters:
permutation (
Union
[Int[Array, '...']
,Int[ndarray, '...']
]) – An array with shape matching the array to transform, with elements 0-(array.size-1) representing the new order based on the flattened array (uses, C-like ordering).
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- class Planar(key, *, dim, cond_dim=None, **mlp_kwargs)[source]
Bases:
AbstractBijection
Planar bijection as used by https://arxiv.org/pdf/1505.05770.pdf.
Uses the transformation \(y + u \cdot \text{tanh}(w \cdot x + b)\), where \(u \in \mathbb{R}^D, \ w \in \mathbb{R}^D\) and \(b \in \mathbb{R}\). In the unconditional case, \(w\), \(u\) and \(b\) are learned directly. In the conditional case they are parameterised by an MLP.
- Parameters:
- get_planar(condition=None)[source]
Get the planar bijection with the conditioning applied if conditional.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- class RationalQuadraticSpline(*, knots, interval, min_derivative=0.001, softmax_adjust=0.01)[source]
Bases:
AbstractBijection
Scalar RationalQuadraticSpline transformation (https://arxiv.org/abs/1906.04032).
- Parameters:
knots (
int
) – Number of knots.interval (
float
|int
) – interval to transform, [-interval, interval].min_derivative (
float
) – Minimum dervivative. Defaults to 1e-3.softmax_adjust (
float
|int
) – Controls minimum bin width and height by rescaling softmax output, e.g. 0=no adjustment, 1=average softmax output with evenly spaced widths, >1 promotes more evenly spaced widths. Seereal_to_increasing_on_interval
. Defaults to 1e-2.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
derivatives:
Union
[Array
,AbstractUnwrappable
[Array
]]
-
x_pos:
Union
[Array
,AbstractUnwrappable
[Array
]]
-
y_pos:
Union
[Array
,AbstractUnwrappable
[Array
]]
- class Reshape(bijection, shape=None, cond_shape=None)[source]
Bases:
AbstractBijection
Wraps bijection methods with reshaping operations.
One use case for this is for bijections that do not directly support a scalar shape, but this allows construction with shape (1, ) and reshaping to ().
- Parameters:
Example
>>> import jax.numpy as jnp >>> from flowjax.bijections import Affine, Reshape >>> affine = Affine(loc=jnp.arange(4)) >>> affine.shape (4,) >>> affine = Reshape(affine, (2,2)) >>> affine.shape (2, 2) >>> affine.transform(jnp.zeros((2,2))) Array([[0., 1.], [2., 3.]], dtype=float32)
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
bijection:
AbstractBijection
- class Scale(scale)[source]
Bases:
AbstractBijection
Scale transformation
y = a*x
.- Parameters:
scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
scale:
Union
[Array
,AbstractUnwrappable
[Array
]]
- class Scan(bijection)[source]
Bases:
AbstractBijection
Repeatedly apply the same bijection with different parameter values.
Internally, uses jax.lax.scan to reduce compilation time. Often it is convenient to construct these using
equinox.filter_vmap
.- Parameters:
bijection (
AbstractBijection
) – A bijection, in which the arrays leaves have an additional leading axis to scan over. It is often can convenient to create compatible bijections withequinox.filter_vmap
.
Example
Below is equivilent to
Chain([Affine(p) for p in params])
.>>> from flowjax.bijections import Scan, Affine >>> import jax.numpy as jnp >>> import equinox as eqx >>> params = jnp.ones((3, 2)) >>> affine = eqx.filter_vmap(Affine)(params) >>> affine = Scan(affine)
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
bijection:
AbstractBijection
- property cond_shape
- property shape
- class SoftPlus(shape=())[source]
Bases:
AbstractBijection
Transforms to positive domain using softplus \(y = \log(1 + \exp(x))\).
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- class Stack(bijections, axis=0)[source]
Bases:
AbstractBijection
Stack bijections along a new axis (analagous to
jnp.stack
).See also
Concatenate
.- Parameters:
bijections (
list
[AbstractBijection
]) – Bijections.axis (
int
) – Axis along which to stack. Defaults to 0.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
bijections:
Sequence
[AbstractBijection
]
- class Tanh(shape=())[source]
Bases:
AbstractBijection
Tanh bijection.
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- class TriangularAffine(loc, arr, *, lower=True)[source]
Bases:
AbstractBijection
A triangular affine transformation.
Transformation has the form \(Ax + b\), where \(A\) is a lower or upper triangular matrix, and \(b\) is the bias vector. We assume the diagonal entries are positive, and constrain the values using SoftPlus. Other parameterizations can be achieved by e.g. replacing
self.triangular
after construction.- Parameters:
loc (
Shaped[ArrayLike, '#dim']
) – Location parameter. If this is scalar, it is broadcast to the dimension inferred from arr.arr (
Shaped[Array, 'dim dim']
) – Triangular matrix.lower (
bool
) – Whether the mask should select the lower or upper triangular matrix (other elements ignored). Defaults to True (lower).
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
loc:
Array
-
triangular:
Union
[Array
,AbstractUnwrappable
[Array
]]
- class Vmap(bijection, *, in_axes=None, axis_size=None, in_axes_condition=None)[source]
Bases:
AbstractBijection
Applies vmap to bijection methods to add a batch dimension to the bijection.
- Parameters:
bijection (
AbstractBijection
) – The bijection to vectorize.in_axes (
PyTree
|None
|int
|Callable
) – Specify which axes of the bijection parameters to vectorise over. It should be a PyTree ofNone
,int
with the tree structure being a prefix of the bijection, or a callable mappingLeaf -> Union[None, int]
. Defaults to None. Note, if the bijection contains unwrappables, then in_axes should be specified for the unwrapped structure of the bijection.axis_size (
int
|None
) – The size of the new axis. This should be left unspecified if in_axes is provided, as the size can be inferred from the bijection parameters. Defaults to None.in_axes_condition (
int
|None
) – Optionally define an axis of the conditioning variable to vectorize over. Defaults to None.
Example
The two most common use cases, are shown below:
Add a batch dimension to a bijection, mapping over bijection parameters: >>> import jax.numpy as jnp >>> import equinox as eqx >>> from flowjax.bijections import Vmap, RationalQuadraticSpline, Affine >>> bijection = eqx.filter_vmap( ... lambda: RationalQuadraticSpline(knots=5, interval=2), ... axis_size=10 ... )() >>> bijection = Vmap(bijection, in_axes=eqx.if_array(0)) >>> bijection.shape (10,) Add a batch dimension to a bijection, broadcasting bijection parameters: >>> bijection = RationalQuadraticSpline(knots=5, interval=2) >>> bijection = Vmap(bijection, axis_size=10) >>> bijection.shape (10,)
A more advanced use case is to create bijections with more fine grained control over parameter broadcasting. For example, the
Affine
constructor broadcasts the location and scale parameters during initialization. What if we want anAffine
bijection, with a global scale parameter, but an elementwise location parameter? We could achieve this as follows.>>> from jax.tree_util import tree_map >>> from flowjax.wrappers import unwrap >>> bijection = Affine(jnp.zeros(()), jnp.ones(())) >>> bijection = eqx.tree_at(lambda bij: bij.loc, bijection, jnp.arange(3)) >>> in_axes = tree_map(lambda _: None, unwrap(bijection)) >>> in_axes = eqx.tree_at( ... lambda bij: bij.loc, in_axes, 0, is_leaf=lambda x: x is None ... ) >>> bijection = Vmap(bijection, in_axes=in_axes) >>> bijection.shape (3,) >>> bijection.bijection.loc.shape (3,) >>> unwrap(bijection.bijection.scale).shape ()
>>> x = jnp.ones(3) >>> bijection.transform(x) Array([1., 2., 3.], dtype=float32)
- inverse(y, condition=None)[source]
Compute the inverse transformation.
- Parameters:
y – Input array with shape matching bijection.shape
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- inverse_and_log_det(y, condition=None)[source]
Inverse transformation and corresponding log absolute jacobian determinant.
- Parameters:
y – Input array with shape matching bijection.shape.
condition – Condition array with shape matching bijection.cond_shape. Required for conditional bijections. Defaults to None.
- transform(x, condition=None)[source]
Apply the forward transformation.
- Parameters:
x – Input with shape matching
bijections.shape
.condition – Condition, with shape matching
bijection.cond_shape
, required for conditional bijections and ignored for unconditional bijections. Defaults to None.
- transform_and_log_det(x, condition=None)[source]
Apply transformation and compute the log absolute Jacobian determinant.
- Parameters:
x – Input with shape matching the bijections shape
condition – . Defaults to None.
-
bijection:
AbstractBijection
- property shape