fix(vlm): forward get_rope_index to neat packing for mRoPE models#2172
Open
khazic wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Open
fix(vlm): forward get_rope_index to neat packing for mRoPE models#2172khazic wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
khazic wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Conversation
The VLM recipe never passed the model's get_rope_index callable to neat_pack_dataset_vlm. With it absent, PackedDatasetWrapper sets has_mrope=False and emits 1D position_ids per pack. The collater then forwards 2D [B, L] position_ids to the model, which short-circuits get_rope_index inside model.forward and the language model expands the same 1D positions across all 3 mRoPE channels. Net effect: packed Qwen2.5-VL / Qwen3-VL / Qwen3-VL-MoE / Qwen3-Omni training silently degraded mRoPE to plain 1D rotary, losing image spatial/temporal positional information. Non-packed and non-mRoPE VLMs were unaffected. Plumbing only: extract get_rope_index via getattr(model_parts[0], ...) in the recipe and forward it through build_dataloader to neat_pack_dataset_vlm. Models without the method (Mistral3, LLaVA-OV, KimiVL, Gemma4-VLM) keep the prior behavior since getattr returns None. Adds two unit tests guarding the wiring against regression. Signed-off-by: khazic <khazzz1c@gmail.com>
Contributor
|
/ok to test 768bb46 |
Contributor
|
/ok to test 768bb46 |
HuiyingLi
approved these changes
May 8, 2026
Contributor
HuiyingLi
left a comment
There was a problem hiding this comment.
Thank you so much for capturing this
Contributor
|
/ok to test 7fc98d0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
When VLM neat-packing is enabled for an mRoPE-aware model (Qwen2.5-VL, Qwen3-VL, Qwen3-VL-MoE, Qwen3-Omni), the recipe never passes
model.get_rope_indextoneat_pack_dataset_vlm. With it absent, packedposition_idsare emitted as 1Drange(seq_len)per sample and the model's internal mRoPE auto-fill is skipped.This PR plumbs
get_rope_indexfrom the model throughbuild_dataloadertoneat_pack_dataset_vlm. Models without the method (Mistral3, LLaVA-OV, KimiVL, Gemma4-VLM) keep prior behavior sincegetattr(model, "get_rope_index", None)returnsNone.Exact trigger conditions
All of the following must hold simultaneously:
recipes/vlm/finetune.py) — not LLM/diffusion/retrieval.packed_sequence.pack_size > 0(or legacydataset.packing.enabled).packed_sequence.pretokenize: true(so the path goes throughPreTokenizedDatasetWrapper+neat_pack_dataset_vlm).get_rope_index— i.e. mRoPE-aware:Qwen2_5_VLForConditionalGeneration,Qwen3VLForConditionalGeneration,Qwen3VLMoeForConditionalGeneration,Qwen3OmniMoeForConditionalGeneration.post_tokenize_hook_fnthat pre-computes 3Dposition_ids.The path is independent of CP. CP-specific mRoPE handling (PR #1482,
cp_utils.py:294) is correct given 3Dposition_idsenter CP — that fix shards correctly ondim=2forndim==3. This PR is about how 3Dposition_idsoriginate in the first place under packing.Unaffected combinations (no behavior change):
default_collate_fndeliberately omitsposition_idsso the model auto-callsget_rope_index(see comment atcollate_fns.py:1277-1281).getattr(model, "get_rope_index", None)returnsNone, packing falls through the existing 1D path, identical to current behavior.neat_pack_dataset(not_vlm), unrelated.Code path (failure case before this PR)
The
expand(3, B, S)produces three identical position channels, so the t/h/w mRoPE chunks rotate by the same angle in image regions. This is mathematically not equivalent to standard mRoPE, which requires distinct temporal/height/width coordinates for image tokens.With this PR
Open question for reviewers
A NeMo CP researcher mentioned mRoPE is "already taken care of" in the codebase. I traced this to
cp_utils.py:287-295(PR #1482) which shards 3D mRoPE correctly. I could not find a path that generates 3D mRoPE under packing withoutget_rope_indexbeing passed through. If such a path exists, please point me to it and I will close this PR. Otherwise this fix is plumbing-only and safe.What I have not measured: the empirical convergence delta between degenerate-mRoPE and proper-mRoPE on packed Qwen-VL training. The fix is a correctness improvement for the packed path; magnitude requires an A/B convergence run.
Changelog
recipes/vlm/finetune.py: extractget_rope_indexviagetattr(self.model_parts[0], "get_rope_index", None)and forward it throughbuild_dataloader→neat_pack_dataset_vlm. 14 lines added.tests/unit_tests/recipes/test_finetune_vlm_helpers.py: two unit tests verifying the wiring (sentinel forwarded; defaultNonefor non-mRoPE models).Pre-checks
ruff formatandruff checkclean on touched lines.pytest tests/unit_tests/recipes/test_finetune_vlm_helpers.py -k get_rope_index.khazic/fix/vlm-packed-mrope-position-idsfrom latestupstream/main.Affected example configs
examples/vlm_finetune/qwen3/qwen3_vl_4b_neat_packing.yamlexamples/vlm_finetune/qwen3/qwen3_vl_moe_30b_neat_packing.yamlexamples/vlm_finetune/qwen3_5_moe/qwen3_5_35b_neat_packing.yamlexamples/vlm_finetune/qwen3_5/qwen3_5_4b_neat_packing.yamlAll four match all 5 trigger conditions above with no
post_tokenize_hook_fnoverride.