Skip to content

Rewrite the Non-Conjugate Priors lecture#913

Merged
jstac merged 7 commits into
mainfrom
rewrite-bayes-nonconj
Jun 18, 2026
Merged

Rewrite the Non-Conjugate Priors lecture#913
jstac merged 7 commits into
mainfrom
rewrite-bayes-nonconj

Conversation

@jstac

@jstac jstac commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

What

A complete, from-scratch rewrite of bayes_nonconj.md that teaches the same material — MCMC and variational inference for non-conjugate priors in NumPyro — more clearly and with much less code (260 lines vs ~1050).

Why

The previous version leaned on two NamedTuple classes (BayesianInference, BayesianInferencePlot) with string dispatch on a name_dist field, plus a sprawling MCMC × VI × guide matrix at the end. It was hard to follow which model was being fit where, used seaborn KDEs, and reported no convergence diagnostics.

New structure

  1. MCMC reproduces the conjugate posterior — one reusable binomial_model(prior, k, n), run with NUTS, validated against the analytical Beta(α+k, β+n−k) from prob_meaning, with arviz r_hat / trace diagnostics.
  2. Non-conjugate priors, one at a time — uniform (including a boundary-exclusion example where the posterior piles against the prior's edge), truncated log-normal, truncated Laplace. Same model each time, just a different prior argument.
  3. Variational inference — full ELBO derivation, then SVI with an AutoNormal autoguide, compared against the NUTS posterior.

Design choices

  • No classes — small, individually-described functions only.
  • Continuity with prob_meaning — same coin-flip DGP, true θ=0.4; a deliberately small sample (n=20) so the prior visibly matters (large n would wash out all prior differences).
  • matplotlib + arviz (drop seaborn); arviz for light diagnostics.

Validation

Verified end-to-end via jupytext --to py + headless execution (exit 0); diagnostics populate (r_hat≈1.0, large ESS).

Note: this PR and #912 both touch bayes_nonconj.md; since this is a full-file rewrite, any merge conflict resolves trivially in favor of this version.

🤖 Generated with Claude Code

jstac and others added 2 commits June 18, 2026 16:55
Complete rewrite teaching the same material more cleanly:

- Remove the two NamedTuple classes (BayesianInference,
  BayesianInferencePlot) and the string-dispatch indirection in favor
  of small functions: one reusable binomial_model(prior, k, n), a
  run_nuts helper, and a plot_prior_posterior helper.
- Reorganize into: (1) MCMC reproduces the conjugate beta posterior
  (validated against the analytical result from prob_meaning, with
  arviz R-hat/trace diagnostics); (2) non-conjugate priors worked
  through one at a time (uniform incl. a boundary-exclusion example,
  truncated log-normal, truncated Laplace); (3) variational inference
  via an AutoNormal autoguide, compared against the NUTS posterior.
- Continuity with prob_meaning: same coin-flip DGP with true θ=0.4,
  and a deliberately small sample so the prior visibly matters.
- Switch plotting to matplotlib + arviz (drop seaborn); keep the full
  ELBO derivation.

Verified end-to-end via jupytext export + headless execution.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
set_host_device_count(4) only helps CPU parallel chains; on the GPU
build it is inert and the default parallel method falls back to running
chains sequentially with a warning. Drop it and use
chain_method="vectorized", which runs all four chains on a single
device — efficient on one GPU and portable to CPU users, while still
yielding multiple chains for the R-hat diagnostic.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
NumPyro's style is idiosyncratic for someone new to it. Add a conceptual
on-ramp before the model: a model is a *declaration* of the generative
story (not a computation, returns nothing, never called directly) that
an inference engine traces; the obs keyword decides whether a sample
site is latent or observed (the likelihood); the string site names are
the engine's handles. Also add a short note on JAX PRNG keys, explaining
why data uses NumPy's generator while NUTS uses random.PRNGKey.

Prose only; code cells unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@github-actions

github-actions Bot commented Jun 18, 2026

Copy link
Copy Markdown

📖 Netlify Preview Ready!

Preview URL: https://pr-913--sunny-cactus-210e3e.netlify.app

Commit: d2a0821

📚 Changed Lectures


Build Info

random.PRNGKey is the legacy key constructor; JAX now recommends
random.key, which returns a typed key. NumPyro accepts it throughout
(MCMC, SVI, sample_posterior). Switch all four uses. Not deprecated yet,
but this future-proofs the lecture.

Verified end-to-end with jupytext export + headless execution.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
jstac and others added 3 commits June 18, 2026 17:42
Previously run_nuts hard-coded binomial_model as a global and took a
prior argument that duplicated the model's own prior argument (inviting
confusion about whether they could differ — they couldn't, since the
prior was forwarded). Pass the model explicitly and forward *args to it,
so there is no hidden global, the name is honest, and the prior is
supplied exactly once. Call sites become run_nuts(binomial_model, ...).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The lecture denotes the data by k throughout, but the ELBO derivation
uses generic Y. Add a clause defining Y as the observed data (the count
k) so the switch doesn't trip readers.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
numpyro's dist.Uniform.log_prob returns its constant value everywhere,
ignoring [low, high], so the restrictive-uniform example plotted a flat
prior across all of [0, 1] instead of a box on [0.5, 0.95]. Mask the
plotted density to prior.support(grid). Sampling was already correct
(NUTS respects the support); this only affects the prior curve.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jstac jstac merged commit 2f12e54 into main Jun 18, 2026
1 of 2 checks passed
@jstac jstac deleted the rewrite-bayes-nonconj branch June 18, 2026 09:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant