Wrappers#
AbstractUnwrappable
objects and utilities.
These are “placeholder” values for specifying custom behaviour for nodes in a pytree.
Many of these facilitate similar functions to pytorch parameterizations. We use this
for example to apply parameter constraints, masking of parameters etc. To apply the
behaviour, we use unwrap()
, which will replace any AbstractUnwrappable
nodes in a pytree with the unwrapped versions.
Unwrapping is automatically called in several places, primarily:
Prior to calling the bijection methods:
transform
,inverse
,transform_and_log_det
andinverse_and_log_det
.Prior to calling distribution methods:
log_prob
,sample
andsample_and_log_prob
.Prior to computing the loss functions.
Note
If creating a custom unwrappable, remember that unwrapping will generally occur
after initialization of the model. Because of this, we recommend ensuring that
the unwrap
method supports unwrapping if the model is constructed in a
vectorized context, such as eqx.filter_vmap
, e.g. through broadcasting or
vectorization.
- unwrap(tree)[source]#
Recursively unwraps all
AbstractUnwrappable
nodes within a pytree.This leaves all other nodes unchanged. If nested, the innermost
AbstractUnwrappable
nodes are unwrapped first.Example
Enforcing positivity.
>>> from flowjax.wrappers import Parameterize, unwrap >>> import jax.numpy as jnp >>> params = Parameterize(jnp.exp, jnp.zeros(3)) >>> unwrap(("abc", 1, params)) ('abc', 1, Array([1., 1., 1.], dtype=float32))
- class AbstractUnwrappable[source]#
An abstract class representing an unwrappable object.
Unwrappables generally replace nodes in a pytree, in order to specify some custom behaviour to apply upon unwrapping before use. This can be used e.g. to apply parameter constraints, such as making scale parameters postive, or applying stop_gradient before accessing the parameters.
- class Parameterize(fn, *args, **kwargs)[source]#
Unwrap an object by calling fn with args and kwargs.
All of fn, args and kwargs may contain trainable parameters. If the Parameterize is created within
eqx.filter_vmap
, unwrapping is automatically vectorized correctly, as long as the vmapped constructor adds leading batch dimensions to all arrays (the default foreqx.filter_vmap
).Example
>>> from flowjax.wrappers import Parameterize, unwrap >>> import jax.numpy as jnp >>> positive = Parameterize(jnp.exp, jnp.zeros(3)) >>> unwrap(positive) # Aplies exp on unwrapping Array([1., 1., 1.], dtype=float32)
- Parameters:
fn (
Callable
) – Callable to call with args, and kwargs.*args – Positional arguments to pass to fn.
**kwargs – Keyword arguments to pass to fn.
- class NonTrainable(tree)[source]#
Applies stop gradient to all arraylike leaves before unwrapping.
See also
non_trainable()
, which is probably a generally prefereable way to achieve similar behaviour, which wraps the arraylike leaves directly, rather than the tree. Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable. We also filter out NonTrainable nodes when partitioning parameters for training, or when parameterizing bijections in coupling/masked autoregressive flows (transformers).
- non_trainable(tree)[source]#
Freezes parameters by wrapping inexact array leaves with
NonTrainable
.Note
Regularization is likely to apply before unwrapping. To avoid regularization impacting non-trainable parameters, they should be filtered out, for example using:
>>> eqx.partition( ... ..., ... is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable), ... )
This is done in both
fit_to_data()
andfit_to_key_based_loss()
.Wrapping the arrays rather than the entire tree is often preferable, allowing easier access to attributes compared to wrapping the entire tree.
- Parameters:
tree (
PyTree
) – The pytree.
- class WeightNormalization(weight)[source]#
Applies weight normalization (https://arxiv.org/abs/1602.07868).
- Parameters:
weight (
Union
[Array
,AbstractUnwrappable
[Array
]]) – The (possibly wrapped) weight matrix.
-
weight:
Union
[Array
,AbstractUnwrappable
[Array
]]#
-
scale:
Union
[Array
,AbstractUnwrappable
[Array
]]#