diff --git a/.translate/state/numpy_vs_numba_vs_jax.md.yml b/.translate/state/numpy_vs_numba_vs_jax.md.yml index 66ae445..372dfed 100644 --- a/.translate/state/numpy_vs_numba_vs_jax.md.yml +++ b/.translate/state/numpy_vs_numba_vs_jax.md.yml @@ -1,5 +1,5 @@ -source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28 -synced-at: "2026-05-14" +source-sha: d37b1d8adbf6e18b17e125cca761a6eb2ccd9041 +synced-at: "2026-06-19" model: claude-sonnet-4-6 mode: UPDATE section-count: 3 diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 1c5f849..0d918d2 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -461,7 +461,10 @@ Numba این عملیات ترتیبی را به طور بسیار کارآمد ```{code-cell} ipython3 cpu = jax.devices("cpu")[0] -@partial(jax.jit, static_argnames=("n",), device=cpu) +# Pin the input to the CPU, which keeps the whole computation there +x0_cpu = jax.device_put(0.1, cpu) + +@partial(jax.jit, static_argnames=("n",)) def qm_jax_fori(x0, n, α=4.0): x = jnp.empty(n + 1).at[0].set(x0) @@ -475,7 +478,7 @@ def qm_jax_fori(x0, n, α=4.0): ``` * ما `n` را ایستا نگه می‌داریم زیرا بر اندازه آرایه تأثیر می‌گذارد و از این رو JAX می‌خواهد روی مقدار آن در کد کامپایل شده تخصصی شود. -* ما به CPU از طریق `device=cpu` متصل می‌مانیم زیرا این بار کاری ترتیبی از بسیاری عملیات کوچک تشکیل شده است که فرصت کمی برای موازی‌سازی GPU باقی می‌گذارد. +* ما ورودی را با `jax.device_put` به CPU متصل می‌کنیم (که کل محاسبات را روی CPU نگه می‌دارد) زیرا این بار کاری ترتیبی از بسیاری عملیات کوچک تشکیل شده است که فرصت کمی برای موازی‌سازی GPU باقی می‌گذارد. مهم: اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد می‌کند، در داخل یک تابع کامپایل‌شده با JIT، کامپایلر تشخیص می‌دهد که آرایه قدیمی دیگر مورد نیاز نیست و به‌روزرسانی را در جا انجام می‌دهد! @@ -484,7 +487,7 @@ def qm_jax_fori(x0, n, α=4.0): ```{code-cell} ipython3 with qe.Timer(): # First run - x_jax = qm_jax_fori(0.1, n) + x_jax = qm_jax_fori(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -494,7 +497,7 @@ with qe.Timer(): ```{code-cell} ipython3 with qe.Timer(): # Second run - x_jax = qm_jax_fori(0.1, n) + x_jax = qm_jax_fori(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -508,7 +511,7 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد این روش جایگزین، به طور قابل بحث، بیشتر با رویکرد تابعی JAX همسو است --- اگرچه سینتکس آن به خاطر سپردن دشواری دارد. ```{code-cell} ipython3 -@partial(jax.jit, static_argnames=("n",), device=cpu) +@partial(jax.jit, static_argnames=("n",)) def qm_jax_scan(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) @@ -525,7 +528,7 @@ def qm_jax_scan(x0, n, α=4.0): ```{code-cell} ipython3 with qe.Timer(): # First run - x_jax = qm_jax_scan(0.1, n) + x_jax = qm_jax_scan(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ``` @@ -535,7 +538,7 @@ with qe.Timer(): ```{code-cell} ipython3 with qe.Timer(): # Second run - x_jax = qm_jax_scan(0.1, n) + x_jax = qm_jax_scan(x0_cpu, n) # Hold interpreter x_jax.block_until_ready() ```