Loss functions
Loss functions from flowjax.train.losses
.
Common loss functions for training normalizing flows.
The loss functions are callables, with the first two arguments being the partitioned
distribution (see equinox.partition
).
- class ContrastiveLoss(prior, n_contrastive)[source]
Loss function for use in a sequential neural posterior estimation algorithm.
Learns a posterior
p(x|condition)
. Contrastive samples for eachx
are generated from other x samples in the batch.Note, that in a simulation based inference context, often \(x\) is used to denote simulations, and \(\theta\) for simulation parameters. However, for consistency with the rest of the package, we use
x
to represent the target variable (the simulator parameters), andcondition
for the conditioning variable (the simulator output/oberved data).- Parameters:
prior (
AbstractDistribution
) – The prior distribution over x (the target variable).n_contrastive (
int
) – The number of contrastive samples/atoms to use when computing the loss.
References
- class ElboLoss(target, num_samples, *, stick_the_landing=False)[source]
The negative evidence lower bound (ELBO), approximated using samples.
- Parameters:
num_samples (
int
) – Number of samples to use in the ELBO approximation.target (
Callable
[[ArrayLike
],Array
]) – The target, i.e. log posterior density up to an additive constant / the negative of the potential function, evaluated for a single point.stick_the_landing (
bool
) – Whether to use the (often) lower variance ELBO gradient estimator introduced in https://arxiv.org/pdf/1703.09194.pdf. Note for flows this requires evaluating the flow in both directions (running the forward and inverse transformation). For some flow architectures, this may be computationally expensive due to assymetrical computational complexity between the forward and inverse transformation. Defaults to False.