Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/tf2/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
from ..common import (
register_dpmodel_mapping,
)
from .base_descriptor import (
BaseDescriptor,
)
from .dpa1 import (
DescrptDPA1,
)


@BaseDescriptor.register("se_atten_v2")
class DescrptSeAttenV2(DescrptDPA1, DescrptSeAttenV2DP):
pass

Expand Down
6 changes: 6 additions & 0 deletions deepmd/tf2/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
)
from ..utils import exclude_mask as _tf2_exclude_mask # noqa: F401
from ..utils import network as _tf2_network # noqa: F401
from .base_descriptor import (
BaseDescriptor,
)


@BaseDescriptor.register("se_e3")
@BaseDescriptor.register("se_at")
@BaseDescriptor.register("se_a_3be")
@tf2_module
class DescrptSeT(DescrptSeTDP):
pass
4 changes: 4 additions & 0 deletions deepmd/tf2/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
from ..utils import exclude_mask as _tf2_exclude_mask # noqa: F401
from ..utils import network as _tf2_network # noqa: F401
from ..utils import type_embed as _tf2_type_embed # noqa: F401
from .base_descriptor import (
BaseDescriptor,
)


@tf2_module
class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP):
pass


@BaseDescriptor.register("se_e3_tebd")
@tf2_module
class DescrptSeTTebd(DescrptSeTTebdDP):
pass
54 changes: 54 additions & 0 deletions source/tests/consistent/test_tf2_descriptor_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Every TF2 descriptor wrapper must register its config type names.

TF2 standard-model construction resolves descriptors through
``BaseDescriptor.get_class_by_type(<type>)``. The wrapper classes exist and are
exported by ``deepmd.tf2.descriptor``, but several used to be defined without the
``@BaseDescriptor.register(...)`` decorators the JAX wrappers carry, so their
config type names could not be resolved and model construction failed with an
unknown-descriptor error.
"""

import unittest

from .common import (
INSTALLED_TF2,
)

if INSTALLED_TF2:
import deepmd.tf2.descriptor # noqa: F401
from deepmd.tf2.descriptor.base_descriptor import (
BaseDescriptor,
)

# type names that must resolve on the TF2 descriptor registry, mirroring the
# JAX wrapper registrations for the same descriptors.
TF2_DESCRIPTOR_TYPES = [
"se_e2_a",
"se_a",
"se_e2_r",
"se_r",
"se_e3", # se_t
"se_at", # se_t
"se_a_3be", # se_t
"se_e3_tebd", # se_t_tebd
"se_atten_v2",
"se_atten", # dpa1
"dpa1",
"dpa2",
"dpa3",
"hybrid",
]


@unittest.skipUnless(INSTALLED_TF2, "TF2 backend is not installed")
class TestTF2DescriptorRegistration(unittest.TestCase):
def test_all_types_resolve(self) -> None:
for descriptor_type in TF2_DESCRIPTOR_TYPES:
with self.subTest(descriptor_type=descriptor_type):
cls = BaseDescriptor.get_class_by_type(descriptor_type)
self.assertTrue(callable(cls))


if __name__ == "__main__":
unittest.main()
Loading