Wrappers

AbstractUnwrappable objects and utilities.

These are “placeholder” values for specifying custom behaviour for nodes in a pytree. Many of these facilitate similar functions to pytorch parameterizations. We use this for example to apply parameter constraints, masking of parameters etc. To apply the behaviour, we use unwrap(), which will replace any AbstractUnwrappable nodes in a pytree with the unwrapped versions.

Unwrapping is automatically called in several places, primarily:

  • Prior to calling the bijection methods: transform, inverse, transform_and_log_det and inverse_and_log_det.

  • Prior to calling distribution methods: log_prob, sample and sample_and_log_prob.

  • Prior to computing the loss functions.

If implementing a custom unwrappable, bear in mind:

  • The wrapper should avoid implementing information or logic beyond what is required for initialization and unwrapping, as this information will be lost when unwrapping.

  • The unwrapping should support broadcasting/vmapped initializations. Otherwise, if the unwrappable is created within a batched context, it will fail to unwrap correctly.

class AbstractUnwrappable[source]

An abstract class representing an unwrappable object.

Unwrappables generally replace nodes in a pytree, in order to specify some custom behaviour to apply upon unwrapping before use. This can be used e.g. to apply parameter constraints, such as making scale parameters postive, or applying stop_gradient before accessing the parameters.

If _dummy is set to an array (must have shape ()), this is used for inferring vmapped dimensions (and sizes) when calling unwrap to automatically vecotorize the method. In some cases this is important for supporting the case where an AbstractUnwrappable is created within e.g. eqx.filter_vmap.

recursive_unwrap()[source]

Returns the unwrapped pytree, unwrapping subnodes as required.

Return type:

TypeVar(T)

abstract unwrap()[source]

Returns the unwrapped pytree, assuming no wrapped subnodes exist.

Return type:

TypeVar(T)

class BijectionReparam(arr, bijection, *, invert_on_init=True)[source]

Reparameterize a parameter using a bijection.

When applying unwrap, bijection.transform is applied. By default, the inverse of the bijection is applied when setting the parameter values.

Parameters:
  • arr (Union[Array, AbstractUnwrappable[Array]]) – The parameter to reparameterize. If invert_on_init is False, then this can be a AbstractUnwrappable[Array].

  • bijection (AbstractBijection) – A bijection whose shape is broadcastable to jnp.shape(arr).

  • invert_on_init (bool) – Whether to apply the inverse transformation when initializing. Defaults to True.

arr: Union[Array, AbstractUnwrappable[Array]]
bijection: AbstractBijection
unwrap()[source]

Returns the unwrapped pytree, assuming no wrapped subnodes exist.

Return type:

Array

class Lambda(fn, *args, **kwargs)[source]

Unwrap an object by calling fn with (possibly trainable) args and kwargs.

If the Lambda is created within eqx.filter_vmap, unwrapping is automatically vectorized correctly, as long as the vmapped constructor adds leading batch dimensions to all arrays in Lambda (the default for eqx.filter_vmap).

Parameters:
  • fn – Function to call with args, and kwargs.

  • *args – Positional arguments to pass to fn.

  • **kwargs – Keyword arguments to pass to fn.

args: Iterable
fn: Callable[..., TypeVar(T)]
kwargs: dict
unwrap()[source]

Returns the unwrapped pytree, assuming no wrapped subnodes exist.

Return type:

TypeVar(T)

class NonTrainable(tree)[source]

Applies stop gradient to all arraylike leaves before unwrapping.

Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable. We also filter out these modules when partitioning parameters for training, or when parameterizing bijections in coupling/masked autoregressive flows (transformers).

tree: TypeVar(T)
unwrap()[source]

Returns the unwrapped pytree, assuming no wrapped subnodes exist.

Return type:

TypeVar(T)

class WeightNormalization(weight)[source]

Applies weight normalization (https://arxiv.org/abs/1602.07868).

Parameters:

weight (Union[Array, AbstractUnwrappable[Array]]) – The (possibly wrapped) weight matrix.

scale: Union[Array, AbstractUnwrappable[Array]]
unwrap()[source]

Returns the unwrapped pytree, assuming no wrapped subnodes exist.

Return type:

Array

weight: Union[Array, AbstractUnwrappable[Array]]
class Where(cond, if_true, if_false)[source]

Applies jnp.where upon unwrapping.

This can be used to construct masks by setting cond=mask and if_false=0.

cond: ArrayLike
if_false: Union[ArrayLike, AbstractUnwrappable[Array]]
if_true: Union[ArrayLike, AbstractUnwrappable[Array]]
unwrap()[source]

Returns the unwrapped pytree, assuming no wrapped subnodes exist.

unwrap(tree)[source]

Unwrap all AbstractUnwrappable nodes within a pytree.