Loss functions#
Loss functions from flowjax.train.losses
.
Common loss functions for training normalizing flows.
In order to be compatible with fit_to_data
, the loss function arguments must match
(params, static, x, condition, key)
, where params
and static
are the
partitioned model (see equinox.partition
).
For fit_to_key_based_loss
, the loss function signature must match
(params, static, key)
.
- class ContrastiveLoss(prior, n_contrastive)[source]#
Loss function for use in a sequential neural posterior estimation algorithm.
Learns a posterior
p(x|condition)
. Contrastive samples for eachx
are generated from other x samples in the batch.Note, that in a simulation based inference context, often \(x\) is used to denote simulations, and \(\theta\) for simulation parameters. However, for consistency with the rest of the package, we use
x
to represent the target variable (the simulator parameters), andcondition
for the conditioning variable (the simulator output/oberved data).- Parameters:
prior (
AbstractDistribution
) – The prior distribution over x (the target variable).n_contrastive (
int
) – The number of contrastive samples/atoms to use when computing the loss. Must be less thanbatch_size
.
References
- class ElboLoss(target, num_samples, *, stick_the_landing=False)[source]#
The negative evidence lower bound (ELBO), approximated using samples.
- Parameters:
num_samples (
int
) – Number of samples to use in the ELBO approximation.target (
Callable
[[ArrayLike
],Array
]) – The target, i.e. log posterior density up to an additive constant / the negative of the potential function, evaluated for a single point.stick_the_landing (
bool
) – Whether to use the (often) lower variance ELBO gradient estimator introduced in https://arxiv.org/pdf/1703.09194.pdf. Note for flows this requires evaluating the flow in both directions (running the forward and inverse transformation). For some flow architectures, this may be computationally expensive due to assymetrical computational complexity between the forward and inverse transformation. Defaults to False.
- __call__(params, static, key)[source]#
Compute the ELBO loss.
- Parameters:
params (
AbstractDistribution
) – The trainable parameters of the model.static (
AbstractDistribution
) – The static components of the model.key (
PRNGKeyArray
) – Jax random key.
- Return type:
Float[Array, '']