Training#

FlowJAX includes basic training scripts for convenience, although users may need to modify these for specific use cases. If we wish to fit the flow to samples from a distribution (and corresponding conditioning variables if appropriate), we can use fit_to_data.

fit_to_data(
key,
dist,
x,
*,
condition=None,
loss_fn=None,
learning_rate=0.0005,
optimizer=None,
max_epochs=100,
max_patience=5,
batch_size=100,
val_prop=0.1,
return_best=True,
show_progress=True,
)[source]#

Train a PyTree (e.g. a distribution) to samples from the target.

The model can be unconditional \(p(x)\) or conditional \(p(x|\text{condition})\). Note that the last batch in each epoch is dropped if truncated (to avoid recompilation). This function can also be used to fit non-distribution pytrees as long as a compatible loss function is provided.

Parameters:
  • key (PRNGKeyArray) – Jax random seed.

  • dist (PyTree) – The pytree to train (usually a distribution).

  • x (ArrayLike) – Samples from target distribution.

  • learning_rate (float) – The learning rate for adam optimizer. Ignored if optimizer is provided.

  • optimizer (GradientTransformation | None) – Optax optimizer. Defaults to None.

  • condition (ArrayLike | None) – Conditioning variables. Defaults to None.

  • loss_fn (Callable | None) – Loss function. Defaults to MaximumLikelihoodLoss.

  • max_epochs (int) – Maximum number of epochs. Defaults to 100.

  • max_patience (int) – Number of consecutive epochs with no validation loss improvement after which training is terminated. Defaults to 5.

  • batch_size (int) – Batch size. Defaults to 100.

  • val_prop (float) – Proportion of data to use in validation set. Defaults to 0.1.

  • return_best (bool) – Whether the result should use the parameters where the minimum loss was reached (when True), or the parameters after the last update (when False). Defaults to True.

  • show_progress (bool) – Whether to show progress bar. Defaults to True.

Returns:

A tuple containing the trained distribution and the losses.

Alternatively, we can use fit_to_key_based_loss to fit the flow to a function using variational inference.

fit_to_key_based_loss(
key,
tree,
*,
loss_fn,
steps,
learning_rate=0.0005,
optimizer=None,
show_progress=True,
)[source]#

Train a pytree, using a loss with params, static and key as arguments.

This can be used e.g. to fit a distribution using a variational objective, such as the evidence lower bound.

Parameters:
  • key (PRNGKeyArray) – Jax random key.

  • tree (PyTree) – PyTree, from which trainable parameters are found using equinox.is_inexact_array.

  • loss_fn (Callable[[PyTree, PyTree, PRNGKeyArray], Shaped[Array, '']]) – The loss function to optimize.

  • steps (int) – The number of optimization steps.

  • learning_rate (float) – The adam learning rate. Ignored if optimizer is provided.

  • optimizer (GradientTransformation | None) – Optax optimizer. Defaults to None.

  • show_progress (bool) – Whether to show progress bar. Defaults to True.

Returns:

A tuple containing the trained pytree and the losses.

Finally, for more control over the training script, you may still find the step function useful.

step(params, *args, optimizer, opt_state, loss_fn, **kwargs)[source]#

Carry out a training step.

Parameters:
  • params (PyTree) – Parameters for the model

  • *args – Arguments passed to the loss function (often the static components of the model).

  • optimizer (GradientTransformation) – Optax optimizer.

  • opt_state (PyTree) – Optimizer state.

  • loss_fn (Callable[[PyTree, PyTree], Shaped[Array, '']]) – The loss function. This should take params and static as the first two arguments.

  • **kwargs – Key word arguments passed to the loss function.

Returns:

(params, opt_state, loss_val)

Return type:

tuple