Skip to content

Fix: Respect mask_ratio=0.0 in PrithviMAE.forward#1167

Open
akshanthsaik wants to merge 2 commits intotorchgeo:mainfrom
akshanthsaik:issue-#1096
Open

Fix: Respect mask_ratio=0.0 in PrithviMAE.forward#1167
akshanthsaik wants to merge 2 commits intotorchgeo:mainfrom
akshanthsaik:issue-#1096

Conversation

@akshanthsaik
Copy link
Copy Markdown

Replaced the truthiness check with an explicit None check so that 0.0 is treated as a valid.

@romeokienzler
Copy link
Copy Markdown
Collaborator

Thanks @akshanthsaik will look into it

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates PrithviMAE.forward to treat mask_ratio=0.0 as a valid value by replacing a truthiness-based defaulting check with an explicit None check.

Changes:

  • Replace mask_ratio = mask_ratio or self.mask_ratio with an explicit None fallback to preserve 0.0.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +855 to 857
mask_ratio = self.mask_ratio if mask_ratio is None else mask_ratio
latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change, mask_ratio=0.0 will be passed through to the encoder, which produces an all-zero mask. forward_loss() then divides by mask.sum() (which becomes 0), yielding inf/NaN. If 0.0 is intended to be valid, forward_loss needs a guard (e.g., return 0 or fall back to unmasked loss when mask.sum()==0), or alternatively reject mask_ratio=0.0 with a clear error.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Comment on lines +855 to 856
mask_ratio = self.mask_ratio if mask_ratio is None else mask_ratio
latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask_ratio is treated as an optional sentinel (mask_ratio is None), but the parameter is annotated as float in the forward() signature. Consider updating the annotation to float | None (or Optional[float]) to match the actual accepted type and avoid static typing warnings.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Comment on lines +855 to 856
mask_ratio = self.mask_ratio if mask_ratio is None else mask_ratio
latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There’s existing test coverage for Prithvi components, but no regression test for PrithviMAE.forward(mask_ratio=0.0). Adding a unit test that exercises this path (and asserts loss is finite / behavior is as intended) would prevent reintroducing the original truthiness bug and catch the divide-by-zero edge case.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

@akshanthsaik
Copy link
Copy Markdown
Author

akshanthsaik commented Apr 10, 2026

Added a guard in forward_loss to handle mask_sum==0, which occurs when mask_ratio=0.0, preventing division by zero and NaNs.
I think, this ensures the forward pass remains numerically stable
@romeokienzler please take a look

@akshanthsaik
Copy link
Copy Markdown
Author

Just realised,
Instead of silently defaulting or returning zero loss, would you consider just raising an error here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants