From 3b738d9422060f1eef5502f45e7a42412b15fc6e Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Thu, 18 Jun 2026 16:55:12 +1000 Subject: [PATCH 1/7] Rewrite the Non-Conjugate Priors lecture MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Complete rewrite teaching the same material more cleanly: - Remove the two NamedTuple classes (BayesianInference, BayesianInferencePlot) and the string-dispatch indirection in favor of small functions: one reusable binomial_model(prior, k, n), a run_nuts helper, and a plot_prior_posterior helper. - Reorganize into: (1) MCMC reproduces the conjugate beta posterior (validated against the analytical result from prob_meaning, with arviz R-hat/trace diagnostics); (2) non-conjugate priors worked through one at a time (uniform incl. a boundary-exclusion example, truncated log-normal, truncated Laplace); (3) variational inference via an AutoNormal autoguide, compared against the NUTS posterior. - Continuity with prob_meaning: same coin-flip DGP with true θ=0.4, and a deliberately small sample so the prior visibly matters. - Switch plotting to matplotlib + arviz (drop seaborn); keep the full ELBO derivation. Verified end-to-end via jupytext export + headless execution. Co-Authored-By: Claude Opus 4.8 (1M context) --- lectures/bayes_nonconj.md | 1313 ++++++++----------------------------- 1 file changed, 260 insertions(+), 1053 deletions(-) diff --git a/lectures/bayes_nonconj.md b/lectures/bayes_nonconj.md index ab78d35b2..98027d8c6 100644 --- a/lectures/bayes_nonconj.md +++ b/lectures/bayes_nonconj.md @@ -16,1232 +16,439 @@ kernelspec: ```{include} _admonition/gpu.md ``` +In addition to what's in Anaconda, this lecture will need the following libraries: + ```{code-cell} ipython3 :tags: [hide-output] -!pip install numpyro jax +!pip install numpyro jax arviz ``` -This lecture is a sequel to the {doc}`prob_meaning`. +## Overview + +This lecture is a sequel to {doc}`prob_meaning`. -That lecture offers a Bayesian interpretation of probability in a setting in which the likelihood function and the prior distribution -over parameters just happened to form a **conjugate** pair in which +In that lecture we adopted a **beta** prior for the unknown probability $\theta$ of a coin landing heads, together with a **binomial** likelihood. -- application of Bayes' Law produces a posterior distribution that has the same functional form as the prior +That prior and likelihood form a **conjugate pair**: applying Bayes' law returns a posterior of the *same* family as the prior — again a beta distribution. -Having a likelihood and prior that are conjugate can simplify calculation of a posterior, facilitating analytical or nearly analytical calculations. +Conjugacy is convenient because it delivers a posterior in closed form. -But in many situations the likelihood and prior need not form a conjugate pair. +But a person's prior beliefs are their own business, and in general they will not happen to be conjugate to the likelihood. -- after all, a person's prior is his or her own business and would take a form conjugate to a likelihood only by remote coincidence +When the prior and likelihood are **not** conjugate, the posterior usually has no closed form, and we must approximate it numerically. -In these situations, computing a posterior can become very challenging. +This lecture introduces two widely used ways to do that, both implemented in the probabilistic programming library [NumPyro](https://num.pyro.ai/en/stable/getting_started.html): -In this lecture, we illustrate how modern Bayesians confront non-conjugate priors by using Monte Carlo techniques that involve +* **Markov chain Monte Carlo (MCMC)** — construct a Markov chain whose stationary distribution is the posterior, then sample from it. We use the **No-U-Turn Sampler (NUTS)**, a state-of-the-art form of Hamiltonian Monte Carlo. -- first cleverly forming a Markov chain whose invariant distribution is the posterior distribution we want -- simulating the Markov chain until it has converged and then sampling from the invariant distribution to approximate the posterior +* **Variational inference (VI)** — replace sampling with optimization: search within a tractable family of distributions for the member closest to the posterior. -We shall illustrate the approach by deploying a powerful Python library, [NumPyro](https://num.pyro.ai/en/stable/getting_started.html) that implements this approach. +Our plan is: -As usual, we begin by importing some Python code. +1. Confirm that MCMC reproduces the *conjugate* beta posterior that we can compute analytically — this validates the machinery on a problem whose answer we already know. +2. Replace the beta prior with several **non-conjugate** priors and approximate each posterior with MCMC. +3. Introduce variational inference and compare it with MCMC. + +Let us start with some imports. ```{code-cell} ipython3 import numpy as np -import seaborn as sns import matplotlib.pyplot as plt import scipy.stats as st -from typing import NamedTuple, Sequence import jax.numpy as jnp from jax import random import numpyro -from numpyro import distributions as dist -import numpyro.distributions.constraints as constraints +import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO +from numpyro.infer.autoguide import AutoNormal from numpyro.optim import Adam -``` - -## Unleashing MCMC on a binomial likelihood - -This lecture begins with the binomial example in the {doc}`prob_meaning`. - -That lecture computed a posterior - -- analytically via choosing the conjugate priors, - -This lecture instead computes posteriors - -- numerically by sampling from the posterior distribution through MCMC methods, and -- using a variational inference (VI) approximation. - -We use `numpyro` with assistance from `jax` to approximate a posterior distribution. - -We use several alternative prior distributions. - -We compare computed posteriors with ones associated with a conjugate prior as described in {doc}`prob_meaning`. - -### Analytical posterior - -Assume that the random variable $X\sim Binom\left(n,\theta\right)$. - -This defines a likelihood function -$$ -L\left(Y\vert\theta\right) = \textrm{Prob}(X = k | \theta) = -\left(\frac{n!}{k! (n-k)!} \right) \theta^k (1-\theta)^{n-k} -$$ - -where $Y=k$ is an observed data point. +import arviz as az +``` -We view $\theta$ as a random variable for which we assign a prior distribution having density $f(\theta)$. +To draw posterior samples from several Markov chains in parallel, we tell NumPyro how many CPU devices to use. -We will try alternative priors later, but for now, suppose the prior is distributed as $\theta\sim Beta\left(\alpha,\beta\right)$, i.e., +```{code-cell} ipython3 +numpyro.set_host_device_count(4) +``` -$$ -f(\theta) = \textrm{Prob}(\theta) = \frac{\theta^{\alpha - 1} (1 - \theta)^{\beta - 1}}{B(\alpha, \beta)} -$$ +## The coin-flipping model -We choose this as our prior for now because we know that a conjugate prior for the binomial likelihood function is a beta distribution. +As in {doc}`prob_meaning`, a coin lands heads ($Y=1$) with probability $\theta$ and tails ($Y=0$) with probability $1-\theta$. -After observing $k$ successes among $N$ sample observations, the posterior probability distribution of $ \theta $ is +If we flip the coin $n$ times, the number of heads $k$ has the **binomial** distribution $$ -\textrm{Prob}(\theta|k) = \frac{\textrm{Prob}(\theta,k)}{\textrm{Prob}(k)}=\frac{\textrm{Prob}(k|\theta)\textrm{Prob}(\theta)}{\textrm{Prob}(k)}=\frac{\textrm{Prob}(k|\theta) \textrm{Prob}(\theta)}{\int_0^1 \textrm{Prob}(k|\theta)\textrm{Prob}(\theta) d\theta} +p(k \mid \theta) = \binom{n}{k}\, \theta^k (1-\theta)^{n-k} . $$ -$$ -=\frac{{N \choose k} (1 - \theta)^{N-k} \theta^k \frac{\theta^{\alpha - 1} (1 - \theta)^{\beta - 1}}{B(\alpha, \beta)}}{\int_0^1 {N \choose k} (1 - \theta)^{N-k} \theta^k\frac{\theta^{\alpha - 1} (1 - \theta)^{\beta - 1}}{B(\alpha, \beta)} d\theta} -$$ +We treat $\theta$ as a random variable with a prior density $p(\theta)$, and we want the posterior $$ -=\frac{(1 -\theta)^{\beta+N-k-1} \theta^{\alpha+k-1}}{\int_0^1 (1 - \theta)^{\beta+N-k-1} \theta^{\alpha+k-1} d\theta} . +p(\theta \mid k) \propto p(k \mid \theta)\, p(\theta) . $$ -Thus, +### Generating data -$$ -\textrm{Prob}(\theta|k) \sim {Beta}(\alpha + k, \beta+N-k) -$$ - -The analytical posterior for a given conjugate beta prior is coded in the following +We simulate a sequence of coin flips from a coin whose true (but unknown to the analyst) probability of heads is $\theta = 0.4$. ```{code-cell} ipython3 -def simulate_draw(θ, n): - """Draws a Bernoulli sample of size n with probability P(Y=1) = θ""" - rand_draw = np.random.rand(n) - draw = (rand_draw < θ).astype(int) - return draw - - -def analytical_beta_posterior(data, α0, β0): - """ - Computes analytically the posterior distribution - with beta prior parametrized by (α, β) - given # num observations - - Parameters - --------- - num : int. - the number of observations after which we calculate the posterior - α0, β0 : float. - the parameters for the beta distribution as a prior - - Returns - --------- - The posterior beta distribution - """ - num = len(data) - up_num = data.sum() - down_num = num - up_num - return st.beta(α0 + up_num, β0 + down_num) -``` - -### Two ways to approximate posteriors +def simulate_coin_flips(θ=0.4, n=20, seed=1234): + "Flip a coin n times; return an array of 0s (tails) and 1s (heads)." + rng = np.random.default_rng(seed) + return (rng.random(n) < θ).astype(int) -Suppose that we don't have a conjugate prior. - -Then we can't compute posteriors analytically. - -Instead, we use computational tools to approximate the posterior distribution for a set of alternative prior distributions using `numpyro`. - -We first use the **Markov Chain Monte Carlo** (MCMC) algorithm. - -We implement the NUTS sampler to sample from the posterior. - -In that way we construct a sampling distribution that approximates the posterior. - -After doing that we deploy another procedure called **Variational Inference** (VI). - -In particular, we implement Stochastic Variational Inference (SVI) machinery in `numpyro`. - -The MCMC algorithm supposedly generates a more accurate approximation since in principle it directly samples from the posterior distribution. - -But it can be computationally expensive, especially when dimension is large. - -A VI approach can be cheaper, but it is likely to produce an inferior approximation to the posterior, for the simple reason that it requires guessing a parametric **guide functional form** that we use to approximate a posterior. - -This guide function is likely at best to be an imperfect approximation. - -By paying the cost of restricting the putative posterior to have a restricted functional form, -the problem of approximating a posterior is transformed to a well-posed optimization problem that seeks parameters of the putative posterior that minimize -a Kullback-Leibler (KL) divergence between true posterior and the putative posterior distribution. - -- minimizing the KL divergence is equivalent to maximizing a criterion called the **Evidence Lower Bound** (ELBO), as we shall verify soon. - -## Prior distributions - -In order to be able to apply MCMC sampling or VI, `numpyro` requires that a prior distribution satisfy special properties: - -- we must be able to sample from it; -- we must be able to compute the log pdf pointwise; -- the pdf must be differentiable with respect to the parameters. +data = simulate_coin_flips() +k, n = int(data.sum()), len(data) +k, n +``` -We'll want to define a distribution `class`. +We deliberately use a **small** sample ($n = 20$). -We will use the following priors: +The reason is that the prior matters most when data are scarce. -- a uniform distribution on $[\underline \theta, \overline \theta]$, where $0 \leq \underline \theta < \overline \theta \leq 1$. +With a large sample the likelihood dominates and almost any reasonable prior leads to the same posterior — exactly the concentration we saw in {doc}`prob_meaning`. -- a truncated log-normal distribution with support on $[0,1]$ with parameters $(\mu,\sigma)$. +A modest $n$ keeps the influence of the prior visible, which is what we want to study here. - - To implement this, let $Z\sim N(\mu,\sigma)$ and $\tilde{Z}$ be truncated normal with support $[-\infty,\log(1)]$, then $\exp(Z)$ has a log normal distribution with bounded support $[0,1]$. This can be easily coded since `numpyro` has a built-in truncated normal distribution, and `numpyro`'s `TransformedDistribution` class that includes an exponential transformation. +### One model, many priors -- a shifted von Mises distribution that has support confined to $[0,1]$ with parameter $(\mu,\kappa)$. +In NumPyro a model is an ordinary Python function that uses `numpyro.sample` to declare random variables. - - Let $X\sim vonMises(0,\kappa)$. We know that $X$ has bounded support $[-\pi, \pi]$. We can define a shifted von Mises random variable $\tilde{X}=a+bX$ where $a=0.5, b=1/(2 \pi)$ so that $\tilde{X}$ is supported on $[0,1]$. +We write a *single* model that takes the prior distribution as an argument. - - This can be implemented using `numpyro`'s `TransformedDistribution` class with its `AffineTransform` method. +This lets us reuse it unchanged for every prior we consider — conjugate or not. -- a truncated Laplace distribution. +```{code-cell} ipython3 +def binomial_model(prior, k, n): + "Binomial likelihood with a caller-supplied prior on θ." + θ = numpyro.sample("θ", prior) + numpyro.sample("k", dist.Binomial(n, θ), obs=k) +``` - - We also considered a truncated Laplace distribution because its density comes in a piece-wise non-smooth form and has a distinctive spiked shape. +The first `sample` statement draws $\theta$ from the prior; the second ties the observed count `k` to the binomial likelihood through `obs=k`. - - The truncated Laplace can be created using `numpyro`'s `TruncatedDistribution` class. +We also write a small helper that runs NUTS and returns the fitted sampler. ```{code-cell} ipython3 -def truncated_log_normal_trans(loc, scale): - """ - Obtains the truncated log normal distribution - using numpyro's TruncatedNormal and ExpTransform - """ - base_dist = dist.TruncatedNormal( - low=-jnp.inf, high=jnp.log(1), loc=loc, scale=scale - ) - return dist.TransformedDistribution( - base_dist, dist.transforms.ExpTransform() - ) - - -def shifted_von_mises(κ): - """Obtains the shifted von Mises distribution using AffineTransform""" - base_dist = dist.VonMises(0, κ) - return dist.TransformedDistribution( - base_dist, - dist.transforms.AffineTransform(loc=0.5, scale=1 / (2 * jnp.pi)) +def run_nuts(prior, k, n, seed=0, num_warmup=1000, num_samples=4000, num_chains=4): + "Sample the posterior of θ with the NUTS sampler." + mcmc = MCMC( + NUTS(binomial_model), + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=num_chains, + progress_bar=False, ) - - -def truncated_laplace(loc, scale): - """Obtains the truncated Laplace distribution on [0,1]""" - base_dist = dist.Laplace(loc, scale) - return dist.TruncatedDistribution(base_dist, low=0.0, high=1.0) + mcmc.run(random.PRNGKey(seed), prior, k, n) + return mcmc ``` -### Variational inference - -Instead of directly sampling from the posterior, the **variational inference** method approximates an unknown posterior distribution with a family of tractable distributions/densities. +## MCMC reproduces the conjugate posterior -It then seeks to minimize a measure of statistical discrepancy between the approximating and true posteriors. +Before trusting MCMC on hard problems, let us check it on an easy one. -Thus variational inference (VI) approximates a posterior by solving a minimization problem. - -Let the latent parameter/variable that we want to infer be $\theta$. - -Let the prior be $p(\theta)$ and the likelihood be $p\left(Y\vert\theta\right)$. - -We want $p\left(\theta\vert Y\right)$. - -Bayes' rule implies - -$$ -p\left(\theta\vert Y\right)=\frac{p\left(Y,\theta\right)}{p\left(Y\right)}=\frac{p\left(Y\vert\theta\right)p\left(\theta\right)}{p\left(Y\right)} -$$ - -where - -$$ -p\left(Y\right)=\int p\left(Y\mid\theta\right)p\left(\theta\right) d\theta. -$$ (eq:intchallenge) - -The integral on the right side of {eq}`eq:intchallenge` is typically difficult to compute. - -Consider a **guide distribution** $q_{\phi}(\theta)$ parameterized by $\phi$ that we'll use to approximate the posterior. - -We choose parameters $\phi$ of the guide distribution to minimize a Kullback-Leibler (KL) divergence between the approximate posterior $q_{\phi}(\theta)$ and the posterior: +With a $\text{Beta}(\alpha_0, \beta_0)$ prior the posterior is known analytically (see {doc}`prob_meaning`): $$ - D_{KL}(q(\theta;\phi)\;\|\;p(\theta\mid Y)) \equiv -\int q(\theta;\phi)\log\frac{p(\theta\mid Y)}{q(\theta;\phi)} d\theta +\theta \mid k \sim \text{Beta}(\alpha_0 + k,\ \beta_0 + n - k) . $$ -Thus, we want a **variational distribution** $q$ that solves +We take $\alpha_0 = \beta_0 = 2$ and sample the posterior with NUTS. -$$ -\min_{\phi}\quad D_{KL}(q(\theta;\phi)\;\|\;p(\theta\mid Y)) -$$ - -Note that - -$$ -\begin{aligned}D_{KL}(q(\theta;\phi)\;\|\;p(\theta\mid Y)) & =-\int q(\theta;\phi)\log\frac{P(\theta\mid Y)}{q(\theta;\phi)} d\theta\\ - & =-\int q(\theta)\log\frac{\frac{p(\theta,Y)}{p(Y)}}{q(\theta)} d\theta\\ - & =-\int q(\theta)\log\frac{p(\theta,Y)}{q(\theta)p(Y)} d\theta\\ - & =-\int q(\theta)\left[\log\frac{p(\theta,Y)}{q(\theta)}-\log p(Y)\right] d\theta\\ - & =-\int q(\theta)\log\frac{p(\theta,Y)}{q(\theta)}+\int q(\theta)\log p(Y) d\theta\\ - & =-\int q(\theta)\log\frac{p(\theta,Y)}{q(\theta)} d\theta+\log p(Y)\\ -\log p(Y)&=D_{KL}(q(\theta;\phi)\;\|\;p(\theta\mid Y))+\int q_{\phi}(\theta)\log\frac{p(\theta,Y)}{q_{\phi}(\theta)} d\theta -\end{aligned} -$$ - -For observed data $Y$, $p(\theta,Y)$ is a constant, so minimizing KL divergence is equivalent to maximizing - -$$ -ELBO\equiv\int q_{\phi}(\theta)\log\frac{p(\theta,Y)}{q_{\phi}(\theta)} d\theta=\mathbb{E}_{q_{\phi}(\theta)}\left[\log p(\theta,Y)-\log q_{\phi}(\theta)\right] -$$ (eq:ELBO) - -Formula {eq}`eq:ELBO` is called the evidence lower bound (ELBO). +```{code-cell} ipython3 +α0, β0 = 2.0, 2.0 +mcmc = run_nuts(dist.Beta(α0, β0), k, n) +``` -A standard optimization routine can be used to search for the optimal $\phi$ in our parametrized distribution $q_{\phi}(\theta)$. +Before looking at the posterior we check that the sampler converged. -The parameterized distribution $q_{\phi}(\theta)$ is called the **variational distribution**. +ArviZ reads NumPyro's output directly and reports standard diagnostics. -We can implement Stochastic Variational Inference (SVI) in numpyro using the `Adam` gradient descent algorithm to approximate the posterior. +```{code-cell} ipython3 +idata = az.from_numpyro(mcmc) +az.summary(idata, var_names=["θ"]) +``` -We use two sets of variational distributions: Beta and TruncatedNormal with support $[0,1]$ +The potential scale reduction factor `r_hat` is essentially $1.0$ and the effective sample sizes are large, both signs that the chains have mixed well. -- Learnable parameters for the Beta distribution are ($\alpha$, $\beta$), both of which are positive. -- Learnable parameters for the Truncated Normal distribution are (loc, scale). +The trace plot tells the same story: the four chains overlap and look like stationary noise. -```{note} -We restrict the truncated Normal parameter 'loc' to be in the interval $[0,1]$ +```{code-cell} ipython3 +az.plot_trace(idata, var_names=["θ"]) +plt.tight_layout() +plt.show() ``` -## Implementation - -We have constructed a Python class `BayesianInference` that requires the following arguments to be initialized: - -- `param`: a tuple/scalar of parameters dependent on distribution types -- `name_dist`: a string that specifies distribution names +Now we compare the MCMC posterior with the analytical beta posterior. -The (`param`, `name_dist`) pair includes: -- ($\alpha$, $\beta$, 'beta') +```{code-cell} ipython3 +θ_grid = np.linspace(0.001, 0.999, 500) +samples = np.asarray(mcmc.get_samples()["θ"]) -- (lower_bound, upper_bound, 'uniform') +fig, ax = plt.subplots() +ax.hist(samples, bins=50, density=True, alpha=0.4, + label="MCMC posterior") +ax.plot(θ_grid, st.beta(α0 + k, β0 + n - k).pdf(θ_grid), + 'k-', lw=2, label="analytical posterior") +ax.plot(θ_grid, st.beta(α0, β0).pdf(θ_grid), + 'C1--', lw=2, label="prior") +ax.set_xlabel(r"$\theta$") +ax.legend() +plt.show() +``` -- (loc, scale, 'lognormal') - - Note: This is the truncated log normal. +The histogram of MCMC draws sits right on top of the analytical posterior density. -- ($\kappa$, 'vonMises'), where $\kappa$ denotes concentration parameter, and center location is set to $0.5$. Using `numpyro`, this is the **shifted** distribution. +The sampler works, so we can rely on it for priors that have no closed-form posterior. -- (loc, scale, 'laplace') - - Note: This is the truncated Laplace +## Non-conjugate priors -The class `BayesianInference` has several key methods : -- `sample_prior`: - - This can be used to draw a single sample from the given prior distribution. +We now keep the binomial likelihood and the same data, but replace the beta prior with priors that are **not** conjugate to it. -- `show_prior`: - - Plots the approximate prior distribution by repeatedly drawing samples and fitting a kernel density curve. +For each prior the recipe is identical: -- `mcmc_sampling`: - - INPUT: (data, num_samples, num_warmup=1000) - - Takes a `jnp.array` data and generates MCMC sampling of posterior of size `num_samples`. +1. describe the prior and build it as a NumPyro distribution, +2. pass it to `binomial_model` and run NUTS, +3. plot the prior against the resulting posterior. -- `svi_run`: - - INPUT: (data, guide_dist, n_steps=10000) - - guide_dist = 'normal' - use a **truncated** normal distribution as the parametrized guide - - guide_dist = 'beta' - use a beta distribution as the parametrized guide - - RETURN: (params, losses) - the learned parameters in a `dict` and the vector of loss at each step. +The following helper draws a prior density and the posterior samples on the same axes. ```{code-cell} ipython3 -class BayesianInference(NamedTuple): - """ - Parameters - --------- - param : tuple. - a tuple object that contains all relevant parameters for the distribution - name_dist : str. - name of the distribution - 'beta', 'uniform', 'lognormal', 'vonMises', 'laplace' - rng_key : jax.random.PRNGKey - PRNG key for random number generation. - """ - param: tuple - name_dist: str - rng_key: random.PRNGKey - - -def create_bayesian_inference( - param: tuple, - name_dist: str, - seed: int = 0 -) -> BayesianInference: - """Factory function to create a BayesianInference instance""" - - rng_key = random.PRNGKey(seed) - - return BayesianInference( - param=param, - name_dist=name_dist, - rng_key=rng_key - ) - - -def sample_prior(model: BayesianInference): - """Define the prior distribution to sample from in numpyro models.""" - if model.name_dist == "beta": - # unpack parameters - α0, β0 = model.param - sample = numpyro.sample( - "theta", dist.Beta(α0, β0), rng_key=model.rng_key - ) - - elif model.name_dist == "uniform": - # unpack parameters - lb, ub = model.param - sample = numpyro.sample( - "theta", dist.Uniform(lb, ub), rng_key=model.rng_key - ) - - elif model.name_dist == "lognormal": - # unpack parameters - loc, scale = model.param - sample = numpyro.sample( - "theta", - truncated_log_normal_trans(loc, scale), - rng_key=model.rng_key - ) - - elif model.name_dist == "vonMises": - # unpack parameters - κ = model.param - sample = numpyro.sample( - "theta", shifted_von_mises(κ), rng_key=model.rng_key - ) - - elif model.name_dist == "laplace": - # unpack parameters - loc, scale = model.param - sample = numpyro.sample( - "theta", truncated_laplace(loc, scale), rng_key=model.rng_key - ) - - return sample - - -def show_prior( - model: BayesianInference, size=1e5, bins=20, disp_plot=1 -): - """ - Visualizes prior distribution by sampling from prior - and plots the approximated sampling distribution - """ - with numpyro.plate("show_prior", size=size): - sample = sample_prior(model) - # to JAX array - sample_array = jnp.asarray(sample) - - # plot histogram and kernel density - if disp_plot == 1: - sns.displot( - sample_array, - kde=True, - stat="density", - bins=bins, - height=5, - aspect=1.5 - ) - plt.xlim(0, 1) - plt.show() - else: - return sample_array - - -def set_model(model: BayesianInference, data): - """ - Define the probabilistic model by specifying prior, - conditional likelihood, and data conditioning - """ - theta = sample_prior(model) - output = numpyro.sample( - "obs", dist.Binomial(len(data), theta), obs=jnp.sum(data) - ) - - -def mcmc_sampling( - model: BayesianInference, data, num_samples, num_warmup=1000 -): - """ - Computes numerically the posterior distribution - with beta prior parametrized by (α0, β0) - given data using MCMC - """ - data = jnp.array(data, dtype=float) - nuts_kernel = NUTS(set_model) - mcmc = MCMC( - nuts_kernel, - num_samples=num_samples, - num_warmup=num_warmup, - progress_bar=False, - ) - mcmc.run(model.rng_key, model=model, data=data) - - samples = mcmc.get_samples()["theta"] - return samples - - -# arguments in this function are used to align with the arguments in set_model() -# this is required by svi.run() -def beta_guide(model: BayesianInference, data): - """ - Defines the candidate parametrized variational distribution - that we train to approximate posterior with numpyro - Here we use parameterized beta - """ - α_q = numpyro.param("alpha_q", 10, constraint=constraints.positive) - β_q = numpyro.param("beta_q", 10, constraint=constraints.positive) - - numpyro.sample("theta", dist.Beta(α_q, β_q)) - - -# similar with beta_guide() -def truncnormal_guide(model: BayesianInference, data): - """ - Defines the candidate parametrized variational distribution - that we train to approximate posterior with numpyro - Here we use truncated normal on [0,1] - """ - loc = numpyro.param("loc", 0.5, constraint=constraints.interval(0.0, 1.0)) - scale = numpyro.param("scale", 1, constraint=constraints.positive) - numpyro.sample( - "theta", - dist.TruncatedNormal(loc, scale, low=0.0, high=1.0) - ) - - -def svi_init(model: BayesianInference, guide_dist, lr=0.0005): - """Initiate SVI training mode with Adam optimizer""" - adam_params = {"lr": lr} - - if guide_dist == "beta": - optimizer = Adam(step_size=lr) - svi = SVI( - set_model, beta_guide, optimizer, loss=Trace_ELBO() - ) - elif guide_dist == "normal": - optimizer = Adam(step_size=lr) - svi = SVI( - set_model, truncnormal_guide, optimizer, loss=Trace_ELBO() - ) - else: - print("WARNING: Please input either 'beta' or 'normal'") - svi = None - - return svi - - -def svi_run(model: BayesianInference, data, guide_dist, n_steps=10000): - """ - Runs SVI and returns optimized parameters and losses - - Returns - -------- - params : the learned parameters for guide - losses : a vector of loss at each step - """ - - # initiate SVI - svi = svi_init(model, guide_dist) - - data = jnp.array(data, dtype=float) - result = svi.run( - model.rng_key, n_steps, model, data, progress_bar=False - ) - params = dict( - (key, jnp.asarray(value)) for key, value in result.params.items() - ) - losses = jnp.asarray(result.losses) +def plot_prior_posterior(prior, samples, title=""): + "Overlay a prior density and posterior MCMC draws for θ on [0, 1]." + grid = jnp.linspace(0.001, 0.999, 500) + prior_pdf = np.exp(np.asarray(prior.log_prob(grid))) - return params, losses + fig, ax = plt.subplots() + ax.hist(np.asarray(samples), bins=50, density=True, alpha=0.4, + label="posterior (MCMC)") + ax.plot(np.asarray(grid), prior_pdf, 'C1--', lw=2, label="prior") + ax.set_xlabel(r"$\theta$") + ax.set_xlim(0, 1) + ax.legend() + if title: + ax.set_title(title) + plt.show() ``` -## Alternative prior distributions +### A uniform prior -Let's see how well our sampling algorithm does in approximating +The simplest non-conjugate prior is **uniform**: the analyst regards every value of $\theta$ in some interval as equally likely. -- a log normal distribution -- a uniform distribution +A uniform prior on all of $[0, 1]$ expresses indifference. -To examine our alternative prior distributions, we'll plot approximate prior distributions below by calling the `show_prior` method. +Because its density is constant, the posterior is then proportional to the likelihood alone. ```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - Truncated log normal distribution - name: fig_lognormal_dist ---- -# truncated log normal -example_ln = create_bayesian_inference(param=(0, 2), name_dist="lognormal") -show_prior(example_ln, size=100000, bins=20) +mcmc_flat = run_nuts(dist.Uniform(0.0, 1.0), k, n) +plot_prior_posterior(dist.Uniform(0.0, 1.0), + mcmc_flat.get_samples()["θ"], + title="flat uniform prior") ``` -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - Truncated uniform distribution - name: fig_uniform_dist ---- -# truncated uniform -example_un = create_bayesian_inference(param=(0.1, 0.8), name_dist="uniform") -show_prior(example_un, size=100000, bins=20) -``` - -The above graphs show that sampling seems to work well with both distributions. - -Now let's see how well things work with von Mises distributions. - -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - Shifted von Mises distribution - name: fig_vonmises_dist ---- -# shifted von Mises -example_vm = create_bayesian_inference(param=10, name_dist="vonMises") -show_prior(example_vm, size=100000, bins=20) -``` +The posterior is centered near the sample frequency $k/n$, just as the likelihood is. -The graphs look good too. +Now suppose instead that the analyst is convinced the coin favors heads, and places a uniform prior on $[0.5, 0.95]$. -Now let's try with a Laplace distribution. +This prior assigns *zero* density to the region around the true value $\theta = 0.4$. ```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - Truncated Laplace distribution - name: fig_laplace_dist ---- -# truncated Laplace -example_lp = create_bayesian_inference(param=(0.5, 0.05), name_dist="laplace") -show_prior(example_lp, size=100000, bins=20) +mcmc_restr = run_nuts(dist.Uniform(0.5, 0.95), k, n) +plot_prior_posterior(dist.Uniform(0.5, 0.95), + mcmc_restr.get_samples()["θ"], + title="restrictive uniform prior") ``` -Having assured ourselves that our sampler seems to do a good job, let's put it to work in using MCMC to compute posterior probabilities. +The posterior cannot put mass where the prior is zero, so it piles up against the lower boundary $0.5$ — as close to the data as the prior permits. -## Posteriors via MCMC and VI +This is a vivid warning: a prior that rules out the truth can never be overturned by data, no matter how much we collect. -We construct a class `BayesianInferencePlot` to implement MCMC or VI algorithms and plot multiple posteriors for different updating data sizes and different possible priors. +### A truncated log-normal prior -This class takes as inputs the true data generating parameter `θ`, a list of updating data sizes for multiple posterior plotting, and a defined and parametrized `BayesianInference` class. +A uniform prior is flat. A more realistic prior is smooth and asymmetric. -It has two key methods: +A convenient choice on $[0, 1]$ is a **truncated log-normal**: take $Z \sim N(\mu, \sigma)$ truncated to $Z \le 0$, and set $\theta = e^{Z}$, which then lies in $(0, 1]$. -- `BayesianInferencePlot.mcmc_plot()` takes desired MCMC sample size as input and plots the output posteriors together with the prior defined in `BayesianInference` class. - -- `BayesianInferencePlot.svi_plot()` takes desired VI distribution class ('beta' or 'normal') as input and plots the posteriors together with the prior. +NumPyro builds this by feeding a `TruncatedNormal` through an `ExpTransform`. ```{code-cell} ipython3 -class BayesianInferencePlot(NamedTuple): - """ - Easily implement the MCMC and VI inference for a given instance of - BayesianInference class and plot the prior together with multiple posteriors - - Parameters - ---------- - θ : float. - the true DGP parameter - N_list : list. - a list of sample size - bayesian_model : BayesianInference. - a class initiated using create_bayesian_inference() - binwidth : float. - plotting parameter for histogram bin width - linewidth : float. - plotting parameter for line width - colorlist : list. - list of colors for plotting - N_max : int. - maximum sample size - data : np.ndarray. - generated data array - """ - θ: float - N_list: Sequence[int] - bayesian_model: BayesianInference - binwidth: float - linewidth: float - colorlist: list - N_max: int - data: np.ndarray - - -def create_bayesian_inference_plot( - θ: float, - N_list: Sequence[int], - bayesian_model: BayesianInference, - *, - binwidth: float = 0.02, - linewidth: float = 0.05, -) -> BayesianInferencePlot: - """Factory function to create a BayesianInferencePlot instance""" - - colorlist = sns.color_palette(n_colors=len(N_list)) - N_max = int(max(N_list)) - data = simulate_draw(θ, N_max) - return BayesianInferencePlot( - θ=θ, - N_list=list(map(int, N_list)), - bayesian_model=bayesian_model, - binwidth=binwidth, - linewidth=linewidth, - colorlist=colorlist, - N_max=N_max, - data=data, - ) - - -def mcmc_plot( - plot_model: BayesianInferencePlot, num_samples, num_warmup=1000 -): - fig, ax = plt.subplots() - - # plot prior - prior_sample = show_prior( - plot_model.bayesian_model, disp_plot=0 - ) - sns.histplot( - data=prior_sample, - kde=True, - stat="density", - binwidth=plot_model.binwidth, - color="#4C4E52", - linewidth=plot_model.linewidth, - alpha=0.1, - ax=ax, - label="Prior distribution", - ) - - # plot posteriors - for id, n in enumerate(plot_model.N_list): - samples = mcmc_sampling( - plot_model.bayesian_model, - plot_model.data[:n], - num_samples, - num_warmup - ) - sns.histplot( - samples, - kde=True, - stat="density", - binwidth=plot_model.binwidth, - linewidth=plot_model.linewidth, - alpha=0.2, - color=plot_model.colorlist[id - 1], - label=f"Posterior with $n={n}$", - ) - ax.legend(loc="upper left") - plt.xlim(0, 1) - plt.show() +def truncated_lognormal(μ, σ): + "Log-normal distribution truncated to the unit interval (0, 1]." + base = dist.TruncatedNormal(loc=μ, scale=σ, low=-jnp.inf, high=0.0) + return dist.TransformedDistribution(base, dist.transforms.ExpTransform()) +prior_ln = truncated_lognormal(0.0, 1.0) +mcmc_ln = run_nuts(prior_ln, k, n) +plot_prior_posterior(prior_ln, mcmc_ln.get_samples()["θ"], + title="truncated log-normal prior") +``` -def svi_fitting(guide_dist, params): - """Fit the beta/truncnormal curve using parameters trained by SVI.""" - # create x axis - xaxis = jnp.linspace(0, 1, 1000) - if guide_dist == "beta": - y = st.beta.pdf(xaxis, a=params["alpha_q"], b=params["beta_q"]) +The prior favors smaller values of $\theta$, but with $\sigma = 1$ it is diffuse, so the likelihood pulls the posterior toward the sample frequency. - elif guide_dist == "normal": - # rescale upper/lower bound. See Scipy's truncnorm doc - lower, upper = (0, 1) - loc, scale = params["loc"], params["scale"] - a, b = (lower - loc) / scale, (upper - loc) / scale +We keep `mcmc_ln` — we will compare it with variational inference below. - y = st.truncnorm.pdf( - xaxis, a=a, b=b, loc=loc, scale=scale - ) - return (xaxis, y) +### A truncated Laplace prior +Our final prior has a sharp, non-smooth peak. -def svi_plot( - plot_model: BayesianInferencePlot, guide_dist, n_steps=2000 -): - fig, ax = plt.subplots() +A **Laplace** density $\propto e^{-|\theta - \mu| / b}$ has a kink at its center $\mu$, expressing a strong belief that $\theta$ sits near $\mu$ while still allowing for surprises in the tails. - # plot prior - prior_sample = show_prior(plot_model.bayesian_model, disp_plot=0) - sns.histplot( - data=prior_sample, - kde=True, - stat="density", - binwidth=plot_model.binwidth, - color="#4C4E52", - linewidth=plot_model.linewidth, - alpha=0.1, - ax=ax, - label="Prior distribution", - ) - - # plot posteriors - for id, n in enumerate(plot_model.N_list): - (params, losses) = svi_run( - plot_model.bayesian_model, plot_model.data[:n], guide_dist, n_steps - ) - x, y = svi_fitting(guide_dist, params) - ax.plot( - x, - y, - alpha=1, - color=plot_model.colorlist[id - 1], - label=f"Posterior with $n={n}$", - ) - ax.legend(loc="upper left") - plt.xlim(0, 1) - plt.show() -``` - -Let's set some parameters that we'll use in all of the examples below. - -To save computer time at first, notice that we'll set `mcmc_num_samples = 2000` and `svi_num_steps = 5000`. - -(Later, to increase accuracy of approximations, we'll want to increase these.) +We truncate it to $[0, 1]$ and center it at $0.5$. ```{code-cell} ipython3 -num_list = [5, 10, 50, 100, 1000] -mcmc_num_samples = 2000 -svi_num_steps = 5000 +def truncated_laplace(μ, b): + "Laplace distribution truncated to the unit interval [0, 1]." + return dist.TruncatedDistribution(dist.Laplace(μ, b), low=0.0, high=1.0) -# θ is the data generating process -true_θ = 0.8 +prior_lp = truncated_laplace(0.5, 0.1) +mcmc_lp = run_nuts(prior_lp, k, n) +plot_prior_posterior(prior_lp, mcmc_lp.get_samples()["θ"], + title="truncated Laplace prior") ``` -### Beta prior and posteriors: +The spiked prior tugs the posterior toward $0.5$, away from the sample frequency near $0.4$. -Let's compare outcomes when we use a Beta prior. +The pull is gentle here because the prior, though peaked, is not very tight; with a smaller $b$ it would dominate the modest sample. -For the same Beta prior, we shall +NUTS handles the kink in the prior without any special tuning — a practical advantage of gradient-based samplers paired with automatic differentiation. -* compute posteriors analytically -* compute posteriors using MCMC using `numpyro`. -* compute posteriors using VI using `numpyro`. +## Variational inference -Let's start with the analytical method that we described in this {doc}`prob_meaning` +MCMC approximates the posterior by *sampling* from it. -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - Analytical density (Beta prior) - name: fig_analytical ---- -# first examine Beta prior -beta = create_bayesian_inference(param=(5, 5), name_dist="beta") +**Variational inference (VI)** takes a different route: it turns posterior approximation into an *optimization* problem. -beta_plot = create_bayesian_inference_plot(true_θ, num_list, beta) +We restrict attention to a tractable family of densities $q_\phi(\theta)$ — the **guide** — indexed by parameters $\phi$, and we search for the member of that family closest to the posterior. -# plot analytical Beta prior and posteriors -xaxis = jnp.linspace(0, 1, 1000) -y_prior = st.beta.pdf(xaxis, 5, 5) +### The evidence lower bound -fig, ax = plt.subplots() -# plot analytical beta prior -ax.plot(xaxis, y_prior, label="Analytical Beta prior", color="#4C4E52") - -data, colorlist, N_list = beta_plot.data, beta_plot.colorlist, beta_plot.N_list - -# Plot analytical beta posteriors -for id, n in enumerate(N_list): - func = analytical_beta_posterior(data[:n], α0=5, β0=5) - y_posterior = func.pdf(xaxis) - ax.plot( - xaxis, - y_posterior, - color=colorlist[id - 1], - label=f"Analytical Beta posterior with $n={n}$", - ) -ax.legend(loc="upper left") -plt.xlim(0, 1) -plt.show() -``` +Let the prior be $p(\theta)$ and the likelihood be $p(Y \mid \theta)$. -Now let's use MCMC while still using a beta prior. +By Bayes' rule, -We'll do this for both MCMC and VI. +$$ +p(\theta \mid Y) = \frac{p(Y, \theta)}{p(Y)} = \frac{p(Y \mid \theta)\, p(\theta)}{p(Y)}, +$$ -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - MCMC density (Beta prior) - name: fig_mcmc_beta ---- +where -mcmc_plot( - beta_plot, num_samples=mcmc_num_samples -) -``` +$$ +p(Y) = \int p(Y \mid \theta)\, p(\theta)\, d\theta . +$$ (eq:intchallenge) -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - SVI density (Beta prior, Beta guide) - name: fig_svi_beta_beta ---- +The integral in {eq}`eq:intchallenge` is the troublesome one: in the non-conjugate case it has no closed form. -svi_plot( - beta_plot, guide_dist="beta", n_steps=svi_num_steps -) -``` +We measure the discrepancy between the guide $q_\phi(\theta)$ and the posterior with the **Kullback–Leibler (KL) divergence** -Here the MCMC approximation looks good. +$$ +D_{KL}\big(q_\phi(\theta)\ \|\ p(\theta \mid Y)\big) += -\int q_\phi(\theta)\, \log \frac{p(\theta \mid Y)}{q_\phi(\theta)}\, d\theta , +$$ -But the VI approximation doesn't look so good. +and we choose $\phi$ to minimize it. -* even though we use the beta distribution as our guide, the VI approximated posterior distributions do not closely resemble the posteriors that we had just computed analytically. +The KL divergence still involves the intractable posterior, but we can rearrange it. Using $p(\theta \mid Y) = p(\theta, Y) / p(Y)$, -(Here, our initial parameter for Beta guide is (0.5, 0.5).) +$$ +\begin{aligned} +D_{KL}\big(q_\phi \,\|\, p(\theta \mid Y)\big) + & = -\int q_\phi(\theta)\, \log \frac{p(\theta, Y) / p(Y)}{q_\phi(\theta)}\, d\theta \\ + & = -\int q_\phi(\theta) \left[\log \frac{p(\theta, Y)}{q_\phi(\theta)} - \log p(Y)\right] d\theta \\ + & = -\int q_\phi(\theta)\, \log \frac{p(\theta, Y)}{q_\phi(\theta)}\, d\theta + \log p(Y) , +\end{aligned} +$$ -But if we increase the number of steps from 5000 to 100000 in VI as we now shall do, we'll get VI-approximated posteriors -that will be more accurate, as we shall see next. +where the last line uses $\int q_\phi(\theta)\, d\theta = 1$. Rearranging, -(Increasing the step size increases computational time though). +$$ +\log p(Y) = D_{KL}\big(q_\phi \,\|\, p(\theta \mid Y)\big) + + \underbrace{\int q_\phi(\theta)\, \log \frac{p(\theta, Y)}{q_\phi(\theta)}\, d\theta}_{\text{ELBO}} . +$$ -```{code-cell} ipython3 -svi_plot( - beta_plot, guide_dist="beta", n_steps=100000 -) -``` +The marginal likelihood $\log p(Y)$ on the left does not depend on $\phi$. -## Non-conjugate prior distributions +Hence **minimizing** the KL divergence is equivalent to **maximizing** the second term, the **evidence lower bound (ELBO)**: -Having assured ourselves that our MCMC and VI methods can work well when we have a conjugate prior and so can also compute analytically, we -next proceed to situations in which our prior is not a beta distribution, so we don't have a conjugate prior. +$$ +\text{ELBO}(\phi) \equiv \int q_\phi(\theta)\, \log \frac{p(\theta, Y)}{q_\phi(\theta)}\, d\theta += \mathbb{E}_{q_\phi(\theta)}\big[\log p(\theta, Y) - \log q_\phi(\theta)\big] . +$$ (eq:ELBO) -So we will have non-conjugate priors and are cast into situations in which we can't calculate posteriors analytically. +Because $D_{KL} \ge 0$, the ELBO is a lower bound on $\log p(Y)$ — hence its name. -### Markov chain Monte Carlo +Crucially, {eq}`eq:ELBO` involves only the *joint* density $p(\theta, Y) = p(Y \mid \theta)\, p(\theta)$, which we can evaluate, not the intractable normalizing constant $p(Y)$. -First, we implement and display MCMC. +The expectation can be estimated by sampling from $q_\phi$, and $\phi$ improved by gradient ascent — this is **stochastic variational inference (SVI)**. -We first initialize the `BayesianInference` classes and then can directly call `BayesianInferencePlot` to plot both MCMC and SVI approximating posteriors. +### Implementing SVI in NumPyro -```{code-cell} ipython3 -# Initialize BayesianInference classes -# Try uniform -std_uniform = create_bayesian_inference(param=(0, 1), name_dist="uniform") -uniform = create_bayesian_inference(param=(0.2, 0.7), name_dist="uniform") +We need a guide $q_\phi$. -# Try truncated log normal -lognormal = create_bayesian_inference(param=(0, 2), name_dist="lognormal") +The simplest choice is an **autoguide**: NumPyro inspects the model and automatically constructs a guide for us. -# Try Von Mises -vonmises = create_bayesian_inference(param=10, name_dist="vonMises") - -# Try Laplace -laplace = create_bayesian_inference(param=(0.5, 0.07), name_dist="laplace") -``` - -To conduct our experiments more concisely, here we define two experiment functions that will print the model information and plot the result. +`AutoNormal` places an independent normal distribution on each latent variable, transformed to respect its support — here, to keep $\theta$ inside $(0, 1)$. +We apply SVI to the truncated log-normal model from above and maximize the ELBO with the Adam optimizer. ```{code-cell} ipython3 -def plot_mcmc_experiment( - bayesian_model: BayesianInference, - true_θ: float, - num_list: Sequence[int], - num_samples: int, - num_warmup: int = 1000, - description: str = "" -): - """ - Helper function to run and plot MCMC experiments for a given Bayesian model - """ - print( - f"=======INFO=======\n" - f"Parameters: {bayesian_model.param}\n" - f"Prior Dist: {bayesian_model.name_dist}" - ) - if description: - print(description) - - plot_model = create_bayesian_inference_plot( - true_θ, num_list, bayesian_model - ) - mcmc_plot(plot_model, num_samples=num_samples, num_warmup=num_warmup) - - -def plot_svi_experiment( - bayesian_model: BayesianInference, - true_θ: float, - num_list: Sequence[int], - guide_dist: str, - n_steps: int, - description: str = "" -): - """ - Helper function to run and plot SVI experiments for a given Bayesian model - """ - print( - f"=======INFO=======\n" - f"Parameters: {bayesian_model.param}\n" - f"Prior Dist: {bayesian_model.name_dist}" - ) - if description: - print(description) +guide = AutoNormal(binomial_model) +optimizer = Adam(step_size=0.01) +svi = SVI(binomial_model, guide, optimizer, loss=Trace_ELBO()) - plot_model = create_bayesian_inference_plot( - true_θ, num_list, bayesian_model - ) - svi_plot(plot_model, guide_dist=guide_dist, n_steps=n_steps) +svi_result = svi.run(random.PRNGKey(0), 5000, prior_ln, k, n, progress_bar=False) ``` -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - MCMC density (uniform prior) - name: fig_mcmc_stduniform ---- -# Uniform -plot_mcmc_experiment( - std_uniform, - true_θ, - num_list, - mcmc_num_samples -) -``` +SVI maximizes the ELBO; equivalently, it minimizes its negative, which is the reported loss. + +A loss curve that flattens out indicates convergence. ```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - MCMC density (uniform prior) - name: fig_mcmc_uniform ---- -plot_mcmc_experiment( - uniform, - true_θ, - num_list, - mcmc_num_samples -) +fig, ax = plt.subplots() +ax.plot(svi_result.losses) +ax.set_xlabel("step") +ax.set_ylabel("negative ELBO") +ax.set_title("SVI convergence") +plt.show() ``` -In the situation depicted above, we have assumed a $Uniform(\underline{\theta}, \overline{\theta})$ prior that puts zero probability outside a bounded support that excludes the true value. - -Consequently, the posterior cannot put positive probability above $\overline{\theta}$ or below $\underline{\theta}$. +### Comparing VI with MCMC -Note how when the true data-generating $\theta$ is located at $0.8$ as it is here, when $n$ gets large, the posterior concentrates on the upper bound of the support of the prior, $0.7$ here. +To assess the approximation, we draw samples from the fitted guide and compare them with the NUTS posterior for the same (log-normal-prior) model. ```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - MCMC density (log normal prior) - name: fig_mcmc_lognormal ---- -# log normal -plot_mcmc_experiment( - lognormal, - true_θ, - num_list, - mcmc_num_samples -) -``` - -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - MCMC density (von Mises prior) - name: fig_mcmc_vonmises ---- -# von Mises -plot_mcmc_experiment( - vonmises, - true_θ, - num_list, - mcmc_num_samples, - description="\nNOTE: Shifted von Mises" -) -``` +vi_samples = guide.sample_posterior( + random.PRNGKey(1), svi_result.params, sample_shape=(4000,) +)["θ"] +nuts_samples = mcmc_ln.get_samples()["θ"] -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - MCMC density (Laplace prior) - name: fig_mcmc_laplace ---- -# Laplace -plot_mcmc_experiment( - laplace, - true_θ, - num_list, - mcmc_num_samples -) +fig, ax = plt.subplots() +ax.hist(np.asarray(nuts_samples), bins=50, density=True, alpha=0.4, + label="MCMC (NUTS)") +ax.hist(np.asarray(vi_samples), bins=50, density=True, alpha=0.4, + label="VI (AutoNormal)") +ax.set_xlabel(r"$\theta$") +ax.legend() +plt.show() ``` -### Variational inference - -To get more accuracy we will now increase the number of steps for Variational Inference (VI) - -```{code-cell} ipython3 -svi_num_steps = 50000 -``` -#### VI with a truncated normal guide +The two approximations broadly agree on the location and spread of the posterior. -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - SVI density (uniform prior, normal guide) - name: fig_svi_uniform_normal ---- -# Uniform -plot_svi_experiment( - create_bayesian_inference(param=(0, 1), name_dist="uniform"), - true_θ, - num_list, - "normal", - svi_num_steps -) -``` +They need not agree perfectly. -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - SVI density (log normal prior, normal guide) - name: fig_svi_lognormal_normal ---- -# log normal -plot_svi_experiment( - lognormal, - true_θ, - num_list, - "normal", - svi_num_steps -) -``` +MCMC samples the true posterior (up to Monte Carlo error), whereas VI reports the best fit *within its guide family*. -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - SVI density (Laplace prior, normal guide) - name: fig_svi_laplace_normal ---- -# Laplace -plot_svi_experiment( - laplace, - true_θ, - num_list, - "normal", - svi_num_steps -) -``` +A mean-field normal guide is symmetric on the transformed scale and can miss skewness or heavy tails in the true posterior. -#### Variational inference with a Beta guide distribution +The trade-off is one of cost against fidelity: VI replaces sampling with optimization and is often much faster in high dimensions, but it delivers an approximation whose quality is capped by the flexibility of the guide. -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - SVI density (uniform prior, Beta guide) - name: fig_svi_uniform_beta ---- -# uniform -plot_svi_experiment( - std_uniform, - true_θ, - num_list, - "beta", - svi_num_steps -) -``` +## Where to next -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - SVI density (log normal prior, Beta guide) - name: fig_svi_lognormal_beta ---- -# log normal -plot_svi_experiment( - lognormal, - true_θ, - num_list, - "beta", - svi_num_steps -) -``` +This lecture showed how to compute posteriors when prior and likelihood are not conjugate, using NUTS and stochastic variational inference in NumPyro. -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - SVI density (von Mises prior, Beta guide) - name: fig_svi_vonmises_beta ---- -# von Mises -plot_svi_experiment( - vonmises, - true_θ, - num_list, - "beta", - svi_num_steps, - description="Shifted von Mises" -) -``` +The same tools carry over to richer models. -```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - SVI density (Laplace prior, Beta guide) - name: fig_svi_laplace_beta ---- -# Laplace -plot_svi_experiment( - laplace, - true_θ, - num_list, - "beta", - svi_num_steps -) -``` +The lectures {doc}`ar1_bayes` and {doc}`ar1_turningpts` apply NumPyro to Bayesian estimation and forecasting of autoregressive time series, where the parameter is a vector and conjugate analysis is unavailable. From 23ec75589fb791256506895d2ed55f611c64c55f Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Thu, 18 Jun 2026 17:03:01 +1000 Subject: [PATCH 2/7] Run MCMC chains vectorized instead of setting host device count MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit set_host_device_count(4) only helps CPU parallel chains; on the GPU build it is inert and the default parallel method falls back to running chains sequentially with a warning. Drop it and use chain_method="vectorized", which runs all four chains on a single device — efficient on one GPU and portable to CPU users, while still yielding multiple chains for the R-hat diagnostic. Co-Authored-By: Claude Opus 4.8 (1M context) --- lectures/bayes_nonconj.md | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/lectures/bayes_nonconj.md b/lectures/bayes_nonconj.md index 98027d8c6..1a7ac1ece 100644 --- a/lectures/bayes_nonconj.md +++ b/lectures/bayes_nonconj.md @@ -69,12 +69,6 @@ from numpyro.optim import Adam import arviz as az ``` -To draw posterior samples from several Markov chains in parallel, we tell NumPyro how many CPU devices to use. - -```{code-cell} ipython3 -numpyro.set_host_device_count(4) -``` - ## The coin-flipping model As in {doc}`prob_meaning`, a coin lands heads ($Y=1$) with probability $\theta$ and tails ($Y=0$) with probability $1-\theta$. @@ -133,6 +127,8 @@ The first `sample` statement draws $\theta$ from the prior; the second ties the We also write a small helper that runs NUTS and returns the fitted sampler. +We request four chains so that we can check convergence below, and run them with `chain_method="vectorized"`, which evaluates all chains together on a single device — so the same code runs unchanged on a CPU or a GPU. + ```{code-cell} ipython3 def run_nuts(prior, k, n, seed=0, num_warmup=1000, num_samples=4000, num_chains=4): "Sample the posterior of θ with the NUTS sampler." @@ -141,6 +137,7 @@ def run_nuts(prior, k, n, seed=0, num_warmup=1000, num_samples=4000, num_chains= num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, + chain_method="vectorized", progress_bar=False, ) mcmc.run(random.PRNGKey(seed), prior, k, n) From 8f2b0cdb5f411f09c3ecd2d2fc3e0a55bbe69a1d Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Thu, 18 Jun 2026 17:10:56 +1000 Subject: [PATCH 3/7] Explain what a NumPyro model represents for first-time readers NumPyro's style is idiosyncratic for someone new to it. Add a conceptual on-ramp before the model: a model is a *declaration* of the generative story (not a computation, returns nothing, never called directly) that an inference engine traces; the obs keyword decides whether a sample site is latent or observed (the likelihood); the string site names are the engine's handles. Also add a short note on JAX PRNG keys, explaining why data uses NumPy's generator while NUTS uses random.PRNGKey. Prose only; code cells unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) --- lectures/bayes_nonconj.md | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/lectures/bayes_nonconj.md b/lectures/bayes_nonconj.md index 1a7ac1ece..5d113aa19 100644 --- a/lectures/bayes_nonconj.md +++ b/lectures/bayes_nonconj.md @@ -108,13 +108,27 @@ With a large sample the likelihood dominates and almost any reasonable prior lea A modest $n$ keeps the influence of the prior visible, which is what we want to study here. -### One model, many priors +### Specifying the model in NumPyro -In NumPyro a model is an ordinary Python function that uses `numpyro.sample` to declare random variables. +For most readers this will be a first encounter with NumPyro, whose style takes some getting used to. -We write a *single* model that takes the prior distribution as an argument. +To use it we describe our probability model as a Python function — which, a little confusingly, NumPyro calls a **model**. -This lets us reuse it unchanged for every prior we consider — conjugate or not. +Such a function does not *compute* anything when called, and it does not return the posterior. + +Instead it is a *declaration* of the generative story for the data: which quantities are random, how they are distributed, and how the data depend on them. + +An inference algorithm — such as the NUTS sampler below — then *reads* this declaration and works out the posterior for us. + +Inside a model, every random quantity is introduced by a call to `numpyro.sample`, and the keyword `obs` decides its role: + +* `numpyro.sample("θ", prior)` introduces a **latent** (unobserved) variable named `"θ"`, drawn from `prior` — a quantity we wish to infer. + +* `numpyro.sample("k", dist.Binomial(n, θ), obs=k)` introduces an **observed** variable: the keyword `obs=k` pins it to the data, which is how the likelihood $p(k \mid \theta)$ enters. + +The string names (`"θ"` and `"k"`) are the labels NumPyro uses to keep track of the variables; we will use them later to pull the posterior draws back out. + +We write a *single* model that takes the prior distribution as an argument, so we can reuse it unchanged for every prior we consider — conjugate or not. ```{code-cell} ipython3 def binomial_model(prior, k, n): @@ -123,7 +137,9 @@ def binomial_model(prior, k, n): numpyro.sample("k", dist.Binomial(n, θ), obs=k) ``` -The first `sample` statement draws $\theta$ from the prior; the second ties the observed count `k` to the binomial likelihood through `obs=k`. +Notice that `binomial_model` returns nothing, and that we never call it ourselves. + +Instead we hand it to an inference algorithm, which supplies the arguments and traces the two `sample` statements to assemble the posterior. We also write a small helper that runs NUTS and returns the fitted sampler. @@ -144,6 +160,12 @@ def run_nuts(prior, k, n, seed=0, num_warmup=1000, num_samples=4000, num_chains= return mcmc ``` +NumPyro is built on [JAX](https://docs.jax.dev), which treats randomness explicitly: rather than relying on a global random state, each run needs its own **PRNG key**, created here with `random.PRNGKey(seed)`. + +(This is why we used NumPy's generator to make the data above but JAX keys here.) + +The remaining arguments to `mcmc.run` — `prior`, `k`, `n` — are simply forwarded to our model. + ## MCMC reproduces the conjugate posterior Before trusting MCMC on hard problems, let us check it on an easy one. From 1d67d5c98b8bf0466c8102fd6c3702be129b9a63 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Thu, 18 Jun 2026 17:21:35 +1000 Subject: [PATCH 4/7] Use jax.random.key instead of the legacy PRNGKey random.PRNGKey is the legacy key constructor; JAX now recommends random.key, which returns a typed key. NumPyro accepts it throughout (MCMC, SVI, sample_posterior). Switch all four uses. Not deprecated yet, but this future-proofs the lecture. Verified end-to-end with jupytext export + headless execution. Co-Authored-By: Claude Opus 4.8 (1M context) --- lectures/bayes_nonconj.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lectures/bayes_nonconj.md b/lectures/bayes_nonconj.md index 5d113aa19..dca32e471 100644 --- a/lectures/bayes_nonconj.md +++ b/lectures/bayes_nonconj.md @@ -156,11 +156,11 @@ def run_nuts(prior, k, n, seed=0, num_warmup=1000, num_samples=4000, num_chains= chain_method="vectorized", progress_bar=False, ) - mcmc.run(random.PRNGKey(seed), prior, k, n) + mcmc.run(random.key(seed), prior, k, n) return mcmc ``` -NumPyro is built on [JAX](https://docs.jax.dev), which treats randomness explicitly: rather than relying on a global random state, each run needs its own **PRNG key**, created here with `random.PRNGKey(seed)`. +NumPyro is built on [JAX](https://docs.jax.dev), which treats randomness explicitly: rather than relying on a global random state, each run needs its own **PRNG key**, created here with `random.key(seed)`. (This is why we used NumPy's generator to make the data above but JAX keys here.) @@ -418,7 +418,7 @@ guide = AutoNormal(binomial_model) optimizer = Adam(step_size=0.01) svi = SVI(binomial_model, guide, optimizer, loss=Trace_ELBO()) -svi_result = svi.run(random.PRNGKey(0), 5000, prior_ln, k, n, progress_bar=False) +svi_result = svi.run(random.key(0), 5000, prior_ln, k, n, progress_bar=False) ``` SVI maximizes the ELBO; equivalently, it minimizes its negative, which is the reported loss. @@ -440,7 +440,7 @@ To assess the approximation, we draw samples from the fitted guide and compare t ```{code-cell} ipython3 vi_samples = guide.sample_posterior( - random.PRNGKey(1), svi_result.params, sample_shape=(4000,) + random.key(1), svi_result.params, sample_shape=(4000,) )["θ"] nuts_samples = mcmc_ln.get_samples()["θ"] From 9de50fc374f3e6c9798262c64779689b6b8d8ebd Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Thu, 18 Jun 2026 17:42:27 +1000 Subject: [PATCH 5/7] Make run_nuts model-agnostic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously run_nuts hard-coded binomial_model as a global and took a prior argument that duplicated the model's own prior argument (inviting confusion about whether they could differ — they couldn't, since the prior was forwarded). Pass the model explicitly and forward *args to it, so there is no hidden global, the name is honest, and the prior is supplied exactly once. Call sites become run_nuts(binomial_model, ...). Co-Authored-By: Claude Opus 4.8 (1M context) --- lectures/bayes_nonconj.md | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/lectures/bayes_nonconj.md b/lectures/bayes_nonconj.md index dca32e471..8a2e74964 100644 --- a/lectures/bayes_nonconj.md +++ b/lectures/bayes_nonconj.md @@ -141,22 +141,22 @@ Notice that `binomial_model` returns nothing, and that we never call it ourselve Instead we hand it to an inference algorithm, which supplies the arguments and traces the two `sample` statements to assemble the posterior. -We also write a small helper that runs NUTS and returns the fitted sampler. +We also write a small helper that runs NUTS on a given model and returns the fitted sampler. We request four chains so that we can check convergence below, and run them with `chain_method="vectorized"`, which evaluates all chains together on a single device — so the same code runs unchanged on a CPU or a GPU. ```{code-cell} ipython3 -def run_nuts(prior, k, n, seed=0, num_warmup=1000, num_samples=4000, num_chains=4): - "Sample the posterior of θ with the NUTS sampler." +def run_nuts(model, *args, seed=0, num_warmup=1000, num_samples=4000, num_chains=4): + "Sample a NumPyro model with the NUTS sampler." mcmc = MCMC( - NUTS(binomial_model), + NUTS(model), num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, chain_method="vectorized", progress_bar=False, ) - mcmc.run(random.key(seed), prior, k, n) + mcmc.run(random.key(seed), *args) return mcmc ``` @@ -164,7 +164,7 @@ NumPyro is built on [JAX](https://docs.jax.dev), which treats randomness explici (This is why we used NumPy's generator to make the data above but JAX keys here.) -The remaining arguments to `mcmc.run` — `prior`, `k`, `n` — are simply forwarded to our model. +`run_nuts` is deliberately generic: it samples whatever model we pass and forwards the extra arguments (`*args`) on to that model through `mcmc.run`. We always call it as `run_nuts(binomial_model, prior, k, n)`, so `prior`, `k`, and `n` reach `binomial_model` unchanged — there is only ever the one prior. ## MCMC reproduces the conjugate posterior @@ -180,7 +180,7 @@ We take $\alpha_0 = \beta_0 = 2$ and sample the posterior with NUTS. ```{code-cell} ipython3 α0, β0 = 2.0, 2.0 -mcmc = run_nuts(dist.Beta(α0, β0), k, n) +mcmc = run_nuts(binomial_model, dist.Beta(α0, β0), k, n) ``` Before looking at the posterior we check that the sampler converged. @@ -263,7 +263,7 @@ A uniform prior on all of $[0, 1]$ expresses indifference. Because its density is constant, the posterior is then proportional to the likelihood alone. ```{code-cell} ipython3 -mcmc_flat = run_nuts(dist.Uniform(0.0, 1.0), k, n) +mcmc_flat = run_nuts(binomial_model, dist.Uniform(0.0, 1.0), k, n) plot_prior_posterior(dist.Uniform(0.0, 1.0), mcmc_flat.get_samples()["θ"], title="flat uniform prior") @@ -276,7 +276,7 @@ Now suppose instead that the analyst is convinced the coin favors heads, and pla This prior assigns *zero* density to the region around the true value $\theta = 0.4$. ```{code-cell} ipython3 -mcmc_restr = run_nuts(dist.Uniform(0.5, 0.95), k, n) +mcmc_restr = run_nuts(binomial_model, dist.Uniform(0.5, 0.95), k, n) plot_prior_posterior(dist.Uniform(0.5, 0.95), mcmc_restr.get_samples()["θ"], title="restrictive uniform prior") @@ -301,7 +301,7 @@ def truncated_lognormal(μ, σ): return dist.TransformedDistribution(base, dist.transforms.ExpTransform()) prior_ln = truncated_lognormal(0.0, 1.0) -mcmc_ln = run_nuts(prior_ln, k, n) +mcmc_ln = run_nuts(binomial_model, prior_ln, k, n) plot_prior_posterior(prior_ln, mcmc_ln.get_samples()["θ"], title="truncated log-normal prior") ``` @@ -324,7 +324,7 @@ def truncated_laplace(μ, b): return dist.TruncatedDistribution(dist.Laplace(μ, b), low=0.0, high=1.0) prior_lp = truncated_laplace(0.5, 0.1) -mcmc_lp = run_nuts(prior_lp, k, n) +mcmc_lp = run_nuts(binomial_model, prior_lp, k, n) plot_prior_posterior(prior_lp, mcmc_lp.get_samples()["θ"], title="truncated Laplace prior") ``` From d06957150ce61c6644a2a1c53d0ef2f1f29c6ac2 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Thu, 18 Jun 2026 17:47:03 +1000 Subject: [PATCH 6/7] Tie down Y notation in the ELBO derivation The lecture denotes the data by k throughout, but the ELBO derivation uses generic Y. Add a clause defining Y as the observed data (the count k) so the switch doesn't trip readers. Co-Authored-By: Claude Opus 4.8 (1M context) --- lectures/bayes_nonconj.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lectures/bayes_nonconj.md b/lectures/bayes_nonconj.md index 8a2e74964..6b09cefc1 100644 --- a/lectures/bayes_nonconj.md +++ b/lectures/bayes_nonconj.md @@ -345,7 +345,7 @@ We restrict attention to a tractable family of densities $q_\phi(\theta)$ — th ### The evidence lower bound -Let the prior be $p(\theta)$ and the likelihood be $p(Y \mid \theta)$. +Let the prior be $p(\theta)$ and the likelihood be $p(Y \mid \theta)$, where $Y$ denotes the observed data (here the head count $k$). By Bayes' rule, From d2a0821efdf8a3fcf2fb328d25dbb41d38eb5f33 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Thu, 18 Jun 2026 18:33:39 +1000 Subject: [PATCH 7/7] Fix prior-density plot to respect the prior's support numpyro's dist.Uniform.log_prob returns its constant value everywhere, ignoring [low, high], so the restrictive-uniform example plotted a flat prior across all of [0, 1] instead of a box on [0.5, 0.95]. Mask the plotted density to prior.support(grid). Sampling was already correct (NUTS respects the support); this only affects the prior curve. Co-Authored-By: Claude Opus 4.8 (1M context) --- lectures/bayes_nonconj.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lectures/bayes_nonconj.md b/lectures/bayes_nonconj.md index 6b09cefc1..61438401c 100644 --- a/lectures/bayes_nonconj.md +++ b/lectures/bayes_nonconj.md @@ -240,7 +240,10 @@ The following helper draws a prior density and the posterior samples on the same def plot_prior_posterior(prior, samples, title=""): "Overlay a prior density and posterior MCMC draws for θ on [0, 1]." grid = jnp.linspace(0.001, 0.999, 500) - prior_pdf = np.exp(np.asarray(prior.log_prob(grid))) + # mask the density to the prior's support: dist.Uniform.log_prob + # returns its constant value even outside [low, high] + in_support = np.asarray(prior.support(grid)) + prior_pdf = np.where(in_support, np.exp(np.asarray(prior.log_prob(grid))), 0.0) fig, ax = plt.subplots() ax.hist(np.asarray(samples), bins=50, density=True, alpha=0.4,