diff --git a/nemo_automodel/components/datasets/vlm/collate_fns.py b/nemo_automodel/components/datasets/vlm/collate_fns.py index a72e3f413..51a96ca15 100644 --- a/nemo_automodel/components/datasets/vlm/collate_fns.py +++ b/nemo_automodel/components/datasets/vlm/collate_fns.py @@ -871,6 +871,10 @@ def kimi_k25_vl_collate_fn( all_expanded = [] all_pixel_values = [] all_grid_thws = [] + # Per-sample image counts, kept in lockstep with all_expanded so that + # n_images_per_sample length matches batch_size downstream. Samples that + # are text-only or whose image region was orphaned by truncation get 0. + per_sample_image_count: List[int] = [] for i, conversation in enumerate(conversations): # Collect medias for this conversation @@ -923,12 +927,14 @@ def kimi_k25_vl_collate_fn( # Only include image data if all expanded image tokens survived truncation. # Partial truncation into image regions would cause a mismatch in the model forward. + sample_image_count = 0 if grid_thws is not None: merge_h, merge_w = _DEFAULT_MERGE_KERNEL expected_image_tokens = sum(int((h // merge_h) * (w // merge_w)) for _, h, w in grid_thws.tolist()) actual_image_tokens = (input_ids == media_token_id).sum().item() if actual_image_tokens == expected_image_tokens: all_grid_thws.append(grid_thws) + sample_image_count = int(grid_thws.shape[0]) if "pixel_values" in sample_batch: all_pixel_values.append(sample_batch["pixel_values"]) else: @@ -943,6 +949,7 @@ def kimi_k25_vl_collate_fn( "attention_mask": attention_mask, } ) + per_sample_image_count.append(sample_image_count) if not all_expanded: raise ValueError( @@ -990,9 +997,10 @@ def kimi_k25_vl_collate_fn( result["grid_thws"] = torch.cat(all_grid_thws, dim=0) # Also add as image_grid_hws for PP chunking in finetune.py result["image_grid_hws"] = result["grid_thws"][:, 1:] # [N, 3] -> [N, 2] (drop temporal dim, keep H,W) - # Per-sample image counts for PP chunking - image_counts = [g.shape[0] for g in all_grid_thws] - result["n_images_per_sample"] = torch.tensor(image_counts, dtype=torch.long) + # Per-sample image counts for PP chunking. Length must equal batch_size, + # so include zeros for text-only samples and for samples whose image + # region was orphaned by truncation. + result["n_images_per_sample"] = torch.tensor(per_sample_image_count, dtype=torch.long) # Build labels labels = build_labels_from_template( diff --git a/tests/unit_tests/datasets/vlm/test_collate_fns.py b/tests/unit_tests/datasets/vlm/test_collate_fns.py index e52c1f094..0cf925eac 100644 --- a/tests/unit_tests/datasets/vlm/test_collate_fns.py +++ b/tests/unit_tests/datasets/vlm/test_collate_fns.py @@ -1410,6 +1410,148 @@ def fake_build_labels(input_ids, conversations, processor_arg): assert (batch["input_ids"] == MEDIA_TOKEN_ID).sum().item() == 0 +def test_kimi_k25_vl_collate_fn_n_images_per_sample_matches_batch_size_text_only_mix( + collate_mod, monkeypatch +): + """Mixed batch (text-only + image): n_images_per_sample length must equal batch_size. + + Regression: previously image_counts was derived from all_grid_thws only, so + text-only samples were skipped and the resulting tensor was shorter than + batch_size. Downstream PP _chunk_vlm_media indexes cumsum_images by + sample index and would IndexError out of bounds. + """ + MEDIA_TOKEN_ID = 163605 + + class MixedProcessor: + def __init__(self): + self.tokenizer = DummyTokenizer(pad_token_id=0) + self.media_placeholder_token_id = MEDIA_TOKEN_ID + + def apply_chat_template(self, conversation, **kwargs): + return "chat:processed" + + def __call__(self, *, text, return_tensors, medias=None, **kwargs): + if medias: + input_ids = torch.tensor([[1, 2, MEDIA_TOKEN_ID, 3, 4]]) + attention_mask = torch.ones_like(input_ids) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "grid_thws": torch.tensor([[1, 4, 4]]), + "pixel_values": torch.randn(1, 3, 14, 14), + } + input_ids = torch.tensor([[10, 11, 12, 13, 14]]) + attention_mask = torch.ones_like(input_ids) + return {"input_ids": input_ids, "attention_mask": attention_mask} + + processor = MixedProcessor() + + def fake_build_labels(input_ids, conversations, processor_arg): + batch_size, seq_len = input_ids.shape + return torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1) + + monkeypatch.setattr(collate_mod, "build_labels_from_template", fake_build_labels, raising=True) + + text_only = [ + {"role": "user", "content": [{"type": "text", "text": "Hi"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "Hello"}]}, + ] + with_image = [ + {"role": "user", "content": [{"type": "image", "image": "x.jpg"}, {"type": "text", "text": "What?"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "Cat."}]}, + ] + examples = [{"conversation": text_only}, {"conversation": with_image}] + + batch = collate_mod.kimi_k25_vl_collate_fn(examples, processor) + + assert "n_images_per_sample" in batch + assert batch["n_images_per_sample"].shape == (2,), ( + f"n_images_per_sample length must equal batch_size=2, " + f"got shape {batch['n_images_per_sample'].shape}" + ) + # text-only sample → 0; image sample → 1 + assert batch["n_images_per_sample"].tolist() == [0, 1] + + +def test_kimi_k25_vl_collate_fn_n_images_per_sample_matches_batch_size_truncation_orphan( + collate_mod, monkeypatch +): + """Mixed batch (truncated image + intact image): n_images_per_sample length must equal batch_size. + + Regression: a sample whose image region got orphaned by truncation was + correctly excluded from all_grid_thws but still kept in all_expanded. + Without the fix, n_images_per_sample length would be smaller than the + final batch and downstream PP indexing would crash. + """ + MEDIA_TOKEN_ID = 163605 + + class MaybeOrphanProcessor: + """Returns the same large grid for both calls; the second call's tokens + will be truncated past the image region by max_length below.""" + + def __init__(self): + self.tokenizer = DummyTokenizer(pad_token_id=0) + self.media_placeholder_token_id = MEDIA_TOKEN_ID + self._call_idx = 0 + + def apply_chat_template(self, conversation, **kwargs): + return "chat:processed" + + def __call__(self, *, text, return_tensors, medias=None, **kwargs): + self._call_idx += 1 + if self._call_idx == 1: + # Small grid that fits within max_length after expansion + input_ids = torch.tensor([[1, 2, MEDIA_TOKEN_ID, 3, 4]]) + attention_mask = torch.ones_like(input_ids) + grid_thws = torch.tensor([[1, 4, 4]]) # 4 image tokens + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "grid_thws": grid_thws, + "pixel_values": torch.randn(1, 3, 14, 14), + } + # Second sample: 5 text + 16 image tokens = 21 post-expansion; + # max_length=15 truncates into the image region → orphan path. + input_ids = torch.tensor([[1, 2, MEDIA_TOKEN_ID, 3, 4, 5]]) + attention_mask = torch.ones_like(input_ids) + grid_thws = torch.tensor([[1, 8, 8]]) # 16 image tokens after expansion + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "grid_thws": grid_thws, + "pixel_values": torch.randn(1, 3, 64, 64), + } + + processor = MaybeOrphanProcessor() + + def fake_build_labels(input_ids, conversations, processor_arg): + batch_size, seq_len = input_ids.shape + return torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1) + + monkeypatch.setattr(collate_mod, "build_labels_from_template", fake_build_labels, raising=True) + + conv_intact = [ + {"role": "user", "content": [{"type": "image", "image": "a.jpg"}, {"type": "text", "text": "?"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "."}]}, + ] + conv_orphan = [ + {"role": "user", "content": [{"type": "image", "image": "b.jpg"}, {"type": "text", "text": "?"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "."}]}, + ] + examples = [{"conversation": conv_intact}, {"conversation": conv_orphan}] + + batch = collate_mod.kimi_k25_vl_collate_fn(examples, processor, max_length=15) + + assert batch["input_ids"].shape[0] == 2 + assert "n_images_per_sample" in batch + assert batch["n_images_per_sample"].shape == (2,), ( + f"n_images_per_sample length must equal batch_size=2, " + f"got shape {batch['n_images_per_sample'].shape}" + ) + # First sample's image survives → 1; second sample is orphaned → 0 + assert batch["n_images_per_sample"].tolist() == [1, 0] + + def test_kimi_k25_vl_collate_fn_multiple_examples(collate_mod, monkeypatch): """Test kimi_k25_vl_collate_fn handles multiple examples with padding.""" # Processor that produces variable length sequences