Fix: Respect mask_ratio=0.0 in PrithviMAE.forward#1167
Fix: Respect mask_ratio=0.0 in PrithviMAE.forward#1167akshanthsaik wants to merge 2 commits intotorchgeo:mainfrom
Conversation
|
Thanks @akshanthsaik will look into it |
There was a problem hiding this comment.
Pull request overview
This PR updates PrithviMAE.forward to treat mask_ratio=0.0 as a valid value by replacing a truthiness-based defaulting check with an explicit None check.
Changes:
- Replace
mask_ratio = mask_ratio or self.mask_ratiowith an explicitNonefallback to preserve0.0.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| mask_ratio = self.mask_ratio if mask_ratio is None else mask_ratio | ||
| latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio) | ||
| pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape) |
There was a problem hiding this comment.
With this change, mask_ratio=0.0 will be passed through to the encoder, which produces an all-zero mask. forward_loss() then divides by mask.sum() (which becomes 0), yielding inf/NaN. If 0.0 is intended to be valid, forward_loss needs a guard (e.g., return 0 or fall back to unmasked loss when mask.sum()==0), or alternatively reject mask_ratio=0.0 with a clear error.
| mask_ratio = self.mask_ratio if mask_ratio is None else mask_ratio | ||
| latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio) |
There was a problem hiding this comment.
mask_ratio is treated as an optional sentinel (mask_ratio is None), but the parameter is annotated as float in the forward() signature. Consider updating the annotation to float | None (or Optional[float]) to match the actual accepted type and avoid static typing warnings.
| mask_ratio = self.mask_ratio if mask_ratio is None else mask_ratio | ||
| latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio) |
There was a problem hiding this comment.
There’s existing test coverage for Prithvi components, but no regression test for PrithviMAE.forward(mask_ratio=0.0). Adding a unit test that exercises this path (and asserts loss is finite / behavior is as intended) would prevent reintroducing the original truthiness bug and catch the divide-by-zero edge case.
|
Added a guard in forward_loss to handle mask_sum==0, which occurs when mask_ratio=0.0, preventing division by zero and NaNs. |
|
Just realised, |
Replaced the truthiness check with an explicit None check so that 0.0 is treated as a valid.