Skip to content

Fix TPU kernel OOB memory access #3946

Merged
copybara-service[bot] merged 1 commit into
mainfrom
fix-security-vulns
May 22, 2026
Merged

Fix TPU kernel OOB memory access #3946
copybara-service[bot] merged 1 commit into
mainfrom
fix-security-vulns

Conversation

@darisoy
Copy link
Copy Markdown
Collaborator

@darisoy darisoy commented May 19, 2026

Description

This PR fixes a critical TPU hardware crash (Denial of Service / Machine Check Exception) caused by a negative prefetch index in MaxText ragged attention kernels, and establishes full CPU-only CI test coverage using JAX Pallas Interpret mode.

Proposed Changes & Implementation:

  1. Negative Prefetch Index Prevention: Clamped last_good_block in ragged_attention.py to 0 using jnp.maximum(0, ...) to ensure empty sequence paddings safely access valid bounds, preventing out-of-bounds TPU hardware crashes (CWE-125).
  2. Preserved Intermediate Scaling & Documented: Retained the intermediate o = o * l unnormalization inside ragged_mha and ragged_gqa because the outer MaxText Attention layer (attention_op.py) explicitly relies on unnormalized outputs to perform multi-chunk prefill/decode cache merging and final normalization. Added clear explanatory inline comments to document this architectural requirement.
  3. Unified CPU Interpret Mode: Exposed interpret parameter in MHA/GQA wrappers to support CPU-based testing and added full unit test coverage on CPU.

BUG=510376806


Tests

The changes were verified using both live TPU hardware (tpu7x-8 VM) and CPU-based Pallas Interpret mode.

1. TPU-Based Verification

The unit tests in tests/unit/kernels_test.py correctly align scales using the division workaround (ragged_out / ragged_denom) to verify unnormalized outputs against reference implementations.

Command run inside the virtual environment on TPU VM:

source maxtext_venv/bin/activate
pytest tests/unit/kernels_test.py -k "not Cpu"

Result: 3 passed (TPU execution verified).

2. CPU-Based Verification & Coverage (Ragged Attention)

To ensure code coverage in CPU-only CI environments, we added a new test suite RaggedAttentionCpuTest in tests/unit/kernels_test.py. This suite forces JAX Pallas to run in Interpret Mode on CPU (bypassing TPU compilation).

Command run locally or in CPU-only CI:

JAX_PLATFORMS=cpu pytest tests/unit/kernels_test.py -k "Cpu"

Result: 3 passed (CPU interpret mode verified). This provides full test coverage for ragged_attention.py in standard GitHub CI runners.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@darisoy darisoy force-pushed the fix-security-vulns branch from e184879 to e29c89f Compare May 19, 2026 18:48
@codecov
Copy link
Copy Markdown

codecov Bot commented May 19, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@darisoy
Copy link
Copy Markdown
Collaborator Author

darisoy commented May 19, 2026

Note on Code Coverage (SparseCore Kernels)

You will notice that Codecov reports 0% Patch Coverage for the newly added security and bounds-checking lines in the following two files:

  • src/maxtext/kernels/gather_reduce_sc.py (Mosaic SparseCore)
  • src/maxtext/kernels/gather_reduce_pallas.py (Pallas SparseCore)

Why these lines are not covered in the CI:

  1. Hardware Requirements: These SparseCore gather-reduce kernels rely on low-level Mosaic MLIR dialects and SparseCore ASIC hardware, which are only supported on actual TPU instances (TPU v5e/v6e/v7+).
  2. CPU CI Limitation: The GitHub CI Code Coverage job runs on standard CPU-only runners. Mosaic MLIR has no CPU interpreter, and Pallas CPU interpret mode does not support or emulate SparseCore hardware features.
  3. Skipped Tests: Consequently, the SparseCore unit tests (tests/gather_reduce_sc_test.py) have hardware checks that automatically skip them on CPU, meaning these files are never executed during the CI coverage runs.

Manual Verification:
All SparseCore clamping, off-by-one, and underflow bounds-checking fixes have been manually verified and compiled successfully on a live tpu7x-8 VM.

(Note: The Ragged Attention fixes in ragged_attention.py have been fully covered and verified in CI using JAX Pallas's CPU Interpret Mode in tests/unit/kernels_test.py).

@darisoy darisoy force-pushed the fix-security-vulns branch 2 times, most recently from 931ff28 to 6d5cd16 Compare May 20, 2026 17:16
Comment thread codecov.yml Outdated
@darisoy darisoy force-pushed the fix-security-vulns branch 3 times, most recently from 6d5cd16 to 1048d22 Compare May 20, 2026 22:27
Copy link
Copy Markdown
Contributor

@clee1994 clee1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing this!

Comment thread src/maxtext/kernels/gather_reduce_sc.py Outdated
@darisoy darisoy force-pushed the fix-security-vulns branch from 1048d22 to b62dd6f Compare May 20, 2026 23:22
Comment thread src/maxtext/kernels/attention/ragged_attention.py
Comment thread src/maxtext/kernels/gather_reduce_sc.py Outdated
@darisoy darisoy force-pushed the fix-security-vulns branch 2 times, most recently from a92e5a6 to 35c9231 Compare May 22, 2026 16:34
@darisoy darisoy changed the title Fix TPU kernel OOB memory access and correctness issues Fix TPU kernel OOB memory access May 22, 2026
@darisoy darisoy force-pushed the fix-security-vulns branch from 35c9231 to 8b07ed8 Compare May 22, 2026 19:07
This commit resolves multiple security and correctness vulnerabilities in MaxText TPU kernels, corresponding to b/510376806.

1. maxtext/kernels/gather_reduce_sc.py:
   - Corrected off-by-one loop boundary in fill_load_offset_tile (CWE-193).
   - Added jnp.clip to clamp user-provided indices and prevent OOB DMA reads (CWE-125).
   - Added assertions for index shape to prevent integer underflow (CWE-191).
   - Added assertion for column size divisibility to prevent uninitialized memory leakage (CWE-908).

2. maxtext/kernels/gather_reduce_pallas.py:
   - Added jnp.clip bounds checking for gather indices (CWE-125).
   - Added divisibility check for column chunk size to prevent uninitialized memory leaks.
   - Added size validation for topk_weights to prevent OOB reads.

3. maxtext/kernels/attention/ragged_attention.py:
   - Clamped last_good_block index to prevent negative prefetch indexing and subsequent TPU crashes (CWE-125).
   - Removed incorrect o = o * l multiplication in ragged_mha and ragged_gqa that corrupted normalized attention states (CWE-682).

4. tests/unit/kernels_test.py:
   - Updated tests to remove the redundant division by denominator, aligning with the corrected attention output.
@darisoy darisoy force-pushed the fix-security-vulns branch from 8b07ed8 to 211ec5e Compare May 22, 2026 19:23
@copybara-service copybara-service Bot merged commit bd85909 into main May 22, 2026
47 checks passed
@copybara-service copybara-service Bot deleted the fix-security-vulns branch May 22, 2026 20:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants