Distributions
Distributions from flowjax.distributions
.
Distributions, including the abstract and concrete classes.
- class AbstractDistribution[source]
Bases:
Module
Abstract distribution class.
Distributions are registered as jax PyTrees (as they are equinox modules), and as such they are compatible with normal jax operations.
- Concrete subclasses can be implemented as follows:
Inherit from
AbstractDistribution
.Define the abstract attributes
shape
andcond_shape
.cond_shape
should beNone
for unconditional distributions.Define the abstract methods _sample and _log_prob.
See the source code for
StandardNormal
for a simple concrete example.- shape
Tuple denoting the shape of a single sample from the distribution.
- cond_shape
Tuple denoting the shape of an instance of the conditioning variable. This should be None for unconditional distributions.
- log_prob(x, condition=None)[source]
Evaluate the log probability.
Uses numpy-like broadcasting if additional leading dimensions are passed.
- Parameters:
x (
ArrayLike
) – Points at which to evaluate density.condition (
ArrayLike
|None
) – Conditioning variables. Defaults to None.
- Returns:
Jax array of log probabilities.
- Return type:
Array
- sample(key, sample_shape=(), condition=None)[source]
Sample from the distribution.
For unconditional distributions, the output will be of shape
sample_shape + dist.shape
. For conditional distributions, a batch dimension in the condition is supported, and the output shape will besample_shape + condition_batch_shape + dist.shape
. See the example for more information.- Parameters:
- Return type:
Array
Example
The below example shows the behaviour of sampling, for an unconditional and a conditional distribution.
For an unconditional distribution:
>>> dist.shape (2,) >>> samples = dist.sample(key, (10, )) >>> samples.shape (10, 2)
For a conditional distribution:
>>> cond_dist.shape (2,) >>> cond_dist.cond_shape (3,) >>> # Sample 10 times for a particular condition >>> samples = cond_dist.sample(key, (10,), condition=jnp.ones(3)) >>> samples.shape (10, 2) >>> # Sampling, batching over a condition >>> samples = cond_dist.sample(key, condition=jnp.ones((5, 3))) >>> samples.shape (5, 2) >>> # Sample 10 times for each of 5 conditioning variables >>> samples = cond_dist.sample(key, (10,), condition=jnp.ones((5, 3))) >>> samples.shape (10, 5, 2)
- sample_and_log_prob(key, sample_shape=(), condition=None)[source]
Sample the distribution and return the samples with their log probabilities.
For transformed distributions (especially flows), this will generally be more efficient than calling the methods seperately. Refer to the
sample()
documentation for more information.
- class AbstractLocScaleDistribution[source]
Bases:
AbstractTransformed
Abstract distribution class for affine transformed distributions.
- property loc
Location of the distribution.
- property scale
Scale of the distribution.
- class AbstractTransformed[source]
Bases:
AbstractDistribution
Abstract class respresenting transformed distributions.
We take the forward bijection for use in sampling, and the inverse for use in density evaluation. See also
Transformed
.Concete implementations should subclass
AbstractTransformed
, and define the abstract attributesbase_dist
andbijection
. See the source code forNormal
as a simple example.- base_dist
The base distribution.
- bijection
The transformation to apply.
- merge_transforms()[source]
Unnests nested transformed distributions.
Returns an equivilent distribution, but ravelling nested
AbstractTransformed
distributions such that the returned distribution has a base distribution that is not anAbstractTransformed
instance.
- class Cauchy(loc=0, scale=1)[source]
Bases:
AbstractLocScaleDistribution
Cauchy distribution (https://en.wikipedia.org/wiki/Cauchy_distribution).
loc
andscale
should broadcast to the dimension of the distribution.- Parameters:
loc (
ArrayLike
) – Location paramter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- class Exponential(rate=1)[source]
Bases:
AbstractTransformed
Exponential distribution.
- Parameters:
rate (
ArrayLike
) – The rate parameter (1 / scale).
- class Gumbel(loc=0, scale=1)[source]
Bases:
AbstractLocScaleDistribution
Gumbel distribution (https://en.wikipedia.org/wiki/Gumbel_distribution).
loc
andscale
should broadcast to the dimension of the distribution.- Parameters:
loc (
ArrayLike
) – Location paramter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- class Laplace(loc=0, scale=1)[source]
Bases:
AbstractLocScaleDistribution
Laplace distribution.
loc
andscale
should broadcast to the dimension of the distribution.- Parameters:
loc (
ArrayLike
) – Location paramter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- class LogNormal(loc=0, scale=1)[source]
Bases:
AbstractTransformed
Log normal distribution.
loc
andscale
here refers to the underlying normal distribution.- Parameters:
loc (
ArrayLike
) – Location paramter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- class Logistic(loc=0, scale=1)[source]
Bases:
AbstractLocScaleDistribution
Logistic distribution.
loc
andscale
should broadcast to the shape of the distribution.- Parameters:
loc (
ArrayLike
) – Location parameter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- class MultivariateNormal(loc, covariance)[source]
Bases:
AbstractTransformed
Multivariate normal distribution.
Internally this is parameterised using the Cholesky decomposition of the covariance matrix.
- Parameters:
loc (
Shaped[ArrayLike, '#dim']
) – The location/mean parameter vector. If this is scalar it is broadcast to the dimension implied by the covariance matrix.covariance (
Shaped[Array, 'dim dim']
) – Covariance matrix.
- property covariance
The covariance matrix.
- property loc
Location (mean) of the distribution.
- class Normal(loc=0, scale=1)[source]
Bases:
AbstractLocScaleDistribution
An independent Normal distribution with mean and std for each dimension.
loc
andscale
should broadcast to the desired shape of the distribution.- Parameters:
loc (
ArrayLike
) – Means. Defaults to 0. Defaults to 0.scale (
ArrayLike
) – Standard deviations. Defaults to 1.
- class StandardNormal(shape=())[source]
Bases:
AbstractDistribution
Standard normal distribution.
Note unlike
Normal
, this has no trainable parameters.
- class StudentT(df, loc=0, scale=1)[source]
Bases:
AbstractLocScaleDistribution
Student T distribution (https://en.wikipedia.org/wiki/Student%27s_t-distribution).
df
,loc
andscale
broadcast to the dimension of the distribution.- Parameters:
df (
ArrayLike
) – The degrees of freedom.loc (
ArrayLike
) – Location parameter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- property df
The degrees of freedom of the distribution.
- class Transformed(base_dist, bijection)[source]
Bases:
AbstractTransformed
Form a distribution like object using a base distribution and a bijection.
We take the forward bijection for use in sampling, and the inverse bijection for use in density evaluation.
Warning
It is the currently the users responsibility to ensure the bijection is valid across the entire support of the distribution. Failure to do so may lead to to unexpected results.
- Parameters:
base_dist (
AbstractDistribution
) – Base distribution.bijection (
AbstractBijection
) – Bijection to transform distribution.
Example
>>> from flowjax.distributions import StandardNormal, Transformed >>> from flowjax.bijections import Affine >>> normal = StandardNormal() >>> bijection = Affine(1) >>> transformed = Transformed(normal, bijection)
- class Uniform(minval, maxval)[source]
Bases:
AbstractLocScaleDistribution
Uniform distribution.
minval
andmaxval
should broadcast to the desired distribution shape.- Parameters:
minval (
ArrayLike
) – Minimum values.maxval (
ArrayLike
) – Maximum values.
- property maxval
Maximum value of the uniform distribution.
- property minval
Minimum value of the uniform distribution.
- class VmapMixture(dist, weights)[source]
Bases:
AbstractDistribution
Create a mixture distribution.
Given a distribution in which the arrays have a leading dimension with size matching the number of components, and a set of weights, create a mixture distribution.
- Parameters:
dist (
AbstractDistribution
) – Distribution with a leading dimension in arrays with size equal to the number of mixture components. Often it is convenient to construct this with with a pattern likeeqx.filter_vmap(MyDistribution)(my_params)
.weights (
ArrayLike
) – The positive, but possibly unnormalized component weights.