FlowJax
Examples
Unconditional density estmation
Conditional density estimation
Variational inference
Bounded flows
Sequential neural posterior estimation
API
Bijections
Distributions
Experimental
Flows
Loss functions
Training
Wrappers
Miscellaneous
FAQ
FlowJax
Index
Index
A
|
B
|
C
|
D
|
E
|
F
|
G
|
I
|
K
|
L
|
M
|
N
|
P
|
R
|
S
|
T
|
U
|
V
|
W
|
X
|
Y
A
AbstractBijection (class in flowjax.bijections)
AbstractDistribution (class in flowjax.distributions)
AbstractLocScaleDistribution (class in flowjax.distributions)
AbstractTransformed (class in flowjax.distributions)
AbstractUnwrappable (class in flowjax.wrappers)
activation (BlockAutoregressiveNetwork attribute)
AdditiveCondition (class in flowjax.bijections)
Affine (class in flowjax.bijections)
args (Lambda attribute)
arr (BijectionReparam attribute)
axis (Concatenate attribute)
(Stack attribute)
axis_size (Vmap attribute)
B
base_dist (AbstractTransformed attribute)
bijection (AbstractTransformed attribute)
(BijectionReparam attribute)
(EmbedCondition attribute)
(Invert attribute)
(Partial attribute)
(Reshape attribute)
(Scan attribute)
(Vmap attribute)
BijectionReparam (class in flowjax.wrappers)
bijections (Chain attribute)
(Concatenate attribute)
(Stack attribute)
block_dim (BlockAutoregressiveNetwork attribute)
block_neural_autoregressive_flow() (in module flowjax.flows)
BlockAutoregressiveNetwork (class in flowjax.bijections)
C
Cauchy (class in flowjax.distributions)
Chain (class in flowjax.bijections)
Concatenate (class in flowjax.bijections)
cond (Where attribute)
cond_linear (BlockAutoregressiveNetwork attribute)
cond_ndim (AbstractDistribution property)
cond_shape (AbstractBijection attribute)
(AbstractDistribution attribute)
(AdditiveCondition attribute)
(Affine attribute)
(BlockAutoregressiveNetwork attribute)
(Chain attribute)
(Concatenate attribute)
(Coupling attribute)
(EmbedCondition attribute)
(Exp attribute)
(Flip attribute)
(Identity attribute)
(Invert property)
(LeakyTanh attribute)
(Loc attribute)
(MaskedAutoregressive attribute)
(Partial property)
(Permute attribute)
(Planar attribute)
(RationalQuadraticSpline attribute)
(Reshape attribute)
(Scale attribute)
(Scan property)
(SoftPlus attribute)
(Stack attribute)
(Tanh attribute)
(TriangularAffine attribute)
(Vmap attribute)
conditioner (Coupling attribute)
(Planar attribute)
ContrastiveLoss (class in flowjax.train.losses)
Coupling (class in flowjax.bijections)
coupling_flow() (in module flowjax.flows)
covariance (MultivariateNormal property)
D
depth (BlockAutoregressiveNetwork attribute)
derivative() (RationalQuadraticSpline method)
derivatives (RationalQuadraticSpline attribute)
df (StudentT property)
dim (Coupling attribute)
distribution_to_numpyro() (in module flowjax.experimental.numpyro)
E
ElboLoss (class in flowjax.train.losses)
EmbedCondition (class in flowjax.bijections)
embedding_net (EmbedCondition attribute)
Exp (class in flowjax.bijections)
Exponential (class in flowjax.distributions)
F
fit_to_data() (in module flowjax.train)
fit_to_variational_target() (in module flowjax.train)
Flip (class in flowjax.bijections)
flowjax.bijections
module
flowjax.distributions
module
flowjax.experimental.numpyro
module
flowjax.flows
module
flowjax.train.losses
module
flowjax.wrappers
module
fn (Lambda attribute)
G
get_cond_shape() (Vmap method)
get_planar() (Planar method)
Gumbel (class in flowjax.distributions)
I
Identity (class in flowjax.bijections)
idxs (Partial attribute)
if_false (Where attribute)
if_true (Where attribute)
in_axes (Vmap attribute)
intercept (LeakyTanh attribute)
interval (RationalQuadraticSpline attribute)
inv_scan_fn() (MaskedAutoregressive method)
inverse() (AbstractBijection method)
(AdditiveCondition method)
(Affine method)
(BlockAutoregressiveNetwork method)
(Chain method)
(Concatenate method)
(Coupling method)
(EmbedCondition method)
(Exp method)
(Flip method)
(Identity method)
(Invert method)
(LeakyTanh method)
(Loc method)
(MaskedAutoregressive method)
(Partial method)
(Permute method)
(Planar method)
(RationalQuadraticSpline method)
(Reshape method)
(Scale method)
(Scan method)
(SoftPlus method)
(Stack method)
(Tanh method)
(TriangularAffine method)
(Vmap method)
inverse_and_log_det() (AbstractBijection method)
(AdditiveCondition method)
(Affine method)
(BlockAutoregressiveNetwork method)
(Chain method)
(Concatenate method)
(Coupling method)
(EmbedCondition method)
(Exp method)
(Flip method)
(Identity method)
(Invert method)
(LeakyTanh method)
(Loc method)
(MaskedAutoregressive method)
(Partial method)
(Permute method)
(Planar method)
(RationalQuadraticSpline method)
(Reshape method)
(Scale method)
(Scan method)
(SoftPlus method)
(Stack method)
(Tanh method)
(TriangularAffine method)
(Vmap method)
inverse_permutation (Permute attribute)
Invert (class in flowjax.bijections)
inverter (BlockAutoregressiveNetwork attribute)
K
knots (RationalQuadraticSpline attribute)
kwargs (Lambda attribute)
L
Lambda (class in flowjax.wrappers)
Laplace (class in flowjax.distributions)
layers (BlockAutoregressiveNetwork attribute)
LeakyTanh (class in flowjax.bijections)
linear_grad (LeakyTanh attribute)
loc (AbstractLocScaleDistribution property)
(Affine attribute)
Loc (class in flowjax.bijections)
loc (Loc attribute)
(MultivariateNormal property)
(TriangularAffine attribute)
log_prob() (AbstractDistribution method)
Logistic (class in flowjax.distributions)
LogNormal (class in flowjax.distributions)
lower (TriangularAffine attribute)
M
masked_autoregressive_flow() (in module flowjax.flows)
masked_autoregressive_mlp (MaskedAutoregressive attribute)
MaskedAutoregressive (class in flowjax.bijections)
max_val (LeakyTanh attribute)
MaximumLikelihoodLoss (class in flowjax.train.losses)
maxval (Uniform property)
merge_chains() (Chain method)
merge_transforms() (AbstractTransformed method)
min_derivative (RationalQuadraticSpline attribute)
minval (Uniform property)
module
flowjax.bijections
flowjax.distributions
flowjax.experimental.numpyro
flowjax.flows
flowjax.train.losses
flowjax.wrappers
module (AdditiveCondition attribute)
MultivariateNormal (class in flowjax.distributions)
N
ndim (AbstractDistribution property)
NonTrainable (class in flowjax.wrappers)
Normal (class in flowjax.distributions)
num_samples (ElboLoss attribute)
P
params (Planar attribute)
Partial (class in flowjax.bijections)
permutation (Permute attribute)
Permute (class in flowjax.bijections)
Planar (class in flowjax.bijections)
planar_flow() (in module flowjax.flows)
R
RationalQuadraticSpline (class in flowjax.bijections)
recursive_unwrap() (AbstractUnwrappable method)
register_params() (in module flowjax.experimental.numpyro)
Reshape (class in flowjax.bijections)
S
sample() (AbstractDistribution method)
(in module flowjax.experimental.numpyro)
sample_and_log_prob() (AbstractDistribution method)
scale (AbstractLocScaleDistribution property)
(Affine attribute)
Scale (class in flowjax.bijections)
scale (Scale attribute)
(WeightNormalization attribute)
Scan (class in flowjax.bijections)
shape (AbstractBijection attribute)
(AbstractDistribution attribute)
(AdditiveCondition attribute)
(Affine attribute)
(BlockAutoregressiveNetwork attribute)
(Chain attribute)
(Concatenate attribute)
(Coupling attribute)
(EmbedCondition property)
(Exp attribute)
(Flip attribute)
(Identity attribute)
(Invert property)
(LeakyTanh attribute)
(Loc attribute)
(MaskedAutoregressive attribute)
(Partial attribute)
(Permute attribute)
(Planar attribute)
(RationalQuadraticSpline attribute)
(Reshape attribute)
(Scale attribute)
(Scan property)
(SoftPlus attribute)
(Stack attribute)
(Tanh attribute)
(TriangularAffine attribute)
(Vmap property)
softmax_adjust (RationalQuadraticSpline attribute)
SoftPlus (class in flowjax.bijections)
split_idxs (Concatenate attribute)
Stack (class in flowjax.bijections)
StandardNormal (class in flowjax.distributions)
stick_the_landing (ElboLoss attribute)
StudentT (class in flowjax.distributions)
T
Tanh (class in flowjax.bijections)
target (ElboLoss attribute)
transform() (AbstractBijection method)
(AdditiveCondition method)
(Affine method)
(BlockAutoregressiveNetwork method)
(Chain method)
(Concatenate method)
(Coupling method)
(EmbedCondition method)
(Exp method)
(Flip method)
(Identity method)
(Invert method)
(LeakyTanh method)
(Loc method)
(MaskedAutoregressive method)
(Partial method)
(Permute method)
(Planar method)
(RationalQuadraticSpline method)
(Reshape method)
(Scale method)
(Scan method)
(SoftPlus method)
(Stack method)
(Tanh method)
(TriangularAffine method)
(Vmap method)
transform_and_log_det() (AbstractBijection method)
(AdditiveCondition method)
(Affine method)
(BlockAutoregressiveNetwork method)
(Chain method)
(Concatenate method)
(Coupling method)
(EmbedCondition method)
(Exp method)
(Flip method)
(Identity method)
(Invert method)
(LeakyTanh method)
(Loc method)
(MaskedAutoregressive method)
(Partial method)
(Permute method)
(Planar method)
(RationalQuadraticSpline method)
(Reshape method)
(Scale method)
(Scan method)
(SoftPlus method)
(Stack method)
(Tanh method)
(TriangularAffine method)
(Vmap method)
Transformed (class in flowjax.distributions)
transformer_constructor (Coupling attribute)
(MaskedAutoregressive attribute)
tree (NonTrainable attribute)
triangular (TriangularAffine attribute)
triangular_spline_flow() (in module flowjax.flows)
TriangularAffine (class in flowjax.bijections)
U
Uniform (class in flowjax.distributions)
untransformed_dim (Coupling attribute)
unwrap() (AbstractUnwrappable method)
(BijectionReparam method)
(in module flowjax.wrappers)
(Lambda method)
(NonTrainable method)
(WeightNormalization method)
(Where method)
V
Vmap (class in flowjax.bijections)
vmap() (Vmap method)
VmapMixture (class in flowjax.distributions)
W
weight (WeightNormalization attribute)
WeightNormalization (class in flowjax.wrappers)
Where (class in flowjax.wrappers)
X
x_pos (RationalQuadraticSpline attribute)
Y
y_pos (RationalQuadraticSpline attribute)