Skip to content

fix(vlm): chunk video inputs for pipeline parallelism#2177

Open
khazic wants to merge 7 commits intoNVIDIA-NeMo:mainfrom
khazic:khazic/fix/vlm-pp-video-chunking
Open

fix(vlm): chunk video inputs for pipeline parallelism#2177
khazic wants to merge 7 commits intoNVIDIA-NeMo:mainfrom
khazic:khazic/fix/vlm-pp-video-chunking

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented May 7, 2026

Summary

VLM pipeline parallelism previously only chunked image tensors per microbatch; video tensors (pixel_values_videos, video_grid_thw, n_videos_per_sample) were left in the batch and got sliced along dim 0 by schedule.step, which is wrong because their dim 0 is total-patches / video-count, not batch-size. Result: VLM + PP + video either crashes on shape mismatch or silently miscorrelates patches with text tokens.

This PR mirrors the existing image chunking path for video.

Changelog

  • fix(vlm): pop pixel_values_videos / video_grid_thw / n_videos_per_sample in _forward_backward_step (PP branch) and pre-chunk them onto stage 0 via _chunk_vlm_media, then consume per-microbatch inside forward() of qwen3_5_moe, qwen3_omni_moe, qwen3_vl_moe. A shared _vlm_chunk_idx is incremented once per call via a consumed_vlm_chunk flag so image+video coexisting on stage 0 advance the cursor correctly.
  • test(vlm): add unit coverage for video-only chunking (_chunk_vlm_media + end-to-end _forward_backward_step) and image+video mixed chunking, asserting popped kwargs, per-microbatch shapes, single shared cursor at 0, and post-step cleanup.
  • style(vlm): run ruff format on tests/unit_tests/recipes/test_finetune_vlm_helpers.py to clean up pre-existing whitespace/quote drift. No behavioral changes.

Test plan

  • uv run pytest tests/unit_tests/recipes/test_finetune_vlm_helpers.py — 68 passed, 3 skipped (skips are pre-existing fused_linear_ce GPU cases)
  • uv run ruff format --check on the touched files — clean
  • PP=2 functional run on a Qwen3-VL-MoE recipe with video data (requires multi-GPU host)

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 7, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

khazic added 4 commits May 7, 2026 18:41
…kimi_k25

In kimi_k25_vl_collate_fn, n_images_per_sample was derived from all_grid_thws
only, which is conditionally appended:

* text-only samples have no grid_thws → not appended
* samples whose image region got orphaned by truncation (when
  drop_overlong=False, the default) also skip the append, but the sample
  itself stays in all_expanded with image tokens replaced by pad

Result: len(n_images_per_sample) < batch_size on mixed batches, while
input_ids has shape [batch_size, max_len]. Downstream PP _chunk_vlm_media
indexes cumsum_images by sample index up to batch_size and raises
IndexError when the cumsum is shorter.

Track per-sample image count in lockstep with all_expanded (zeros for
text-only and orphaned samples) so the resulting tensor has length
batch_size in all cases. No behavior change for batches where every sample
has an intact image, since the per-sample count then equals what the
old derivation produced.

Adds two regression tests covering (1) text-only + image mixed batch and
(2) intact-image + truncation-orphaned mixed batch.

Signed-off-by: khazic <khazzz1c@gmail.com>
Signed-off-by: khazic <khazzz1c@gmail.com>
Cover the path where a single batch contains both pixel_values and
pixel_values_videos under PP. Verifies that image and video are popped
from schedule kwargs, both chunk arrays are sized by their per-sample
counts, the shared _vlm_chunk_idx is initialized once at 0, and both
streams plus the cursor are cleared after the step.

Signed-off-by: khazic <khazzz1c@gmail.com>
Apply ruff format to clean up pre-existing whitespace drift in the
file. No behavioral changes.

Signed-off-by: khazic <khazzz1c@gmail.com>
@khazic khazic force-pushed the khazic/fix/vlm-pp-video-chunking branch from 8821cc6 to dca45ea Compare May 7, 2026 10:42
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test dca45ea

@HuiyingLi
Copy link
Copy Markdown
Contributor

/claude review

batch["video_grid_thw"] = video_grid_thw

if self.pp.info.has_first_stage:
self.pp.info.schedule.step(input_ids, target=targets, losses=losses, **batch)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM — clean fix. The video chunking mirrors the existing image path correctly, the shared _vlm_chunk_idx cursor works because both modalities are chunked by the same microbatch boundaries, and the cleanup is properly guarded. Test coverage for video-only, mixed image+video, and the n_images_per_sample batch-size regression is solid.

HuiyingLi
HuiyingLi previously approved these changes May 8, 2026
The PP video chunking refactor in this PR routes image and video chunks
through separate model attributes (``_vlm_pixel_values_chunks`` /
``_vlm_image_grid_hws_chunks`` for images vs.
``_vlm_pixel_values_videos_chunks`` / ``_vlm_video_grid_thw_chunks`` for
videos). The forward path is also strict: a video token in input_ids only
consumes the video chunk arrays.

The GPU PP-guard test still set up image chunks (legacy attribute names)
while feeding only the video token 151656, so the new strict consumption
path took neither branch and ``_vlm_chunk_idx`` stayed at 0 — the assertion
``assert model._vlm_chunk_idx == 1`` therefore failed in L0_Unit_Tests_GPU.

Update the fixture to set ``_vlm_pixel_values_videos_chunks`` /
``_vlm_video_grid_thw_chunks`` so the video-token branch fires. The shapes
mirror the existing CPU equivalent test
``TestQwen3VLMoeForConditionalGenerationPpGuardCpu::test_chunked_pixel_values_videos_consumed_for_video_token``.

Signed-off-by: khazic <khazzz1c@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test d2e13ae

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 62f7695

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants