FlowJAX

Contents

FlowJAX#

FlowJAX: a package for continuous distributions, bijections and normalizing flows using equinox and jax:

  • Includes a wide range of distributions and bijections.

  • Distributions and bijections are PyTrees, registered through equinox modules, making them compatible with JAX transformations.

  • Includes many state of the art normalizing flow models.

  • First class support for conditional distributions, important for many applications such as amortized variational inference, and simulation-based inference.

Installation#

pip install flowjax

Getting started

Miscellaneous