Wrappers#

AbstractUnwrappable objects and utilities.

These are “placeholder” values for specifying custom behaviour for nodes in a pytree. Many of these facilitate similar functions to pytorch parameterizations. We use this for example to apply parameter constraints, masking of parameters etc. To apply the behaviour, we use unwrap(), which will replace any AbstractUnwrappable nodes in a pytree with the unwrapped versions.

Unwrapping is automatically called in several places, primarily:

  • Prior to calling the bijection methods: transform, inverse, transform_and_log_det and inverse_and_log_det.

  • Prior to calling distribution methods: log_prob, sample and sample_and_log_prob.

  • Prior to computing the loss functions.

Note

If creating a custom unwrappable, remember that unwrapping will generally occur after initialization of the model. Because of this, we recommend ensuring that the unwrap method supports unwrapping if the model is constructed in a vectorized context, such as eqx.filter_vmap, e.g. through broadcasting or vectorization.

unwrap(tree)[source]#

Recursively unwraps all AbstractUnwrappable nodes within a pytree.

This leaves all other nodes unchanged. If nested, the innermost AbstractUnwrappable nodes are unwrapped first.

Example

Enforcing positivity.

>>> from flowjax.wrappers import Parameterize, unwrap
>>> import jax.numpy as jnp
>>> params = Parameterize(jnp.exp, jnp.zeros(3))
>>> unwrap(("abc", 1, params))
('abc', 1, Array([1., 1., 1.], dtype=float32))
class AbstractUnwrappable[source]#

An abstract class representing an unwrappable object.

Unwrappables generally replace nodes in a pytree, in order to specify some custom behaviour to apply upon unwrapping before use. This can be used e.g. to apply parameter constraints, such as making scale parameters postive, or applying stop_gradient before accessing the parameters.

abstract unwrap()[source]#

Returns the unwrapped pytree, assuming no wrapped subnodes exist.

Return type:

TypeVar(T)

class Parameterize(fn, *args, **kwargs)[source]#

Unwrap an object by calling fn with args and kwargs.

All of fn, args and kwargs may contain trainable parameters. If the Parameterize is created within eqx.filter_vmap, unwrapping is automatically vectorized correctly, as long as the vmapped constructor adds leading batch dimensions to all arrays (the default for eqx.filter_vmap).

Example

>>> from flowjax.wrappers import Parameterize, unwrap
>>> import jax.numpy as jnp
>>> positive = Parameterize(jnp.exp, jnp.zeros(3))
>>> unwrap(positive)  # Aplies exp on unwrapping
Array([1., 1., 1.], dtype=float32)
Parameters:
  • fn (Callable) – Callable to call with args, and kwargs.

  • *args – Positional arguments to pass to fn.

  • **kwargs – Keyword arguments to pass to fn.

fn: Callable[..., TypeVar(T)]#
args: Iterable#
kwargs: dict[str, Any]#
unwrap()[source]#
Return type:

TypeVar(T)

class NonTrainable(tree)[source]#

Applies stop gradient to all arraylike leaves before unwrapping.

See also non_trainable(), which is probably a generally prefereable way to achieve similar behaviour, which wraps the arraylike leaves directly, rather than the tree. Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable. We also filter out NonTrainable nodes when partitioning parameters for training, or when parameterizing bijections in coupling/masked autoregressive flows (transformers).

tree: TypeVar(T)#
unwrap()[source]#
Return type:

TypeVar(T)

non_trainable(tree)[source]#

Freezes parameters by wrapping inexact array leaves with NonTrainable.

Note

Regularization is likely to apply before unwrapping. To avoid regularization impacting non-trainable parameters, they should be filtered out, for example using:

>>> eqx.partition(
...     ...,
...     is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable),
... )

This is done in both fit_to_data() and fit_to_key_based_loss().

Wrapping the arrays rather than the entire tree is often preferable, allowing easier access to attributes compared to wrapping the entire tree.

Parameters:

tree (PyTree) – The pytree.

class WeightNormalization(weight)[source]#

Applies weight normalization (https://arxiv.org/abs/1602.07868).

Parameters:

weight (Union[Array, AbstractUnwrappable[Array]]) – The (possibly wrapped) weight matrix.

weight: Union[Array, AbstractUnwrappable[Array]]#
scale: Union[Array, AbstractUnwrappable[Array]]#
unwrap()[source]#
Return type:

Array