From c95c79e282b33b26252320be5ce17c2d9f600137 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 3 Jul 2026 15:04:46 +0800 Subject: [PATCH] fix(tf2): register se_t, se_t_tebd, and se_atten_v2 descriptors TF2 standard-model construction resolves descriptor classes through BaseDescriptor.get_class_by_type() on the TF2 descriptor registry, which is a separate registry from the dpmodel and JAX ones. The se_t, se_t_tebd, and se_atten_v2 wrappers were defined without the @BaseDescriptor.register(...) decorators their JAX counterparts carry, so their config type names (se_e3, se_at, se_a_3be, se_e3_tebd, se_atten_v2) could not be resolved and TF2 model construction failed with an unknown-descriptor error, even though the wrapper classes exist and are exported by deepmd.tf2.descriptor. Add the missing decorators, matching the JAX registrations for the same descriptor names. Adds a registration test asserting every TF2 descriptor type name resolves via get_class_by_type; it fails on master for the five unregistered names and passes with the fix. The test is gated on INSTALLED_TF2 and run with DEEPMD_TEST_TF2=1. Existing TF2 consistency tests never caught this because they instantiate the wrapper classes directly rather than through the registry. Fix #5677 --- deepmd/tf2/descriptor/se_atten_v2.py | 4 ++ deepmd/tf2/descriptor/se_t.py | 6 +++ deepmd/tf2/descriptor/se_t_tebd.py | 4 ++ .../test_tf2_descriptor_registration.py | 54 +++++++++++++++++++ 4 files changed, 68 insertions(+) create mode 100644 source/tests/consistent/test_tf2_descriptor_registration.py diff --git a/deepmd/tf2/descriptor/se_atten_v2.py b/deepmd/tf2/descriptor/se_atten_v2.py index 84db19fc59..d343226da9 100644 --- a/deepmd/tf2/descriptor/se_atten_v2.py +++ b/deepmd/tf2/descriptor/se_atten_v2.py @@ -4,11 +4,15 @@ from ..common import ( register_dpmodel_mapping, ) +from .base_descriptor import ( + BaseDescriptor, +) from .dpa1 import ( DescrptDPA1, ) +@BaseDescriptor.register("se_atten_v2") class DescrptSeAttenV2(DescrptDPA1, DescrptSeAttenV2DP): pass diff --git a/deepmd/tf2/descriptor/se_t.py b/deepmd/tf2/descriptor/se_t.py index 98e142de6c..1636778486 100644 --- a/deepmd/tf2/descriptor/se_t.py +++ b/deepmd/tf2/descriptor/se_t.py @@ -6,8 +6,14 @@ ) from ..utils import exclude_mask as _tf2_exclude_mask # noqa: F401 from ..utils import network as _tf2_network # noqa: F401 +from .base_descriptor import ( + BaseDescriptor, +) +@BaseDescriptor.register("se_e3") +@BaseDescriptor.register("se_at") +@BaseDescriptor.register("se_a_3be") @tf2_module class DescrptSeT(DescrptSeTDP): pass diff --git a/deepmd/tf2/descriptor/se_t_tebd.py b/deepmd/tf2/descriptor/se_t_tebd.py index 4b7e65c50f..f7be203692 100644 --- a/deepmd/tf2/descriptor/se_t_tebd.py +++ b/deepmd/tf2/descriptor/se_t_tebd.py @@ -10,6 +10,9 @@ from ..utils import exclude_mask as _tf2_exclude_mask # noqa: F401 from ..utils import network as _tf2_network # noqa: F401 from ..utils import type_embed as _tf2_type_embed # noqa: F401 +from .base_descriptor import ( + BaseDescriptor, +) @tf2_module @@ -17,6 +20,7 @@ class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP): pass +@BaseDescriptor.register("se_e3_tebd") @tf2_module class DescrptSeTTebd(DescrptSeTTebdDP): pass diff --git a/source/tests/consistent/test_tf2_descriptor_registration.py b/source/tests/consistent/test_tf2_descriptor_registration.py new file mode 100644 index 0000000000..42968efa2b --- /dev/null +++ b/source/tests/consistent/test_tf2_descriptor_registration.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Every TF2 descriptor wrapper must register its config type names. + +TF2 standard-model construction resolves descriptors through +``BaseDescriptor.get_class_by_type()``. The wrapper classes exist and are +exported by ``deepmd.tf2.descriptor``, but several used to be defined without the +``@BaseDescriptor.register(...)`` decorators the JAX wrappers carry, so their +config type names could not be resolved and model construction failed with an +unknown-descriptor error. +""" + +import unittest + +from .common import ( + INSTALLED_TF2, +) + +if INSTALLED_TF2: + import deepmd.tf2.descriptor # noqa: F401 + from deepmd.tf2.descriptor.base_descriptor import ( + BaseDescriptor, + ) + +# type names that must resolve on the TF2 descriptor registry, mirroring the +# JAX wrapper registrations for the same descriptors. +TF2_DESCRIPTOR_TYPES = [ + "se_e2_a", + "se_a", + "se_e2_r", + "se_r", + "se_e3", # se_t + "se_at", # se_t + "se_a_3be", # se_t + "se_e3_tebd", # se_t_tebd + "se_atten_v2", + "se_atten", # dpa1 + "dpa1", + "dpa2", + "dpa3", + "hybrid", +] + + +@unittest.skipUnless(INSTALLED_TF2, "TF2 backend is not installed") +class TestTF2DescriptorRegistration(unittest.TestCase): + def test_all_types_resolve(self) -> None: + for descriptor_type in TF2_DESCRIPTOR_TYPES: + with self.subTest(descriptor_type=descriptor_type): + cls = BaseDescriptor.get_class_by_type(descriptor_type) + self.assertTrue(callable(cls)) + + +if __name__ == "__main__": + unittest.main()