FlowJAX: a package for continuous distributions, bijections and normalizing flows using equinox and jax:
Includes a wide range of distributions and bijections.
Distributions and bijections are PyTrees, registered through equinox modules, making them compatible with JAX transformations.
Includes many state of the art normalizing flow models.
First class support for conditional distributions, important for many applications such as amortized variational inference, and simulation-based inference.
pip install flowjax