Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,20 @@ def __init__(
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
causal (bool, optional): whether to apply causal masking in self-attention. Defaults to False.
sequence_length (int | None, optional): sequence length required for causal masking. Defaults to None.
with_cross_attention (bool, optional): whether to include cross-attention layers that attend to an
external context tensor. When False, cross_attn is set to nn.Identity() so that the attribute
always exists for typing and checkpoint compatibility. Defaults to False.
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.

Raises:
ValueError: if dropout_rate is not in [0, 1].
ValueError: if hidden_size is not divisible by num_heads.

"""

super().__init__()
Expand Down Expand Up @@ -79,14 +88,18 @@ def __init__(
self.with_cross_attention = with_cross_attention

self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)
self.cross_attn: CrossAttentionBlock | nn.Identity
if with_cross_attention:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've had this issue with other classes with conditional components. If we don't take this branch then the members aren't created which causes issues with typing, Torchscript (though this isn't so much a concern anymore), and loading weights. The saved weights for this class will have those for cross_attn even if it's not used, so loading it with this updated version of the class will raise exceptions about unused keys. We've had to work around this with methods to load old state dicts like this.

You would need to look at where this class is used and see if such adaptation is needed, but either way the norm_cross_attn and cross_attn members should always exist. Since norm_cross_attn is pretty lightweight I'd instantiate it always, and cross_attn should be nn.Identity.

self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)
else:
self.cross_attn = nn.Identity()

def forward(
self, x: torch.Tensor, context: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/nets/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:

# fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2
for k in list(old_state_dict.keys()):
if "norm2" in k:
if "norm2" in k and k.replace("norm2", "norm_cross_attn") in new_state_dict:
new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict.pop(k)
if "norm3" in k:
new_state_dict[k.replace("norm3", "norm2")] = old_state_dict.pop(k)
Expand Down
32 changes: 32 additions & 0 deletions tests/networks/blocks/test_transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

import numpy as np
import torch
import torch.nn as nn
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.blocks.crossattention import CrossAttentionBlock
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.utils import optional_import
from tests.test_utils import dict_product
Expand Down Expand Up @@ -53,6 +55,36 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4)

@skipUnless(has_einops, "Requires einops")
def test_cross_attention_is_identity_when_disabled(self):
block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=False)
# attributes always exist for typing and checkpoint compatibility
self.assertTrue(hasattr(block, "cross_attn"))
self.assertTrue(hasattr(block, "norm_cross_attn"))
# cross_attn is nn.Identity (no parameters) when disabled
self.assertIsInstance(block.cross_attn, nn.Identity)
param_names = [name for name, _ in block.named_parameters()]
self.assertFalse(any(n.startswith("cross_attn") for n in param_names))

@skipUnless(has_einops, "Requires einops")
def test_cross_attention_params_registered_when_enabled(self):
block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=True)
self.assertIsInstance(block.cross_attn, CrossAttentionBlock)
self.assertTrue(hasattr(block, "norm_cross_attn"))
param_names = [name for name, _ in block.named_parameters()]
self.assertTrue(any(n.startswith("cross_attn.") for n in param_names))
self.assertTrue(any("norm_cross_attn" in n for n in param_names))

@skipUnless(has_einops, "Requires einops")
def test_cross_attention_forward_with_context(self):
hidden_size = 128
block = TransformerBlock(hidden_size=hidden_size, mlp_dim=256, num_heads=4, with_cross_attention=True)
x = torch.randn(2, 16, hidden_size)
context = torch.randn(2, 8, hidden_size)
with eval_mode(block):
out = block(x, context=context)
self.assertEqual(out.shape, x.shape)

@skipUnless(has_einops, "Requires einops")
def test_access_attn_matrix(self):
# input format
Expand Down
Loading