Skip to content

fix(vlm): forward get_rope_index to neat packing for mRoPE models#2172

Open
khazic wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
khazic:khazic/fix/vlm-packed-mrope-position-ids
Open

fix(vlm): forward get_rope_index to neat packing for mRoPE models#2172
khazic wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
khazic:khazic/fix/vlm-packed-mrope-position-ids

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented May 7, 2026

What

When VLM neat-packing is enabled for an mRoPE-aware model (Qwen2.5-VL, Qwen3-VL, Qwen3-VL-MoE, Qwen3-Omni), the recipe never passes model.get_rope_index to neat_pack_dataset_vlm. With it absent, packed position_ids are emitted as 1D range(seq_len) per sample and the model's internal mRoPE auto-fill is skipped.

This PR plumbs get_rope_index from the model through build_dataloader to neat_pack_dataset_vlm. Models without the method (Mistral3, LLaVA-OV, KimiVL, Gemma4-VLM) keep prior behavior since getattr(model, "get_rope_index", None) returns None.

Exact trigger conditions

All of the following must hold simultaneously:

  1. VLM recipe (recipes/vlm/finetune.py) — not LLM/diffusion/retrieval.
  2. Packing enabledpacked_sequence.pack_size > 0 (or legacy dataset.packing.enabled).
  3. Pretokenize onpacked_sequence.pretokenize: true (so the path goes through PreTokenizedDatasetWrapper + neat_pack_dataset_vlm).
  4. Model implements get_rope_index — i.e. mRoPE-aware: Qwen2_5_VLForConditionalGeneration, Qwen3VLForConditionalGeneration, Qwen3VLMoeForConditionalGeneration, Qwen3OmniMoeForConditionalGeneration.
  5. No user-supplied post_tokenize_hook_fn that pre-computes 3D position_ids.

The path is independent of CP. CP-specific mRoPE handling (PR #1482, cp_utils.py:294) is correct given 3D position_ids enter CP — that fix shards correctly on dim=2 for ndim==3. This PR is about how 3D position_ids originate in the first place under packing.

Unaffected combinations (no behavior change):

  • Non-packing VLM training — default_collate_fn deliberately omits position_ids so the model auto-calls get_rope_index (see comment at collate_fns.py:1277-1281).
  • Non-mRoPE VLMs — getattr(model, "get_rope_index", None) returns None, packing falls through the existing 1D path, identical to current behavior.
  • LLM neat packing — uses neat_pack_dataset (not _vlm), unrelated.

Code path (failure case before this PR)

recipe `build_dataloader` → neat_pack_dataset_vlm(get_rope_index=None)
  → PackedDatasetWrapper.has_mrope = False                  # neat_packing_vlm.py:450
  → _build_packed_vlm_sample(has_mrope=False)
      → all_position_ids_1d.extend(range(seq_len))          # line 367-370
      → packed["position_ids"] = 1D [total_len]             # line 400
  → neat_packed_vlm_collater
      → 1D path: stack to [B, max_len]                      # collate_fns.py:1528
  → batch sent to model.forward via filter_forward_kwargs
  → Qwen3VLMoe.forward (model.py:230)
      → if position_ids is None: ... # FALSE (we passed 2D), skipped
  → language_model.forward (model.py:354-357)
      → elif position_ids.ndim == 2:
            position_ids = position_ids[None, ...].expand(3, B, S)  # 3 identical channels

The expand(3, B, S) produces three identical position channels, so the t/h/w mRoPE chunks rotate by the same angle in image regions. This is mathematically not equivalent to standard mRoPE, which requires distinct temporal/height/width coordinates for image tokens.

With this PR

recipe `build_dataloader` → neat_pack_dataset_vlm(get_rope_index=model.get_rope_index)
  → PackedDatasetWrapper.has_mrope = True
  → __getitem__: _compute_mrope_position_ids(sample, get_rope_index)
      → returns 3D [3, seq_len] per sample
  → _build_packed_vlm_sample(has_mrope=True)
      → packed["position_ids"] = torch.cat(mrope_list, dim=1)  # [3, total_len]
  → collater stacks to [3, B, max_len]                          # collate_fns.py:1525
  → CP path (if enabled) shards on dim=2 (PR #1482)
  → model.forward sees 3D position_ids, uses as-is

Open question for reviewers

A NeMo CP researcher mentioned mRoPE is "already taken care of" in the codebase. I traced this to cp_utils.py:287-295 (PR #1482) which shards 3D mRoPE correctly. I could not find a path that generates 3D mRoPE under packing without get_rope_index being passed through. If such a path exists, please point me to it and I will close this PR. Otherwise this fix is plumbing-only and safe.

What I have not measured: the empirical convergence delta between degenerate-mRoPE and proper-mRoPE on packed Qwen-VL training. The fix is a correctness improvement for the packed path; magnitude requires an A/B convergence run.

Changelog

  • recipes/vlm/finetune.py: extract get_rope_index via getattr(self.model_parts[0], "get_rope_index", None) and forward it through build_dataloaderneat_pack_dataset_vlm. 14 lines added.
  • tests/unit_tests/recipes/test_finetune_vlm_helpers.py: two unit tests verifying the wiring (sentinel forwarded; default None for non-mRoPE models).

Pre-checks

  • ruff format and ruff check clean on touched lines.
  • Two new unit tests pass locally: pytest tests/unit_tests/recipes/test_finetune_vlm_helpers.py -k get_rope_index.
  • DCO sign-off present.
  • Branch khazic/fix/vlm-packed-mrope-position-ids from latest upstream/main.

Affected example configs

  • examples/vlm_finetune/qwen3/qwen3_vl_4b_neat_packing.yaml
  • examples/vlm_finetune/qwen3/qwen3_vl_moe_30b_neat_packing.yaml
  • examples/vlm_finetune/qwen3_5_moe/qwen3_5_35b_neat_packing.yaml
  • examples/vlm_finetune/qwen3_5/qwen3_5_4b_neat_packing.yaml

All four match all 5 trigger conditions above with no post_tokenize_hook_fn override.

The VLM recipe never passed the model's get_rope_index callable to
neat_pack_dataset_vlm. With it absent, PackedDatasetWrapper sets
has_mrope=False and emits 1D position_ids per pack. The collater then
forwards 2D [B, L] position_ids to the model, which short-circuits
get_rope_index inside model.forward and the language model expands the
same 1D positions across all 3 mRoPE channels.

Net effect: packed Qwen2.5-VL / Qwen3-VL / Qwen3-VL-MoE / Qwen3-Omni
training silently degraded mRoPE to plain 1D rotary, losing image
spatial/temporal positional information. Non-packed and non-mRoPE VLMs
were unaffected.

Plumbing only: extract get_rope_index via getattr(model_parts[0], ...)
in the recipe and forward it through build_dataloader to
neat_pack_dataset_vlm. Models without the method (Mistral3, LLaVA-OV,
KimiVL, Gemma4-VLM) keep the prior behavior since getattr returns None.

Adds two unit tests guarding the wiring against regression.

Signed-off-by: khazic <khazzz1c@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 7, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@khazic khazic changed the title [recipe][vlm] fix: forward get_rope_index to neat packing for mRoPE models fix(vlm): forward get_rope_index to neat packing for mRoPE models May 7, 2026
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 768bb46

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 768bb46

Copy link
Copy Markdown
Contributor

@HuiyingLi HuiyingLi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for capturing this

@HuiyingLi HuiyingLi enabled auto-merge (squash) May 8, 2026 06:04
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label May 8, 2026
@HuiyingLi HuiyingLi disabled auto-merge May 8, 2026 15:44
@HuiyingLi HuiyingLi enabled auto-merge (squash) May 8, 2026 15:44
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 7fc98d0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants