feat(deepseek-v4): add Multi-Token Prediction (MTP) training support#2191
Open
khazic wants to merge 5 commits intoNVIDIA-NeMo:mainfrom
Open
feat(deepseek-v4): add Multi-Token Prediction (MTP) training support#2191khazic wants to merge 5 commits intoNVIDIA-NeMo:mainfrom
khazic wants to merge 5 commits intoNVIDIA-NeMo:mainfrom
Conversation
- Add model-agnostic MTP scaffold (MTPConfig, MTPModule, roll_tensor) under nemo_automodel/components/models/common/mtp/ - Add DeepseekV4MTPSublayer: pre-norm attention+MoE blocks without HC machinery; compress_ratios forced to None to avoid IndexError; rotary embeddings stored as non-registered references via object.__setattr__ - Add build_mtp_config_from_hf and build_deepseek_v4_mtp factory functions - Add DeepseekV4CausalLMOutput dataclass so forward returns logits + optional mtp_per_depth_h list for MTP loss computation in train_ft.py - Update DeepseekV4ForCausalLM.__init__ to construct MTP module when num_nextn_predict_layers > 0 - Update state_dict_adapter.py: from_hf splits MTP keys and converts back - Add calculate_mtp_loss to train_ft.py and wire into _forward_backward_step - Add 8 unit tests covering config, construction, forward, backward, state dict Signed-off-by: khazic <khazzz1c@gmail.com>
State-dict adapter:
- from_hf: route MTP layers (layers.{N+k}.*) through dequantize +
aggregate-experts + rename pipeline by renumbering them as layers.{k}.*
and re-prefixing the result to mtp.layers.{k}.*. Previously MTP keys
bypassed dequantization, leaving FP8/FP4 buffers undequantized.
- to_hf: rewrite mtp.layers.{k}.* into model.layers.{N+k}.* and run the
unified split / rename / quantize path; strip the leftover model.
prefix for fusion-only modules (eh_proj, enorm, hnorm, final_layernorm)
that have no entry in the rename table.
- Drop dead _apply_inverse_rename helper.
Recipe (train_ft.py):
- Add _mtp_is_enabled(cfg, model_parts) helper that detects MTP via
YAML override (model.config.num_nextn_predict_layers) or via an
enabled mtp_config attribute on any constructed submodule.
- Raise NotImplementedError in setup() when PP and MTP are both
enabled. The PP schedule does not aggregate the MTP auxiliary loss,
so the MTP head would silently receive no gradients. PP + MTP
wiring is intentionally deferred to a follow-up PR.
- Add TODO marker in _forward_backward_step PP branch pointing at the
same follow-up.
Tests:
- Fix test_forward_shape / test_backward to read .logits from the new
DeepseekV4CausalLMOutput dataclass returned by forward.
- Add MTP round-trip coverage: layer rename, FP8 dequantize, expert
aggregation, to_hf rename / split / quantize, and the fusion-only
fallback for both directions.
Signed-off-by: khazic <khazzz1c@gmail.com>
Contributor
|
/ok to test 3990e0c |
DeepSeek-V4 HF safetensors emit MTP layer keys in two forms:
* ``model.layers.{N+k}.*`` for the standard self_attn / mlp / norms
(carries the canonical ``model.`` prefix like every backbone block).
* ``layers.{N+k}.*`` for V4's MTP-only fusion modules (``eh_proj``,
``enorm``, ``hnorm``, ``final_layernorm``) which sit outside the
HF ``model.`` namespace.
The previous split regex (``r"^layers\.(\d+)\."``) only matched the
unprefixed form, so the prefixed self_attn / mlp / norms keys silently
fell into the backbone bucket. They were then renamed by the standard
backbone pipeline and ended up at ``model.layers.{N+k}.*`` in the
converted state dict — but the model only has ``model.layers.{0..N-1}``,
so DCP load dropped them and ``model.mtp.layers[*].*`` started from
random init. End result: MTP-enabled training silently ran without
loading the MTP head weights from the HF checkpoint.
Repro on a tiny config (num_hidden_layers=2, num_nextn_predict_layers=1):
Model expects 38 mtp.* state_dict keys
adapter.from_hf produced 4 mtp.* keys (the 4 unprefixed fusion ones)
35 mtp.* keys MISSING, 24 keys leaked to model.layers.2.* (dropped)
Make the regex prefix-tolerant (``^(model\.)?layers\.(\d+)\.``) and use
the second capture group as the layer index. After the fix, the same
repro produces 0 missing / 0 extra, and a save→load round-trip via
to_hf -> from_hf reconstructs every mtp.* key the model exposes.
Add a regression test ``test_from_hf_renames_mtp_layer_with_model_prefix``
that exercises the prefixed form so this cannot silently regress again.
Signed-off-by: khazic <khazzz1c@gmail.com>
Contributor
|
/ok to test c228ec4 |
Contributor
|
/ok to test 8af2e5c |
8af2e5c to
3c08682
Compare
Contributor
|
/ok to test 3c08682 |
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds Multi-Token Prediction (MTP) training support for DeepSeek V4 (Flash). MTP layers run as standard pre-norm attention + MoE blocks (no HC machinery), with rotary embeddings shared from the main backbone. The auxiliary loss is computed via the recipe-side
calculate_mtp_lossand added to the main CE loss in the non-PP training path.What's in this PR
Model side
components/models/common/mtp/: model-agnostic scaffold (MTPConfig,MTPModule,roll_tensor).components/models/deepseek_v4/mtp.py: V4-specificDeepseekV4MTPSublayerandbuild_deepseek_v4_mtpfactory.compress_ratiosis forced toNonefor MTP attention to avoidIndexErrorpast the backbone layer count; rotary refs are stored viaobject.__setattr__so they don't pollutestate_dict.components/models/deepseek_v4/model.py:DeepseekV4ForCausalLMnow constructsself.mtpwhennum_nextn_predict_layers > 0and returns aDeepseekV4CausalLMOutputdataclass (logits+ optionalmtp_per_depth_h).State-dict adapter
from_hfruns MTP layers (layers.{N+k}.*) through the same dequantize / aggregate-experts / rename pipeline as the backbone (renumber tolayers.{k}.*, run pipeline, re-prefix tomtp.layers.{k}.*). Previously MTP keys bypassed dequantization and FP8/FP4 buffers were left raw.to_hfrewritesmtp.layers.{k}.*intomodel.layers.{N+k}.*and runs the unified split / rename / quantize path; an explicit fallback strips the leftovermodel.prefix for fusion-only modules (eh_proj/enorm/hnorm/final_layernorm) that have no entry in the rename table.Recipe (
recipes/llm/train_ft.py)calculate_mtp_loss: per-depth CE through the configured loss class (FusedLinearCE / MaskedCE), summed withloss_scaling_factor / Dweighting._forward_backward_step(non-PP branch) readsout.mtp_per_depth_hand adds the MTP loss to the main loss._mtp_is_enabled(cfg, model_parts)+ setup-time guard: raisesNotImplementedErrorif pipeline parallelism is enabled together with MTP, since the PP schedule does not currently aggregate the MTP auxiliary loss. PP + MTP is intentionally deferred to a follow-up PR.Tests
test_deepseek_v4_mtp.py: config / construction / forward / backward / state-dict coverage.test_dsv4_state_dict_adapter.py: MTP round-trip for layer rename, FP8 dequantize, expert aggregation, and the fusion-only fallback in both directions.test_dsv4_model_smoke.py: updated to read.logitsfrom the new dataclass output.Overlap with #2161
PR #2161 (Nemotron V3 MTP) introduces the same
calculate_mtp_losshelper and the same non-PP integration in_forward_backward_step. Those two regions are byte-identical between the branches.This is intentional — both PRs need the same recipe-side scaffolding, and the model-agnostic MTP base (
components/models/common/mtp/) is shared. When #2161 lands first, those duplicated lines will be auto-resolved on rebase, and this PR will reduce to the V4-specific changes (model, MTP sublayer, adapter, PP guard, V4 tests).Test plan
tests/unit_tests/models/deepseek_v4/— 67 passed, 7 skipped (CUDA-gated)ruff format --check .— cleanruff check .— cleanFollow-up