From 4bd7e549163ee2986c5f48c45358bd216f7e4d0c Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Fri, 19 Jun 2026 04:18:31 +0100 Subject: [PATCH 1/2] Update translation: lectures/numpy_vs_numba_vs_jax.md --- lectures/numpy_vs_numba_vs_jax.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 560b836..435dc53 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -467,7 +467,10 @@ Numba 非常高效地处理了这个顺序运算。 ```{code-cell} ipython3 cpu = jax.devices("cpu")[0] -@partial(jax.jit, static_argnames=("n",), device=cpu) +# Pin the input to the CPU, which keeps the whole computation there +x0_cpu = jax.device_put(0.1, cpu) + +@partial(jax.jit, static_argnames=("n",)) def qm_jax_fori(x0, n, α=4.0): x = jnp.empty(n + 1).at[0].set(x0) @@ -481,7 +484,7 @@ def qm_jax_fori(x0, n, α=4.0): ``` * 我们将 `n` 设为静态,因为它影响数组大小,JAX 希望在编译代码中针对其值进行特化处理。 -* 我们通过 `device=cpu` 将计算固定到 CPU,因为这种顺序工作负载由许多小型运算组成,几乎没有机会利用 GPU 并行性。 +* 我们通过 `jax.device_put` 将输入固定到 CPU(从而使整个计算保持在 CPU 上),因为这种顺序工作负载由许多小型运算组成,几乎没有机会利用 GPU 并行性。 重要提示:虽然 `at[t].set` 看起来在每一步都创建了一个新数组,但在 JIT 编译的函数内部,编译器会检测到旧数组不再需要,并就地执行更新! @@ -490,7 +493,7 @@ def qm_jax_fori(x0, n, α=4.0): ```{code-cell} ipython3 with qe.Timer(): # First run - x_jax = qm_jax_fori(0.1, n) + x_jax = qm_jax_fori(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -500,7 +503,7 @@ with qe.Timer(): ```{code-cell} ipython3 with qe.Timer(): # Second run - x_jax = qm_jax_fori(0.1, n) + x_jax = qm_jax_fori(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -514,7 +517,7 @@ JAX 对于这种顺序运算也相当高效! 这种替代方案可以说更符合 JAX 的函数式风格——尽管语法难以记忆。 ```{code-cell} ipython3 -@partial(jax.jit, static_argnames=("n",), device=cpu) +@partial(jax.jit, static_argnames=("n",)) def qm_jax_scan(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) @@ -531,7 +534,7 @@ def qm_jax_scan(x0, n, α=4.0): ```{code-cell} ipython3 with qe.Timer(): # First run - x_jax = qm_jax_scan(0.1, n) + x_jax = qm_jax_scan(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -541,7 +544,7 @@ with qe.Timer(): ```{code-cell} ipython3 with qe.Timer(): # Second run - x_jax = qm_jax_scan(0.1, n) + x_jax = qm_jax_scan(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ``` From 8b6a73fc01cccca790a12f86bed0eb223bb22a63 Mon Sep 17 00:00:00 2001 From: Matt McKay Date: Fri, 19 Jun 2026 04:18:31 +0100 Subject: [PATCH 2/2] Update translation: .translate/state/numpy_vs_numba_vs_jax.md.yml --- .translate/state/numpy_vs_numba_vs_jax.md.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.translate/state/numpy_vs_numba_vs_jax.md.yml b/.translate/state/numpy_vs_numba_vs_jax.md.yml index 66ae445..372dfed 100644 --- a/.translate/state/numpy_vs_numba_vs_jax.md.yml +++ b/.translate/state/numpy_vs_numba_vs_jax.md.yml @@ -1,5 +1,5 @@ -source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28 -synced-at: "2026-05-14" +source-sha: d37b1d8adbf6e18b17e125cca761a6eb2ccd9041 +synced-at: "2026-06-19" model: claude-sonnet-4-6 mode: UPDATE section-count: 3