Skip to content

feat(qwen3_5_mtp): Qwen3.5 / Qwen3.5-MoE MTP speculative decoding#1338

Open
sufubao wants to merge 5 commits into
ModelTC:mainfrom
sufubao:qw35_mtp_feature
Open

feat(qwen3_5_mtp): Qwen3.5 / Qwen3.5-MoE MTP speculative decoding#1338
sufubao wants to merge 5 commits into
ModelTC:mainfrom
sufubao:qw35_mtp_feature

Conversation

@sufubao

@sufubao sufubao commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

What

Adds Qwen3.5 / Qwen3.5-MoE Multi-Token Prediction (MTP / speculative decoding) end-to-end, together with the model-agnostic verify-decode machinery it needs. Builds on the BaseMTPModel refactor in #1337 (merged).

Commits

  1. feat(mtp): MTP verify-decode infrastructure — model-agnostic verify dispatch in TpPartBaseModel, dedicated decode CUDA-graph capture/replay for the (mtp_step+1)-expanded verify layout, a shared mtp_verify_extra_state block, and fa3 decode attention narrowed to the verify layout (b_att_seq_len + causal) for fp / fp8 / mla.
  2. feat(qwen3_5_mtp): Qwen3.5 / Qwen3.5-MoE MTP draft models — dense and MoE draft packages sharing a weight-retarget mixin (mtp.* head, embeddings shared with the main model) and the MTP pre-layer fuse.
  3. feat(qwen3next): GDN spec-decode verify path + linear-att cache split — a per-sequence spec causal_conv1d kernel, a widened conv working slot split from the committed (narrow) persisted slot, and MTP draft full-attn KV-slot accounting across the linear-att cache config / mem operator / req manager.
  4. feat(scheduler): MTP verify backend + accept-len transport — a single draft-model factory keyed on (model_type, mtp_mode), the verify decode batch, eagle + vanilla draft decode, and per-request accept-length (b_num_accepted_tokens) transport through the chunked-prefill and dp backends.
  5. test(mtp): unit tests + static MTP benchmark — and the .gitignore benchmark-output rule anchored to /benchmark.

Testing

  • pre-commit (black 21.12b0 + flake8 6.1.0) clean.
  • Unit tests under unit_tests/ cover verify-extra-state, decode CUDA-graph layouts, fa3 narrowing, GDN verify equivalence, linear-att conv/SSM split + CPU-cache persistence, and the draft-model factory.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements comprehensive multi-token prediction (MTP) and speculative decoding support for Qwen3.5 and Qwen3Next models, including updates to attention backends, CUDA graph warmup layouts, and Triton kernels for spec-decode updates. Feedback on the changes highlights a performance concern regarding synchronous device-to-host transfers when validating b_num_accepted_tokens on the GPU, as well as a potential division-by-zero error in the causal conv1d update kernel if the batch size is empty.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +100 to 103
assert int(b_num_accepted_tokens.min()) >= 1 and int(b_num_accepted_tokens.max()) <= mtp_step + 1, (
f"b_num_accepted_tokens out of range [1, {mtp_step + 1}]: "
f"min={int(b_num_accepted_tokens.min())} max={int(b_num_accepted_tokens.max())}"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling .min() and .max() on the GPU tensor b_num_accepted_tokens and casting to int causes a synchronous device-to-host (D2H) transfer. Since this function is called on the eager decode hot path, this synchronization will stall the CPU and degrade inference performance. Consider performing this validation on the CPU side (e.g., on the list of mtp_accept_len in infer_batch.py before moving it to CUDA) to keep the execution fully asynchronous.

Comment on lines +383 to +385
assert conv_state_indices is not None
batch = conv_state_indices.size(0)
dim = x.size(1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If batch (the size of conv_state_indices) is 0, calculating seqlen = x.size(0) // batch will raise a ZeroDivisionError. To make the function robust against empty batches (which can occur in edge cases or testing), add a defensive guard to return x immediately if batch == 0.

        assert conv_state_indices is not None\n        batch = conv_state_indices.size(0)\n        if batch == 0:\n            return x\n        dim = x.size(1)

@sufubao sufubao force-pushed the qw35_mtp_feature branch from 3b495c6 to 3a2354e Compare June 9, 2026 00:37
@sufubao sufubao marked this pull request as ready for review June 9, 2026 00:37
@sufubao sufubao force-pushed the qw35_mtp_feature branch 4 times, most recently from 493ce5e to ed22f8e Compare June 9, 2026 01:43
sufubao added 5 commits June 9, 2026 14:31
Model-agnostic verify-decode machinery: MTP-verify dispatch in TpPartBaseModel, dedicated decode CUDA-graph capture/replay for the (mtp_step+1)-expanded verify layout, a shared mtp_verify_extra_state block on infer_struct/batch_objs, fa3 decode attention narrowed to the verify layout (b_att_seq_len + causal) for fp/fp8/mla, and env/kv-cache helpers for MTP added-layer accounting.
Self-contained dense (qwen3_5_mtp) and MoE (qwen3_5_moe_mtp) MTP draft packages: each carries its own draft wiring (reuse the main model's req/mem managers + rope caches, is_mtp_draft_model marker) and shares a weight-retarget mixin (mtp.* head, embeddings shared with the main model) plus the MTP pre-layer fuse. No shared model base class.
Gated-delta-net (linear attention) speculative-decode verify path for qwen3next: a per-sequence spec causal_conv1d kernel; a widened conv working slot split from the committed (narrow) persisted slot; MTP draft full-attn KV-slot accounting across the linear-att cache config, mem operator and req manager; and removal of the dead gen_b_req_mtp_start_loc kernel.
Wire the verify path through the inference backends: a single draft-model factory keyed on (model_type, mtp_mode); build the (mtp_step+1)-expanded verify decode batch; run the eagle + vanilla draft decode; verify accepted tokens; and thread per-request accept-lengths (b_num_accepted_tokens) from the chunked-prefill and dp backends into the model verify forward.
Behavioural/CUDA coverage for the subtle MTP paths: verify-extra-state metadata, decode CUDA-graph verify layouts, fa3 fp8 verify narrowing, GDN verify equivalence, the spec causal_conv1d kernel and its prefill->decode roundtrip, and the linear-att conv/SSM widened-slot split + snapshot + CPU-cache persistence. Also extends the static-inference MTP benchmark and anchors the .gitignore benchmark-output rule to /benchmark.
@sufubao sufubao force-pushed the qw35_mtp_feature branch from ed22f8e to 05e0d19 Compare June 9, 2026 07:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant