Skip to content

fix(vlm): align n_images_per_sample with batch_size in kimi_k25 collate#2175

Open
khazic wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
khazic:khazic/fix/kimi-k25-vl-image-counts-mismatch
Open

fix(vlm): align n_images_per_sample with batch_size in kimi_k25 collate#2175
khazic wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
khazic:khazic/fix/kimi-k25-vl-image-counts-mismatch

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented May 7, 2026

What

kimi_k25_vl_collate_fn produces a per-batch n_images_per_sample tensor whose length must equal batch_size so that downstream PP _chunk_vlm_media (in recipes/vlm/finetune.py:281-298) can index it by sample position. Currently it is derived from all_grid_thws, which is conditionally appended:

  1. Text-only samples — never produce grid_thws, so they are skipped.
  2. Truncated-orphan samples — when drop_overlong=False (the default) and max_length cuts into an image region, actual_image_tokens != expected_image_tokens so the grid is also skipped (collate_fns.py:930-936), but the sample itself is kept in all_expanded with orphaned image tokens replaced by pad.

Both cases leave len(all_grid_thws) < len(all_expanded), so:

image_counts = [g.shape[0] for g in all_grid_thws]   # length < batch_size
result["n_images_per_sample"] = torch.tensor(image_counts, dtype=torch.long)

produces a tensor shorter than batch_size. Downstream PP code:

cumsum_images = torch.cumsum(n_images_per_sample, dim=0)
samples_per_mb = batch_size // n_microbatches
for mb_idx in range(n_microbatches):
    s_start = mb_idx * samples_per_mb
    s_end = min(s_start + samples_per_mb, batch_size)
    img_end = int(cumsum_images[s_end - 1].item())   # IndexError when s_end > len(cumsum_images)

raises IndexError (or, in the Layout-3 patch-counts branch, indexes the wrong sample silently).

Trigger conditions

All required:

  1. Processor type KimiK25Processor.
  2. max_length set on the collate.
  3. drop_overlong=False (the default).
  4. Batch contains a mixture of:
    • at least one sample with an intact image (so all_grid_thws is non-empty and n_images_per_sample is set), AND
    • at least one sample without an intact image — text-only, OR an image whose region got partially truncated.
  5. Pipeline parallelism enabled (pp_size > 1) — without PP, _chunk_vlm_media is not called and the wrong-length tensor is unused, so no immediate crash.

Without #5 the bug is dormant; with #5 it crashes loudly.

Fix

Track image count per sample in lockstep with all_expanded. Text-only samples and truncation-orphaned samples contribute 0; samples with an intact image contribute grid_thws.shape[0]. The resulting tensor has length batch_size in all cases.

No behavior change for batches where every sample has an intact image — the new per-sample list then has the same values the old derivation produced.

Changelog

  • nemo_automodel/components/datasets/vlm/collate_fns.py: in kimi_k25_vl_collate_fn, introduce per_sample_image_count updated synchronously with all_expanded; use it to build n_images_per_sample instead of deriving from all_grid_thws. Net +14 lines.
  • tests/unit_tests/datasets/vlm/test_collate_fns.py: two regression tests:
    • test_kimi_k25_vl_collate_fn_n_images_per_sample_matches_batch_size_text_only_mix
    • test_kimi_k25_vl_collate_fn_n_images_per_sample_matches_batch_size_truncation_orphan

Pre-checks

  • Two new tests pass locally; all 16 kimi_k25 tests in the suite pass.
  • ruff check clean on touched lines.
  • DCO sign-off present.
  • Branch from latest upstream/main.

…kimi_k25

In kimi_k25_vl_collate_fn, n_images_per_sample was derived from all_grid_thws
only, which is conditionally appended:

* text-only samples have no grid_thws → not appended
* samples whose image region got orphaned by truncation (when
  drop_overlong=False, the default) also skip the append, but the sample
  itself stays in all_expanded with image tokens replaced by pad

Result: len(n_images_per_sample) < batch_size on mixed batches, while
input_ids has shape [batch_size, max_len]. Downstream PP _chunk_vlm_media
indexes cumsum_images by sample index up to batch_size and raises
IndexError when the cumsum is shorter.

Track per-sample image count in lockstep with all_expanded (zeros for
text-only and orphaned samples) so the resulting tensor has length
batch_size in all cases. No behavior change for batches where every sample
has an intact image, since the per-sample count then equals what the
old derivation produced.

Adds two regression tests covering (1) text-only + image mixed batch and
(2) intact-image + truncation-orphaned mixed batch.

Signed-off-by: khazic <khazzz1c@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 7, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 0c31b25

@HuiyingLi
Copy link
Copy Markdown
Contributor

/claude review

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. The fix correctly keeps per_sample_image_count in lockstep with all_expanded — text-only samples and truncation-orphaned samples contribute 0, and the drop_overlong continue path skips both lists. Regression tests cover both trigger conditions cleanly.

Copy link
Copy Markdown
Contributor

@HuiyingLi HuiyingLi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thank you so much!

@HuiyingLi HuiyingLi enabled auto-merge (squash) May 8, 2026 00:47
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label May 8, 2026
@thomasdhc
Copy link
Copy Markdown
Contributor

/ok to test 793232b

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 793232b

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants