FAQ

Freezing parameters

Often it is useful to not train particular parameters. The easiest way to achieve this is to use the flowjax.wrappers.NonTrainable wrapper class. For example, to avoid training the base distribution of a transformed distribution:

>>> import equinox as eqx
>>> from flowjax.wrappers import NonTrainable
>>> flow = eqx.tree_at(lambda flow: flow.base_dist, flow, replace_fn=NonTrainable)

If you wish to avoid training e.g. a specific type, it may be easier to use jax.tree_map to apply the NonTrainable wrapper as required.

Standardizing variables

In general you should consider the form and scales of the target samples. For example, you could define a bijection to carry out the preprocessing, then to transform the flow with the inverse, to “undo” the preprocessing, e.g.

>>> import jax
>>> from flowjax.bijections import Affine, Invert
>>> from flowjax.distributions import Transformed
>>> preprocess = Affine(-x.mean(axis=0)/x.std(axis=0), 1/x.std(axis=0))
>>> x_processed = jax.vmap(preprocess.transform)(x)
>>> flow, losses = fit_to_data(key, dist=flow, x=x_processed) 
>>> flow = Transformed(flow, Invert(preprocess))  # "undo" the preprocessing

When to JIT

The methods of distributions and bijections are not jitted by default. For example, if you wanted to sample several batches after training, then it is usually worth using jit

>>> import equinox as eqx
>>> import jax.random as jr

>>> batch_size = 256
>>> keys = jr.split(jr.PRNGKey(0), 5)

>>> # Often slow - sample not jitted!
>>> results = []
>>> for batch_key in keys:
...     x = flow.sample(batch_key, (batch_size,))
...     results.append(x)

>>> # Fast - sample jitted!
>>> results = []
>>> for batch_key in keys:
...     x = eqx.filter_jit(flow.sample)(batch_key, (batch_size,))
...     results.append(x)

Serialization

As the distributions and bijections are equinox modules, we can serialize/deserialize them using the same method outlined in the equinox documentation.

Runtime type checking

If you want to enable runtime type checking we can use jaxtyping and a typechecker such as beartype. Below is an example using jaxtypings import hook

>>> from jaxtyping import install_import_hook

>>> with install_import_hook("flowjax", "beartype.beartype"):
...    from flowjax import bijections as bij

>>> bij.Exp(shape=2)  # Accidentally provide an integer shape instead of tuple
jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of Exp.
The problem arose whilst typechecking parameter 'shape'.
Actual value: 2
Expected type: tuple[int, ...].
----------------------
Called with parameters: {'self': Exp(...), 'shape': 2}
Parameter annotations: (self: Any, shape: tuple[int, ...]).