Distributions#
For an introduction to using distributions in FlowJAX, see Getting started.
All distributions inherit from AbstractDistribution.
Distributions from flowjax.distributions.
- class AbstractDistribution[source]#
Bases:
ModuleAbstract distribution class.
Distributions are registered as JAX PyTrees (as they are equinox modules), and as such they are compatible with normal JAX operations.
Concrete subclasses can be implemented as follows:
Inherit from
AbstractDistribution.Define the abstract attributes
shapeandcond_shape.cond_shapeshould beNonefor unconditional distributions.Define the abstract method
_samplewhich returns a single sample with shapedist.shape, (given a single conditioning variable, if needed).Define the abstract method
_log_prob, returning a scalar log probability of a single sample, (given a single conditioning variable, if needed).
The abstract class then defines vectorized versions with shape checking for the public API. See the source code for
StandardNormalfor a simple concrete example.- Variables:
shape – Tuple denoting the shape of a single sample from the distribution.
cond_shape – Tuple denoting the shape of an instance of the conditioning variable. This should be None for unconditional distributions.
- log_prob(x, condition=None)[source]#
Evaluate the log probability.
Uses numpy-like broadcasting if additional leading dimensions are passed.
- Parameters:
x (
ArrayLike) – Points at which to evaluate density.condition (
ArrayLike|None) – Conditioning variables. Defaults to None.
- Returns:
Jax array of log probabilities.
- Return type:
Array
- sample(key, sample_shape=(), condition=None)[source]#
Sample from the distribution.
For unconditional distributions, the output will be of shape
sample_shape + dist.shape. For conditional distributions, batch dimensions in the condition is supported, and the output will have shapesample_shape + condition_batch_shape + dist.shape.
- sample_and_log_prob(key, sample_shape=(), condition=None)[source]#
Sample the distribution and return the samples with their log probabilities.
For transformed distributions (especially flows), this will generally be more efficient than calling the methods seperately. Refer to the
sample()documentation for more information.
- class AbstractLocScaleDistribution[source]#
Bases:
AbstractTransformedAbstract distribution class for affine transformed distributions.
- property loc#
Location of the distribution.
- property scale#
Scale of the distribution.
- class AbstractTransformed[source]#
Bases:
AbstractDistributionAbstract class respresenting transformed distributions.
We take the forward bijection for use in sampling, and the inverse for use in density evaluation. See also
Transformed. Concete implementations should subclassAbstractTransformed, and define the abstract attributesbase_distandbijection. See the source code forNormalas a simple example.Warning
It is the users responsibility to ensure the bijection is valid across the entire support of the distribution. Failure to do so may result in non-finite values or incorrectly normalized densities.
- Variables:
base_dist – The base distribution.
bijection – The transformation to apply.
- merge_transforms()[source]#
Unnests nested transformed distributions.
Returns an equivilent distribution, but ravelling nested
AbstractTransformeddistributions such that the returned distribution has a base distribution that is not anAbstractTransformedinstance.
- class Beta(alpha, beta)[source]#
Bases:
AbstractDistributionBeta distribution.
- Parameters:
alpha (
ArrayLike) – The alpha shape parameter.beta (
ArrayLike) – The beta shape parameter.
- class Cauchy(loc=0, scale=1)[source]#
Bases:
AbstractLocScaleDistributionCauchy distribution.
locandscaleshould broadcast to the dimension of the distribution.- Parameters:
loc (
ArrayLike) – Location paramter. Defaults to 0.scale (
ArrayLike) – Scale parameter. Defaults to 1.
- class Exponential(rate=1)[source]#
Bases:
AbstractTransformedExponential distribution.
- Parameters:
rate (
ArrayLike) – The rate parameter (1 / scale).
- class Gamma(concentration, scale)[source]#
Bases:
AbstractTransformedGamma distribution.
- Parameters:
concentration (
ArrayLike) – Positive concentration parameter.scale (
ArrayLike) – The scale (inverse of rate) parameter.
- class Gumbel(loc=0, scale=1)[source]#
Bases:
AbstractLocScaleDistributionGumbel distribution.
locandscaleshould broadcast to the dimension of the distribution.- Parameters:
loc (
ArrayLike) – Location paramter. Defaults to 0.scale (
ArrayLike) – Scale parameter. Defaults to 1.
- class Laplace(loc=0, scale=1)[source]#
Bases:
AbstractLocScaleDistributionLaplace distribution.
locandscaleshould broadcast to the dimension of the distribution.- Parameters:
loc (
ArrayLike) – Location paramter. Defaults to 0.scale (
ArrayLike) – Scale parameter. Defaults to 1.
- class LogNormal(loc=0, scale=1)[source]#
Bases:
AbstractTransformedLog normal distribution.
locandscalehere refers to the underlying normal distribution.- Parameters:
loc (
ArrayLike) – Location paramter. Defaults to 0.scale (
ArrayLike) – Scale parameter. Defaults to 1.
- class Logistic(loc=0, scale=1)[source]#
Bases:
AbstractLocScaleDistributionLogistic distribution.
locandscaleshould broadcast to the shape of the distribution.- Parameters:
loc (
ArrayLike) – Location parameter. Defaults to 0.scale (
ArrayLike) – Scale parameter. Defaults to 1.
- class MultivariateNormal(loc, covariance)[source]#
Bases:
AbstractTransformedMultivariate normal distribution.
Internally this is parameterised using the Cholesky decomposition of the covariance matrix.
- Parameters:
loc (
Shaped[ArrayLike, '#dim']) – The location/mean parameter vector. If this is scalar it is broadcast to the dimension implied by the covariance matrix.covariance (
Shaped[Array, 'dim dim']) – Covariance matrix.
- property covariance#
The covariance matrix.
- property loc#
Location (mean) of the distribution.
- class Normal(loc=0, scale=1)[source]#
Bases:
AbstractLocScaleDistributionAn independent Normal distribution with mean and std for each dimension.
locandscaleshould broadcast to the desired shape of the distribution.- Parameters:
loc (
ArrayLike) – Means. Defaults to 0. Defaults to 0.scale (
ArrayLike) – Standard deviations. Defaults to 1.
- class StandardNormal(shape=())[source]#
Bases:
AbstractDistributionStandard normal distribution.
Note unlike
Normal, this has no trainable parameters.
- class StudentT(df, loc=0, scale=1)[source]#
Bases:
AbstractLocScaleDistributionStudent T distribution.
df,locandscalebroadcast to the dimension of the distribution.- Parameters:
df (
ArrayLike) – The degrees of freedom.loc (
ArrayLike) – Location parameter. Defaults to 0.scale (
ArrayLike) – Scale parameter. Defaults to 1.
- property df#
The degrees of freedom of the distribution.
- class Transformed(base_dist, bijection)[source]#
Bases:
AbstractTransformedForm a distribution like object using a base distribution and a bijection.
We take the forward bijection for use in sampling, and the inverse bijection for use in density evaluation.
Warning
It is the users responsibility to ensure the bijection is valid across the entire support of the distribution. Failure to do so may result in non-finite values or incorrectly normalized densities.
- Parameters:
base_dist (
AbstractDistribution) – Base distribution.bijection (
AbstractBijection) – Bijection to transform distribution.
Example
>>> from flowjax.distributions import StandardNormal, Transformed >>> from flowjax.bijections import Affine >>> normal = StandardNormal() >>> bijection = Affine(1) >>> transformed = Transformed(normal, bijection)
- class Uniform(minval, maxval)[source]#
Bases:
AbstractLocScaleDistributionUniform distribution.
minvalandmaxvalshould broadcast to the desired distribution shape.- Parameters:
minval (
ArrayLike) – Minimum values.maxval (
ArrayLike) – Maximum values.
- property maxval#
Maximum value of the uniform distribution.
- property minval#
Minimum value of the uniform distribution.
- class VmapMixture(dist, weights)[source]#
Bases:
AbstractDistributionCreate a mixture distribution.
Given a distribution in which the arrays have a leading dimension with size matching the number of components, and a set of weights, create a mixture distribution.
Example
>>> # Creating a 3 component, 2D gaussian mixture >>> from flowjax.distributions import Normal, VmapMixture >>> import equinox as eqx >>> import jax.numpy as jnp >>> normals = eqx.filter_vmap(Normal)(jnp.zeros((3, 2))) >>> mixture = VmapMixture(normals, weights=jnp.ones(3)) >>> mixture.shape (2,)
- Parameters:
dist (
AbstractDistribution) – Distribution with a leading dimension in arrays with size equal to the number of mixture components. Often it is convenient to construct this with with a pattern likeeqx.filter_vmap(MyDistribution)(my_params).weights (
ArrayLike) – The positive, but possibly unnormalized component weights.