Fix TPU kernel OOB memory access #3946
Merged
Merged
Conversation
e184879 to
e29c89f
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
9f61efe to
5c87df7
Compare
Collaborator
Author
Note on Code Coverage (SparseCore Kernels)You will notice that Codecov reports
Why these lines are not covered in the CI:
Manual Verification: (Note: The Ragged Attention fixes in |
931ff28 to
6d5cd16
Compare
bvandermoon
reviewed
May 20, 2026
6d5cd16 to
1048d22
Compare
clee1994
reviewed
May 20, 2026
Contributor
clee1994
left a comment
There was a problem hiding this comment.
LGTM, thanks for fixing this!
1048d22 to
b62dd6f
Compare
NuojCheng
reviewed
May 22, 2026
NuojCheng
reviewed
May 22, 2026
NuojCheng
approved these changes
May 22, 2026
gobbleturk
approved these changes
May 22, 2026
a92e5a6 to
35c9231
Compare
35c9231 to
8b07ed8
Compare
This commit resolves multiple security and correctness vulnerabilities in MaxText TPU kernels, corresponding to b/510376806. 1. maxtext/kernels/gather_reduce_sc.py: - Corrected off-by-one loop boundary in fill_load_offset_tile (CWE-193). - Added jnp.clip to clamp user-provided indices and prevent OOB DMA reads (CWE-125). - Added assertions for index shape to prevent integer underflow (CWE-191). - Added assertion for column size divisibility to prevent uninitialized memory leakage (CWE-908). 2. maxtext/kernels/gather_reduce_pallas.py: - Added jnp.clip bounds checking for gather indices (CWE-125). - Added divisibility check for column chunk size to prevent uninitialized memory leaks. - Added size validation for topk_weights to prevent OOB reads. 3. maxtext/kernels/attention/ragged_attention.py: - Clamped last_good_block index to prevent negative prefetch indexing and subsequent TPU crashes (CWE-125). - Removed incorrect o = o * l multiplication in ragged_mha and ragged_gqa that corrupted normalized attention states (CWE-682). 4. tests/unit/kernels_test.py: - Updated tests to remove the redundant division by denominator, aligning with the corrected attention output.
8b07ed8 to
211ec5e
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR fixes a critical TPU hardware crash (Denial of Service / Machine Check Exception) caused by a negative prefetch index in MaxText ragged attention kernels, and establishes full CPU-only CI test coverage using JAX Pallas Interpret mode.
Proposed Changes & Implementation:
last_good_blockinragged_attention.pyto0usingjnp.maximum(0, ...)to ensure empty sequence paddings safely access valid bounds, preventing out-of-bounds TPU hardware crashes (CWE-125).o = o * lunnormalization insideragged_mhaandragged_gqabecause the outer MaxTextAttentionlayer (attention_op.py) explicitly relies on unnormalized outputs to perform multi-chunk prefill/decode cache merging and final normalization. Added clear explanatory inline comments to document this architectural requirement.interpretparameter in MHA/GQA wrappers to support CPU-based testing and added full unit test coverage on CPU.BUG=510376806
Tests
The changes were verified using both live TPU hardware (
tpu7x-8VM) and CPU-based Pallas Interpret mode.1. TPU-Based Verification
The unit tests in
tests/unit/kernels_test.pycorrectly align scales using the division workaround (ragged_out / ragged_denom) to verify unnormalized outputs against reference implementations.Command run inside the virtual environment on TPU VM:
Result:
3 passed(TPU execution verified).2. CPU-Based Verification & Coverage (Ragged Attention)
To ensure code coverage in CPU-only CI environments, we added a new test suite
RaggedAttentionCpuTestintests/unit/kernels_test.py. This suite forces JAX Pallas to run in Interpret Mode on CPU (bypassing TPU compilation).Command run locally or in CPU-only CI:
JAX_PLATFORMS=cpu pytest tests/unit/kernels_test.py -k "Cpu"Result:
3 passed(CPU interpret mode verified). This provides full test coverage forragged_attention.pyin standard GitHub CI runners.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.