Training

FlowJax includes basic training scripts for convenience, although users may need to modify these for specific use cases. If we wish to fit the flow to samples from a distribution (and corresponding conditioning variables if appropriate), we can use fit_to_data.

fit_to_data(key, dist, x, *, condition=None, loss_fn=None, max_epochs=100, max_patience=5, batch_size=100, val_prop=0.1, learning_rate=0.0005, optimizer=None, return_best=True, show_progress=True)[source]

Train a distribution (e.g. a flow) to samples from the target distribution.

The distribution can be unconditional \(p(x)\) or conditional \(p(x|\text{condition})\). Note that the last batch in each epoch is dropped if truncated (to avoid recompilation). This function can also be used to fit non-distribution pytrees as long as a compatible loss function is provided.

Parameters:
  • key (PRNGKeyArray) – Jax random seed.

  • dist (PyTree) – The distribution to train.

  • x (ArrayLike) – Samples from target distribution.

  • condition (ArrayLike | None) – Conditioning variables. Defaults to None.

  • loss_fn (Callable | None) – Loss function. Defaults to MaximumLikelihoodLoss.

  • max_epochs (int) – Maximum number of epochs. Defaults to 100.

  • max_patience (int) – Number of consecutive epochs with no validation loss improvement after which training is terminated. Defaults to 5.

  • batch_size (int) – Batch size. Defaults to 100.

  • val_prop (float) – Proportion of data to use in validation set. Defaults to 0.1.

  • learning_rate (float) – Adam learning rate. Defaults to 5e-4.

  • optimizer (GradientTransformation | None) – Optax optimizer. If provided, this overrides the default Adam optimizer, and the learning_rate is ignored. Defaults to None.

  • return_best (bool) – Whether the result should use the parameters where the minimum loss was reached (when True), or the parameters after the last update (when False). Defaults to True.

  • show_progress (bool) – Whether to show progress bar. Defaults to True.

Returns:

A tuple containing the trained distribution and the losses.

Alternatively, we can use fit_to_variational_target to fit the flow to a function using variational inference.

fit_to_variational_target(key, dist, loss_fn, *, steps=100, learning_rate=0.0005, optimizer=None, return_best=True, show_progress=True)[source]

Train a distribution (e.g. a flow) by variational inference.

Parameters:
  • key (PRNGKeyArray) – Jax PRNGKey.

  • dist (PyTree) – Distribution object, trainable parameters are found using equinox.is_inexact_array.

  • loss_fn (Callable) – The loss function to optimize (e.g. the ElboLoss).

  • steps (int) – The number of training steps to run. Defaults to 100.

  • learning_rate (float) – Learning rate. Defaults to 5e-4.

  • optimizer (GradientTransformation | None) – Optax optimizer. If provided, this overrides the default Adam optimizer, and the learning_rate is ignored. Defaults to None.

  • return_best (bool) – Whether the result should use the parameters where the minimum loss was reached (when True), or the parameters after the last update (when False). Defaults to True.

  • show_progress (bool) – Whether to show progress bar. Defaults to True.

Return type:

tuple[PyTree, list]

Returns:

A tuple containing the trained distribution and the losses.