Skip to content

Commit 20c702a

Browse files
fix(losses): address review comments on GlobalMutualInformationLoss buffer fix
Use register_buffer("preterm", None) / register_buffer("bin_centers", None) unconditionally so that both buffers are always present in _buffers (with None for b-spline). This avoids a KeyError that occurred when plain instance attribute assignment conflicted with a subsequent register_buffer call. Also add docstrings to the new test methods and a device-movement test that verifies buffers follow the module when .cuda() is called. Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
1 parent f20d3f6 commit 20c702a

2 files changed

Lines changed: 22 additions & 5 deletions

File tree

monai/losses/image_dissimilarity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,8 @@ def __init__(
233233
self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"])
234234
self.num_bins = num_bins
235235
self.kernel_type = kernel_type
236-
self.preterm: torch.Tensor
237-
self.bin_centers: torch.Tensor
236+
self.register_buffer("preterm", None, persistent=False)
237+
self.register_buffer("bin_centers", None, persistent=False)
238238
if self.kernel_type == "gaussian":
239239
self.register_buffer("preterm", 1 / (2 * sigma**2), persistent=False)
240240
self.register_buffer("bin_centers", bin_centers[None, None, ...], persistent=False)

tests/losses/image_dissimilarity/test_global_mutual_information_loss.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,26 +147,43 @@ def test_ill_opts(self, num_bins, reduction, expected_exception, expected_messag
147147

148148
class TestGlobalMutualInformationLossBuffers(unittest.TestCase):
149149
def test_gaussian_kernel_registers_buffers(self):
150+
"""preterm and bin_centers are registered as non-persistent buffers for gaussian kernel."""
150151
loss = GlobalMutualInformationLoss(kernel_type="gaussian")
151-
# preterm and bin_centers must be registered buffers so .to() moves them
152152
self.assertIn("preterm", loss._buffers)
153153
self.assertIn("bin_centers", loss._buffers)
154154
self.assertFalse(loss.preterm.requires_grad)
155155
self.assertFalse(loss.bin_centers.requires_grad)
156156
self.assertEqual(loss.bin_centers.ndim, 3)
157157

158158
def test_bspline_kernel_has_no_gaussian_buffers(self):
159+
"""b-spline kernel does not register gaussian-specific buffers."""
159160
loss = GlobalMutualInformationLoss(kernel_type="b-spline")
160-
self.assertNotIn("preterm", loss._buffers)
161-
self.assertNotIn("bin_centers", loss._buffers)
161+
self.assertIsNone(loss.preterm)
162+
self.assertIsNone(loss.bin_centers)
162163

163164
def test_gaussian_kernel_forward_correct(self):
165+
"""Gaussian kernel forward pass returns a scalar loss."""
164166
pred = torch.rand(2, 1, 8, 8, dtype=torch.float32)
165167
target = torch.rand(2, 1, 8, 8, dtype=torch.float32)
166168
loss = GlobalMutualInformationLoss(kernel_type="gaussian")
167169
result = loss(pred, target)
168170
self.assertEqual(result.shape, torch.Size([]))
169171

172+
def test_gaussian_buffers_move_with_module(self):
173+
"""Buffers move to the correct device when the module is moved with .to()."""
174+
loss = GlobalMutualInformationLoss(kernel_type="gaussian")
175+
self.assertEqual(loss.preterm.device.type, "cpu")
176+
self.assertEqual(loss.bin_centers.device.type, "cpu")
177+
if not torch.cuda.is_available():
178+
return
179+
loss = loss.cuda()
180+
self.assertEqual(loss.preterm.device.type, "cuda")
181+
self.assertEqual(loss.bin_centers.device.type, "cuda")
182+
pred = torch.rand(2, 1, 8, 8, device="cuda")
183+
target = torch.rand(2, 1, 8, 8, device="cuda")
184+
result = loss(pred, target)
185+
self.assertEqual(result.device.type, "cuda")
186+
170187

171188
if __name__ == "__main__":
172189
unittest.main()

0 commit comments

Comments
 (0)