Wrappers#
AbstractUnwrappable
objects and utilities.
These are placeholder values for specifying custom behaviour for nodes in a pytree,
applied using unwrap()
.
- class AbstractUnwrappable[source]#
An abstract class representing an unwrappable object.
Unwrappables replace PyTree nodes, applying custom behavior upon unwrapping.
- unwrap(tree)[source]#
Map across a PyTree and unwrap all
AbstractUnwrappable
nodes.This leaves all other nodes unchanged. If nested, the innermost
AbstractUnwrappable
nodes are unwrapped first.Example
Enforcing positivity.
>>> import paramax >>> import jax.numpy as jnp >>> params = paramax.Parameterize(jnp.exp, jnp.zeros(3)) >>> paramax.unwrap(("abc", 1, params)) ('abc', 1, Array([1., 1., 1.], dtype=float32))
- class Parameterize(fn, *args, **kwargs)[source]#
Unwrap an object by calling fn with args and kwargs.
All of fn, args and kwargs may contain trainable parameters.
Note
Unwrapping typically occurs after model initialization. Therefore, if the
Parameterize
object may be created in a vectorized context, we recommend ensuring thatfn
still unwraps correctly, e.g. by supporting broadcasting.Example
>>> from paramax.wrappers import Parameterize, unwrap >>> import jax.numpy as jnp >>> positive = Parameterize(jnp.exp, jnp.zeros(3)) >>> unwrap(positive) # Aplies exp on unwrapping Array([1., 1., 1.], dtype=float32)
- Parameters:
fn (
Callable
) – Callable to call with args, and kwargs.*args – Positional arguments to pass to fn.
**kwargs – Keyword arguments to pass to fn.
- non_trainable(tree)[source]#
Freezes parameters by wrapping inexact array leaves with
NonTrainable
.Note
Regularization is likely to apply before unwrapping. To avoid regularization impacting non-trainable parameters, they should be filtered out, for example using:
>>> eqx.partition( ... ..., ... is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable), ... )
Wrapping the arrays in a model rather than the entire tree is often preferable, allowing easier access to attributes compared to wrapping the entire tree.
- Parameters:
tree (
PyTree
) – The pytree.
- class NonTrainable(tree)[source]#
Applies stop gradient to all arraylike leaves before unwrapping.
See also
non_trainable()
, which is probably a generally prefereable way to achieve similar behaviour, which wraps the arraylike leaves directly, rather than the tree. Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable. Note that the underlying parameters may still be impacted by regularization, so it is generally advised to use this as a suggestively named class for filtering parameters.
- class WeightNormalization(weight)[source]#
Applies weight normalization (https://arxiv.org/abs/1602.07868).
- Parameters:
weight (
Union
[Array
,AbstractUnwrappable
[Array
]]) – The (possibly wrapped) weight matrix.
-
weight:
Union
[Array
,AbstractUnwrappable
[Array
]]#
-
scale:
Union
[Array
,AbstractUnwrappable
[Array
]]#