Skip to content

Commit fa1a46e

Browse files
fix(losses): address review comments on buffer tests
- Add state_dict assertions to verify non-persistent=False contract - Update test docstrings to use Verify... format Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
1 parent 0c07558 commit fa1a46e

1 file changed

Lines changed: 10 additions & 4 deletions

File tree

tests/losses/image_dissimilarity/test_global_mutual_information_loss.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,30 +147,36 @@ def test_ill_opts(self, num_bins, reduction, expected_exception, expected_messag
147147

148148
class TestGlobalMutualInformationLossBuffers(unittest.TestCase):
149149
def test_gaussian_kernel_registers_buffers(self):
150-
"""preterm and bin_centers are registered as non-persistent buffers for gaussian kernel."""
150+
"""Verify gaussian kernel registers preterm and bin_centers as non-trainable, non-persistent buffers."""
151151
loss = GlobalMutualInformationLoss(kernel_type="gaussian")
152152
self.assertIn("preterm", loss._buffers)
153153
self.assertIn("bin_centers", loss._buffers)
154154
self.assertFalse(loss.preterm.requires_grad)
155155
self.assertFalse(loss.bin_centers.requires_grad)
156156
self.assertEqual(loss.bin_centers.ndim, 3)
157+
state = loss.state_dict()
158+
self.assertNotIn("preterm", state)
159+
self.assertNotIn("bin_centers", state)
157160

158161
def test_bspline_kernel_has_no_gaussian_buffers(self):
159-
"""b-spline kernel does not register gaussian-specific buffers."""
162+
"""Verify b-spline kernel does not populate gaussian-specific buffers."""
160163
loss = GlobalMutualInformationLoss(kernel_type="b-spline")
161164
self.assertIsNone(loss.preterm)
162165
self.assertIsNone(loss.bin_centers)
166+
state = loss.state_dict()
167+
self.assertNotIn("preterm", state)
168+
self.assertNotIn("bin_centers", state)
163169

164170
def test_gaussian_kernel_forward_correct(self):
165-
"""Gaussian kernel forward pass returns a scalar loss."""
171+
"""Verify gaussian kernel forward pass returns a scalar loss tensor."""
166172
pred = torch.rand(2, 1, 8, 8, dtype=torch.float32)
167173
target = torch.rand(2, 1, 8, 8, dtype=torch.float32)
168174
loss = GlobalMutualInformationLoss(kernel_type="gaussian")
169175
result = loss(pred, target)
170176
self.assertEqual(result.shape, torch.Size([]))
171177

172178
def test_gaussian_buffers_move_with_module(self):
173-
"""Buffers move to the correct device when the module is moved with .to()."""
179+
"""Verify preterm and bin_centers buffers move to the target device with the module."""
174180
loss = GlobalMutualInformationLoss(kernel_type="gaussian")
175181
self.assertEqual(loss.preterm.device.type, "cpu")
176182
self.assertEqual(loss.bin_centers.device.type, "cpu")

0 commit comments

Comments
 (0)