From 2343b3be795cbafd43ffbac30fe157ee29156b4f Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 9 Jun 2026 09:31:07 +0800 Subject: [PATCH 1/2] Revert "refactor(mtp): extract BaseMTPModel mixin shared by existing MTP draft models (#1337)" This reverts commit e03ef9a6f597d0b140e1b33e336fa976aadb19bd. --- lightllm/models/base_mtp_model.py | 36 ------------------- lightllm/models/deepseek_mtp/model.py | 28 +++++++++++++-- lightllm/models/glm4_moe_lite_mtp/model.py | 23 ++++++++++-- lightllm/models/mistral_mtp/model.py | 28 +++++++++++++-- lightllm/models/qwen3_moe_mtp/model.py | 28 +++++++++++++-- .../models/test_base_mtp_model_mixin.py | 22 ------------ 6 files changed, 99 insertions(+), 66 deletions(-) delete mode 100644 lightllm/models/base_mtp_model.py delete mode 100644 unit_tests/models/test_base_mtp_model_mixin.py diff --git a/lightllm/models/base_mtp_model.py b/lightllm/models/base_mtp_model.py deleted file mode 100644 index 796a34d9ad..0000000000 --- a/lightllm/models/base_mtp_model.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import List - -from lightllm.common.basemodel.basemodel import TpPartBaseModel - - -class BaseMTPModel: - """Shared wiring for MTP draft models: they reuse the main model's req/mem managers and rope - caches, and pop the main_model / previous-draft-models kwargs before the base __init__ (#25). - Mixed in BEFORE the concrete base model so these overrides win via MRO. - - Also carries the is_mtp_draft_model marker consumed by detection sites (#23).""" - - is_mtp_draft_model = True - - def __init__(self, kvargs: dict): - self._pre_init(kvargs) - super().__init__(kvargs) - return - - def _pre_init(self, kvargs: dict): - self.main_model: TpPartBaseModel = kvargs.pop("main_model") - self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") - return - - def _init_custom(self): - self._cos_cached = self.main_model._cos_cached - self._sin_cached = self.main_model._sin_cached - return - - def _init_req_manager(self): - self.req_manager = self.main_model.req_manager - return - - def _init_mem_manager(self): - self.mem_manager = self.main_model.mem_manager - return diff --git a/lightllm/models/deepseek_mtp/model.py b/lightllm/models/deepseek_mtp/model.py index 369ff1766f..d9ffdb0e31 100644 --- a/lightllm/models/deepseek_mtp/model.py +++ b/lightllm/models/deepseek_mtp/model.py @@ -1,14 +1,38 @@ -from lightllm.models.base_mtp_model import BaseMTPModel +from typing import List from lightllm.models.deepseek2.model import Deepseek2TpPartModel from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight +from lightllm.common.basemodel import TpPartBaseModel -class Deepseek3MTPModel(BaseMTPModel, Deepseek2TpPartModel): +class Deepseek3MTPModel(Deepseek2TpPartModel): pre_and_post_weight_class = Deepseek3MTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + def _init_weights(self, start_layer_index=None): assert start_layer_index is None self.pre_post_weight = self.pre_and_post_weight_class( diff --git a/lightllm/models/glm4_moe_lite_mtp/model.py b/lightllm/models/glm4_moe_lite_mtp/model.py index 039491dc1f..549bf7ce41 100644 --- a/lightllm/models/glm4_moe_lite_mtp/model.py +++ b/lightllm/models/glm4_moe_lite_mtp/model.py @@ -1,17 +1,36 @@ -from lightllm.models.base_mtp_model import BaseMTPModel +from typing import List from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import ( Glm4MoeLiteMTPPreAndPostLayerWeight, ) +from lightllm.common.basemodel import TpPartBaseModel from lightllm.common.basemodel.basemodel import load_hf_weights -class Glm4MoeLiteMTPModel(BaseMTPModel, Glm4MoeLiteTpPartModel): +class Glm4MoeLiteMTPModel(Glm4MoeLiteTpPartModel): pre_and_post_weight_class = Glm4MoeLiteMTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + def _init_weights(self, start_layer_index=None): assert start_layer_index is None diff --git a/lightllm/models/mistral_mtp/model.py b/lightllm/models/mistral_mtp/model.py index e1edd9d027..7c64625ca8 100644 --- a/lightllm/models/mistral_mtp/model.py +++ b/lightllm/models/mistral_mtp/model.py @@ -1,13 +1,14 @@ -from lightllm.models.base_mtp_model import BaseMTPModel +from typing import List from lightllm.models.mistral.model import MistralTpPartModel from lightllm.models.mistral_mtp.layer_weights.pre_and_post_layer_weight import MistralMTPPreAndPostLayerWeight from lightllm.models.mistral_mtp.layer_infer.pre_layer_infer import MistralMTPPreLayerInfer from lightllm.models.mistral_mtp.layer_infer.post_layer_infer import MistralMTPPostLayerInfer from lightllm.models.mistral_mtp.layer_infer.transformer_layer_infer import MistralMTPTransformerLayerInfer from lightllm.models.mistral_mtp.layer_weights.transformer_layer_weight import MistralMTPTransformerLayerWeight +from lightllm.common.basemodel import TpPartBaseModel -class MistralMTPModel(BaseMTPModel, MistralTpPartModel): +class MistralMTPModel(MistralTpPartModel): pre_and_post_weight_class = MistralMTPPreAndPostLayerWeight pre_layer_infer_class = MistralMTPPreLayerInfer @@ -17,11 +18,34 @@ class MistralMTPModel(BaseMTPModel, MistralTpPartModel): post_layer_infer_class = MistralMTPPostLayerInfer + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + return + def _init_some_value(self): super()._init_some_value() self.layers_num = 1 return + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + def _init_weights(self, start_layer_index=None): assert start_layer_index is None self.config["n_layer"] = 1 diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index ec27b57070..9f83832a7e 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -1,12 +1,13 @@ -from lightllm.models.base_mtp_model import BaseMTPModel +from typing import List from lightllm.models.qwen3_moe.model import Qwen3MOEModel from lightllm.models.qwen3_moe_mtp.layer_weights.pre_and_post_layer_weight import Qwen3MOEMTPPreAndPostLayerWeight from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer from lightllm.models.qwen3_moe_mtp.layer_infer.transformer_layer_infer import Qwen3MOEMTPTransformerLayerInfer from lightllm.models.qwen3_moe_mtp.layer_weights.transformer_layer_weight import Qwen3MOEMTPTransformerLayerWeight +from lightllm.common.basemodel import TpPartBaseModel -class Qwen3MOEMTPModel(BaseMTPModel, Qwen3MOEModel): +class Qwen3MOEMTPModel(Qwen3MOEModel): pre_and_post_weight_class = Qwen3MOEMTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer @@ -14,6 +15,29 @@ class Qwen3MOEMTPModel(BaseMTPModel, Qwen3MOEModel): transformer_weight_class = Qwen3MOEMTPTransformerLayerWeight transformer_layer_infer_class = Qwen3MOEMTPTransformerLayerInfer + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + def _init_weights(self, start_layer_index=None): assert start_layer_index is None mtp_index = len(self.mtp_previous_draft_models) diff --git a/unit_tests/models/test_base_mtp_model_mixin.py b/unit_tests/models/test_base_mtp_model_mixin.py deleted file mode 100644 index 7cddb5179c..0000000000 --- a/unit_tests/models/test_base_mtp_model_mixin.py +++ /dev/null @@ -1,22 +0,0 @@ -import types - - -def test_mixin_pre_init_pops_and_shares_managers(): - from lightllm.models.base_mtp_model import BaseMTPModel - - main = types.SimpleNamespace(_cos_cached="cos", _sin_cached="sin", req_manager="rm", mem_manager="mm") - - obj = BaseMTPModel.__new__(BaseMTPModel) - kvargs = {"main_model": main, "mtp_previous_draft_models": ["d0"], "other": 1} - obj._pre_init(kvargs) - assert obj.main_model is main - assert obj.mtp_previous_draft_models == ["d0"] - assert "main_model" not in kvargs and "mtp_previous_draft_models" not in kvargs - - obj._init_custom() - obj._init_req_manager() - obj._init_mem_manager() - assert obj._cos_cached == "cos" and obj._sin_cached == "sin" - assert obj.req_manager == "rm" and obj.mem_manager == "mm" - - assert BaseMTPModel.is_mtp_draft_model is True From df2d921f804cd292966231927696c0ddabda83a1 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 9 Jun 2026 09:32:23 +0800 Subject: [PATCH 2/2] refactor(mtp): mark MTP draft models with is_mtp_draft_model Re-add the is_mtp_draft_model = True class marker (previously supplied by the reverted BaseMTPModel base) directly on each standalone MTP draft model, so the decode CUDA-graph / padding paths still detect them. Keeps the models self-contained with no shared base class. --- lightllm/models/deepseek_mtp/model.py | 3 +++ lightllm/models/glm4_moe_lite_mtp/model.py | 3 +++ lightllm/models/mistral_mtp/model.py | 3 +++ lightllm/models/qwen3_moe_mtp/model.py | 3 +++ 4 files changed, 12 insertions(+) diff --git a/lightllm/models/deepseek_mtp/model.py b/lightllm/models/deepseek_mtp/model.py index d9ffdb0e31..e2b2a56137 100644 --- a/lightllm/models/deepseek_mtp/model.py +++ b/lightllm/models/deepseek_mtp/model.py @@ -7,6 +7,9 @@ class Deepseek3MTPModel(Deepseek2TpPartModel): + # MTP draft model marker (consumed by the decode CUDA-graph / padding paths). + is_mtp_draft_model = True + pre_and_post_weight_class = Deepseek3MTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer diff --git a/lightllm/models/glm4_moe_lite_mtp/model.py b/lightllm/models/glm4_moe_lite_mtp/model.py index 549bf7ce41..2e4ba5c86b 100644 --- a/lightllm/models/glm4_moe_lite_mtp/model.py +++ b/lightllm/models/glm4_moe_lite_mtp/model.py @@ -10,6 +10,9 @@ class Glm4MoeLiteMTPModel(Glm4MoeLiteTpPartModel): + # MTP draft model marker (consumed by the decode CUDA-graph / padding paths). + is_mtp_draft_model = True + pre_and_post_weight_class = Glm4MoeLiteMTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer diff --git a/lightllm/models/mistral_mtp/model.py b/lightllm/models/mistral_mtp/model.py index 7c64625ca8..f17bc0a383 100644 --- a/lightllm/models/mistral_mtp/model.py +++ b/lightllm/models/mistral_mtp/model.py @@ -10,6 +10,9 @@ class MistralMTPModel(MistralTpPartModel): + # MTP draft model marker (consumed by the decode CUDA-graph / padding paths). + is_mtp_draft_model = True + pre_and_post_weight_class = MistralMTPPreAndPostLayerWeight pre_layer_infer_class = MistralMTPPreLayerInfer diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index 9f83832a7e..d9854250e2 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -9,6 +9,9 @@ class Qwen3MOEMTPModel(Qwen3MOEModel): + # MTP draft model marker (consumed by the decode CUDA-graph / padding paths). + is_mtp_draft_model = True + pre_and_post_weight_class = Qwen3MOEMTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer