Training#
FlowJAX includes basic training scripts for convenience, although users may need to modify these
for specific use cases. If we wish to fit the flow to samples from a distribution (and
corresponding conditioning variables if appropriate), we can use fit_to_data
.
- fit_to_data(
- key,
- dist,
- x,
- *,
- condition=None,
- loss_fn=None,
- learning_rate=0.0005,
- optimizer=None,
- max_epochs=100,
- max_patience=5,
- batch_size=100,
- val_prop=0.1,
- return_best=True,
- show_progress=True,
Train a PyTree (e.g. a distribution) to samples from the target.
The model can be unconditional \(p(x)\) or conditional \(p(x|\text{condition})\). Note that the last batch in each epoch is dropped if truncated (to avoid recompilation). This function can also be used to fit non-distribution pytrees as long as a compatible loss function is provided.
- Parameters:
key (
PRNGKeyArray
) – Jax random seed.dist (
PyTree
) – The pytree to train (usually a distribution).x (
ArrayLike
) – Samples from target distribution.learning_rate (
float
) – The learning rate for adam optimizer. Ignored if optimizer is provided.optimizer (
GradientTransformation
|None
) – Optax optimizer. Defaults to None.condition (
ArrayLike
|None
) – Conditioning variables. Defaults to None.loss_fn (
Callable
|None
) – Loss function. Defaults to MaximumLikelihoodLoss.max_epochs (
int
) – Maximum number of epochs. Defaults to 100.max_patience (
int
) – Number of consecutive epochs with no validation loss improvement after which training is terminated. Defaults to 5.batch_size (
int
) – Batch size. Defaults to 100.val_prop (
float
) – Proportion of data to use in validation set. Defaults to 0.1.return_best (
bool
) – Whether the result should use the parameters where the minimum loss was reached (when True), or the parameters after the last update (when False). Defaults to True.show_progress (
bool
) – Whether to show progress bar. Defaults to True.
- Returns:
A tuple containing the trained distribution and the losses.
Alternatively, we can use fit_to_key_based_loss
to fit the flow to a function
using variational inference.
- fit_to_key_based_loss(
- key,
- tree,
- *,
- loss_fn,
- steps,
- learning_rate=0.0005,
- optimizer=None,
- show_progress=True,
Train a pytree, using a loss with params, static and key as arguments.
This can be used e.g. to fit a distribution using a variational objective, such as the evidence lower bound.
- Parameters:
key (
PRNGKeyArray
) – Jax random key.tree (
PyTree
) – PyTree, from which trainable parameters are found usingequinox.is_inexact_array
.loss_fn (
Callable
[[PyTree
,PyTree
,PRNGKeyArray
],Shaped[Array, '']
]) – The loss function to optimize.steps (
int
) – The number of optimization steps.learning_rate (
float
) – The adam learning rate. Ignored if optimizer is provided.optimizer (
GradientTransformation
|None
) – Optax optimizer. Defaults to None.show_progress (
bool
) – Whether to show progress bar. Defaults to True.
- Returns:
A tuple containing the trained pytree and the losses.
Finally, for more control over the training script, you may still find the step
function useful.
- step(params, *args, optimizer, opt_state, loss_fn, **kwargs)[source]#
Carry out a training step.
- Parameters:
params (
PyTree
) – Parameters for the model*args – Arguments passed to the loss function (often the static components of the model).
optimizer (
GradientTransformation
) – Optax optimizer.opt_state (
PyTree
) – Optimizer state.loss_fn (
Callable
[[PyTree
,PyTree
],Shaped[Array, '']
]) – The loss function. This should take params and static as the first two arguments.**kwargs – Key word arguments passed to the loss function.
- Returns:
(params, opt_state, loss_val)
- Return type: