Source code for flowjax.bijections.rational_quadratic_spline

"""Rational quadratic spline bijections (https://arxiv.org/abs/1906.04032)."""

from functools import partial
from typing import ClassVar

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from flowjax.bijections.bijection import AbstractBijection
from flowjax.utils import inv_softplus
from flowjax.wrappers import AbstractUnwrappable, Parameterize


def _real_to_increasing_on_interval(
    arr: Float[Array, " dim"],
    interval: tuple[int | float, int | float],
    softmax_adjust: float = 1e-2,
    *,
    pad_with_ends: bool = True,
):
    """Transform unconstrained vector to monotonically increasing positions on [-B, B].

    Args:
        arr: Parameter vector.
        interval: Interval to transform output. Defaults to 1.
        softmax_adjust : Rescales softmax output using
            ``(widths + softmax_adjust/widths.size) / (1 + softmax_adjust)``. e.g.
            0=no adjustment, 1=average softmax output with evenly spaced widths, >1
            promotes more evenly spaced widths.
        pad_with_ends: Whether to pad the with -interval and interval. Defaults to True.
    """
    if softmax_adjust < 0:
        raise ValueError("softmax_adjust should be >= 0.")

    widths = jax.nn.softmax(arr)
    widths = (widths + softmax_adjust / widths.size) / (1 + softmax_adjust)
    widths = widths.at[0].set(widths[0] / 2)
    scale = interval[1] - interval[0]
    pos = interval[0] + scale * jnp.cumsum(widths)

    if pad_with_ends:
        pos = jnp.pad(pos, pad_width=1, constant_values=interval)

    return pos


[docs] class RationalQuadraticSpline(AbstractBijection): """Scalar RationalQuadraticSpline transformation (https://arxiv.org/abs/1906.04032). Args: knots: Number of knots. interval: Interval to transform, if a scalar value, uses [-interval, interval], if a tuple, uses [interval[0], interval[1]] min_derivative: Minimum dervivative. Defaults to 1e-3. softmax_adjust: Controls minimum bin width and height by rescaling softmax output, e.g. 0=no adjustment, 1=average softmax output with evenly spaced widths, >1 promotes more evenly spaced widths. See ``real_to_increasing_on_interval``. Defaults to 1e-2. """ knots: int interval: tuple[int | float, int | float] softmax_adjust: float | int min_derivative: float x_pos: Array | AbstractUnwrappable[Array] y_pos: Array | AbstractUnwrappable[Array] derivatives: Array | AbstractUnwrappable[Array] shape: ClassVar[tuple] = () cond_shape: ClassVar[None] = None def __init__( self, *, knots: int, interval: float | int | tuple[int | float, int | float], min_derivative: float = 1e-3, softmax_adjust: float | int = 1e-2, ): self.knots = knots interval = interval if isinstance(interval, tuple) else (-interval, interval) self.interval = interval self.softmax_adjust = softmax_adjust self.min_derivative = min_derivative # Inexact arrays pos_parameterization = partial( _real_to_increasing_on_interval, interval=interval, softmax_adjust=softmax_adjust, ) self.x_pos = Parameterize(pos_parameterization, jnp.zeros(knots)) self.y_pos = Parameterize(pos_parameterization, jnp.zeros(knots)) self.derivatives = Parameterize( lambda arr: jax.nn.softplus(arr) + self.min_derivative, jnp.full(knots + 2, inv_softplus(1 - min_derivative)), ) def transform(self, x, condition=None): # Following notation from the paper x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives in_bounds = jnp.logical_and(x >= self.interval[0], x <= self.interval[1]) x_robust = jnp.where(in_bounds, x, 0) # To avoid nans k = jnp.searchsorted(x_pos, x_robust) - 1 # k is bin number xi = (x_robust - x_pos[k]) / (x_pos[k + 1] - x_pos[k]) sk = (y_pos[k + 1] - y_pos[k]) / (x_pos[k + 1] - x_pos[k]) dk, dk1, yk, yk1 = derivatives[k], derivatives[k + 1], y_pos[k], y_pos[k + 1] num = (yk1 - yk) * (sk * xi**2 + dk * xi * (1 - xi)) den = sk + (dk1 + dk - 2 * sk) * xi * (1 - xi) y = yk + num / den # eq. 4 # avoid numerical precision issues transforming from in -> out of bounds y = jnp.clip(y, self.interval[0], self.interval[1]) return jnp.where(in_bounds, y, x) def transform_and_log_det(self, x, condition=None): y = self.transform(x) derivative = self.derivative(x) return y, jnp.log(derivative).sum() def inverse(self, y, condition=None): # Following notation from the paper x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives in_bounds = jnp.logical_and(y >= self.interval[0], y <= self.interval[1]) y_robust = jnp.where(in_bounds, y, 0) # To avoid nans k = jnp.searchsorted(y_pos, y_robust) - 1 xk, xk1, yk, yk1 = x_pos[k], x_pos[k + 1], y_pos[k], y_pos[k + 1] sk = (yk1 - yk) / (xk1 - xk) y_delta_s_term = (y_robust - yk) * ( derivatives[k + 1] + derivatives[k] - 2 * sk ) a = (yk1 - yk) * (sk - derivatives[k]) + y_delta_s_term b = (yk1 - yk) * derivatives[k] - y_delta_s_term c = -sk * (y_robust - yk) sqrt_term = jnp.sqrt(b**2 - 4 * a * c) xi = (2 * c) / (-b - sqrt_term) x = xi * (xk1 - xk) + xk # avoid numerical precision issues transforming from in -> out of bounds x = jnp.clip(x, self.interval[0], self.interval[1]) return jnp.where(in_bounds, x, y) def inverse_and_log_det(self, y, condition=None): x = self.inverse(y) derivative = self.derivative(x) return x, -jnp.log(derivative).sum()
[docs] def derivative(self, x) -> Array: """The derivative dy/dx of the forward transformation.""" # Following notation from the paper (eq. 5) x_pos, y_pos, derivatives = self.x_pos, self.y_pos, self.derivatives in_bounds = jnp.logical_and(x >= self.interval[0], x <= self.interval[1]) x_robust = jnp.where(in_bounds, x, 0) # To avoid nans k = jnp.searchsorted(x_pos, x_robust) - 1 xi = (x_robust - x_pos[k]) / (x_pos[k + 1] - x_pos[k]) sk = (y_pos[k + 1] - y_pos[k]) / (x_pos[k + 1] - x_pos[k]) dk, dk1 = derivatives[k], derivatives[k + 1] num = sk**2 * (dk1 * xi**2 + 2 * sk * xi * (1 - xi) + dk * (1 - xi) ** 2) den = (sk + (dk1 + dk - 2 * sk) * xi * (1 - xi)) ** 2 derivative = num / den return jnp.where(in_bounds, derivative, 1.0)