I finally found the timehad a closer look at was discussed in #592 with romeokienzler, as I was trying to move my segmentation training from Clay 1.0 to Clay 1.5 with terratorch.
The problem
The current Clay1_5ModelFactory returns the MAE pretraining losses (loss, reconstruction_loss, representation_loss) instead of a ModelOutput with predictions. Looking at the code:
- Line 548 in
ClayMAE.forward(): returns (loss, reconstruction_loss, representation_loss)
- Line 89 in
Clay1_5ModelFactory: just wraps ClayMAE in ModelWrapper, doesn't use PixelWiseModel or ScalarOutputModel
This breaks downstream tasks because the model is doing MAE pretraining (reconstruction), not feature extraction for segmentation/classification.
The fix
I created an adapter that uses only the ClayMAE encoder for feature extraction (skipping the masking and loss computation). Then I wrap it in the standard PixelWiseModel pipeline, just like Clay 1.0 does.
The key insight is that for downstream tasks we don't want the full MAE forward pass - we just need the encoder to extract features:
class ClayMAEBackbone(nn.Module):
def forward(self, x):
# Use encoder only (no masking for downstream tasks)
patches, waves_encoded = self.clayma.encoder.to_patch_embed(x, waves)
patches = self.clayma.encoder.add_encodings(patches, time, latlon, gsd)
# Add CLS token and pass through transformer
cls_tokens = repeat(self.clayma.encoder.cls_token, "1 1 D -> B 1 D", B=B)
patches_with_cls = torch.cat((cls_tokens, patches), dim=1)
encoded = self.clayma.encoder.transformer(patches_with_cls)
# Remove CLS, reshape to spatial features
features = encoded[:, 1:, :].reshape(B, self.dim, grid_size, grid_size)
return [features]
Then the factory wraps this in PixelWiseModel which handles the decoder and head properly.
I've validated this with full training runs - it works for semantic segmentation (~7 it/s, 313M params).
Happy to share the full implementation or help with a PR if you're interested.
I finally found the timehad a closer look at was discussed in #592 with romeokienzler, as I was trying to move my segmentation training from Clay 1.0 to Clay 1.5 with terratorch.
The problem
The current
Clay1_5ModelFactoryreturns the MAE pretraining losses(loss, reconstruction_loss, representation_loss)instead of aModelOutputwith predictions. Looking at the code:ClayMAE.forward(): returns(loss, reconstruction_loss, representation_loss)Clay1_5ModelFactory: just wrapsClayMAEinModelWrapper, doesn't usePixelWiseModelorScalarOutputModelThis breaks downstream tasks because the model is doing MAE pretraining (reconstruction), not feature extraction for segmentation/classification.
The fix
I created an adapter that uses only the ClayMAE encoder for feature extraction (skipping the masking and loss computation). Then I wrap it in the standard
PixelWiseModelpipeline, just like Clay 1.0 does.The key insight is that for downstream tasks we don't want the full MAE forward pass - we just need the encoder to extract features:
Then the factory wraps this in
PixelWiseModelwhich handles the decoder and head properly.I've validated this with full training runs - it works for semantic segmentation (~7 it/s, 313M params).
Happy to share the full implementation or help with a PR if you're interested.