Loss functions

Loss functions from flowjax.train.losses.

Common loss functions for training normalizing flows.

The loss functions are callables, with the first two arguments being the partitioned distribution (see equinox.partition).

class ContrastiveLoss(prior, n_contrastive)[source]

Loss function for use in a sequential neural posterior estimation algorithm.

Learns a posterior p(x|condition). Contrastive samples for each x are generated from other x samples in the batch.

Note, that in a simulation based inference context, often \(x\) is used to denote simulations, and \(\theta\) for simulation parameters. However, for consistency with the rest of the package, we use x to represent the target variable (the simulator parameters), and condition for the conditioning variable (the simulator output/oberved data).

Parameters:
  • prior (AbstractDistribution) – The prior distribution over x (the target variable).

  • n_contrastive (int) – The number of contrastive samples/atoms to use when computing the loss.

References

class ElboLoss(target, num_samples, *, stick_the_landing=False)[source]

The negative evidence lower bound (ELBO), approximated using samples.

Parameters:
  • num_samples (int) – Number of samples to use in the ELBO approximation.

  • target (Callable[[ArrayLike], Array]) – The target, i.e. log posterior density up to an additive constant / the negative of the potential function, evaluated for a single point.

  • stick_the_landing (bool) – Whether to use the (often) lower variance ELBO gradient estimator introduced in https://arxiv.org/pdf/1703.09194.pdf. Note for flows this requires evaluating the flow in both directions (running the forward and inverse transformation). For some flow architectures, this may be computationally expensive due to assymetrical computational complexity between the forward and inverse transformation. Defaults to False.

num_samples: int
stick_the_landing: bool
target: Callable[[ArrayLike], Array]
class MaximumLikelihoodLoss[source]

Loss for fitting a flow with maximum likelihood (negative log likelihood).

This loss can be used to learn either conditional or unconditional distributions.