@@ -147,30 +147,36 @@ def test_ill_opts(self, num_bins, reduction, expected_exception, expected_messag
147147
148148class TestGlobalMutualInformationLossBuffers (unittest .TestCase ):
149149 def test_gaussian_kernel_registers_buffers (self ):
150- """preterm and bin_centers are registered as non-persistent buffers for gaussian kernel ."""
150+ """Verify gaussian kernel registers preterm and bin_centers as non-trainable, non- persistent buffers."""
151151 loss = GlobalMutualInformationLoss (kernel_type = "gaussian" )
152152 self .assertIn ("preterm" , loss ._buffers )
153153 self .assertIn ("bin_centers" , loss ._buffers )
154154 self .assertFalse (loss .preterm .requires_grad )
155155 self .assertFalse (loss .bin_centers .requires_grad )
156156 self .assertEqual (loss .bin_centers .ndim , 3 )
157+ state = loss .state_dict ()
158+ self .assertNotIn ("preterm" , state )
159+ self .assertNotIn ("bin_centers" , state )
157160
158161 def test_bspline_kernel_has_no_gaussian_buffers (self ):
159- """b-spline kernel does not register gaussian-specific buffers."""
162+ """Verify b-spline kernel does not populate gaussian-specific buffers."""
160163 loss = GlobalMutualInformationLoss (kernel_type = "b-spline" )
161164 self .assertIsNone (loss .preterm )
162165 self .assertIsNone (loss .bin_centers )
166+ state = loss .state_dict ()
167+ self .assertNotIn ("preterm" , state )
168+ self .assertNotIn ("bin_centers" , state )
163169
164170 def test_gaussian_kernel_forward_correct (self ):
165- """Gaussian kernel forward pass returns a scalar loss."""
171+ """Verify gaussian kernel forward pass returns a scalar loss tensor ."""
166172 pred = torch .rand (2 , 1 , 8 , 8 , dtype = torch .float32 )
167173 target = torch .rand (2 , 1 , 8 , 8 , dtype = torch .float32 )
168174 loss = GlobalMutualInformationLoss (kernel_type = "gaussian" )
169175 result = loss (pred , target )
170176 self .assertEqual (result .shape , torch .Size ([]))
171177
172178 def test_gaussian_buffers_move_with_module (self ):
173- """Buffers move to the correct device when the module is moved with .to() ."""
179+ """Verify preterm and bin_centers buffers move to the target device with the module."""
174180 loss = GlobalMutualInformationLoss (kernel_type = "gaussian" )
175181 self .assertEqual (loss .preterm .device .type , "cpu" )
176182 self .assertEqual (loss .bin_centers .device .type , "cpu" )
0 commit comments