Wrappers#

AbstractUnwrappable objects and utilities.

These are placeholder values for specifying custom behaviour for nodes in a pytree, applied using unwrap().

class AbstractUnwrappable[source]#

An abstract class representing an unwrappable object.

Unwrappables replace PyTree nodes, applying custom behavior upon unwrapping.

abstract unwrap()[source]#

Returns the unwrapped pytree, assuming no wrapped subnodes exist.

Return type:

TypeVar(T)

unwrap(tree)[source]#

Map across a PyTree and unwrap all AbstractUnwrappable nodes.

This leaves all other nodes unchanged. If nested, the innermost AbstractUnwrappable nodes are unwrapped first.

Example

Enforcing positivity.

>>> import paramax
>>> import jax.numpy as jnp
>>> params = paramax.Parameterize(jnp.exp, jnp.zeros(3))
>>> paramax.unwrap(("abc", 1, params))
('abc', 1, Array([1., 1., 1.], dtype=float32))
class Parameterize(fn, *args, **kwargs)[source]#

Unwrap an object by calling fn with args and kwargs.

All of fn, args and kwargs may contain trainable parameters.

Note

Unwrapping typically occurs after model initialization. Therefore, if the Parameterize object may be created in a vectorized context, we recommend ensuring that fn still unwraps correctly, e.g. by supporting broadcasting.

Example

>>> from paramax.wrappers import Parameterize, unwrap
>>> import jax.numpy as jnp
>>> positive = Parameterize(jnp.exp, jnp.zeros(3))
>>> unwrap(positive)  # Aplies exp on unwrapping
Array([1., 1., 1.], dtype=float32)
Parameters:
  • fn (Callable) – Callable to call with args, and kwargs.

  • *args – Positional arguments to pass to fn.

  • **kwargs – Keyword arguments to pass to fn.

fn: Callable[..., TypeVar(T)]#
args: tuple[Any, ...]#
kwargs: dict[str, Any]#
unwrap()[source]#
Return type:

TypeVar(T)

non_trainable(tree)[source]#

Freezes parameters by wrapping inexact array leaves with NonTrainable.

Note

Regularization is likely to apply before unwrapping. To avoid regularization impacting non-trainable parameters, they should be filtered out, for example using:

>>> eqx.partition(
...     ...,
...     is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable),
... )

Wrapping the arrays in a model rather than the entire tree is often preferable, allowing easier access to attributes compared to wrapping the entire tree.

Parameters:

tree (PyTree) – The pytree.

class NonTrainable(tree)[source]#

Applies stop gradient to all arraylike leaves before unwrapping.

See also non_trainable(), which is probably a generally prefereable way to achieve similar behaviour, which wraps the arraylike leaves directly, rather than the tree. Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable. Note that the underlying parameters may still be impacted by regularization, so it is generally advised to use this as a suggestively named class for filtering parameters.

tree: TypeVar(T)#
unwrap()[source]#
Return type:

TypeVar(T)

class WeightNormalization(weight)[source]#

Applies weight normalization (https://arxiv.org/abs/1602.07868).

Parameters:

weight (Union[Array, AbstractUnwrappable[Array]]) – The (possibly wrapped) weight matrix.

weight: Union[Array, AbstractUnwrappable[Array]]#
scale: Union[Array, AbstractUnwrappable[Array]]#
unwrap()[source]#
Return type:

Array

contains_unwrappables(pytree)[source]#

Check if a pytree contains unwrappables.