fix(vlm): align n_images_per_sample with batch_size in kimi_k25 collate#2175
Open
khazic wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Open
fix(vlm): align n_images_per_sample with batch_size in kimi_k25 collate#2175khazic wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
khazic wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Conversation
…kimi_k25 In kimi_k25_vl_collate_fn, n_images_per_sample was derived from all_grid_thws only, which is conditionally appended: * text-only samples have no grid_thws → not appended * samples whose image region got orphaned by truncation (when drop_overlong=False, the default) also skip the append, but the sample itself stays in all_expanded with image tokens replaced by pad Result: len(n_images_per_sample) < batch_size on mixed batches, while input_ids has shape [batch_size, max_len]. Downstream PP _chunk_vlm_media indexes cumsum_images by sample index up to batch_size and raises IndexError when the cumsum is shorter. Track per-sample image count in lockstep with all_expanded (zeros for text-only and orphaned samples) so the resulting tensor has length batch_size in all cases. No behavior change for batches where every sample has an intact image, since the per-sample count then equals what the old derivation produced. Adds two regression tests covering (1) text-only + image mixed batch and (2) intact-image + truncation-orphaned mixed batch. Signed-off-by: khazic <khazzz1c@gmail.com>
Contributor
|
/ok to test 0c31b25 |
Contributor
|
/claude review |
Contributor
There was a problem hiding this comment.
LGTM. The fix correctly keeps per_sample_image_count in lockstep with all_expanded — text-only samples and truncation-orphaned samples contribute 0, and the drop_overlong continue path skips both lists. Regression tests cover both trigger conditions cleanly.
HuiyingLi
approved these changes
May 8, 2026
Contributor
HuiyingLi
left a comment
There was a problem hiding this comment.
LGTM thank you so much!
Contributor
|
/ok to test 793232b |
Contributor
|
/ok to test 793232b |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
kimi_k25_vl_collate_fnproduces a per-batchn_images_per_sampletensor whose length must equalbatch_sizeso that downstream PP_chunk_vlm_media(inrecipes/vlm/finetune.py:281-298) can index it by sample position. Currently it is derived fromall_grid_thws, which is conditionally appended:grid_thws, so they are skipped.drop_overlong=False(the default) andmax_lengthcuts into an image region,actual_image_tokens != expected_image_tokensso the grid is also skipped (collate_fns.py:930-936), but the sample itself is kept inall_expandedwith orphaned image tokens replaced by pad.Both cases leave
len(all_grid_thws) < len(all_expanded), so:produces a tensor shorter than
batch_size. Downstream PP code:raises
IndexError(or, in the Layout-3 patch-counts branch, indexes the wrong sample silently).Trigger conditions
All required:
KimiK25Processor.max_lengthset on the collate.drop_overlong=False(the default).all_grid_thwsis non-empty andn_images_per_sampleis set), ANDpp_size > 1) — without PP,_chunk_vlm_mediais not called and the wrong-length tensor is unused, so no immediate crash.Without #5 the bug is dormant; with #5 it crashes loudly.
Fix
Track image count per sample in lockstep with
all_expanded. Text-only samples and truncation-orphaned samples contribute 0; samples with an intact image contributegrid_thws.shape[0]. The resulting tensor has lengthbatch_sizein all cases.No behavior change for batches where every sample has an intact image — the new per-sample list then has the same values the old derivation produced.
Changelog
nemo_automodel/components/datasets/vlm/collate_fns.py: inkimi_k25_vl_collate_fn, introduceper_sample_image_countupdated synchronously withall_expanded; use it to buildn_images_per_sampleinstead of deriving fromall_grid_thws. Net +14 lines.tests/unit_tests/datasets/vlm/test_collate_fns.py: two regression tests:test_kimi_k25_vl_collate_fn_n_images_per_sample_matches_batch_size_text_only_mixtest_kimi_k25_vl_collate_fn_n_images_per_sample_matches_batch_size_truncation_orphanPre-checks
kimi_k25tests in the suite pass.ruff checkclean on touched lines.upstream/main.