Conditional density estimation

This example shows how we can perform conditional density estimation with normalising flows. Here we use a block_neural_autoregressive_flow, although other flows are available and all support conditional density estimation (see flowjax.flows). We consider a two dimensional model, where the upper limit of the target uniform distribution depends on another uniform random variable:

\[u_i \sim \text{Uniform}(0,5) \quad \text{for}\ i\ \text{in}\ 1,2\]
\[x_i \sim \text{Uniform}(0, u_i), \quad \text{for}\ i\ \text{in}\ 1,2\]

We will try to infer the conditional distribution \(p(x|u)\) using samples from the model.

Importing the required libraries.

[1]:
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

from flowjax.distributions import Normal
from flowjax.flows import block_neural_autoregressive_flow
from flowjax.train import fit_to_data

Generating the toy data.

[2]:
key, x_key, cond_key = jr.split(jr.PRNGKey(0), 3)
u = jr.uniform(cond_key, (10000, 2), minval=0, maxval=5)
x = jr.uniform(x_key, shape=u.shape, maxval=u)

Creating and training the flow.

[5]:
key, subkey = jr.split(jr.PRNGKey(0))

flow = block_neural_autoregressive_flow(
    key=subkey,
    base_dist=Normal(jnp.zeros(x.shape[1])),
    cond_dim=u.shape[1],
)

key, subkey = jr.split(key)
flow, losses = fit_to_data(
    key=subkey,
    dist=flow,
    x=x,
    condition=u,
    learning_rate=5e-2,
    max_patience=10,
)
 50%|█████     | 50/100 [00:13<00:13,  3.67it/s, train=1.318699, val=1.4249771 (Max patience reached)]

We can now visualise the learned density. Let’s condition on \(u=[1,3]'\), in which case we expect \(x_1 \sim \text{Uniform}(0, 1)\) and \(x_2 \sim \text{Uniform}(0, 3)\)

[6]:
resolution = 200
test_u = jnp.array([1.0, 3])

xgrid, ygrid = jnp.meshgrid(
    jnp.linspace(-1, 4, resolution), jnp.linspace(-1, 4, resolution),
)
xyinput = jnp.column_stack((xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)))
zgrid = jnp.exp(flow.log_prob(xyinput, test_u).reshape(resolution, resolution))
plt.contourf(xgrid, ygrid, zgrid, levels=50)
plt.show()
../_images/examples_conditional_8_0.png
[ ]: