Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .translate/state/numpy_vs_numba_vs_jax.md.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28
synced-at: "2026-05-14"
source-sha: d37b1d8adbf6e18b17e125cca761a6eb2ccd9041
synced-at: "2026-06-19"
model: claude-sonnet-4-6
mode: UPDATE
section-count: 3
Expand Down
17 changes: 10 additions & 7 deletions lectures/numpy_vs_numba_vs_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,10 @@ Numba این عملیات ترتیبی را به طور بسیار کارآمد
```{code-cell} ipython3
cpu = jax.devices("cpu")[0]

@partial(jax.jit, static_argnames=("n",), device=cpu)
# Pin the input to the CPU, which keeps the whole computation there
x0_cpu = jax.device_put(0.1, cpu)

@partial(jax.jit, static_argnames=("n",))
def qm_jax_fori(x0, n, α=4.0):

x = jnp.empty(n + 1).at[0].set(x0)
Expand All @@ -475,7 +478,7 @@ def qm_jax_fori(x0, n, α=4.0):
```

* ما `n` را ایستا نگه می‌داریم زیرا بر اندازه آرایه تأثیر می‌گذارد و از این رو JAX می‌خواهد روی مقدار آن در کد کامپایل شده تخصصی شود.
* ما به CPU از طریق `device=cpu` متصل می‌مانیم زیرا این بار کاری ترتیبی از بسیاری عملیات کوچک تشکیل شده است که فرصت کمی برای موازی‌سازی GPU باقی می‌گذارد.
* ما ورودی را با `jax.device_put` به CPU متصل می‌کنیم (که کل محاسبات را روی CPU نگه می‌دارد) زیرا این بار کاری ترتیبی از بسیاری عملیات کوچک تشکیل شده است که فرصت کمی برای موازی‌سازی GPU باقی می‌گذارد.

مهم: اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد می‌کند، در داخل یک تابع کامپایل‌شده با JIT، کامپایلر تشخیص می‌دهد که آرایه قدیمی دیگر مورد نیاز نیست و به‌روزرسانی را در جا انجام می‌دهد!

Expand All @@ -484,7 +487,7 @@ def qm_jax_fori(x0, n, α=4.0):
```{code-cell} ipython3
with qe.Timer():
# First run
x_jax = qm_jax_fori(0.1, n)
x_jax = qm_jax_fori(x0_cpu, n)
# Hold interpreter
x_jax.block_until_ready()
```
Expand All @@ -494,7 +497,7 @@ with qe.Timer():
```{code-cell} ipython3
with qe.Timer():
# Second run
x_jax = qm_jax_fori(0.1, n)
x_jax = qm_jax_fori(x0_cpu, n)
# Hold interpreter
x_jax.block_until_ready()
```
Expand All @@ -508,7 +511,7 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
این روش جایگزین، به طور قابل بحث، بیشتر با رویکرد تابعی JAX همسو است --- اگرچه سینتکس آن به خاطر سپردن دشواری دارد.

```{code-cell} ipython3
@partial(jax.jit, static_argnames=("n",), device=cpu)
@partial(jax.jit, static_argnames=("n",))
def qm_jax_scan(x0, n, α=4.0):
def update(x, t):
x_new = α * x * (1 - x)
Expand All @@ -525,7 +528,7 @@ def qm_jax_scan(x0, n, α=4.0):
```{code-cell} ipython3
with qe.Timer():
# First run
x_jax = qm_jax_scan(0.1, n)
x_jax = qm_jax_scan(x0_cpu, n)
# Hold interpreter
x_jax.block_until_ready()
```
Expand All @@ -535,7 +538,7 @@ with qe.Timer():
```{code-cell} ipython3
with qe.Timer():
# Second run
x_jax = qm_jax_scan(0.1, n)
x_jax = qm_jax_scan(x0_cpu, n)
# Hold interpreter
x_jax.block_until_ready()
```
Expand Down
Loading