Wrappers
AbstractUnwrappable
objects and utilities.
These are “placeholder” values for specifying custom behaviour for nodes in a pytree.
Many of these facilitate similar functions to pytorch parameterizations. We use this
for example to apply parameter constraints, masking of parameters etc. To apply the
behaviour, we use unwrap()
, which will replace any AbstractUnwrappable
nodes in a pytree with the unwrapped versions.
Unwrapping is automatically called in several places, primarily:
Prior to calling the bijection methods:
transform
,inverse
,transform_and_log_det
andinverse_and_log_det
.Prior to calling distribution methods:
log_prob
,sample
andsample_and_log_prob
.Prior to computing the loss functions.
If implementing a custom unwrappable, bear in mind:
The wrapper should avoid implementing information or logic beyond what is required for initialization and unwrapping, as this information will be lost when unwrapping.
The unwrapping should support broadcasting/vmapped initializations. Otherwise, if the unwrappable is created within a batched context, it will fail to unwrap correctly.
- class AbstractUnwrappable[source]
An abstract class representing an unwrappable object.
Unwrappables generally replace nodes in a pytree, in order to specify some custom behaviour to apply upon unwrapping before use. This can be used e.g. to apply parameter constraints, such as making scale parameters postive, or applying stop_gradient before accessing the parameters.
If
_dummy
is set to an array (must have shape ()), this is used for inferring vmapped dimensions (and sizes) when callingunwrap
to automatically vecotorize the method. In some cases this is important for supporting the case where anAbstractUnwrappable
is created within e.g.eqx.filter_vmap
.
- class BijectionReparam(arr, bijection, *, invert_on_init=True)[source]
Reparameterize a parameter using a bijection.
When applying unwrap,
bijection.transform
is applied. By default, the inverse of the bijection is applied when setting the parameter values.- Parameters:
arr (
Union
[Array
,AbstractUnwrappable
[Array
]]) – The parameter to reparameterize. If invert_on_init is False, then this can be aAbstractUnwrappable[Array]
.bijection (
AbstractBijection
) – A bijection whose shape is broadcastable tojnp.shape(arr)
.invert_on_init (
bool
) – Whether to apply the inverse transformation when initializing. Defaults to True.
-
arr:
Union
[Array
,AbstractUnwrappable
[Array
]]
-
bijection:
AbstractBijection
- class Lambda(fn, *args, **kwargs)[source]
Unwrap an object by calling fn with (possibly trainable) args and kwargs.
If the Lambda is created within
eqx.filter_vmap
, unwrapping is automatically vectorized correctly, as long as the vmapped constructor adds leading batch dimensions to all arrays in Lambda (the default foreqx.filter_vmap
).- Parameters:
fn – Function to call with args, and kwargs.
*args – Positional arguments to pass to fn.
**kwargs – Keyword arguments to pass to fn.
- class NonTrainable(tree)[source]
Applies stop gradient to all arraylike leaves before unwrapping.
Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable. We also filter out these modules when partitioning parameters for training, or when parameterizing bijections in coupling/masked autoregressive flows (transformers).
- class WeightNormalization(weight)[source]
Applies weight normalization (https://arxiv.org/abs/1602.07868).
- Parameters:
weight (
Union
[Array
,AbstractUnwrappable
[Array
]]) – The (possibly wrapped) weight matrix.
-
scale:
Union
[Array
,AbstractUnwrappable
[Array
]]
- unwrap()[source]
Returns the unwrapped pytree, assuming no wrapped subnodes exist.
- Return type:
Array
-
weight:
Union
[Array
,AbstractUnwrappable
[Array
]]
- class Where(cond, if_true, if_false)[source]
Applies jnp.where upon unwrapping.
This can be used to construct masks by setting
cond=mask
andif_false=0
.-
cond:
ArrayLike
-
if_false:
Union
[ArrayLike
,AbstractUnwrappable
[Array
]]
-
if_true:
Union
[ArrayLike
,AbstractUnwrappable
[Array
]]
-
cond:
- unwrap(tree)[source]
Unwrap all
AbstractUnwrappable
nodes within a pytree.