Training
FlowJax includes basic training scripts for convenience, although users may need to modify these
for specific use cases. If we wish to fit the flow to samples from a distribution (and
corresponding conditioning variables if appropriate), we can use fit_to_data
.
- fit_to_data(key, dist, x, *, condition=None, loss_fn=None, max_epochs=100, max_patience=5, batch_size=100, val_prop=0.1, learning_rate=0.0005, optimizer=None, return_best=True, show_progress=True)[source]
Train a distribution (e.g. a flow) to samples from the target distribution.
The distribution can be unconditional \(p(x)\) or conditional \(p(x|\text{condition})\). Note that the last batch in each epoch is dropped if truncated (to avoid recompilation). This function can also be used to fit non-distribution pytrees as long as a compatible loss function is provided.
- Parameters:
key (
PRNGKeyArray
) – Jax random seed.dist (
PyTree
) – The distribution to train.x (
ArrayLike
) – Samples from target distribution.condition (
ArrayLike
|None
) – Conditioning variables. Defaults to None.loss_fn (
Callable
|None
) – Loss function. Defaults to MaximumLikelihoodLoss.max_epochs (
int
) – Maximum number of epochs. Defaults to 100.max_patience (
int
) – Number of consecutive epochs with no validation loss improvement after which training is terminated. Defaults to 5.batch_size (
int
) – Batch size. Defaults to 100.val_prop (
float
) – Proportion of data to use in validation set. Defaults to 0.1.learning_rate (
float
) – Adam learning rate. Defaults to 5e-4.optimizer (
GradientTransformation
|None
) – Optax optimizer. If provided, this overrides the default Adam optimizer, and the learning_rate is ignored. Defaults to None.return_best (
bool
) – Whether the result should use the parameters where the minimum loss was reached (when True), or the parameters after the last update (when False). Defaults to True.show_progress (
bool
) – Whether to show progress bar. Defaults to True.
- Returns:
A tuple containing the trained distribution and the losses.
Alternatively, we can use fit_to_variational_target
to fit the flow to a function
using variational inference.
- fit_to_variational_target(key, dist, loss_fn, *, steps=100, learning_rate=0.0005, optimizer=None, return_best=True, show_progress=True)[source]
Train a distribution (e.g. a flow) by variational inference.
- Parameters:
key (
PRNGKeyArray
) – Jax PRNGKey.dist (
PyTree
) – Distribution object, trainable parameters are found using equinox.is_inexact_array.loss_fn (
Callable
) – The loss function to optimize (e.g. the ElboLoss).steps (
int
) – The number of training steps to run. Defaults to 100.learning_rate (
float
) – Learning rate. Defaults to 5e-4.optimizer (
GradientTransformation
|None
) – Optax optimizer. If provided, this overrides the default Adam optimizer, and the learning_rate is ignored. Defaults to None.return_best (
bool
) – Whether the result should use the parameters where the minimum loss was reached (when True), or the parameters after the last update (when False). Defaults to True.show_progress (
bool
) – Whether to show progress bar. Defaults to True.
- Return type:
- Returns:
A tuple containing the trained distribution and the losses.