Distributions#

For an introduction to using distributions in FlowJAX, see Getting started. All distributions inherit from AbstractDistribution.

Distributions from flowjax.distributions.

class AbstractDistribution[source]#

Bases: Module

Abstract distribution class.

Distributions are registered as JAX PyTrees (as they are equinox modules), and as such they are compatible with normal JAX operations.

Concrete subclasses can be implemented as follows:

  • Inherit from AbstractDistribution.

  • Define the abstract attributes shape and cond_shape. cond_shape should be None for unconditional distributions.

  • Define the abstract method _sample which returns a single sample with shape dist.shape, (given a single conditioning variable, if needed).

  • Define the abstract method _log_prob, returning a scalar log probability of a single sample, (given a single conditioning variable, if needed).

The abstract class then defines vectorized versions with shape checking for the public API. See the source code for StandardNormal for a simple concrete example.

Variables:
  • shape – Tuple denoting the shape of a single sample from the distribution.

  • cond_shape – Tuple denoting the shape of an instance of the conditioning variable. This should be None for unconditional distributions.

log_prob(x, condition=None)[source]#

Evaluate the log probability.

Uses numpy-like broadcasting if additional leading dimensions are passed.

Parameters:
  • x (ArrayLike) – Points at which to evaluate density.

  • condition (ArrayLike | None) – Conditioning variables. Defaults to None.

Returns:

Jax array of log probabilities.

Return type:

Array

sample(key, sample_shape=(), condition=None)[source]#

Sample from the distribution.

For unconditional distributions, the output will be of shape sample_shape + dist.shape. For conditional distributions, batch dimensions in the condition is supported, and the output will have shape sample_shape + condition_batch_shape + dist.shape.

Parameters:
  • key (PRNGKeyArray) – Jax random key.

  • condition (ArrayLike | None) – Conditioning variables. Defaults to None.

  • sample_shape (tuple[int, ...]) – Sample shape. Defaults to ().

Return type:

Array

sample_and_log_prob(key, sample_shape=(), condition=None)[source]#

Sample the distribution and return the samples with their log probabilities.

For transformed distributions (especially flows), this will generally be more efficient than calling the methods seperately. Refer to the sample() documentation for more information.

Parameters:
  • key (PRNGKeyArray) – Jax random key.

  • condition (ArrayLike | None) – Conditioning variables. Defaults to None.

  • sample_shape (tuple[int, ...]) – Sample shape. Defaults to ().

Return type:

tuple[Array, Array]

property cond_ndim: None | int#

Number of dimensions of the conditioning variable (length of cond_shape).

property ndim: int#

Number of dimensions in the distribution (the length of the shape).

class AbstractLocScaleDistribution[source]#

Bases: AbstractTransformed

Abstract distribution class for affine transformed distributions.

property loc#

Location of the distribution.

property scale#

Scale of the distribution.

class AbstractTransformed[source]#

Bases: AbstractDistribution

Abstract class respresenting transformed distributions.

We take the forward bijection for use in sampling, and the inverse for use in density evaluation. See also Transformed. Concete implementations should subclass AbstractTransformed, and define the abstract attributes base_dist and bijection. See the source code for Normal as a simple example.

Warning

It is the users responsibility to ensure the bijection is valid across the entire support of the distribution. Failure to do so may result in non-finite values or incorrectly normalized densities.

Variables:
  • base_dist – The base distribution.

  • bijection – The transformation to apply.

merge_transforms()[source]#

Unnests nested transformed distributions.

Returns an equivilent distribution, but ravelling nested AbstractTransformed distributions such that the returned distribution has a base distribution that is not an AbstractTransformed instance.

class Beta(alpha, beta)[source]#

Bases: AbstractDistribution

Beta distribution.

Parameters:
  • alpha (ArrayLike) – The alpha shape parameter.

  • beta (ArrayLike) – The beta shape parameter.

class Cauchy(loc=0, scale=1)[source]#

Bases: AbstractLocScaleDistribution

Cauchy distribution.

loc and scale should broadcast to the dimension of the distribution.

Parameters:
  • loc (ArrayLike) – Location paramter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

class Exponential(rate=1)[source]#

Bases: AbstractTransformed

Exponential distribution.

Parameters:

rate (ArrayLike) – The rate parameter (1 / scale).

class Gamma(concentration, scale)[source]#

Bases: AbstractTransformed

Gamma distribution.

Parameters:
  • concentration (ArrayLike) – Positive concentration parameter.

  • scale (ArrayLike) – The scale (inverse of rate) parameter.

class Gumbel(loc=0, scale=1)[source]#

Bases: AbstractLocScaleDistribution

Gumbel distribution.

loc and scale should broadcast to the dimension of the distribution.

Parameters:
  • loc (ArrayLike) – Location paramter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

class Laplace(loc=0, scale=1)[source]#

Bases: AbstractLocScaleDistribution

Laplace distribution.

loc and scale should broadcast to the dimension of the distribution.

Parameters:
  • loc (ArrayLike) – Location paramter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

class LogNormal(loc=0, scale=1)[source]#

Bases: AbstractTransformed

Log normal distribution.

loc and scale here refers to the underlying normal distribution.

Parameters:
  • loc (ArrayLike) – Location paramter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

class Logistic(loc=0, scale=1)[source]#

Bases: AbstractLocScaleDistribution

Logistic distribution.

loc and scale should broadcast to the shape of the distribution.

Parameters:
  • loc (ArrayLike) – Location parameter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

class MultivariateNormal(loc, covariance)[source]#

Bases: AbstractTransformed

Multivariate normal distribution.

Internally this is parameterised using the Cholesky decomposition of the covariance matrix.

Parameters:
  • loc (Shaped[ArrayLike, '#dim']) – The location/mean parameter vector. If this is scalar it is broadcast to the dimension implied by the covariance matrix.

  • covariance (Shaped[Array, 'dim dim']) – Covariance matrix.

property covariance#

The covariance matrix.

property loc#

Location (mean) of the distribution.

class Normal(loc=0, scale=1)[source]#

Bases: AbstractLocScaleDistribution

An independent Normal distribution with mean and std for each dimension.

loc and scale should broadcast to the desired shape of the distribution.

Parameters:
  • loc (ArrayLike) – Means. Defaults to 0. Defaults to 0.

  • scale (ArrayLike) – Standard deviations. Defaults to 1.

class StandardNormal(shape=())[source]#

Bases: AbstractDistribution

Standard normal distribution.

Note unlike Normal, this has no trainable parameters.

Parameters:

shape (tuple[int, ...]) – The shape of the distribution. Defaults to ().

class StudentT(df, loc=0, scale=1)[source]#

Bases: AbstractLocScaleDistribution

Student T distribution.

df, loc and scale broadcast to the dimension of the distribution.

Parameters:
  • df (ArrayLike) – The degrees of freedom.

  • loc (ArrayLike) – Location parameter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

property df#

The degrees of freedom of the distribution.

class Transformed(base_dist, bijection)[source]#

Bases: AbstractTransformed

Form a distribution like object using a base distribution and a bijection.

We take the forward bijection for use in sampling, and the inverse bijection for use in density evaluation.

Warning

It is the users responsibility to ensure the bijection is valid across the entire support of the distribution. Failure to do so may result in non-finite values or incorrectly normalized densities.

Parameters:

Example

>>> from flowjax.distributions import StandardNormal, Transformed
>>> from flowjax.bijections import Affine
>>> normal = StandardNormal()
>>> bijection = Affine(1)
>>> transformed = Transformed(normal, bijection)
class Uniform(minval, maxval)[source]#

Bases: AbstractLocScaleDistribution

Uniform distribution.

minval and maxval should broadcast to the desired distribution shape.

Parameters:
  • minval (ArrayLike) – Minimum values.

  • maxval (ArrayLike) – Maximum values.

property maxval#

Maximum value of the uniform distribution.

property minval#

Minimum value of the uniform distribution.

class VmapMixture(dist, weights)[source]#

Bases: AbstractDistribution

Create a mixture distribution.

Given a distribution in which the arrays have a leading dimension with size matching the number of components, and a set of weights, create a mixture distribution.

Example

>>> # Creating a 3 component, 2D gaussian mixture
>>> from flowjax.distributions import Normal, VmapMixture
>>> import equinox as eqx
>>> import jax.numpy as jnp
>>> normals = eqx.filter_vmap(Normal)(jnp.zeros((3, 2)))
>>> mixture = VmapMixture(normals, weights=jnp.ones(3))
>>> mixture.shape
(2,)
Parameters:
  • dist (AbstractDistribution) – Distribution with a leading dimension in arrays with size equal to the number of mixture components. Often it is convenient to construct this with with a pattern like eqx.filter_vmap(MyDistribution)(my_params).

  • weights (ArrayLike) – The positive, but possibly unnormalized component weights.