Distributions#
For an introduction to using distributions in FlowJAX, see Getting started.
All distributions inherit from AbstractDistribution
.
Distributions from flowjax.distributions.
- class AbstractDistribution[source]#
Bases:
Module
Abstract distribution class.
Distributions are registered as JAX PyTrees (as they are equinox modules), and as such they are compatible with normal JAX operations.
Concrete subclasses can be implemented as follows:
Inherit from
AbstractDistribution
.Define the abstract attributes
shape
andcond_shape
.cond_shape
should beNone
for unconditional distributions.Define the abstract method
_sample
which returns a single sample with shapedist.shape
, (given a single conditioning variable, if needed).Define the abstract method
_log_prob
, returning a scalar log probability of a single sample, (given a single conditioning variable, if needed).
The abstract class then defines vectorized versions with shape checking for the public API. See the source code for
StandardNormal
for a simple concrete example.- Variables:
shape – Tuple denoting the shape of a single sample from the distribution.
cond_shape – Tuple denoting the shape of an instance of the conditioning variable. This should be None for unconditional distributions.
- log_prob(x, condition=None)[source]#
Evaluate the log probability.
Uses numpy-like broadcasting if additional leading dimensions are passed.
- Parameters:
x (
ArrayLike
) – Points at which to evaluate density.condition (
ArrayLike
|None
) – Conditioning variables. Defaults to None.
- Returns:
Jax array of log probabilities.
- Return type:
Array
- sample(key, sample_shape=(), condition=None)[source]#
Sample from the distribution.
For unconditional distributions, the output will be of shape
sample_shape + dist.shape
. For conditional distributions, batch dimensions in the condition is supported, and the output will have shapesample_shape + condition_batch_shape + dist.shape
.
- sample_and_log_prob(key, sample_shape=(), condition=None)[source]#
Sample the distribution and return the samples with their log probabilities.
For transformed distributions (especially flows), this will generally be more efficient than calling the methods seperately. Refer to the
sample()
documentation for more information.
- class AbstractLocScaleDistribution[source]#
Bases:
AbstractTransformed
Abstract distribution class for affine transformed distributions.
- property loc#
Location of the distribution.
- property scale#
Scale of the distribution.
- class AbstractTransformed[source]#
Bases:
AbstractDistribution
Abstract class respresenting transformed distributions.
We take the forward bijection for use in sampling, and the inverse for use in density evaluation. See also
Transformed
. Concete implementations should subclassAbstractTransformed
, and define the abstract attributesbase_dist
andbijection
. See the source code forNormal
as a simple example.Warning
It is the users responsibility to ensure the bijection is valid across the entire support of the distribution. Failure to do so may result in non-finite values or incorrectly normalized densities.
- Variables:
base_dist – The base distribution.
bijection – The transformation to apply.
- merge_transforms()[source]#
Unnests nested transformed distributions.
Returns an equivilent distribution, but ravelling nested
AbstractTransformed
distributions such that the returned distribution has a base distribution that is not anAbstractTransformed
instance.
- class Beta(alpha, beta)[source]#
Bases:
AbstractDistribution
Beta distribution.
- Parameters:
alpha (
ArrayLike
) – The alpha shape parameter.beta (
ArrayLike
) – The beta shape parameter.
- class Cauchy(loc=0, scale=1)[source]#
Bases:
AbstractLocScaleDistribution
Cauchy distribution.
loc
andscale
should broadcast to the dimension of the distribution.- Parameters:
loc (
ArrayLike
) – Location paramter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- class Exponential(rate=1)[source]#
Bases:
AbstractTransformed
Exponential distribution.
- Parameters:
rate (
ArrayLike
) – The rate parameter (1 / scale).
- class Gamma(concentration, scale)[source]#
Bases:
AbstractTransformed
Gamma distribution.
- Parameters:
concentration (
ArrayLike
) – Positive concentration parameter.scale (
ArrayLike
) – The scale (inverse of rate) parameter.
- class Gumbel(loc=0, scale=1)[source]#
Bases:
AbstractLocScaleDistribution
Gumbel distribution.
loc
andscale
should broadcast to the dimension of the distribution.- Parameters:
loc (
ArrayLike
) – Location paramter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- class Laplace(loc=0, scale=1)[source]#
Bases:
AbstractLocScaleDistribution
Laplace distribution.
loc
andscale
should broadcast to the dimension of the distribution.- Parameters:
loc (
ArrayLike
) – Location paramter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- class LogNormal(loc=0, scale=1)[source]#
Bases:
AbstractTransformed
Log normal distribution.
loc
andscale
here refers to the underlying normal distribution.- Parameters:
loc (
ArrayLike
) – Location paramter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- class Logistic(loc=0, scale=1)[source]#
Bases:
AbstractLocScaleDistribution
Logistic distribution.
loc
andscale
should broadcast to the shape of the distribution.- Parameters:
loc (
ArrayLike
) – Location parameter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- class MultivariateNormal(loc, covariance)[source]#
Bases:
AbstractTransformed
Multivariate normal distribution.
Internally this is parameterised using the Cholesky decomposition of the covariance matrix.
- Parameters:
loc (
Shaped[ArrayLike, '#dim']
) – The location/mean parameter vector. If this is scalar it is broadcast to the dimension implied by the covariance matrix.covariance (
Shaped[Array, 'dim dim']
) – Covariance matrix.
- property covariance#
The covariance matrix.
- property loc#
Location (mean) of the distribution.
- class Normal(loc=0, scale=1)[source]#
Bases:
AbstractLocScaleDistribution
An independent Normal distribution with mean and std for each dimension.
loc
andscale
should broadcast to the desired shape of the distribution.- Parameters:
loc (
ArrayLike
) – Means. Defaults to 0. Defaults to 0.scale (
ArrayLike
) – Standard deviations. Defaults to 1.
- class StandardNormal(shape=())[source]#
Bases:
AbstractDistribution
Standard normal distribution.
Note unlike
Normal
, this has no trainable parameters.
- class StudentT(df, loc=0, scale=1)[source]#
Bases:
AbstractLocScaleDistribution
Student T distribution.
df
,loc
andscale
broadcast to the dimension of the distribution.- Parameters:
df (
ArrayLike
) – The degrees of freedom.loc (
ArrayLike
) – Location parameter. Defaults to 0.scale (
ArrayLike
) – Scale parameter. Defaults to 1.
- property df#
The degrees of freedom of the distribution.
- class Transformed(base_dist, bijection)[source]#
Bases:
AbstractTransformed
Form a distribution like object using a base distribution and a bijection.
We take the forward bijection for use in sampling, and the inverse bijection for use in density evaluation.
Warning
It is the users responsibility to ensure the bijection is valid across the entire support of the distribution. Failure to do so may result in non-finite values or incorrectly normalized densities.
- Parameters:
base_dist (
AbstractDistribution
) – Base distribution.bijection (
AbstractBijection
) – Bijection to transform distribution.
Example
>>> from flowjax.distributions import StandardNormal, Transformed >>> from flowjax.bijections import Affine >>> normal = StandardNormal() >>> bijection = Affine(1) >>> transformed = Transformed(normal, bijection)
- class Uniform(minval, maxval)[source]#
Bases:
AbstractLocScaleDistribution
Uniform distribution.
minval
andmaxval
should broadcast to the desired distribution shape.- Parameters:
minval (
ArrayLike
) – Minimum values.maxval (
ArrayLike
) – Maximum values.
- property maxval#
Maximum value of the uniform distribution.
- property minval#
Minimum value of the uniform distribution.
- class VmapMixture(dist, weights)[source]#
Bases:
AbstractDistribution
Create a mixture distribution.
Given a distribution in which the arrays have a leading dimension with size matching the number of components, and a set of weights, create a mixture distribution.
Example
>>> # Creating a 3 component, 2D gaussian mixture >>> from flowjax.distributions import Normal, VmapMixture >>> import equinox as eqx >>> import jax.numpy as jnp >>> normals = eqx.filter_vmap(Normal)(jnp.zeros((3, 2))) >>> mixture = VmapMixture(normals, weights=jnp.ones(3)) >>> mixture.shape (2,)
- Parameters:
dist (
AbstractDistribution
) – Distribution with a leading dimension in arrays with size equal to the number of mixture components. Often it is convenient to construct this with with a pattern likeeqx.filter_vmap(MyDistribution)(my_params)
.weights (
ArrayLike
) – The positive, but possibly unnormalized component weights.