Skip to content

Commit e33941c

Browse files
fix: only instantiate CrossAttentionBlock when with_cross_attention=True (#8848)
Fixes #8845. ### Description `TransformerBlock` previously instantiated `norm_cross_attn` and `cross_attn` unconditionally in `__init__`, even when `with_cross_attention=False`. These unused modules registered dead parameters in `model.parameters()`, consuming memory without contributing to computation. The `forward()` method already had the correct guard (`if self.with_cross_attention:`), so the instantiation and the usage were inconsistent. This fix wraps both instantiations in `if with_cross_attention:`, so the modules are only created when actually needed. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. Signed-off-by: chhayankjain <chhayank44@gmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 63a518e commit e33941c

8 files changed

Lines changed: 253 additions & 8 deletions

File tree

monai/networks/blocks/transformerblock.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,20 @@ def __init__(
4646
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
4747
qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
4848
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
49+
causal (bool, optional): whether to apply causal masking in self-attention. Defaults to False.
50+
sequence_length (int | None, optional): sequence length required for causal masking. Defaults to None.
51+
with_cross_attention (bool, optional): whether to include cross-attention layers that attend to an
52+
external context tensor. When False, cross_attn is set to nn.Identity() so that the attribute
53+
always exists for typing and checkpoint compatibility. Defaults to False.
4954
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
5055
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
5156
include_fc: whether to include the final linear layer. Default to True.
5257
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
5358
59+
Raises:
60+
ValueError: if dropout_rate is not in [0, 1].
61+
ValueError: if hidden_size is not divisible by num_heads.
62+
5463
"""
5564

5665
super().__init__()
@@ -79,14 +88,18 @@ def __init__(
7988
self.with_cross_attention = with_cross_attention
8089

8190
self.norm_cross_attn = nn.LayerNorm(hidden_size)
82-
self.cross_attn = CrossAttentionBlock(
83-
hidden_size=hidden_size,
84-
num_heads=num_heads,
85-
dropout_rate=dropout_rate,
86-
qkv_bias=qkv_bias,
87-
causal=False,
88-
use_flash_attention=use_flash_attention,
89-
)
91+
self.cross_attn: CrossAttentionBlock | nn.Identity
92+
if with_cross_attention:
93+
self.cross_attn = CrossAttentionBlock(
94+
hidden_size=hidden_size,
95+
num_heads=num_heads,
96+
dropout_rate=dropout_rate,
97+
qkv_bias=qkv_bias,
98+
causal=False,
99+
use_flash_attention=use_flash_attention,
100+
)
101+
else:
102+
self.cross_attn = nn.Identity()
90103

91104
def forward(
92105
self, x: torch.Tensor, context: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None

monai/networks/nets/masked_autoencoder_vit.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,41 @@ def forward(self, x, masking_ratio: float | None = None):
209209

210210
x = x[:, 1:, :]
211211
return x, mask
212+
213+
def load_old_state_dict(self, old_state_dict: dict, verbose: bool = False) -> None:
214+
"""
215+
Load a state dict from a MaskedAutoEncoderViT model trained with an older version of MONAI
216+
where ``CrossAttentionBlock`` was unconditionally instantiated in ``TransformerBlock``
217+
even when ``with_cross_attention=False``. Old checkpoints contain stale
218+
``blocks.{i}.cross_attn.*`` and ``decoder_blocks.{i}.cross_attn.*`` keys that are not
219+
present in the current model and are automatically dropped.
220+
221+
Args:
222+
old_state_dict: state dict from the older MaskedAutoEncoderViT model.
223+
verbose: if True, print keys that are missing or unmatched. Defaults to False.
224+
"""
225+
new_state_dict = self.state_dict()
226+
if all(k in new_state_dict for k in old_state_dict):
227+
if verbose:
228+
print("All keys match, loading state dict.")
229+
self.load_state_dict(old_state_dict)
230+
return
231+
232+
if verbose:
233+
for k in new_state_dict:
234+
if k not in old_state_dict:
235+
print(f"key {k} not found in old state dict")
236+
print("----------------------------------------------")
237+
for k in old_state_dict:
238+
if k not in new_state_dict:
239+
print(f"key {k} not found in new state dict")
240+
241+
# copy over all matching keys; stale cross_attn.* keys in blocks and decoder_blocks
242+
# are left as unmatched leftovers and are not inserted into new_state_dict
243+
for k in new_state_dict:
244+
if k in old_state_dict:
245+
new_state_dict[k] = old_state_dict.pop(k)
246+
247+
if verbose:
248+
print("remaining keys in old_state_dict:", old_state_dict.keys())
249+
self.load_state_dict(new_state_dict)

monai/networks/nets/vit.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,41 @@ def forward(self, x):
131131
if hasattr(self, "classification_head"):
132132
x = self.classification_head(x[:, 0])
133133
return x, hidden_states_out
134+
135+
def load_old_state_dict(self, old_state_dict: dict, verbose: bool = False) -> None:
136+
"""
137+
Load a state dict from a ViT model trained with an older version of MONAI where
138+
``CrossAttentionBlock`` was unconditionally instantiated in ``TransformerBlock``
139+
even when ``with_cross_attention=False``. Old checkpoints contain stale
140+
``blocks.{i}.cross_attn.*`` keys that are not present in the current model and
141+
are automatically dropped.
142+
143+
Args:
144+
old_state_dict: state dict from the older ViT model.
145+
verbose: if True, print keys that are missing or unmatched. Defaults to False.
146+
"""
147+
new_state_dict = self.state_dict()
148+
if all(k in new_state_dict for k in old_state_dict):
149+
if verbose:
150+
print("All keys match, loading state dict.")
151+
self.load_state_dict(old_state_dict)
152+
return
153+
154+
if verbose:
155+
for k in new_state_dict:
156+
if k not in old_state_dict:
157+
print(f"key {k} not found in old state dict")
158+
print("----------------------------------------------")
159+
for k in old_state_dict:
160+
if k not in new_state_dict:
161+
print(f"key {k} not found in new state dict")
162+
163+
# copy over all matching keys; stale cross_attn.* keys are left as unmatched
164+
# leftovers and are not inserted into new_state_dict
165+
for k in new_state_dict:
166+
if k in old_state_dict:
167+
new_state_dict[k] = old_state_dict.pop(k)
168+
169+
if verbose:
170+
print("remaining keys in old_state_dict:", old_state_dict.keys())
171+
self.load_state_dict(new_state_dict)

monai/networks/nets/vitautoenc.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,41 @@ def forward(self, x):
133133
x = self.conv3d_transpose(x)
134134
x = self.conv3d_transpose_1(x)
135135
return x, hidden_states_out
136+
137+
def load_old_state_dict(self, old_state_dict: dict, verbose: bool = False) -> None:
138+
"""
139+
Load a state dict from a ViTAutoEnc model trained with an older version of MONAI where
140+
``CrossAttentionBlock`` was unconditionally instantiated in ``TransformerBlock``
141+
even when ``with_cross_attention=False``. Old checkpoints contain stale
142+
``blocks.{i}.cross_attn.*`` keys that are not present in the current model and
143+
are automatically dropped.
144+
145+
Args:
146+
old_state_dict: state dict from the older ViTAutoEnc model.
147+
verbose: if True, print keys that are missing or unmatched. Defaults to False.
148+
"""
149+
new_state_dict = self.state_dict()
150+
if all(k in new_state_dict for k in old_state_dict):
151+
if verbose:
152+
print("All keys match, loading state dict.")
153+
self.load_state_dict(old_state_dict)
154+
return
155+
156+
if verbose:
157+
for k in new_state_dict:
158+
if k not in old_state_dict:
159+
print(f"key {k} not found in old state dict")
160+
print("----------------------------------------------")
161+
for k in old_state_dict:
162+
if k not in new_state_dict:
163+
print(f"key {k} not found in new state dict")
164+
165+
# copy over all matching keys; stale cross_attn.* keys are left as unmatched
166+
# leftovers and are not inserted into new_state_dict
167+
for k in new_state_dict:
168+
if k in old_state_dict:
169+
new_state_dict[k] = old_state_dict.pop(k)
170+
171+
if verbose:
172+
print("remaining keys in old_state_dict:", old_state_dict.keys())
173+
self.load_state_dict(new_state_dict)

tests/networks/blocks/test_transformerblock.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616

1717
import numpy as np
1818
import torch
19+
import torch.nn as nn
1920
from parameterized import parameterized
2021

2122
from monai.networks import eval_mode
23+
from monai.networks.blocks.crossattention import CrossAttentionBlock
2224
from monai.networks.blocks.transformerblock import TransformerBlock
2325
from monai.utils import optional_import
2426
from tests.test_utils import dict_product
@@ -53,6 +55,36 @@ def test_ill_arg(self):
5355
with self.assertRaises(ValueError):
5456
TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4)
5557

58+
@skipUnless(has_einops, "Requires einops")
59+
def test_cross_attention_is_identity_when_disabled(self):
60+
block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=False)
61+
# attributes always exist for typing and checkpoint compatibility
62+
self.assertTrue(hasattr(block, "cross_attn"))
63+
self.assertTrue(hasattr(block, "norm_cross_attn"))
64+
# cross_attn is nn.Identity (no parameters) when disabled
65+
self.assertIsInstance(block.cross_attn, nn.Identity)
66+
param_names = [name for name, _ in block.named_parameters()]
67+
self.assertFalse(any(n.startswith("cross_attn") for n in param_names))
68+
69+
@skipUnless(has_einops, "Requires einops")
70+
def test_cross_attention_params_registered_when_enabled(self):
71+
block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=True)
72+
self.assertIsInstance(block.cross_attn, CrossAttentionBlock)
73+
self.assertTrue(hasattr(block, "norm_cross_attn"))
74+
param_names = [name for name, _ in block.named_parameters()]
75+
self.assertTrue(any(n.startswith("cross_attn.") for n in param_names))
76+
self.assertTrue(any("norm_cross_attn" in n for n in param_names))
77+
78+
@skipUnless(has_einops, "Requires einops")
79+
def test_cross_attention_forward_with_context(self):
80+
hidden_size = 128
81+
block = TransformerBlock(hidden_size=hidden_size, mlp_dim=256, num_heads=4, with_cross_attention=True)
82+
x = torch.randn(2, 16, hidden_size)
83+
context = torch.randn(2, 8, hidden_size)
84+
with eval_mode(block):
85+
out = block(x, context=context)
86+
self.assertEqual(out.shape, x.shape)
87+
5688
@skipUnless(has_einops, "Requires einops")
5789
def test_access_attn_matrix(self):
5890
# input format

tests/networks/nets/test_masked_autoencoder_vit.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,38 @@ def test_masking_ratio(self):
152152

153153
assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens
154154

155+
def test_load_old_state_dict_drops_stale_cross_attn_keys(self):
156+
# simulate an old checkpoint where CrossAttentionBlock was always instantiated
157+
net = MaskedAutoEncoderViT(
158+
in_channels=1,
159+
img_size=(32, 32),
160+
patch_size=(16, 16),
161+
hidden_size=64,
162+
mlp_dim=128,
163+
num_layers=2,
164+
num_heads=4,
165+
decoder_hidden_size=32,
166+
decoder_mlp_dim=64,
167+
decoder_num_layers=2,
168+
decoder_num_heads=4,
169+
spatial_dims=2,
170+
)
171+
old_state = {k: torch.rand_like(v) for k, v in net.state_dict().items()}
172+
# inject stale cross_attn keys from both encoder blocks and decoder blocks
173+
old_state["blocks.0.cross_attn.to_q.weight"] = torch.randn(64, 64)
174+
old_state["blocks.1.cross_attn.out_proj.weight"] = torch.randn(64, 64)
175+
old_state["decoder_blocks.0.cross_attn.to_v.weight"] = torch.randn(32, 32)
176+
old_state["decoder_blocks.1.cross_attn.out_proj.weight"] = torch.randn(32, 32)
177+
178+
# save expected values before the call since load_old_state_dict pops matching keys
179+
expected = {k: v.clone() for k, v in old_state.items() if k in net.state_dict()}
180+
net.load_old_state_dict(old_state)
181+
182+
# all current model keys should be loaded from old_state; stale keys silently dropped
183+
loaded = net.state_dict()
184+
for k in loaded:
185+
self.assertTrue(torch.allclose(loaded[k], expected[k]))
186+
155187

156188
if __name__ == "__main__":
157189
unittest.main()

tests/networks/nets/test_vit.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,33 @@ def test_access_attn_matrix(self):
127127
matrix_acess_blk(torch.randn(in_shape))
128128
assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 216, 216)
129129

130+
def test_load_old_state_dict_drops_stale_cross_attn_keys(self):
131+
# simulate an old checkpoint where CrossAttentionBlock was always instantiated
132+
net = ViT(
133+
in_channels=1,
134+
img_size=(32, 32),
135+
patch_size=(16, 16),
136+
hidden_size=64,
137+
mlp_dim=128,
138+
num_layers=2,
139+
num_heads=4,
140+
spatial_dims=2,
141+
)
142+
old_state = {k: torch.rand_like(v) for k, v in net.state_dict().items()}
143+
# inject stale cross_attn keys that the new model no longer has
144+
old_state["blocks.0.cross_attn.to_q.weight"] = torch.randn(64, 64)
145+
old_state["blocks.0.cross_attn.out_proj.weight"] = torch.randn(64, 64)
146+
old_state["blocks.1.cross_attn.to_v.weight"] = torch.randn(64, 64)
147+
148+
# save expected values before the call since load_old_state_dict pops matching keys
149+
expected = {k: v.clone() for k, v in old_state.items() if k in net.state_dict()}
150+
net.load_old_state_dict(old_state)
151+
152+
# all current model keys should be loaded from old_state; stale keys silently dropped
153+
loaded = net.state_dict()
154+
for k in loaded:
155+
self.assertTrue(torch.allclose(loaded[k], expected[k]))
156+
130157

131158
if __name__ == "__main__":
132159
unittest.main()

tests/networks/nets/test_vitautoenc.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,33 @@ def test_ill_arg(
104104
dropout_rate=dropout_rate,
105105
)
106106

107+
def test_load_old_state_dict_drops_stale_cross_attn_keys(self):
108+
# simulate an old checkpoint where CrossAttentionBlock was always instantiated
109+
net = ViTAutoEnc(
110+
in_channels=1,
111+
img_size=(32, 32),
112+
patch_size=(16, 16),
113+
hidden_size=64,
114+
mlp_dim=128,
115+
num_layers=2,
116+
num_heads=4,
117+
spatial_dims=2,
118+
)
119+
old_state = {k: torch.rand_like(v) for k, v in net.state_dict().items()}
120+
# inject stale cross_attn keys that the new model no longer has
121+
old_state["blocks.0.cross_attn.to_q.weight"] = torch.randn(64, 64)
122+
old_state["blocks.0.cross_attn.out_proj.weight"] = torch.randn(64, 64)
123+
old_state["blocks.1.cross_attn.to_v.weight"] = torch.randn(64, 64)
124+
125+
# save expected values before the call since load_old_state_dict pops matching keys
126+
expected = {k: v.clone() for k, v in old_state.items() if k in net.state_dict()}
127+
net.load_old_state_dict(old_state)
128+
129+
# all current model keys should be loaded from old_state; stale keys silently dropped
130+
loaded = net.state_dict()
131+
for k in loaded:
132+
self.assertTrue(torch.allclose(loaded[k], expected[k]))
133+
107134

108135
if __name__ == "__main__":
109136
unittest.main()

0 commit comments

Comments
 (0)