Distributions

Distributions from flowjax.distributions.

Distributions, including the abstract and concrete classes.

class AbstractDistribution[source]

Bases: Module

Abstract distribution class.

Distributions are registered as jax PyTrees (as they are equinox modules), and as such they are compatible with normal jax operations.

Concrete subclasses can be implemented as follows:
  1. Inherit from AbstractDistribution.

  2. Define the abstract attributes shape and cond_shape. cond_shape should be None for unconditional distributions.

  3. Define the abstract methods _sample and _log_prob.

See the source code for StandardNormal for a simple concrete example.

shape

Tuple denoting the shape of a single sample from the distribution.

cond_shape

Tuple denoting the shape of an instance of the conditioning variable. This should be None for unconditional distributions.

log_prob(x, condition=None)[source]

Evaluate the log probability.

Uses numpy-like broadcasting if additional leading dimensions are passed.

Parameters:
  • x (ArrayLike) – Points at which to evaluate density.

  • condition (ArrayLike | None) – Conditioning variables. Defaults to None.

Returns:

Jax array of log probabilities.

Return type:

Array

sample(key, sample_shape=(), condition=None)[source]

Sample from the distribution.

For unconditional distributions, the output will be of shape sample_shape + dist.shape. For conditional distributions, a batch dimension in the condition is supported, and the output shape will be sample_shape + condition_batch_shape + dist.shape. See the example for more information.

Parameters:
  • key (PRNGKeyArray) – Jax random key.

  • condition (ArrayLike | None) – Conditioning variables. Defaults to None.

  • sample_shape (tuple[int, ...]) – Sample shape. Defaults to ().

Return type:

Array

Example

The below example shows the behaviour of sampling, for an unconditional and a conditional distribution.

For an unconditional distribution:

>>> dist.shape
(2,)
>>> samples = dist.sample(key, (10, ))
>>> samples.shape
(10, 2)

For a conditional distribution:

>>> cond_dist.shape
(2,)
>>> cond_dist.cond_shape
(3,)
>>> # Sample 10 times for a particular condition
>>> samples = cond_dist.sample(key, (10,), condition=jnp.ones(3))
>>> samples.shape
(10, 2)
>>> # Sampling, batching over a condition
>>> samples = cond_dist.sample(key, condition=jnp.ones((5, 3)))
>>> samples.shape
(5, 2)
>>> # Sample 10 times for each of 5 conditioning variables
>>> samples = cond_dist.sample(key, (10,), condition=jnp.ones((5, 3)))
>>> samples.shape
(10, 5, 2)
sample_and_log_prob(key, sample_shape=(), condition=None)[source]

Sample the distribution and return the samples with their log probabilities.

For transformed distributions (especially flows), this will generally be more efficient than calling the methods seperately. Refer to the sample() documentation for more information.

Parameters:
  • key (PRNGKeyArray) – Jax random key.

  • condition (ArrayLike | None) – Conditioning variables. Defaults to None.

  • sample_shape (tuple[int, ...]) – Sample shape. Defaults to ().

Return type:

tuple[Array, Array]

property cond_ndim: None | int

Number of dimensions of the conditioning variable (length of cond_shape).

property ndim: int

Number of dimensions in the distribution (the length of the shape).

class AbstractLocScaleDistribution[source]

Bases: AbstractTransformed

Abstract distribution class for affine transformed distributions.

property loc

Location of the distribution.

property scale

Scale of the distribution.

class AbstractTransformed[source]

Bases: AbstractDistribution

Abstract class respresenting transformed distributions.

We take the forward bijection for use in sampling, and the inverse for use in density evaluation. See also Transformed.

Concete implementations should subclass AbstractTransformed, and define the abstract attributes base_dist and bijection. See the source code for Normal as a simple example.

base_dist

The base distribution.

bijection

The transformation to apply.

merge_transforms()[source]

Unnests nested transformed distributions.

Returns an equivilent distribution, but ravelling nested AbstractTransformed distributions such that the returned distribution has a base distribution that is not an AbstractTransformed instance.

class Cauchy(loc=0, scale=1)[source]

Bases: AbstractLocScaleDistribution

Cauchy distribution (https://en.wikipedia.org/wiki/Cauchy_distribution).

loc and scale should broadcast to the dimension of the distribution.

Parameters:
  • loc (ArrayLike) – Location paramter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

class Exponential(rate=1)[source]

Bases: AbstractTransformed

Exponential distribution.

Parameters:

rate (ArrayLike) – The rate parameter (1 / scale).

class Gumbel(loc=0, scale=1)[source]

Bases: AbstractLocScaleDistribution

Gumbel distribution (https://en.wikipedia.org/wiki/Gumbel_distribution).

loc and scale should broadcast to the dimension of the distribution.

Parameters:
  • loc (ArrayLike) – Location paramter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

class Laplace(loc=0, scale=1)[source]

Bases: AbstractLocScaleDistribution

Laplace distribution.

loc and scale should broadcast to the dimension of the distribution.

Parameters:
  • loc (ArrayLike) – Location paramter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

class LogNormal(loc=0, scale=1)[source]

Bases: AbstractTransformed

Log normal distribution.

loc and scale here refers to the underlying normal distribution.

Parameters:
  • loc (ArrayLike) – Location paramter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

class Logistic(loc=0, scale=1)[source]

Bases: AbstractLocScaleDistribution

Logistic distribution.

loc and scale should broadcast to the shape of the distribution.

Parameters:
  • loc (ArrayLike) – Location parameter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

class MultivariateNormal(loc, covariance)[source]

Bases: AbstractTransformed

Multivariate normal distribution.

Internally this is parameterised using the Cholesky decomposition of the covariance matrix.

Parameters:
  • loc (Shaped[ArrayLike, '#dim']) – The location/mean parameter vector. If this is scalar it is broadcast to the dimension implied by the covariance matrix.

  • covariance (Shaped[Array, 'dim dim']) – Covariance matrix.

property covariance

The covariance matrix.

property loc

Location (mean) of the distribution.

class Normal(loc=0, scale=1)[source]

Bases: AbstractLocScaleDistribution

An independent Normal distribution with mean and std for each dimension.

loc and scale should broadcast to the desired shape of the distribution.

Parameters:
  • loc (ArrayLike) – Means. Defaults to 0. Defaults to 0.

  • scale (ArrayLike) – Standard deviations. Defaults to 1.

class StandardNormal(shape=())[source]

Bases: AbstractDistribution

Standard normal distribution.

Note unlike Normal, this has no trainable parameters.

Parameters:

shape (tuple[int, ...]) – The shape of the distribution. Defaults to ().

class StudentT(df, loc=0, scale=1)[source]

Bases: AbstractLocScaleDistribution

Student T distribution (https://en.wikipedia.org/wiki/Student%27s_t-distribution).

df, loc and scale broadcast to the dimension of the distribution.

Parameters:
  • df (ArrayLike) – The degrees of freedom.

  • loc (ArrayLike) – Location parameter. Defaults to 0.

  • scale (ArrayLike) – Scale parameter. Defaults to 1.

property df

The degrees of freedom of the distribution.

class Transformed(base_dist, bijection)[source]

Bases: AbstractTransformed

Form a distribution like object using a base distribution and a bijection.

We take the forward bijection for use in sampling, and the inverse bijection for use in density evaluation.

Warning

It is the currently the users responsibility to ensure the bijection is valid across the entire support of the distribution. Failure to do so may lead to to unexpected results.

Parameters:

Example

>>> from flowjax.distributions import StandardNormal, Transformed
>>> from flowjax.bijections import Affine
>>> normal = StandardNormal()
>>> bijection = Affine(1)
>>> transformed = Transformed(normal, bijection)
class Uniform(minval, maxval)[source]

Bases: AbstractLocScaleDistribution

Uniform distribution.

minval and maxval should broadcast to the desired distribution shape.

Parameters:
  • minval (ArrayLike) – Minimum values.

  • maxval (ArrayLike) – Maximum values.

property maxval

Maximum value of the uniform distribution.

property minval

Minimum value of the uniform distribution.

class VmapMixture(dist, weights)[source]

Bases: AbstractDistribution

Create a mixture distribution.

Given a distribution in which the arrays have a leading dimension with size matching the number of components, and a set of weights, create a mixture distribution.

Parameters:
  • dist (AbstractDistribution) – Distribution with a leading dimension in arrays with size equal to the number of mixture components. Often it is convenient to construct this with with a pattern like eqx.filter_vmap(MyDistribution)(my_params).

  • weights (ArrayLike) – The positive, but possibly unnormalized component weights.