Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions nemo_automodel/components/datasets/vlm/collate_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,10 @@ def kimi_k25_vl_collate_fn(
all_expanded = []
all_pixel_values = []
all_grid_thws = []
# Per-sample image counts, kept in lockstep with all_expanded so that
# n_images_per_sample length matches batch_size downstream. Samples that
# are text-only or whose image region was orphaned by truncation get 0.
per_sample_image_count: List[int] = []

for i, conversation in enumerate(conversations):
# Collect medias for this conversation
Expand Down Expand Up @@ -923,12 +927,14 @@ def kimi_k25_vl_collate_fn(

# Only include image data if all expanded image tokens survived truncation.
# Partial truncation into image regions would cause a mismatch in the model forward.
sample_image_count = 0
if grid_thws is not None:
merge_h, merge_w = _DEFAULT_MERGE_KERNEL
expected_image_tokens = sum(int((h // merge_h) * (w // merge_w)) for _, h, w in grid_thws.tolist())
actual_image_tokens = (input_ids == media_token_id).sum().item()
if actual_image_tokens == expected_image_tokens:
all_grid_thws.append(grid_thws)
sample_image_count = int(grid_thws.shape[0])
if "pixel_values" in sample_batch:
all_pixel_values.append(sample_batch["pixel_values"])
else:
Expand All @@ -943,6 +949,7 @@ def kimi_k25_vl_collate_fn(
"attention_mask": attention_mask,
}
)
per_sample_image_count.append(sample_image_count)

if not all_expanded:
raise ValueError(
Expand Down Expand Up @@ -990,9 +997,10 @@ def kimi_k25_vl_collate_fn(
result["grid_thws"] = torch.cat(all_grid_thws, dim=0)
# Also add as image_grid_hws for PP chunking in finetune.py
result["image_grid_hws"] = result["grid_thws"][:, 1:] # [N, 3] -> [N, 2] (drop temporal dim, keep H,W)
# Per-sample image counts for PP chunking
image_counts = [g.shape[0] for g in all_grid_thws]
result["n_images_per_sample"] = torch.tensor(image_counts, dtype=torch.long)
# Per-sample image counts for PP chunking. Length must equal batch_size,
# so include zeros for text-only samples and for samples whose image
# region was orphaned by truncation.
result["n_images_per_sample"] = torch.tensor(per_sample_image_count, dtype=torch.long)

# Build labels
labels = build_labels_from_template(
Expand Down
58 changes: 38 additions & 20 deletions nemo_automodel/components/models/qwen3_5_moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def forward(
else:
raise ValueError("inputs_embeds must be provided for pipeline stages without embed_tokens")

# If we have pixel values and a vision encoder, go through the full HF
# If we have visual pixel values and a vision encoder, go through the full HF
# VL forward (vision encoding + multimodal scatter + text).
if pixel_values is not None and self.visual is not None:
if (pixel_values is not None or pixel_values_videos is not None) and self.visual is not None:
return super().forward(
input_ids=None,
attention_mask=attention_mask,
Expand Down Expand Up @@ -491,23 +491,25 @@ def forward(
):
# PP VLM support: retrieve pixel_values from stored chunks if not passed
pixel_values = kwargs.get("pixel_values", None)
pixel_values_videos = kwargs.get("pixel_values_videos", None)
image_grid_thw = kwargs.get("image_grid_thw", None)
if (
pixel_values is None
and hasattr(self, "_vlm_pixel_values_chunks")
and self._vlm_pixel_values_chunks is not None
):
image_token_id = self.config.image_token_id
vision_start_token_id = self.config.vision_start_token_id
has_media_tokens = input_ids is not None and (
(input_ids == image_token_id).any() or (input_ids == vision_start_token_id).any()
)
video_grid_thw = kwargs.get("video_grid_thw", None)
image_token_id = self.config.image_token_id
vision_start_token_id = self.config.vision_start_token_id
has_media_tokens = input_ids is not None and (
(input_ids == image_token_id).any() or (input_ids == vision_start_token_id).any()
)

chunk_idx = getattr(self, "_vlm_chunk_idx", 0)
consumed_vlm_chunk = False

if has_media_tokens:
chunk_idx = getattr(self, "_vlm_chunk_idx", 0)
if chunk_idx < len(self._vlm_pixel_values_chunks):
pixel_values = self._vlm_pixel_values_chunks[chunk_idx]
image_grid_hws = self._vlm_image_grid_hws_chunks[chunk_idx]
if pixel_values is None and has_media_tokens:
image_chunks = getattr(self, "_vlm_pixel_values_chunks", None)
if image_chunks is not None and chunk_idx < len(image_chunks):
pixel_values = image_chunks[chunk_idx]
image_grid_chunks = getattr(self, "_vlm_image_grid_hws_chunks", None)
if image_grid_chunks is not None and chunk_idx < len(image_grid_chunks):
image_grid_hws = image_grid_chunks[chunk_idx]
if image_grid_hws is not None and image_grid_hws.numel() > 0:
if image_grid_hws.shape[-1] == 2:
ones = torch.ones(
Expand All @@ -516,9 +518,25 @@ def forward(
image_grid_thw = torch.cat([ones, image_grid_hws], dim=-1)
else:
image_grid_thw = image_grid_hws
kwargs["pixel_values"] = pixel_values
kwargs["image_grid_thw"] = image_grid_thw
self._vlm_chunk_idx = chunk_idx + 1
kwargs["pixel_values"] = pixel_values
kwargs["image_grid_thw"] = image_grid_thw
consumed_vlm_chunk = True

if pixel_values_videos is None and has_media_tokens:
video_chunks = getattr(self, "_vlm_pixel_values_videos_chunks", None)
if video_chunks is not None and chunk_idx < len(video_chunks):
video_chunk = video_chunks[chunk_idx]
if video_chunk.numel() > 0:
pixel_values_videos = video_chunk
video_grid_chunks = getattr(self, "_vlm_video_grid_thw_chunks", None)
if video_grid_chunks is not None and chunk_idx < len(video_grid_chunks):
video_grid_thw = video_grid_chunks[chunk_idx]
kwargs["pixel_values_videos"] = pixel_values_videos
kwargs["video_grid_thw"] = video_grid_thw
consumed_vlm_chunk = True

if consumed_vlm_chunk:
self._vlm_chunk_idx = chunk_idx + 1

if "qkv_format" in kwargs and kwargs["qkv_format"] == "thd":
input_ids, position_ids, padding_mask, kwargs = squeeze_input_for_thd(
Expand Down
36 changes: 36 additions & 0 deletions nemo_automodel/components/models/qwen3_omni_moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,42 @@ def forward(
)
attention_mask = None

chunk_idx = getattr(self, "_vlm_chunk_idx", 0)
consumed_vlm_chunk = False

if pixel_values is None:
image_chunks = getattr(self, "_vlm_pixel_values_chunks", None)
if image_chunks is not None and chunk_idx < len(image_chunks):
image_chunk = image_chunks[chunk_idx]
if image_chunk.numel() > 0:
pixel_values = image_chunk
image_grid_chunks = getattr(self, "_vlm_image_grid_hws_chunks", None)
if image_grid_chunks is not None and chunk_idx < len(image_grid_chunks):
image_grid = image_grid_chunks[chunk_idx]
if image_grid is not None and image_grid.numel() > 0:
if image_grid.shape[-1] == 2:
ones = torch.ones(
image_grid.shape[0], 1, dtype=image_grid.dtype, device=image_grid.device
)
image_grid_thw = torch.cat([ones, image_grid], dim=-1)
else:
image_grid_thw = image_grid
consumed_vlm_chunk = True

if pixel_values_videos is None:
video_chunks = getattr(self, "_vlm_pixel_values_videos_chunks", None)
if video_chunks is not None and chunk_idx < len(video_chunks):
video_chunk = video_chunks[chunk_idx]
if video_chunk.numel() > 0:
pixel_values_videos = video_chunk
video_grid_chunks = getattr(self, "_vlm_video_grid_thw_chunks", None)
if video_grid_chunks is not None and chunk_idx < len(video_grid_chunks):
video_grid_thw = video_grid_chunks[chunk_idx]
consumed_vlm_chunk = True

if consumed_vlm_chunk:
self._vlm_chunk_idx = chunk_idx + 1

# 1. Get input embeddings
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
Expand Down
53 changes: 36 additions & 17 deletions nemo_automodel/components/models/qwen3_vl_moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,21 +526,26 @@ def forward(
):
# PP VLM support: retrieve pixel_values from stored chunks if not passed directly
pixel_values = kwargs.get("pixel_values", None)
pixel_values_videos = kwargs.get("pixel_values_videos", None)
image_grid_thw = kwargs.get("image_grid_thw", None)
if (
pixel_values is None
and hasattr(self, "_vlm_pixel_values_chunks")
and self._vlm_pixel_values_chunks is not None
):
# Check if we have media tokens in input_ids
# 151655 = <|image_pad|>, 151656 = <|video_pad|>
has_media_tokens = input_ids is not None and ((input_ids == 151655).any() or (input_ids == 151656).any())

if has_media_tokens:
chunk_idx = getattr(self, "_vlm_chunk_idx", 0)
if chunk_idx < len(self._vlm_pixel_values_chunks):
pixel_values = self._vlm_pixel_values_chunks[chunk_idx]
image_grid_hws = self._vlm_image_grid_hws_chunks[chunk_idx]
video_grid_thw = kwargs.get("video_grid_thw", None)
if input_ids is not None:
has_image_tokens = (input_ids == 151655).any()
has_video_tokens = (input_ids == 151656).any()
else:
has_image_tokens = False
has_video_tokens = False

chunk_idx = getattr(self, "_vlm_chunk_idx", 0)
consumed_vlm_chunk = False

if pixel_values is None and has_image_tokens:
image_chunks = getattr(self, "_vlm_pixel_values_chunks", None)
if image_chunks is not None and chunk_idx < len(image_chunks):
pixel_values = image_chunks[chunk_idx]
image_grid_chunks = getattr(self, "_vlm_image_grid_hws_chunks", None)
if image_grid_chunks is not None and chunk_idx < len(image_grid_chunks):
image_grid_hws = image_grid_chunks[chunk_idx]
# Convert image_grid_hws [N, 2] to image_grid_thw [N, 3] by prepending T=1
if image_grid_hws is not None and image_grid_hws.numel() > 0:
if image_grid_hws.shape[-1] == 2:
Expand All @@ -550,9 +555,23 @@ def forward(
image_grid_thw = torch.cat([ones, image_grid_hws], dim=-1)
else:
image_grid_thw = image_grid_hws
kwargs["pixel_values"] = pixel_values
kwargs["image_grid_thw"] = image_grid_thw
self._vlm_chunk_idx = chunk_idx + 1
kwargs["pixel_values"] = pixel_values
kwargs["image_grid_thw"] = image_grid_thw
consumed_vlm_chunk = True

if pixel_values_videos is None and has_video_tokens:
video_chunks = getattr(self, "_vlm_pixel_values_videos_chunks", None)
if video_chunks is not None and chunk_idx < len(video_chunks):
pixel_values_videos = video_chunks[chunk_idx]
video_grid_chunks = getattr(self, "_vlm_video_grid_thw_chunks", None)
if video_grid_chunks is not None and chunk_idx < len(video_grid_chunks):
video_grid_thw = video_grid_chunks[chunk_idx]
kwargs["pixel_values_videos"] = pixel_values_videos
kwargs["video_grid_thw"] = video_grid_thw
consumed_vlm_chunk = True

if consumed_vlm_chunk:
self._vlm_chunk_idx = chunk_idx + 1

# With pipeline parallelism, attention_mask (from batch kwargs) can have a
# different sequence length than inputs_embeds (hidden states from prev stage).
Expand Down
32 changes: 32 additions & 0 deletions nemo_automodel/recipes/vlm/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,9 @@ def _forward_backward_step(
image_sizes = batch.pop("image_sizes", None)
image_position_ids = batch.pop("image_position_ids", None)
n_images_per_sample = batch.pop("n_images_per_sample", None)
pixel_values_videos = batch.pop("pixel_values_videos", None)
video_grid_thw = batch.pop("video_grid_thw", None)
n_videos_per_sample = batch.pop("n_videos_per_sample", None)

image_grid = image_grid_hws if image_grid_hws is not None else image_grid_thw
if image_grid is None and image_sizes is not None:
Expand All @@ -1029,6 +1032,30 @@ def _forward_backward_step(
stage0_model._vlm_pixel_values_chunks = pixel_values_chunks
stage0_model._vlm_image_grid_hws_chunks = image_grid_chunks
stage0_model._vlm_chunk_idx = 0
elif pixel_values is not None:
batch["pixel_values"] = pixel_values

if self.pp.info.has_first_stage and pixel_values_videos is not None and video_grid_thw is not None:
stage0_model = self.model_parts[0]
n_microbatches = self.pp._info.schedule._n_microbatches
batch_size = input_ids.shape[0]

pixel_values_videos_chunks, video_grid_thw_chunks = _chunk_vlm_media(
pixel_values_videos,
video_grid_thw,
batch_size,
n_microbatches,
n_images_per_sample=n_videos_per_sample,
)

stage0_model._vlm_pixel_values_videos_chunks = pixel_values_videos_chunks
stage0_model._vlm_video_grid_thw_chunks = video_grid_thw_chunks
stage0_model._vlm_chunk_idx = 0
else:
if pixel_values_videos is not None:
batch["pixel_values_videos"] = pixel_values_videos
if video_grid_thw is not None:
batch["video_grid_thw"] = video_grid_thw

if self.pp.info.has_first_stage:
self.pp.info.schedule.step(input_ids, target=targets, losses=losses, **batch)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM — clean fix. The video chunking mirrors the existing image path correctly, the shared _vlm_chunk_idx cursor works because both modalities are chunked by the same microbatch boundaries, and the cleanup is properly guarded. Test coverage for video-only, mixed image+video, and the n_images_per_sample batch-size regression is solid.

Expand All @@ -1041,6 +1068,11 @@ def _forward_backward_step(
stage0_model._vlm_pixel_values_chunks = None
stage0_model._vlm_image_grid_hws_chunks = None
stage0_model._vlm_chunk_idx = None
if self.pp.info.has_first_stage and pixel_values_videos is not None and video_grid_thw is not None:
stage0_model = self.model_parts[0]
stage0_model._vlm_pixel_values_videos_chunks = None
stage0_model._vlm_video_grid_thw_chunks = None
stage0_model._vlm_chunk_idx = None

if self.pp.info.has_last_stage:
local_loss = torch.sum(torch.stack(losses))
Expand Down
Loading
Loading