diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index f45df634..b5142c04 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -471,7 +471,10 @@ We'll apply a `lax.fori_loop`, which is a version of a for loop that can be comp ```{code-cell} ipython3 cpu = jax.devices("cpu")[0] -@partial(jax.jit, static_argnames=("n",), device=cpu) +# Pin the input to the CPU, which keeps the whole computation there +x0_cpu = jax.device_put(0.1, cpu) + +@partial(jax.jit, static_argnames=("n",)) def qm_jax_fori(x0, n, α=4.0): x = jnp.empty(n + 1).at[0].set(x0) @@ -485,7 +488,7 @@ def qm_jax_fori(x0, n, α=4.0): ``` * We hold `n` static because it affects array size and hence JAX wants to specialize on its value in the compiled code. -* We pin to the CPU via `device=cpu` because this sequential workload consists of many small operations, leaving little opportunity for GPU parallelism. +* We pin the input to the CPU with `jax.device_put` (which keeps the whole computation on the CPU) because this sequential workload consists of many small operations, leaving little opportunity for GPU parallelism. Important: Although `at[t].set` appears to create a new array at each step, inside a JIT-compiled function the compiler detects that the old array is no longer needed and performs the update in place! @@ -494,7 +497,7 @@ Let's time it with the same parameters: ```{code-cell} ipython3 with qe.Timer(): # First run - x_jax = qm_jax_fori(0.1, n) + x_jax = qm_jax_fori(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -504,7 +507,7 @@ Let's run it again to eliminate compilation overhead: ```{code-cell} ipython3 with qe.Timer(): # Second run - x_jax = qm_jax_fori(0.1, n) + x_jax = qm_jax_fori(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -521,7 +524,7 @@ although the syntax is difficult to remember. ```{code-cell} ipython3 -@partial(jax.jit, static_argnames=("n",), device=cpu) +@partial(jax.jit, static_argnames=("n",)) def qm_jax_scan(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) @@ -538,7 +541,7 @@ Let's time it with the same parameters: ```{code-cell} ipython3 with qe.Timer(): # First run - x_jax = qm_jax_scan(0.1, n) + x_jax = qm_jax_scan(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -548,7 +551,7 @@ Let's run it again to eliminate compilation overhead: ```{code-cell} ipython3 with qe.Timer(): # Second run - x_jax = qm_jax_scan(0.1, n) + x_jax = qm_jax_scan(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ```