diff --git a/lightllm/models/base_mtp_model.py b/lightllm/models/base_mtp_model.py deleted file mode 100644 index 796a34d9a..000000000 --- a/lightllm/models/base_mtp_model.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import List - -from lightllm.common.basemodel.basemodel import TpPartBaseModel - - -class BaseMTPModel: - """Shared wiring for MTP draft models: they reuse the main model's req/mem managers and rope - caches, and pop the main_model / previous-draft-models kwargs before the base __init__ (#25). - Mixed in BEFORE the concrete base model so these overrides win via MRO. - - Also carries the is_mtp_draft_model marker consumed by detection sites (#23).""" - - is_mtp_draft_model = True - - def __init__(self, kvargs: dict): - self._pre_init(kvargs) - super().__init__(kvargs) - return - - def _pre_init(self, kvargs: dict): - self.main_model: TpPartBaseModel = kvargs.pop("main_model") - self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") - return - - def _init_custom(self): - self._cos_cached = self.main_model._cos_cached - self._sin_cached = self.main_model._sin_cached - return - - def _init_req_manager(self): - self.req_manager = self.main_model.req_manager - return - - def _init_mem_manager(self): - self.mem_manager = self.main_model.mem_manager - return diff --git a/lightllm/models/deepseek_mtp/model.py b/lightllm/models/deepseek_mtp/model.py index 369ff1766..e2b2a5613 100644 --- a/lightllm/models/deepseek_mtp/model.py +++ b/lightllm/models/deepseek_mtp/model.py @@ -1,14 +1,41 @@ -from lightllm.models.base_mtp_model import BaseMTPModel +from typing import List from lightllm.models.deepseek2.model import Deepseek2TpPartModel from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight +from lightllm.common.basemodel import TpPartBaseModel -class Deepseek3MTPModel(BaseMTPModel, Deepseek2TpPartModel): +class Deepseek3MTPModel(Deepseek2TpPartModel): + + # MTP draft model marker (consumed by the decode CUDA-graph / padding paths). + is_mtp_draft_model = True pre_and_post_weight_class = Deepseek3MTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + def _init_weights(self, start_layer_index=None): assert start_layer_index is None self.pre_post_weight = self.pre_and_post_weight_class( diff --git a/lightllm/models/glm4_moe_lite_mtp/model.py b/lightllm/models/glm4_moe_lite_mtp/model.py index 039491dc1..2e4ba5c86 100644 --- a/lightllm/models/glm4_moe_lite_mtp/model.py +++ b/lightllm/models/glm4_moe_lite_mtp/model.py @@ -1,17 +1,39 @@ -from lightllm.models.base_mtp_model import BaseMTPModel +from typing import List from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import ( Glm4MoeLiteMTPPreAndPostLayerWeight, ) +from lightllm.common.basemodel import TpPartBaseModel from lightllm.common.basemodel.basemodel import load_hf_weights -class Glm4MoeLiteMTPModel(BaseMTPModel, Glm4MoeLiteTpPartModel): +class Glm4MoeLiteMTPModel(Glm4MoeLiteTpPartModel): + + # MTP draft model marker (consumed by the decode CUDA-graph / padding paths). + is_mtp_draft_model = True pre_and_post_weight_class = Glm4MoeLiteMTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + def _init_weights(self, start_layer_index=None): assert start_layer_index is None diff --git a/lightllm/models/mistral_mtp/model.py b/lightllm/models/mistral_mtp/model.py index e1edd9d02..f17bc0a38 100644 --- a/lightllm/models/mistral_mtp/model.py +++ b/lightllm/models/mistral_mtp/model.py @@ -1,13 +1,17 @@ -from lightllm.models.base_mtp_model import BaseMTPModel +from typing import List from lightllm.models.mistral.model import MistralTpPartModel from lightllm.models.mistral_mtp.layer_weights.pre_and_post_layer_weight import MistralMTPPreAndPostLayerWeight from lightllm.models.mistral_mtp.layer_infer.pre_layer_infer import MistralMTPPreLayerInfer from lightllm.models.mistral_mtp.layer_infer.post_layer_infer import MistralMTPPostLayerInfer from lightllm.models.mistral_mtp.layer_infer.transformer_layer_infer import MistralMTPTransformerLayerInfer from lightllm.models.mistral_mtp.layer_weights.transformer_layer_weight import MistralMTPTransformerLayerWeight +from lightllm.common.basemodel import TpPartBaseModel -class MistralMTPModel(BaseMTPModel, MistralTpPartModel): +class MistralMTPModel(MistralTpPartModel): + + # MTP draft model marker (consumed by the decode CUDA-graph / padding paths). + is_mtp_draft_model = True pre_and_post_weight_class = MistralMTPPreAndPostLayerWeight pre_layer_infer_class = MistralMTPPreLayerInfer @@ -17,11 +21,34 @@ class MistralMTPModel(BaseMTPModel, MistralTpPartModel): post_layer_infer_class = MistralMTPPostLayerInfer + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + return + def _init_some_value(self): super()._init_some_value() self.layers_num = 1 return + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + def _init_weights(self, start_layer_index=None): assert start_layer_index is None self.config["n_layer"] = 1 diff --git a/lightllm/models/qwen3_moe_mtp/model.py b/lightllm/models/qwen3_moe_mtp/model.py index ec27b5707..d9854250e 100644 --- a/lightllm/models/qwen3_moe_mtp/model.py +++ b/lightllm/models/qwen3_moe_mtp/model.py @@ -1,12 +1,16 @@ -from lightllm.models.base_mtp_model import BaseMTPModel +from typing import List from lightllm.models.qwen3_moe.model import Qwen3MOEModel from lightllm.models.qwen3_moe_mtp.layer_weights.pre_and_post_layer_weight import Qwen3MOEMTPPreAndPostLayerWeight from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer from lightllm.models.qwen3_moe_mtp.layer_infer.transformer_layer_infer import Qwen3MOEMTPTransformerLayerInfer from lightllm.models.qwen3_moe_mtp.layer_weights.transformer_layer_weight import Qwen3MOEMTPTransformerLayerWeight +from lightllm.common.basemodel import TpPartBaseModel -class Qwen3MOEMTPModel(BaseMTPModel, Qwen3MOEModel): +class Qwen3MOEMTPModel(Qwen3MOEModel): + + # MTP draft model marker (consumed by the decode CUDA-graph / padding paths). + is_mtp_draft_model = True pre_and_post_weight_class = Qwen3MOEMTPPreAndPostLayerWeight pre_layer_infer_class = Deepseek3MTPPreLayerInfer @@ -14,6 +18,29 @@ class Qwen3MOEMTPModel(BaseMTPModel, Qwen3MOEModel): transformer_weight_class = Qwen3MOEMTPTransformerLayerWeight transformer_layer_infer_class = Qwen3MOEMTPTransformerLayerInfer + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + def _init_weights(self, start_layer_index=None): assert start_layer_index is None mtp_index = len(self.mtp_previous_draft_models) diff --git a/unit_tests/models/test_base_mtp_model_mixin.py b/unit_tests/models/test_base_mtp_model_mixin.py deleted file mode 100644 index 7cddb5179..000000000 --- a/unit_tests/models/test_base_mtp_model_mixin.py +++ /dev/null @@ -1,22 +0,0 @@ -import types - - -def test_mixin_pre_init_pops_and_shares_managers(): - from lightllm.models.base_mtp_model import BaseMTPModel - - main = types.SimpleNamespace(_cos_cached="cos", _sin_cached="sin", req_manager="rm", mem_manager="mm") - - obj = BaseMTPModel.__new__(BaseMTPModel) - kvargs = {"main_model": main, "mtp_previous_draft_models": ["d0"], "other": 1} - obj._pre_init(kvargs) - assert obj.main_model is main - assert obj.mtp_previous_draft_models == ["d0"] - assert "main_model" not in kvargs and "mtp_previous_draft_models" not in kvargs - - obj._init_custom() - obj._init_req_manager() - obj._init_mem_manager() - assert obj._cos_cached == "cos" and obj._sin_cached == "sin" - assert obj.req_manager == "rm" and obj.mem_manager == "mm" - - assert BaseMTPModel.is_mtp_draft_model is True