Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions nemo_automodel/components/datasets/vlm/collate_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,10 @@ def kimi_k25_vl_collate_fn(
all_expanded = []
all_pixel_values = []
all_grid_thws = []
# Per-sample image counts, kept in lockstep with all_expanded so that
# n_images_per_sample length matches batch_size downstream. Samples that
# are text-only or whose image region was orphaned by truncation get 0.
per_sample_image_count: List[int] = []

for i, conversation in enumerate(conversations):
# Collect medias for this conversation
Expand Down Expand Up @@ -923,12 +927,14 @@ def kimi_k25_vl_collate_fn(

# Only include image data if all expanded image tokens survived truncation.
# Partial truncation into image regions would cause a mismatch in the model forward.
sample_image_count = 0
if grid_thws is not None:
merge_h, merge_w = _DEFAULT_MERGE_KERNEL
expected_image_tokens = sum(int((h // merge_h) * (w // merge_w)) for _, h, w in grid_thws.tolist())
actual_image_tokens = (input_ids == media_token_id).sum().item()
if actual_image_tokens == expected_image_tokens:
all_grid_thws.append(grid_thws)
sample_image_count = int(grid_thws.shape[0])
if "pixel_values" in sample_batch:
all_pixel_values.append(sample_batch["pixel_values"])
else:
Expand All @@ -943,6 +949,7 @@ def kimi_k25_vl_collate_fn(
"attention_mask": attention_mask,
}
)
per_sample_image_count.append(sample_image_count)

if not all_expanded:
raise ValueError(
Expand Down Expand Up @@ -990,9 +997,10 @@ def kimi_k25_vl_collate_fn(
result["grid_thws"] = torch.cat(all_grid_thws, dim=0)
# Also add as image_grid_hws for PP chunking in finetune.py
result["image_grid_hws"] = result["grid_thws"][:, 1:] # [N, 3] -> [N, 2] (drop temporal dim, keep H,W)
# Per-sample image counts for PP chunking
image_counts = [g.shape[0] for g in all_grid_thws]
result["n_images_per_sample"] = torch.tensor(image_counts, dtype=torch.long)
# Per-sample image counts for PP chunking. Length must equal batch_size,
# so include zeros for text-only samples and for samples whose image
# region was orphaned by truncation.
result["n_images_per_sample"] = torch.tensor(per_sample_image_count, dtype=torch.long)

# Build labels
labels = build_labels_from_template(
Expand Down
142 changes: 142 additions & 0 deletions tests/unit_tests/datasets/vlm/test_collate_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,6 +1410,148 @@ def fake_build_labels(input_ids, conversations, processor_arg):
assert (batch["input_ids"] == MEDIA_TOKEN_ID).sum().item() == 0


def test_kimi_k25_vl_collate_fn_n_images_per_sample_matches_batch_size_text_only_mix(
collate_mod, monkeypatch
):
"""Mixed batch (text-only + image): n_images_per_sample length must equal batch_size.

Regression: previously image_counts was derived from all_grid_thws only, so
text-only samples were skipped and the resulting tensor was shorter than
batch_size. Downstream PP _chunk_vlm_media indexes cumsum_images by
sample index and would IndexError out of bounds.
"""
MEDIA_TOKEN_ID = 163605

class MixedProcessor:
def __init__(self):
self.tokenizer = DummyTokenizer(pad_token_id=0)
self.media_placeholder_token_id = MEDIA_TOKEN_ID

def apply_chat_template(self, conversation, **kwargs):
return "chat:processed"

def __call__(self, *, text, return_tensors, medias=None, **kwargs):
if medias:
input_ids = torch.tensor([[1, 2, MEDIA_TOKEN_ID, 3, 4]])
attention_mask = torch.ones_like(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"grid_thws": torch.tensor([[1, 4, 4]]),
"pixel_values": torch.randn(1, 3, 14, 14),
}
input_ids = torch.tensor([[10, 11, 12, 13, 14]])
attention_mask = torch.ones_like(input_ids)
return {"input_ids": input_ids, "attention_mask": attention_mask}

processor = MixedProcessor()

def fake_build_labels(input_ids, conversations, processor_arg):
batch_size, seq_len = input_ids.shape
return torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)

monkeypatch.setattr(collate_mod, "build_labels_from_template", fake_build_labels, raising=True)

text_only = [
{"role": "user", "content": [{"type": "text", "text": "Hi"}]},
{"role": "assistant", "content": [{"type": "text", "text": "Hello"}]},
]
with_image = [
{"role": "user", "content": [{"type": "image", "image": "x.jpg"}, {"type": "text", "text": "What?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "Cat."}]},
]
examples = [{"conversation": text_only}, {"conversation": with_image}]

batch = collate_mod.kimi_k25_vl_collate_fn(examples, processor)

assert "n_images_per_sample" in batch
assert batch["n_images_per_sample"].shape == (2,), (
f"n_images_per_sample length must equal batch_size=2, "
f"got shape {batch['n_images_per_sample'].shape}"
)
# text-only sample → 0; image sample → 1
assert batch["n_images_per_sample"].tolist() == [0, 1]


def test_kimi_k25_vl_collate_fn_n_images_per_sample_matches_batch_size_truncation_orphan(
collate_mod, monkeypatch
):
"""Mixed batch (truncated image + intact image): n_images_per_sample length must equal batch_size.

Regression: a sample whose image region got orphaned by truncation was
correctly excluded from all_grid_thws but still kept in all_expanded.
Without the fix, n_images_per_sample length would be smaller than the
final batch and downstream PP indexing would crash.
"""
MEDIA_TOKEN_ID = 163605

class MaybeOrphanProcessor:
"""Returns the same large grid for both calls; the second call's tokens
will be truncated past the image region by max_length below."""

def __init__(self):
self.tokenizer = DummyTokenizer(pad_token_id=0)
self.media_placeholder_token_id = MEDIA_TOKEN_ID
self._call_idx = 0

def apply_chat_template(self, conversation, **kwargs):
return "chat:processed"

def __call__(self, *, text, return_tensors, medias=None, **kwargs):
self._call_idx += 1
if self._call_idx == 1:
# Small grid that fits within max_length after expansion
input_ids = torch.tensor([[1, 2, MEDIA_TOKEN_ID, 3, 4]])
attention_mask = torch.ones_like(input_ids)
grid_thws = torch.tensor([[1, 4, 4]]) # 4 image tokens
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"grid_thws": grid_thws,
"pixel_values": torch.randn(1, 3, 14, 14),
}
# Second sample: 5 text + 16 image tokens = 21 post-expansion;
# max_length=15 truncates into the image region → orphan path.
input_ids = torch.tensor([[1, 2, MEDIA_TOKEN_ID, 3, 4, 5]])
attention_mask = torch.ones_like(input_ids)
grid_thws = torch.tensor([[1, 8, 8]]) # 16 image tokens after expansion
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"grid_thws": grid_thws,
"pixel_values": torch.randn(1, 3, 64, 64),
}

processor = MaybeOrphanProcessor()

def fake_build_labels(input_ids, conversations, processor_arg):
batch_size, seq_len = input_ids.shape
return torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)

monkeypatch.setattr(collate_mod, "build_labels_from_template", fake_build_labels, raising=True)

conv_intact = [
{"role": "user", "content": [{"type": "image", "image": "a.jpg"}, {"type": "text", "text": "?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "."}]},
]
conv_orphan = [
{"role": "user", "content": [{"type": "image", "image": "b.jpg"}, {"type": "text", "text": "?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "."}]},
]
examples = [{"conversation": conv_intact}, {"conversation": conv_orphan}]

batch = collate_mod.kimi_k25_vl_collate_fn(examples, processor, max_length=15)

assert batch["input_ids"].shape[0] == 2
assert "n_images_per_sample" in batch
assert batch["n_images_per_sample"].shape == (2,), (
f"n_images_per_sample length must equal batch_size=2, "
f"got shape {batch['n_images_per_sample'].shape}"
)
# First sample's image survives → 1; second sample is orphaned → 0
assert batch["n_images_per_sample"].tolist() == [1, 0]


def test_kimi_k25_vl_collate_fn_multiple_examples(collate_mod, monkeypatch):
"""Test kimi_k25_vl_collate_fn handles multiple examples with padding."""
# Processor that produces variable length sequences
Expand Down
Loading