|
16 | 16 |
|
17 | 17 | import numpy as np |
18 | 18 | import torch |
| 19 | +import torch.nn as nn |
19 | 20 | from parameterized import parameterized |
20 | 21 |
|
21 | 22 | from monai.networks import eval_mode |
| 23 | +from monai.networks.blocks.crossattention import CrossAttentionBlock |
22 | 24 | from monai.networks.blocks.transformerblock import TransformerBlock |
23 | 25 | from monai.utils import optional_import |
24 | 26 | from tests.test_utils import dict_product |
@@ -53,6 +55,36 @@ def test_ill_arg(self): |
53 | 55 | with self.assertRaises(ValueError): |
54 | 56 | TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4) |
55 | 57 |
|
| 58 | + @skipUnless(has_einops, "Requires einops") |
| 59 | + def test_cross_attention_is_identity_when_disabled(self): |
| 60 | + block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=False) |
| 61 | + # attributes always exist for typing and checkpoint compatibility |
| 62 | + self.assertTrue(hasattr(block, "cross_attn")) |
| 63 | + self.assertTrue(hasattr(block, "norm_cross_attn")) |
| 64 | + # cross_attn is nn.Identity (no parameters) when disabled |
| 65 | + self.assertIsInstance(block.cross_attn, nn.Identity) |
| 66 | + param_names = [name for name, _ in block.named_parameters()] |
| 67 | + self.assertFalse(any(n.startswith("cross_attn") for n in param_names)) |
| 68 | + |
| 69 | + @skipUnless(has_einops, "Requires einops") |
| 70 | + def test_cross_attention_params_registered_when_enabled(self): |
| 71 | + block = TransformerBlock(hidden_size=128, mlp_dim=256, num_heads=4, with_cross_attention=True) |
| 72 | + self.assertIsInstance(block.cross_attn, CrossAttentionBlock) |
| 73 | + self.assertTrue(hasattr(block, "norm_cross_attn")) |
| 74 | + param_names = [name for name, _ in block.named_parameters()] |
| 75 | + self.assertTrue(any(n.startswith("cross_attn.") for n in param_names)) |
| 76 | + self.assertTrue(any("norm_cross_attn" in n for n in param_names)) |
| 77 | + |
| 78 | + @skipUnless(has_einops, "Requires einops") |
| 79 | + def test_cross_attention_forward_with_context(self): |
| 80 | + hidden_size = 128 |
| 81 | + block = TransformerBlock(hidden_size=hidden_size, mlp_dim=256, num_heads=4, with_cross_attention=True) |
| 82 | + x = torch.randn(2, 16, hidden_size) |
| 83 | + context = torch.randn(2, 8, hidden_size) |
| 84 | + with eval_mode(block): |
| 85 | + out = block(x, context=context) |
| 86 | + self.assertEqual(out.shape, x.shape) |
| 87 | + |
56 | 88 | @skipUnless(has_einops, "Requires einops") |
57 | 89 | def test_access_attn_matrix(self): |
58 | 90 | # input format |
|
0 commit comments