FAQ
Freezing parameters
Often it is useful to not train particular parameters. The easiest way to achieve this
is to use the flowjax.wrappers.NonTrainable
wrapper class. For example, to
avoid training the base distribution of a transformed distribution:
>>> import equinox as eqx
>>> from flowjax.wrappers import NonTrainable
>>> flow = eqx.tree_at(lambda flow: flow.base_dist, flow, replace_fn=NonTrainable)
If you wish to avoid training e.g. a specific type, it may be easier to use
jax.tree_map
to apply the NonTrainable wrapper as required.
Standardizing variables
In general you should consider the form and scales of the target samples. For example, you could define a bijection to carry out the preprocessing, then to transform the flow with the inverse, to “undo” the preprocessing, e.g.
>>> import jax
>>> from flowjax.bijections import Affine, Invert
>>> from flowjax.distributions import Transformed
>>> preprocess = Affine(-x.mean(axis=0)/x.std(axis=0), 1/x.std(axis=0))
>>> x_processed = jax.vmap(preprocess.transform)(x)
>>> flow, losses = fit_to_data(key, dist=flow, x=x_processed)
>>> flow = Transformed(flow, Invert(preprocess)) # "undo" the preprocessing
When to JIT
The methods of distributions and bijections are not jitted by default. For example, if you wanted to sample several batches after training, then it is usually worth using jit
>>> import equinox as eqx
>>> import jax.random as jr
>>> batch_size = 256
>>> keys = jr.split(jr.PRNGKey(0), 5)
>>> # Often slow - sample not jitted!
>>> results = []
>>> for batch_key in keys:
... x = flow.sample(batch_key, (batch_size,))
... results.append(x)
>>> # Fast - sample jitted!
>>> results = []
>>> for batch_key in keys:
... x = eqx.filter_jit(flow.sample)(batch_key, (batch_size,))
... results.append(x)
Serialization
As the distributions and bijections are equinox modules, we can serialize/deserialize them using the same method outlined in the equinox documentation.
Runtime type checking
If you want to enable runtime type checking we can use jaxtyping and a typechecker such as beartype. Below is an example using jaxtypings import hook
>>> from jaxtyping import install_import_hook
>>> with install_import_hook("flowjax", "beartype.beartype"):
... from flowjax import bijections as bij
>>> bij.Exp(shape=2) # Accidentally provide an integer shape instead of tuple
jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of Exp.
The problem arose whilst typechecking parameter 'shape'.
Actual value: 2
Expected type: tuple[int, ...].
----------------------
Called with parameters: {'self': Exp(...), 'shape': 2}
Parameter annotations: (self: Any, shape: tuple[int, ...]).