Paramax#

A small package for applying parameterizations and constraints to nodes in JAX PyTrees.

Installation#

pip install paramax

How it works#

A simple example of an AbstractUnwrappable is Parameterize. This class takes a callable and any positional or keyword arguments, which are stored and passed to the function when unwrapping.

>>> import paramax
>>> import jax.numpy as jnp
>>> scale = jnp.ones(3)  # Keep this positive
>>> constrained_scale = paramax.Parameterize(jnp.exp, jnp.log(scale))
>>> model = ("abc", 1, constrained_scale)  # Any PyTree
>>> paramax.unwrap(model)  # Unwraps any AbstractUnwrappables
('abc', 1, Array([1., 1., 1.], dtype=float32))

Many simple parameterizations can be handled with this class, for example, we can parameterize a lower triangular matrix using

>>> import paramax
>>> import jax.numpy as jnp
>>> tril = jnp.tril(jnp.ones((3,3)))
>>> tril = paramax.Parameterize(jnp.tril, tril)

See Wrappers for more AbstractUnwrappable objects.

When to unwrap#

  • Unwrap whenever necessary, typically at the top of loss functions, functions or methods requiring the parameterizations to have been applied.

  • Unwrapping prior to a gradient computation used for optimization is usually a mistake!