Skip to content

fix(pp): preserve VLM forward when class opts in via _pp_keep_self_forward#2192

Open
khazic wants to merge 3 commits intoNVIDIA-NeMo:mainfrom
khazic:khazic/fix/pp-vlm-preserve-forward
Open

fix(pp): preserve VLM forward when class opts in via _pp_keep_self_forward#2192
khazic wants to merge 3 commits intoNVIDIA-NeMo:mainfrom
khazic:khazic/fix/pp-vlm-preserve-forward

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented May 8, 2026

What

patch_hf_model_for_pp unconditionally replaced model.forward with the generic pipeline_forward_causal_lm for any VLM that is not Gemma4 or Mistral3. Chunk-aware VLMs (Qwen3-VL-MoE, Qwen3.5-MoE, KimiVL, Kimi-K2.5-VL) keep the per-microbatch pixel_values in self._vlm_pixel_values_chunks and fetch them inside their own forward; the generic forward never reads that attribute, so under PP the vision tower silently never ran and image tokens were trained against placeholder text embeddings.

The existing recipes worked around this by setting patch_inner_model: false and patch_causal_lm_model: false in their YAMLs, but the framework had no way to enforce it — any new PP+VLM recipe that forgot the two flags would silently corrupt training.

Repro (pre-fix)

from nemo_automodel.components.distributed.pipelining.hf_utils import patch_hf_model_for_pp
# build a chunk-aware VLM, attach _vlm_pixel_values_chunks, call patch with defaults
patch_hf_model_for_pp(model, patch_inner_model=True, patch_causal_lm_model=True)
# model.forward is now generic causal_lm; pixel_values never reaches inner model

grep _vlm_pixel_values_chunks pipelining/hf_utils.py confirms only the dedicated gemma4_vlm and mistral3_vlm forwards read the chunks.

Changelog

  • Add opt-in class attribute _pp_keep_self_forward. When True, patch_hf_model_for_pp is a no-op for both inner and outer modules — preserves the model's chunk-aware forward.
  • validate_hf_model_for_pipeline_support now refuses to start a PP job for a VLM that has neither a dedicated forward (gemma4 / mistral3) nor _pp_keep_self_forward, so the same misconfiguration cannot reach training silently again.
  • Mark Qwen3VLMoeForConditionalGeneration, Qwen3_5MoeForConditionalGeneration, KimiVLForConditionalGeneration, KimiK25VLForConditionalGeneration with _pp_keep_self_forward = True.
  • Unit tests for: the patch escape hatch, validation passing for chunk-aware VLMs, validation passing for dedicated-forward VLMs, validation raising for unsupported VLM-PP combos.

Test plan

  • pytest tests/unit_tests/distributed/pipelining/test_hf_utils.py — 58 passed
  • pytest tests/unit_tests/models/{qwen3_vl_moe,qwen3_5_moe,kimivl,kimi_k25_vl}/ — 489 passed (3 pre-existing flash-attn-env failures unrelated to this change)
  • End-to-end repro: confirmed pixel_values now reaches inner model under default patch_*=True for chunk-aware VLMs
  • CI: L0 + L2 VLM PP suites

…rward

patch_hf_model_for_pp unconditionally replaced model.forward with the
generic CausalLM forward for any VLM that is not Gemma4 or Mistral3.
Chunk-aware VLMs (Qwen3-VL-MoE, Qwen3.5-MoE, KimiVL, Kimi-K2.5-VL) keep
the per-microbatch pixel_values in self._vlm_pixel_values_chunks and
fetch them inside their own forward; the generic forward never reads
that attribute, so vision_tower silently never ran under PP and image
tokens were trained against placeholder text embeddings.

Add an opt-in class attribute _pp_keep_self_forward; when True the patch
is a no-op for both inner and outer modules, matching the behaviour the
existing recipes had to enforce by setting both patch_*: false flags by
hand. validate_hf_model_for_pipeline_support now refuses to start a PP
job for a VLM that has neither a dedicated forward (gemma4/mistral3)
nor _pp_keep_self_forward, so the same misconfiguration cannot reach
training silently in the future.

Mark Qwen3VLMoeForConditionalGeneration, Qwen3_5MoeForConditionalGeneration,
KimiVLForConditionalGeneration, and KimiK25VLForConditionalGeneration as
keepers; add unit tests covering both the patch escape hatch and the new
validation paths.

Signed-off-by: khazic <khazzz1c@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 8, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 2a8f924

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test d28e393

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label May 8, 2026
@HuiyingLi
Copy link
Copy Markdown
Contributor

Hi @khazic , we have two paths for model implementation, HF implementation v.s. custom implementation (mostly moe). The flag patch_hf_model_for_pp is for HF implementation models. For custom moe models I don't think we need the patch since we have the model forward.

Per @HuiyingLi's review feedback: the decision of whether a model needs
``patch_hf_model_for_pp`` belongs at the pipeline-build call site, not
inside the patcher. The patcher should keep a single concern -- pick the
right patch flavour for HF-impl / Gemma4 / Mistral3 -- and the call site
should decide whether to invoke it at all.

- Move the early-return that consults ``_pp_keep_self_forward`` out of
  ``patch_hf_model_for_pp`` and into ``_build_stage_from_modules`` in
  ``functional.py`` so custom-impl models that own their PP-aware forward
  (Qwen3-VL-MoE, Qwen3.5-MoE, KimiVL, Kimi-K2.5-VL) skip the patcher
  entirely instead of entering it just to bail.
- Expose the helper as ``model_keeps_self_forward(model)`` so the call site
  has a named, testable predicate rather than an inlined ``getattr``.
- Restructure the existing regression test around the helper instead of
  the patcher's no-op behaviour. The validation guard added in the prior
  commit (Mistral3 / Gemma4 dedicated-forward whitelist) is unchanged and
  still catches future custom VLMs that forget to declare the flag.

The class-level ``_pp_keep_self_forward = True`` markers remain on the
four chunk-aware VLMs because Mistral3 (also a custom-impl class) does
need ``patch_hf_model_for_pp`` for its dedicated forward, so a blanket
"custom impl skips patch" rule based on registry membership would
misroute Mistral3 -- the per-class flag is still the right signal.

Signed-off-by: khazic <khazzz1c@gmail.com>
@khazic
Copy link
Copy Markdown
Contributor Author

khazic commented May 9, 2026

Thanks @HuiyingLi — that framing is right and I've restructured the PR accordingly in 1a09f17.

The decision of whether to patch now lives at the call site (functional._build_stage_from_modules), and patch_hf_model_for_pp is back to a single concern (pick the right patch flavour for HF / Gemma4 / Mistral3). The custom-impl MoE classes that own their PP-aware forward — Qwen3-VL-MoE, Qwen3.5-MoE, KimiVL, Kimi-K2.5-VL — now skip the patcher entirely instead of entering it and bailing.

One detail worth flagging: I kept the per-class _pp_keep_self_forward = True marker rather than deciding "skip patch iff custom-impl" purely from MODEL_ARCH_MAPPING membership. Mistral3 is also a custom-impl class registered in NeMo, but it actually does need patch_hf_model_for_pp for its dedicated pipeline_forward_mistral3_vlm. A blanket "custom = skip" rule would silently misroute Mistral3, so the flag stays as the precise signal. The call site reads it via the new model_keeps_self_forward(model) helper.

The validation guard added earlier in the PR (catches a future custom-impl VLM that forgets the marker) is unchanged.

New SHA for /ok to test: 1a09f173296d5635b9d95365a790bed2d69cf461.

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 1a09f17

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants