from typing import Tuple
from datetime import date
from functools import partial
from warnings import filterwarnings
import jax
import jax.random as jr
import jax.numpy as jnp
import equinox as eqx
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.preprocessing import scale
import blackjax
import tensorflow_probability.substrates.jax.distributions as tfd
from sklearn.model_selection import train_test_split
"ignore")
filterwarnings(= jr.key(int(date.today().strftime("%Y%m%d")))
key = mpl.colormaps["RdYlGn"] cmap
Bayesian Modeling is a very suitable choice if you want to obtain the uncertainty associated with the predictions of a model. Here, typically a Markov Chain Monte Carlo estimator is used, which explores any stationary distribution and recovers (asymptotically) consistent estimators and thus those samplers are of primary interest, because we can essentially (re-)construct any distribution. This can also be the joint distribution of the weights of a Neural Network! This could be incredibly promising, since we can combine the powers from statistical modeling techniques with the of universal function approximation from Neural Networks. To this end, there are recent voices arguing why Bayesian Deep Learning is a promising avenue.
In Bayesian Modeling, Hierarchical Bayesian Modeling is a special kind of model specification, helping the sampler to expore the distribution of interest. It is in fact so powerful that once you know about it, you can’t unsee applications of it (primarily in the Sciences). Hierarchical Modeling can be used if you have some grouped structure in your dataset, e.g. if products can be assigned to clusters that share some properties. More on this technique in Section 2.
There are some very useful blog entries and notebooks out there (e.g. by Thomas Wiecki using Theano and PyMC3 and this repo using a more recent version of JAX). However, those examples only work with the critical assumption that the group sizes are all of the same size. In reality, this is rarely the case, of course.
Here, I will show you how you can implement a Hierarchical Bayesian Neural Network
irrespective of the group sizes you observe in your dataset.
The notebook is structured as follows:
- ✍️ Create a dataset: Binary classification with unequal observations per group
- 🧠 What’s the modelling approach and why does it work?
- 👾 Code the model
- 🚀 Run the model and evaluate
1 Setup and dummy data generation
Let’s first import the libraries we’ll need:
Thoughout the notebook, we’ll use the standard two-moons
dataset, being a binary classification problem. Figure 1 shows how the dataset looks with some training examples.
= 0.3
noise = make_moons(noise=noise, n_samples=2000)
X, Y
for i in range(2):
== i, 0], X[Y == i, 1], color=cmap(float(i)), label=f"Class {i}", alpha=0.8)
plt.scatter(X[Y
plt.legend() plt.show()

Next, let’s choose some values for the data generation of our grouped dataset. We’ll create several groups with a random number of samples, choose some settings for our Neural Network implementation and set two parameters for the MCMC-Algorithm: the number of ‘warmup’ samples (which will be discarded after the model fitting finished) and the number of sampling steps.
# Data
= 16
n_groups = int(np.sqrt(n_groups))
n_grps_sq = np.random.randint(10, 200, size=n_groups)
n_samples
# MLP params
= 2
data_dim = 8
hidden_layer_width = 3
n_hidden_layers
# Sampling
= 1000
num_warmup = 2000 num_samples
We then write a function which rotates the dataset in the 2-D space a bit and generate the datasets, store them in lists:
def rotate(X, deg):
= np.radians(deg)
theta = np.cos(theta), np.sin(theta)
c, s = np.matrix([[c, -s], [s, c]])
R = X.dot(R)
X return np.asarray(X)
31)
np.random.seed(
= [], [], []
Xs, Ys, gs = [], [], [], [], [], []
Xs_train, Ys_train, gs_train, Xs_test, Ys_test, gs_test
for i in range(n_groups):
# Generate data with 2 classes that are not linearly separable
= make_moons(noise=noise, n_samples=n_samples[i])
X, Y = scale(X)
X
# Rotate the points randomly for each category
= np.random.randn() * 90.0
rotate_by = rotate(X, rotate_by)
X
Xs.append(X)
Ys.append(Y)0])
gs.append(X.shape[= train_test_split(X, Y, test_size=0.2, random_state=31)
X_train, X_test, Y_train, Y_test
Xs_train.append(X_train)
Ys_train.append(Y_train)0])
gs_train.append(X_train.shape[
Xs_test.append(X_test)
Ys_test.append(Y_test)0]) gs_test.append(X_test.shape[
Next, we pad the entries in our list of datasets such that all the entries have the same shape: the shape of the largest dataset. We also create a mask, marking the elements of the entries which were padded. Padding works here, because we can disregard the masked positions in our datasets in the loglikelihood function
.
def pad_arrays(arrays, fill_value):
= max(array.shape[0] for array in arrays)
max_size = []
padded_arrays for array in arrays:
if array.ndim == 1:
= (0, max_size - array.shape[0])
padding = jnp.pad(array, padding, mode="constant", constant_values=fill_value)
padded_array
padded_arrays.append(padded_array[:, np.newaxis]) else:
= ((0, max_size - array.shape[0]), (0, 0))
padding = jnp.pad(array, padding, mode="constant", constant_values=fill_value)
padded_array
padded_arrays.append(padded_array)return padded_arrays
# Stack group arrays and create a mask
= 1e5
fill_value = jnp.stack(pad_arrays(Xs_train, fill_value))
Xs_train = jnp.stack(pad_arrays(Ys_train, fill_value)).squeeze(axis=2)
Ys_train = jnp.stack(pad_arrays(Xs_test, fill_value))
Xs_test = jnp.stack(pad_arrays(Ys_test, fill_value)).squeeze(axis=2)
Ys_test
= jnp.where(Xs_train == fill_value, fill_value, 1)
mask_train = jnp.where(Xs_test == fill_value, fill_value, 1) mask_test
In Figure 2 you can see how the datasets look and how many entries the individual groups got:
# utility function for plotting
def closest_factors(n):
= 1, n
a, b for i in range(1, int(n**0.5) + 1):
if n % i == 0:
= i, n // i
a, b return a, b
# Number of rows and columns for subplots
= closest_factors(n_groups)
n_cols, n_rows
= plt.subplots(n_rows, n_cols, figsize=(n_cols*3, n_rows*2), sharex=True, sharey=True)
fig, axes
# Flatten axes array for easy iteration
= axes.flatten()
axes
for i, (X, Y, group_size, ax) in enumerate(zip(Xs_train, Ys_train, gs_train, axes)):
for c in range(2):
== c, 0], X[Y == c, 1], color=cmap(float(c)), label=f"Class {c}", alpha=0.8)
ax.scatter(X[Y
ax.set_xticks([])
ax.set_yticks([])=False)
ax.legend(frameonset(title=f"Category {i + 1}, N_training = {group_size}")
ax.
# Hide any unused subplots
for j in range(n_groups, len(axes)):
fig.delaxes(axes[j])
plt.tight_layout() plt.show()

2 A Bayesian Hierarchical Neural Network
Now, let’s dig into why modeling this dataset with a hierarchical
Neural Network might make sense. Beware: we’ll use some vocabulary from statistical modeling.
The prominent probabilistic modeler Michael Betancourt provides an in depth introduction to the foundations of hierarchical modeling. Conceptually, hierarchical modeling is an approach if there is a latent population that we can couple context-dependent parameters to. In our dataset, we assume there is some homogeneous structure withtin the groups, whereas the groups may be different between them. This means that we can actually share the weights of the group’s individual Neural Networks across all networks, because the task (binary classification with some Z-shape) is similar; even though the individual groups are all oriented differently in the 2-D space (difference between the groups).
Summarizing, a HBNN requires that:
- you can make use of a grouping structure in your dataset
- the data generating process of the individual groups is similar
It is the strongest, if you have a small-ish number of observations (probably in relation to the difficulty of the learning task?), since in this case ‘traditional’ approaches will fail (see Tip 1).
Since it’s covered in the HBNN-tutorial from the BlackJax Sampling Book, I will just copy-paste the code here for reference and the curious (and because I don’t want to loose this example if the Sampling Book disappears at some point).
The key takeaway is this: If you fit the models separately, there will be not enough training examples for the Neural Network to capture the nonlinear relationship separating the two classes.
import matplotlib.pyplot as plt
import jax
from datetime import date
from functools import partial
from warnings import filterwarnings
from flax import linen as nn
from flax.linen.initializers import ones
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax.distributions as tfd
from sklearn.datasets import make_moons
from sklearn.preprocessing import scale
import blackjax
"ignore")
filterwarnings(
import matplotlib as mpl
"axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
plt.rcParams[
= jax.random.key(int(date.today().strftime("%Y%m%d")))
rng_key
= 18
n_groups_tut
= int(np.sqrt(n_groups_tut))
n_grps_sq_tut = 100
n_samples_tut
def rotate_tut(X, deg):
= np.radians(deg)
theta = np.cos(theta), np.sin(theta)
c, s = np.matrix([[c, -s], [s, c]])
R
= X.dot(R)
X
return np.asarray(X)
31)
np.random.seed(
= [], []
Xs_tut, Ys_tut for i in range(n_groups_tut):
# Generate data with 2 classes that are not linearly separable
= make_moons(noise=0.3, n_samples=n_samples_tut)
X, Y = scale(X)
X
# Rotate the points randomly for each category
= np.random.randn() * 90.0
rotate_by = rotate_tut(X, rotate_by)
X
Xs_tut.append(X)
Ys_tut.append(Y)
= jnp.stack(Xs_tut)
Xs_tut = jnp.stack(Ys_tut)
Ys_tut
= Xs_tut[:, : n_samples_tut // 2, :]
Xs_tut_train = Xs_tut[:, n_samples_tut // 2 :, :]
Xs_tut_test = Ys_tut[:, : n_samples_tut // 2]
Ys_tut_train = Ys_tut[:, n_samples_tut // 2 :]
Ys_tut_test
= jnp.mgrid[-3:3:100j, -3:3:100j].reshape((2, -1)).T
grid = jnp.repeat(grid[None, ...], n_groups_tut, axis=0)
grid_3d
def inference_loop_tut(rng_key, step_fn, initial_state, num_samples):
def one_step(state, rng_key):
= step_fn(rng_key, state)
state, _ return state, state
= jax.random.split(rng_key, num_samples)
keys = jax.lax.scan(one_step, initial_state, keys)
_, states
return states
def get_predictions_tut(model, samples, X, rng_key):
= jax.vmap(model.apply, in_axes=(0, None), out_axes=0)
vectorized_apply = vectorized_apply(samples, X)
z = tfd.Bernoulli(logits=z).sample(seed=rng_key)
predictions
return predictions.squeeze(-1)
def get_mean_predictions_tut(predictions, threshold=0.5):
# compute mean prediction and confidence interval around median
= jnp.mean(predictions, axis=0)
mean_prediction return mean_prediction > threshold
def fit_and_eval_tut(
rng_key,
model,
logdensity_fn,
X_train,
Y_train,
X_test,
grid,=None,
n_groups=1000,
num_warmup=2000,
num_samples
):
(
init_key,
warmup_key,
inference_key,
train_key,
test_key,
grid_key,= jax.random.split(rng_key, 6)
)
if n_groups is None:
= model.init(init_key, jnp.ones(X_train.shape[-1]))
initial_position else:
= model.init(init_key, jnp.ones(X_train.shape))
initial_position
# initialization
= partial(logdensity_fn, X=X_train, Y=Y_train, model=model)
logprob
# warm up
= blackjax.window_adaptation(blackjax.nuts, logprob)
adapt = adapt.run(warmup_key, initial_position, num_warmup)
(final_state, params), _ = blackjax.nuts(logprob, **params).step
step_fn
# inference
= inference_loop_tut(inference_key, step_fn, final_state, num_samples)
states = states.position
samples
# evaluation
= get_predictions_tut(model, samples, X_train, train_key)
predictions = get_mean_predictions_tut(predictions)
Y_pred_train
= get_predictions_tut(model, samples, X_test, test_key)
predictions = get_mean_predictions_tut(predictions)
Y_pred_test
= get_predictions_tut(model, samples, grid, grid_key)
pred_grid
return Y_pred_train, Y_pred_test, pred_grid
# MLP params
= 5
hidden_layer_width_tut = 2
n_hidden_layers_tut
class NN(nn.Module):
int
n_hidden_layers: int
layer_width:
@nn.compact
def __call__(self, x):
for i in range(self.n_hidden_layers):
= nn.Dense(features=self.layer_width)(x)
x = nn.tanh(x)
x return nn.Dense(features=1)(x)
= NN(n_hidden_layers_tut, hidden_layer_width_tut)
bnn
def logprior_fn_tut(params):
= jax.tree_util.tree_flatten(params)
leaves, _ = jnp.concatenate([jnp.ravel(a) for a in leaves])
flat_params return jnp.sum(tfd.Normal(0, 1).log_prob(flat_params))
def loglikelihood_fn_tut(params, X, Y, model):
= jnp.ravel(model.apply(params, X))
logits return jnp.sum(tfd.Bernoulli(logits).log_prob(Y))
def logdensity_fn_of_bnn_tut(params, X, Y, model):
return logprior_fn_tut(params) + loglikelihood_fn_tut(params, X, Y, model)
= jax.random.split(rng_key)
rng_key, eval_key = jax.random.split(eval_key, n_groups_tut)
keys
def fit_and_eval_single_mlp_tut(key, X_train, Y_train, X_test):
return fit_and_eval_tut(
=None
key, bnn, logdensity_fn_of_bnn_tut, X_train, Y_train, X_test, grid, n_groups
)
= jax.vmap(fit_and_eval_single_mlp_tut)(
Ys_pred_train, Ys_pred_test, ppc_grid_single
keys, Xs_tut_train, Ys_tut_train, Xs_tut_test
)
def plot_decision_surfaces_non_hierarchical_tut(nrows=2, ncols=2):
= plt.subplots(
fig, axes =(15, 12), nrows=nrows, ncols=ncols, sharex=True, sharey=True
figsize
)= axes.flatten()
axes for i, (X, Y_pred, Y_true, ax) in enumerate(
zip(Xs_tut_train, Ys_pred_train, Ys_tut_train, axes)
):
ax.contourf(0].reshape(100, 100),
grid[:, 1].reshape(100, 100),
grid[:, =0).reshape(100, 100),
ppc_grid_single[i, ...].mean(axis=cmap,
cmap
)for i in range(2):
ax.scatter(== i, 0], X[Y_true == i, 1],
X[Y_true =cmap(float(i)), label=f"Class {i}", alpha=.8)
color
ax.legend()
=n_grps_sq_tut, ncols=n_grps_sq_tut) plot_decision_surfaces_non_hierarchical_tut(nrows
3 Coding the model
Now we’re ready for the modeling code. To get an overview of the model we we’re about to create, have a look at Figure 3. There, we have groups \(g=1:G\) with the respective weight matrices \(w^g_{l}\) for the input, hidden and output layers \(l\).
The group weights are drawn from a Normal distribution, which, in a non-centered form means
\[ w^g_l = \mu_{l} + \epsilon^g_{l} \sigma_l \]
with
\[ \begin{align}\mu_l &\sim \mathcal{N}(0,1) \\\epsilon^g_l &\sim \mathcal{N}(0,1) \\\sigma_l &= 1\end{align} \]
This non-centered formulation simplifies the space to be explored by our sampler.
We don’t model the individual weights directly. This ‘Centered Formulation’ would mean \(w^g_l \sim \mathcal{N}(\mu_l, \sigma_l)\) with usually a \(\mathcal{N}\) prior for \(\mu\); and a \(\mathcal{N}^+\) prior for \(\sigma\).
In the following code block, we write the model using Equinox, which in turn uses JAX for all the numerical routines (e.g. autodiff).
class NonCentredLinear(eqx.Module):
mu: jax.Array
eps: jax.Array
std: jax.Array
def __init__(self, in_size, out_size, n_groups, *, key):
self.mu = jr.normal(key, (in_size, out_size))
self.eps = jr.normal(key, (n_groups, in_size, out_size))
self.std = jnp.ones((1,))
def __call__(self, x):
= self.mu + self.eps * self.std
w return x @ w
class HNN(eqx.Module):
layers: Tuple[NonCentredLinear]
out: eqx.nn.Linear
def __init__(self, data_dim, layer_width, n_layers, n_groups, *, key):
= [data_dim] + [layer_width] * n_layers
dims = []
layers for n, (_in, _out) in enumerate(zip(dims[:-1], dims[1:])):
= NonCentredLinear(_in, _out, n_groups, key=jr.fold_in(key, n))
layer += [layer]
layers self.layers = tuple(layers)
self.out = eqx.nn.Linear(layer_width, 1, key=key)
def __call__(self, x):
for layer in self.layers:
= layer(x)
x = jax.nn.tanh(x)
x # Vmap over groups and samples
= jax.vmap(jax.vmap(self.out))(x)
o return o
- 1
- Write the Non-centered layers as Equinox module
- 2
- Initialize the weights in \(\mu\) as standard Normal, one for each layer
- 3
- Initialize the weights in \(\epsilon\) as standard Normal, one for each layer and group
- 4
- Initialize \(\sigma\) as 1.
- 5
- Non-centered combination of the matrices and the dot-product of \(x\) and \(w\).
- 6
- Write the Hierarchical Neural Network as Equinox module
- 7
- In this implementation, all layers have the same width
- 8
- Create all hidden layers
- 9
- Final linear layer
- 10
- Choose \(\tanh\) as activation function
Next, we instantiate the HNN model and write some code that Equinox needs.
def get_init_apply_fns(model):
= eqx.partition(model, eqx.is_inexact_array)
params, static
def init_fn():
return params
def apply_fn(_params, x):
= eqx.combine(_params, static)
model return model(x)
return init_fn, apply_fn
= HNN(data_dim, hidden_layer_width, n_hidden_layers, n_groups, key=key)
hnn = get_init_apply_fns(hnn)
init_fn, apply_fn = init_fn()
params
def inference_loop(key, step_fn, initial_state, num_samples):
def one_step(state, key):
= step_fn(key, state)
state, _ return state, state
= jr.split(key, num_samples)
keys = jax.lax.scan(one_step, initial_state, keys)
_, states return states
- 1
- Instantiate the HNN model
- 2
-
In Jax, we can use the
scan
method to iterate over a function
Next, we write the (log)-prior function for the parameters of the model, as well as the log-likelihood.
def logprior_fn(params):
= tfd.Normal(0.0, 1.0)
normal = jax.tree_util.tree_flatten(params)
leaves, _ = jnp.concatenate([jnp.ravel(a) for a in leaves])
flat_params return jnp.sum(normal.log_prob(flat_params))
def logprior_fn_of_hnn(params, model):
"""p(w) where w is NN(X; w)"""
= 0.0
lp = tfd.HalfNormal(1.0)
half_normal = tfd.Normal(0.0, 1.0)
normal for layer in params.layers:
+= normal.log_prob(layer.mu).sum()
lp += normal.log_prob(layer.eps).sum()
lp += half_normal.log_prob(layer.std).sum()
lp += logprior_fn(params.out)
lp return lp
def loglikelihood_fn(params, X, Y, mask, fill_value, model):
"""p(Y|Y_=NN(X; w))"""
= jnp.ravel(apply_fn(params, X))
logits = jnp.where(jnp.ravel(mask[:, :, 0]) == fill_value, 0, logits)
logits return jnp.sum(tfd.Bernoulli(logits).log_prob(jnp.ravel(Y)))
def logdensity_fn_of_hnn(params, X, Y, mask, fill_value, model):
return logprior_fn_of_hnn(params, model) + loglikelihood_fn(params, X, Y, mask, fill_value, model)
- 1
- apply the mask: where the mask has the fill value, the logits should also be zero
And some utility functions for extracting the model predictions
def get_predictions(model, samples, X, mask, fill_value, key):
= jax.vmap(apply_fn, in_axes=(0, None), out_axes=0)
vectorized_apply = vectorized_apply(samples, X)
z = tfd.Bernoulli(logits=z).sample(seed=key)
predictions = jnp.broadcast_to(
mask_reshaped =-1).reshape(mask.shape[0], mask.shape[1], 1), predictions.shape
jnp.mean(mask, axis
)= jnp.where(mask_reshaped == fill_value, jnp.nan, predictions)
predictions return predictions.squeeze(-1)
def get_mean_predictions(predictions, threshold=0.5):
# Compute mean prediction and confidence interval around median
= jnp.nanmean(predictions, axis=0)
mean_prediction return mean_prediction > threshold
def fit_and_eval(
key,# Passed from `init_fn` of init/apply function conversion of Equinox NN
initial_position,
model,
logdensity_fn,
X_train,
Y_train,
mask_train,
fill_value,
X_test,
grid,=20,
num_warmup=10,
num_samples
):
(
warmup_key,
inference_key,
train_key,
test_key,
grid_key,= jr.split(key, 5)
)
# Initialization
= partial(logdensity_fn, X=X_train, Y=Y_train, mask=mask_train, fill_value=fill_value, model=model)
logprob
# Warm up
= blackjax.window_adaptation(blackjax.nuts, logprob)
adapt = adapt.run(warmup_key, initial_position, num_warmup)
(final_state, params), _ = blackjax.nuts(logprob, **params).step
step_fn
# Inference
= inference_loop(inference_key, step_fn, final_state, num_samples)
states = states.position
samples
# Evaluation
= get_predictions(model, samples, X_train, mask_train, fill_value, train_key)
predictions = get_mean_predictions(predictions)
Y_pred_train
= get_predictions(model, samples, X_test, mask_test, fill_value, test_key)
predictions = get_mean_predictions(predictions)
Y_pred_test
= get_predictions(model, samples, grid, mask_grid, fill_value, grid_key)
pred_grid
return Y_pred_train, Y_pred_test, pred_grid
def reverse_mask(targets, predictions, mask, fill_value):
= jnp.ravel(targets), jnp.ravel(predictions), jnp.ravel(mask[:,:,0])
targets, predictions, mask = jnp.where(mask == fill_value)[0]
positions_to_omit = jnp.delete(targets, positions_to_omit), jnp.delete(predictions, positions_to_omit)
filtered_targets, filtered_predictions return filtered_targets,filtered_predictions
We create a 3D-grid (n_groups, 100*100, 2) to get the model’s predictions (with the respective probability) and run the model:
= jnp.mgrid[-3:3:100j, -3:3:100j].reshape((2, -1)).T
grid = jnp.repeat(grid[None, ...], n_groups, axis=0)
grid_3d = jnp.ones(grid_3d.shape)
mask_grid
= jr.split(key)
key, inference_key
= fit_and_eval(
(Ys_hierarchical_pred_train, Ys_hierarchical_pred_test, ppc_grid)
inference_key,
params,
hnn,
logdensity_fn_of_hnn,
Xs_train,
Ys_train,
mask_train,
fill_value,
Xs_test,
grid_3d,=num_warmup,
num_warmup=num_samples,
num_samples )
= reverse_mask(
filtered_Ys_train, filtered_Ys_hierarchical_pred_train
Ys_train, Ys_hierarchical_pred_train, mask_train, fill_value
)= reverse_mask(
filtered_Ys_test, filtered_Ys_hierarchical_pred_test
Ys_test, Ys_hierarchical_pred_test, mask_test, fill_value )
print("Train accuracy = {:.2f}%".format(100 * jnp.mean(filtered_Ys_hierarchical_pred_train == filtered_Ys_train)))
Train accuracy = 91.91%
print("Test accuracy = {:.2f}%".format(100 * jnp.mean(filtered_Ys_hierarchical_pred_test == filtered_Ys_test)))
Test accuracy = 91.89%
def plot_decision_surfaces_hierarchical(nrows=2, ncols=2):
= plt.subplots(figsize=(15, 12), nrows=nrows, ncols=ncols, sharex=True, sharey=True)
fig, axes
for i, (X, Y_pred, Y_true, ax) in enumerate(zip(Xs_train, Ys_hierarchical_pred_train, Ys_train, axes.flatten())):
ax.contourf(0].reshape((100, 100)),
grid[:, 1].reshape((100, 100)),
grid[:, =0).reshape(100, 100),
ppc_grid[:, i, :].mean(axis=cmap,
cmap=0,
zorder
)for i in range(2):
== i, 0], X[Y_true == i, 1], color="w", alpha=0.8, s=20.0, zorder=1)
ax.scatter(X[Y_true
ax.scatter(== i, 0],
X[Y_true == i, 1],
X[Y_true =cmap(float(i)),
color=f"Class {i}",
label=0.8,
alpha=10.0,
s=2,
zorder
)
ax.set_xticks([])
ax.set_yticks([])=False)
ax.legend(frameon
# %%
=n_grps_sq, ncols=n_grps_sq)
plot_decision_surfaces_hierarchical(nrows"figures/hbnn_decision_boundaries.png", bbox_inches="tight")
plt.savefig( plt.show()
4 Summary
Great! We successfully wrote a model that can work with varying inpout shapes, by using padding and masking. Written in this way, the Bayesian Hierarchical Neural Network is much more generally applicable compared to an implementation assuming groups of equal sizes.
from print_versions import print_versions
globals()) print_versions(
jax==0.6.1
equinox==0.12.2
numpy==2.3.0
matplotlib==3.10.3
blackjax==1.2.5
jaxlib==0.6.1