"""Distributions from flowjax.distributions."""
import inspect
from abc import abstractmethod
from collections.abc import Callable
from functools import wraps
from math import prod
from typing import ClassVar
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
from equinox import AbstractVar
from jax import dtypes
from jax.nn import log_softmax, softplus
from jax.numpy import linalg
from jax.scipy import stats as jstats
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map
from jaxtyping import Array, ArrayLike, PRNGKeyArray, Shaped
from flowjax.bijections import (
AbstractBijection,
Affine,
Chain,
Exp,
Scale,
TriangularAffine,
)
from flowjax.utils import (
_get_ufunc_signature,
arraylike_to_array,
inv_softplus,
merge_cond_shapes,
)
from flowjax.wrappers import AbstractUnwrappable, Parameterize, unwrap
[docs]
class AbstractDistribution(eqx.Module):
"""Abstract distribution class.
Distributions are registered as JAX PyTrees (as they are equinox modules), and as
such they are compatible with normal JAX operations.
Concrete subclasses can be implemented as follows:
- Inherit from :class:`AbstractDistribution`.
- Define the abstract attributes ``shape`` and ``cond_shape``.
``cond_shape`` should be ``None`` for unconditional distributions.
- Define the abstract method ``_sample`` which returns a single sample
with shape ``dist.shape``, (given a single conditioning variable, if needed).
- Define the abstract method ``_log_prob``, returning a scalar log probability
of a single sample, (given a single conditioning variable, if needed).
The abstract class then defines vectorized versions with shape checking for the
public API. See the source code for :class:`StandardNormal` for a simple concrete
example.
Attributes:
shape: Tuple denoting the shape of a single sample from the distribution.
cond_shape: Tuple denoting the shape of an instance of the conditioning
variable. This should be None for unconditional distributions.
"""
shape: AbstractVar[tuple[int, ...]]
cond_shape: AbstractVar[tuple[int, ...] | None]
@abstractmethod
def _log_prob(self, x: Array, condition: Array | None = None) -> Array:
"""Evaluate the log probability of point x.
This method should be be valid for inputs with shapes matching
``distribution.shape`` and ``distribution.cond_shape`` for conditional
distributions (i.e. it defines the method for unbatched inputs).
"""
@abstractmethod
def _sample(self, key: PRNGKeyArray, condition: Array | None = None) -> Array:
"""Sample a point from the distribution.
This method should return a single sample with shape matching
``distribution.shape``.
"""
def _sample_and_log_prob(self, key: PRNGKeyArray, condition: Array | None = None):
"""Sample a point from the distribution, and return its log probability."""
x = self._sample(key, condition)
return x, self._log_prob(x, condition)
[docs]
def log_prob(self, x: ArrayLike, condition: ArrayLike | None = None) -> Array:
"""Evaluate the log probability.
Uses numpy-like broadcasting if additional leading dimensions are passed.
Args:
x: Points at which to evaluate density.
condition: Conditioning variables. Defaults to None.
Returns:
Array: Jax array of log probabilities.
"""
self = unwrap(self)
x = arraylike_to_array(x, err_name="x", dtype=float)
if self.cond_shape is not None:
condition = arraylike_to_array(condition, err_name="condition", dtype=float)
lps = self._vectorize(self._log_prob)(x, condition)
return jnp.where(jnp.isnan(lps), -jnp.inf, lps)
[docs]
def sample(
self,
key: PRNGKeyArray,
sample_shape: tuple[int, ...] = (),
condition: ArrayLike | None = None,
) -> Array:
"""Sample from the distribution.
For unconditional distributions, the output will be of shape
``sample_shape + dist.shape``. For conditional distributions, batch dimensions
in the condition is supported, and the output will have shape
``sample_shape + condition_batch_shape + dist.shape``.
Args:
key: Jax random key.
condition: Conditioning variables. Defaults to None.
sample_shape: Sample shape. Defaults to ().
"""
self = unwrap(self)
if self.cond_shape is not None:
condition = arraylike_to_array(condition, err_name="condition")
keys = self._get_sample_keys(key, sample_shape, condition)
return self._vectorize(self._sample)(keys, condition)
[docs]
def sample_and_log_prob(
self,
key: PRNGKeyArray,
sample_shape: tuple[int, ...] = (),
condition: ArrayLike | None = None,
) -> tuple[Array, Array]:
"""Sample the distribution and return the samples with their log probabilities.
For transformed distributions (especially flows), this will generally be more
efficient than calling the methods seperately. Refer to the
:py:meth:`~flowjax.distributions.AbstractDistribution.sample` documentation for
more information.
Args:
key: Jax random key.
condition: Conditioning variables. Defaults to None.
sample_shape: Sample shape. Defaults to ().
"""
self = unwrap(self)
if self.cond_shape is not None:
condition = arraylike_to_array(condition, err_name="condition")
keys = self._get_sample_keys(key, sample_shape, condition)
return self._vectorize(self._sample_and_log_prob)(keys, condition)
@property
def ndim(self) -> int:
"""Number of dimensions in the distribution (the length of the shape)."""
return len(self.shape)
@property
def cond_ndim(self) -> None | int:
"""Number of dimensions of the conditioning variable (length of cond_shape)."""
return None if self.cond_shape is None else len(self.cond_shape)
def _vectorize(self, method: Callable) -> Callable:
"""Returns a vectorized version of the distribution method."""
# Get shapes without broadcasting - note the (2, ) corresponds to key arrays.
maybe_cond = [] if self.cond_shape is None else [self.cond_shape]
in_shapes = {
"_sample_and_log_prob": [()] + maybe_cond,
"_sample": [()] + maybe_cond,
"_log_prob": [self.shape] + maybe_cond,
}
out_shapes = {
"_sample_and_log_prob": [self.shape, ()],
"_sample": [self.shape],
"_log_prob": [()],
}
in_shapes, out_shapes = in_shapes[method.__name__], out_shapes[method.__name__]
def _check_shapes(method):
# Wraps unvectorised method with shape checking
@wraps(method)
def _wrapper(*args, **kwargs):
bound = inspect.signature(method).bind(*args, **kwargs)
for in_shape, (name, arg) in zip(
in_shapes,
bound.arguments.items(),
strict=False,
):
if arg.shape != in_shape:
raise ValueError(
f"Expected trailing dimensions matching {in_shape} for "
f"{name}; got {arg.shape}.",
)
return method(*args, **kwargs)
return _wrapper
signature = _get_ufunc_signature(in_shapes, out_shapes)
ex = frozenset([1]) if self.cond_shape is None else frozenset()
return jnp.vectorize(_check_shapes(method), signature=signature, excluded=ex)
def _get_sample_keys(
self,
key: PRNGKeyArray,
sample_shape: tuple[int, ...],
condition,
):
if not dtypes.issubdtype(key.dtype, dtypes.prng_key):
raise TypeError("New-style typed JAX PRNG keys required.")
if self.cond_ndim is not None:
leading_cond_shape = condition.shape[: -self.cond_ndim or None]
else:
leading_cond_shape = ()
key_shape = sample_shape + leading_cond_shape
key_size = prod(key_shape) # note: prod(()) == 1, so works for scalar smaples
return jr.split(key, key_size).reshape(key_shape)
[docs]
class AbstractLocScaleDistribution(AbstractTransformed):
"""Abstract distribution class for affine transformed distributions."""
base_dist: AbstractVar[AbstractDistribution]
bijection: AbstractVar[Affine]
@property
def loc(self):
"""Location of the distribution."""
return self.bijection.loc
@property
def scale(self):
"""Scale of the distribution."""
return unwrap(self.bijection.scale)
[docs]
class StandardNormal(AbstractDistribution):
"""Standard normal distribution.
Note unlike :class:`Normal`, this has no trainable parameters.
Args:
shape: The shape of the distribution. Defaults to ().
"""
shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None
def _log_prob(self, x, condition=None):
return jstats.norm.logpdf(x).sum()
def _sample(self, key, condition=None):
return jr.normal(key, self.shape)
[docs]
class Normal(AbstractLocScaleDistribution):
"""An independent Normal distribution with mean and std for each dimension.
``loc`` and ``scale`` should broadcast to the desired shape of the distribution.
Args:
loc: Means. Defaults to 0. Defaults to 0.
scale: Standard deviations. Defaults to 1.
"""
base_dist: StandardNormal
bijection: Affine
def __init__(self, loc: ArrayLike = 0, scale: ArrayLike = 1):
self.base_dist = StandardNormal(
jnp.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)),
)
self.bijection = Affine(loc=loc, scale=scale)
[docs]
class LogNormal(AbstractTransformed):
"""Log normal distribution.
``loc`` and ``scale`` here refers to the underlying normal distribution.
Args:
loc: Location paramter. Defaults to 0.
scale: Scale parameter. Defaults to 1.
"""
base_dist: Normal
bijection: Exp
def __init__(self, loc: ArrayLike = 0, scale: ArrayLike = 1):
self.base_dist = Normal(loc, scale)
self.bijection = Exp(self.base_dist.shape)
[docs]
class MultivariateNormal(AbstractTransformed):
"""Multivariate normal distribution.
Internally this is parameterised using the Cholesky decomposition of the covariance
matrix.
Args:
loc: The location/mean parameter vector. If this is scalar it is broadcast to
the dimension implied by the covariance matrix.
covariance: Covariance matrix.
"""
base_dist: StandardNormal
bijection: TriangularAffine
def __init__(
self,
loc: Shaped[ArrayLike, "#dim"],
covariance: Shaped[Array, "dim dim"],
):
self.bijection = TriangularAffine(loc, linalg.cholesky(covariance))
self.base_dist = StandardNormal(self.bijection.shape)
@property
def loc(self):
"""Location (mean) of the distribution."""
return self.bijection.loc
@property
def covariance(self):
"""The covariance matrix."""
cholesky = unwrap(self.bijection.triangular)
return cholesky @ cholesky.T
class _StandardUniform(AbstractDistribution):
r"""Standard Uniform distribution."""
shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None
def _log_prob(self, x, condition=None):
return jstats.uniform.logpdf(x).sum()
def _sample(self, key, condition=None):
return jr.uniform(key, shape=self.shape)
class _StandardGumbel(AbstractDistribution):
"""Standard gumbel distribution."""
shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None
def _log_prob(self, x, condition=None):
return -(x + jnp.exp(-x)).sum()
def _sample(self, key, condition=None):
return jr.gumbel(key, shape=self.shape)
[docs]
class Gumbel(AbstractLocScaleDistribution):
"""Gumbel distribution.
``loc`` and ``scale`` should broadcast to the dimension of the distribution.
Args:
loc: Location paramter. Defaults to 0.
scale: Scale parameter. Defaults to 1.
"""
base_dist: _StandardGumbel
bijection: Affine
def __init__(self, loc: ArrayLike = 0, scale: ArrayLike = 1):
self.base_dist = _StandardGumbel(
jnp.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)),
)
self.bijection = Affine(loc, scale)
class _StandardCauchy(AbstractDistribution):
"""Implements standard cauchy distribution (loc=0, scale=1)."""
shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None
def _log_prob(self, x, condition=None):
return jstats.cauchy.logpdf(x).sum()
def _sample(self, key, condition=None):
return jr.cauchy(key, shape=self.shape)
[docs]
class Cauchy(AbstractLocScaleDistribution):
"""Cauchy distribution.
``loc`` and ``scale`` should broadcast to the dimension of the distribution.
Args:
loc: Location paramter. Defaults to 0.
scale: Scale parameter. Defaults to 1.
"""
base_dist: _StandardCauchy
bijection: Affine
def __init__(self, loc: ArrayLike = 0, scale: ArrayLike = 1):
self.base_dist = _StandardCauchy(
jnp.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)),
)
self.bijection = Affine(loc, scale)
class _StandardStudentT(AbstractDistribution):
"""Implements student T distribution with specified degrees of freedom."""
shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
df: Array | AbstractUnwrappable[Array]
def __init__(self, df: ArrayLike):
df = arraylike_to_array(df, dtype=float)
df = eqx.error_if(df, df <= 0, "Degrees of freedom values must be positive.")
self.shape = jnp.shape(df)
self.df = Parameterize(softplus, inv_softplus(df))
def _log_prob(self, x, condition=None):
return jstats.t.logpdf(x, df=self.df).sum()
def _sample(self, key, condition=None):
return jr.t(key, df=self.df, shape=self.shape)
[docs]
class StudentT(AbstractLocScaleDistribution):
"""Student T distribution.
``df``, ``loc`` and ``scale`` broadcast to the dimension of the distribution.
Args:
df: The degrees of freedom.
loc: Location parameter. Defaults to 0.
scale: Scale parameter. Defaults to 1.
"""
base_dist: _StandardStudentT
bijection: Affine
def __init__(self, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1):
df, loc, scale = jnp.broadcast_arrays(df, loc, scale)
self.base_dist = _StandardStudentT(df)
self.bijection = Affine(loc, scale)
@property
def df(self):
"""The degrees of freedom of the distribution."""
return unwrap(self.base_dist.df)
class _StandardLaplace(AbstractDistribution):
"""Implements standard laplace distribution (loc=0, scale=1)."""
shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None
def _log_prob(self, x, condition=None):
return jstats.laplace.logpdf(x).sum()
def _sample(self, key, condition=None):
return jr.laplace(key, shape=self.shape)
[docs]
class Laplace(AbstractLocScaleDistribution):
"""Laplace distribution.
``loc`` and ``scale`` should broadcast to the dimension of the distribution.
Args:
loc: Location paramter. Defaults to 0.
scale: Scale parameter. Defaults to 1.
"""
base_dist: _StandardLaplace
bijection: Affine
def __init__(self, loc: ArrayLike = 0, scale: ArrayLike = 1):
shape = jnp.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
self.base_dist = _StandardLaplace(shape)
self.bijection = Affine(loc, scale)
class _StandardExponential(AbstractDistribution):
shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None
def _log_prob(self, x, condition=None):
return jstats.expon.logpdf(x).sum()
def _sample(self, key, condition=None):
return jr.exponential(key, shape=self.shape)
[docs]
class Exponential(AbstractTransformed):
"""Exponential distribution.
Args:
rate: The rate parameter (1 / scale).
"""
base_dist: _StandardExponential
bijection: Scale
def __init__(self, rate: ArrayLike = 1):
self.base_dist = _StandardExponential(jnp.shape(rate))
self.bijection = Scale(1 / rate)
@property
def rate(self):
return 1 / unwrap(self.bijection.scale)
class _StandardLogistic(AbstractDistribution):
shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None
def _sample(self, key, condition=None):
return jr.logistic(key, self.shape)
def _log_prob(self, x, condition=None):
return jstats.logistic.logpdf(x).sum()
[docs]
class Logistic(AbstractLocScaleDistribution):
"""Logistic distribution.
``loc`` and ``scale`` should broadcast to the shape of the distribution.
Args:
loc: Location parameter. Defaults to 0.
scale: Scale parameter. Defaults to 1.
"""
base_dist: _StandardLogistic
bijection: Affine
def __init__(self, loc: ArrayLike = 0, scale: ArrayLike = 1):
self.base_dist = _StandardLogistic(
shape=jnp.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)),
)
self.bijection = Affine(loc=loc, scale=scale)
[docs]
class VmapMixture(AbstractDistribution):
"""Create a mixture distribution.
Given a distribution in which the arrays have a leading dimension with size matching
the number of components, and a set of weights, create a mixture distribution.
Example:
.. doctest::
>>> # Creating a 3 component, 2D gaussian mixture
>>> from flowjax.distributions import Normal, VmapMixture
>>> import equinox as eqx
>>> import jax.numpy as jnp
>>> normals = eqx.filter_vmap(Normal)(jnp.zeros((3, 2)))
>>> mixture = VmapMixture(normals, weights=jnp.ones(3))
>>> mixture.shape
(2,)
Args:
dist: Distribution with a leading dimension in arrays with size equal to the
number of mixture components. Often it is convenient to construct this with
with a pattern like ``eqx.filter_vmap(MyDistribution)(my_params)``.
weights: The positive, but possibly unnormalized component weights.
"""
shape: tuple[int, ...]
cond_shape: tuple[int, ...] | None
log_normalized_weights: Array | AbstractUnwrappable[Array]
dist: AbstractDistribution
def __init__(
self,
dist: AbstractDistribution,
weights: ArrayLike,
):
weights = eqx.error_if(weights, weights <= 0, "Weights must be positive.")
self.dist = dist
self.log_normalized_weights = Parameterize(log_softmax, jnp.log(weights))
self.shape = dist.shape
self.cond_shape = dist.cond_shape
def _log_prob(self, x, condition=None):
log_probs = eqx.filter_vmap(lambda d: d._log_prob(x, condition))(self.dist)
return logsumexp(log_probs + self.log_normalized_weights)
def _sample(self, key, condition=None):
key1, key2 = jr.split(key)
component = jr.categorical(key1, self.log_normalized_weights)
component_dist = tree_map(
lambda leaf: leaf[component] if isinstance(leaf, Array) else leaf,
tree=self.dist,
)
return component_dist._sample(key2, condition)
class _StandardGamma(AbstractDistribution):
concentration: Array | AbstractUnwrappable[Array]
shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
def __init__(self, concentration: ArrayLike):
self.concentration = Parameterize(softplus, inv_softplus(concentration))
self.shape = jnp.shape(concentration)
def _sample(self, key, condition=None):
return jr.gamma(key, self.concentration)
def _log_prob(self, x, condition=None):
return jstats.gamma.logpdf(x, self.concentration).sum()
[docs]
class Gamma(AbstractTransformed):
"""Gamma distribution.
Args:
concentration: Positive concentration parameter.
scale: The scale (inverse of rate) parameter.
"""
base_dist: _StandardGamma
bijection: Scale
def __init__(self, concentration: ArrayLike, scale: ArrayLike):
concentration, scale = jnp.broadcast_arrays(concentration, scale)
self.base_dist = _StandardGamma(concentration)
self.bijection = Scale(scale)
[docs]
class Beta(AbstractDistribution):
"""Beta distribution.
Args:
alpha: The alpha shape parameter.
beta: The beta shape parameter.
"""
alpha: Array | AbstractUnwrappable[Array]
beta: Array | AbstractUnwrappable[Array]
shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
def __init__(self, alpha: ArrayLike, beta: ArrayLike):
alpha, beta = jnp.broadcast_arrays(
arraylike_to_array(alpha, dtype=float),
arraylike_to_array(beta, dtype=float),
)
self.alpha = Parameterize(softplus, inv_softplus(alpha))
self.beta = Parameterize(softplus, inv_softplus(beta))
self.shape = alpha.shape
def _sample(self, key, condition=None):
return jr.beta(key, self.alpha, self.beta)
def _log_prob(self, x, condition=None):
return jstats.beta.logpdf(x, self.alpha, self.beta).sum()