fix(pp): preserve VLM forward when class opts in via _pp_keep_self_forward#2192
fix(pp): preserve VLM forward when class opts in via _pp_keep_self_forward#2192khazic wants to merge 3 commits intoNVIDIA-NeMo:mainfrom
Conversation
…rward patch_hf_model_for_pp unconditionally replaced model.forward with the generic CausalLM forward for any VLM that is not Gemma4 or Mistral3. Chunk-aware VLMs (Qwen3-VL-MoE, Qwen3.5-MoE, KimiVL, Kimi-K2.5-VL) keep the per-microbatch pixel_values in self._vlm_pixel_values_chunks and fetch them inside their own forward; the generic forward never reads that attribute, so vision_tower silently never ran under PP and image tokens were trained against placeholder text embeddings. Add an opt-in class attribute _pp_keep_self_forward; when True the patch is a no-op for both inner and outer modules, matching the behaviour the existing recipes had to enforce by setting both patch_*: false flags by hand. validate_hf_model_for_pipeline_support now refuses to start a PP job for a VLM that has neither a dedicated forward (gemma4/mistral3) nor _pp_keep_self_forward, so the same misconfiguration cannot reach training silently in the future. Mark Qwen3VLMoeForConditionalGeneration, Qwen3_5MoeForConditionalGeneration, KimiVLForConditionalGeneration, and KimiK25VLForConditionalGeneration as keepers; add unit tests covering both the patch escape hatch and the new validation paths. Signed-off-by: khazic <khazzz1c@gmail.com>
|
/ok to test 2a8f924 |
|
/ok to test d28e393 |
|
Hi @khazic , we have two paths for model implementation, HF implementation v.s. custom implementation (mostly moe). The flag |
Per @HuiyingLi's review feedback: the decision of whether a model needs ``patch_hf_model_for_pp`` belongs at the pipeline-build call site, not inside the patcher. The patcher should keep a single concern -- pick the right patch flavour for HF-impl / Gemma4 / Mistral3 -- and the call site should decide whether to invoke it at all. - Move the early-return that consults ``_pp_keep_self_forward`` out of ``patch_hf_model_for_pp`` and into ``_build_stage_from_modules`` in ``functional.py`` so custom-impl models that own their PP-aware forward (Qwen3-VL-MoE, Qwen3.5-MoE, KimiVL, Kimi-K2.5-VL) skip the patcher entirely instead of entering it just to bail. - Expose the helper as ``model_keeps_self_forward(model)`` so the call site has a named, testable predicate rather than an inlined ``getattr``. - Restructure the existing regression test around the helper instead of the patcher's no-op behaviour. The validation guard added in the prior commit (Mistral3 / Gemma4 dedicated-forward whitelist) is unchanged and still catches future custom VLMs that forget to declare the flag. The class-level ``_pp_keep_self_forward = True`` markers remain on the four chunk-aware VLMs because Mistral3 (also a custom-impl class) does need ``patch_hf_model_for_pp`` for its dedicated forward, so a blanket "custom impl skips patch" rule based on registry membership would misroute Mistral3 -- the per-class flag is still the right signal. Signed-off-by: khazic <khazzz1c@gmail.com>
|
Thanks @HuiyingLi — that framing is right and I've restructured the PR accordingly in The decision of whether to patch now lives at the call site ( One detail worth flagging: I kept the per-class The validation guard added earlier in the PR (catches a future custom-impl VLM that forgets the marker) is unchanged. New SHA for |
|
/ok to test 1a09f17 |
What
patch_hf_model_for_ppunconditionally replacedmodel.forwardwith the genericpipeline_forward_causal_lmfor any VLM that is not Gemma4 or Mistral3. Chunk-aware VLMs (Qwen3-VL-MoE, Qwen3.5-MoE, KimiVL, Kimi-K2.5-VL) keep the per-microbatchpixel_valuesinself._vlm_pixel_values_chunksand fetch them inside their own forward; the generic forward never reads that attribute, so under PP the vision tower silently never ran and image tokens were trained against placeholder text embeddings.The existing recipes worked around this by setting
patch_inner_model: falseandpatch_causal_lm_model: falsein their YAMLs, but the framework had no way to enforce it — any new PP+VLM recipe that forgot the two flags would silently corrupt training.Repro (pre-fix)
grep _vlm_pixel_values_chunks pipelining/hf_utils.pyconfirms only the dedicatedgemma4_vlmandmistral3_vlmforwards read the chunks.Changelog
_pp_keep_self_forward. WhenTrue,patch_hf_model_for_ppis a no-op for both inner and outer modules — preserves the model's chunk-aware forward.validate_hf_model_for_pipeline_supportnow refuses to start a PP job for a VLM that has neither a dedicated forward (gemma4 / mistral3) nor_pp_keep_self_forward, so the same misconfiguration cannot reach training silently again.Qwen3VLMoeForConditionalGeneration,Qwen3_5MoeForConditionalGeneration,KimiVLForConditionalGeneration,KimiK25VLForConditionalGenerationwith_pp_keep_self_forward = True.Test plan
pytest tests/unit_tests/distributed/pipelining/test_hf_utils.py— 58 passedpytest tests/unit_tests/models/{qwen3_vl_moe,qwen3_5_moe,kimivl,kimi_k25_vl}/— 489 passed (3 pre-existing flash-attn-env failures unrelated to this change)pixel_valuesnow reaches inner model under defaultpatch_*=Truefor chunk-aware VLMs