From 48748f15cd6366a8ca5843e0826ade794cca944d Mon Sep 17 00:00:00 2001 From: srmsoumya Date: Mon, 15 Jul 2024 12:27:14 +0530 Subject: [PATCH 01/11] Compile clay model encoder --- src/__init__.py | 0 src/export.py | 68 +++++++++++++++++++++++++++++++++++++++++++++++++ src/model.py | 16 ++++++------ src/utils.py | 5 +++- 4 files changed, 80 insertions(+), 9 deletions(-) create mode 100644 src/__init__.py create mode 100644 src/export.py diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/export.py b/src/export.py new file mode 100644 index 00000000..2adccbb9 --- /dev/null +++ b/src/export.py @@ -0,0 +1,68 @@ +from pathlib import Path + +import torch +from torch.export import Dim + +from src.model import ClayMAEModule + +CHECKPOINT_PATH = "checkpoints/clay-v1-base.ckpt" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cpu") + + +def get_data(): + # Load data + cube = torch.randn(128, 3, 224, 224).to(device) + time = torch.randn(128, 4).to(device) + latlon = torch.randn(128, 4).to(device) + waves = torch.randn(3).to(device) + gsd = torch.randn(1).to(device) + return cube, time, latlon, waves, gsd + + +def load_model(): + module = ClayMAEModule.load_from_checkpoint(CHECKPOINT_PATH) + encoder = module.model.encoder # Get the encoder + encoder = encoder.to(device) # Move to device + return encoder + + +def main(): + # Load data + cube, time, latlon, waves, gsd = get_data() + + # Load model + encoder = load_model() + + # Define dynamic shapes for model export + batch_size = Dim("batch_size", min=2, max=128) # Define batch size range + channel_bands = Dim("channel_bands", min=1, max=12) # Define channel bands range + + dynamic_shapes = { + "cube": {0: batch_size, 1: channel_bands}, + "time": {0: batch_size}, + "latlon": {0: batch_size}, + "waves": {0: channel_bands}, + "gsd": {0: None}, + } + + # Export model + exp_compiled_encoder = torch.export.export( + mod=encoder, + args=(cube, time, latlon, waves, gsd), + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + # tensortrt compiled model + # trt_encoder = torch_tensorrt.dynamo.compile( + # exp_compiled_encoder, [cube, time, latlon, waves, gsd] + # ) + + # Save model + Path("checkpoints/compiled").mkdir(parents=True, exist_ok=True) + torch.export.save(exp_compiled_encoder, "checkpoints/compiled/encoder.pt") + + +if __name__ == "__main__": + main() diff --git a/src/model.py b/src/model.py index ee211b97..1ecf508e 100644 --- a/src/model.py +++ b/src/model.py @@ -160,14 +160,14 @@ def mask_out(self, patches): masked_matrix, ) # [B L:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B L] - def forward(self, datacube): - cube, time, latlon, gsd, waves = ( - datacube["pixels"], # [B C H W] - datacube["time"], # [B 2] - datacube["latlon"], # [B 2] - datacube["gsd"], # 1 - datacube["waves"], # [N] - ) # [B C H W] + def forward(self, cube, time, latlon, waves, gsd): + # cube, time, latlon, gsd, waves = ( + # datacube["pixels"], # [B C H W] + # datacube["time"], # [B 2] + # datacube["latlon"], # [B 2] + # datacube["gsd"], # 1 + # datacube["waves"], # [N] + # ) # [B C H W] B, C, H, W = cube.shape diff --git a/src/utils.py b/src/utils.py index 539a2acd..1e35731f 100644 --- a/src/utils.py +++ b/src/utils.py @@ -11,6 +11,7 @@ def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" omega = torch.arange(dim // 4) / (dim // 4 - 1) omega = 1.0 / (temperature**omega) + omega = omega.to(y.device) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] @@ -24,8 +25,9 @@ def posemb_sincos_2d_with_gsd( y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" - omega = torch.arange(dim // 4) / (dim // 4 - 1) + omega = torch.arange(dim // 4, device=gsd.device) / (dim // 4 - 1) omega = 1.0 / (temperature ** (2 * omega / dim)) * (gsd / 1.0) # Adjusted for g + omega = omega.to(y.device) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] @@ -41,6 +43,7 @@ def posemb_sincos_1d(pos, dim, temperature: int = 10000, dtype=torch.float32): omega = torch.arange(dim // 2) / (dim // 2 - 1) omega = 1.0 / (temperature**omega) + omega = omega.to(pos.device) scaled_pos = pos[:, None] * omega[None, :] pe = torch.cat((scaled_pos.sin(), scaled_pos.cos()), dim=1) From 97eb19a23a177031a02d9d73ba5db417d4de3d5f Mon Sep 17 00:00:00 2001 From: srmsoumya Date: Thu, 18 Jul 2024 17:37:52 +0530 Subject: [PATCH 02/11] Add benchmark & test files for the compiled clay encoder --- src/benchmark_encoder.py | 80 +++++++++++++++++++++++++++++++++++++ src/export.py | 65 ++++++++++++++++-------------- src/model.py | 9 ++++- src/test_encoder.py | 86 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 208 insertions(+), 32 deletions(-) create mode 100644 src/benchmark_encoder.py create mode 100644 src/test_encoder.py diff --git a/src/benchmark_encoder.py b/src/benchmark_encoder.py new file mode 100644 index 00000000..08c2b0ac --- /dev/null +++ b/src/benchmark_encoder.py @@ -0,0 +1,80 @@ +import argparse +import time +import warnings + +import torch + +warnings.filterwarnings("ignore") + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def get_data(): + """ + Generate random data tensors for model input. + """ + cube = torch.randn(128, 3, 256, 256).to(DEVICE) + timestep = torch.randn(128, 4).to(DEVICE) + latlon = torch.randn(128, 4).to(DEVICE) + waves = torch.randn(3).to(DEVICE) + gsd = torch.randn(1).to(DEVICE) + return cube, timestep, latlon, waves, gsd + + +def load_exported_model(eager=True): + """ + Load the exported model from a file. + + Args: + eager (bool): Flag to decide whether to use eager mode or compiled mode. + """ + print("Loading exported model") + ep = torch.export.load("checkpoints/compiled/encoder.pt") + if eager: + model = ep.module() + else: + model = torch.compile(ep.module(), backend="inductor") + return model + + +def benchmark_model(model): + """ + Benchmark the model by running inference on randomly generated data. + + Args: + model: The model to benchmark. + """ + print("Benchmarking model") + start = time.time() + for i in range(20): + cube, timestep, latlon, waves, gsd = get_data() + with torch.inference_mode(): + out = model(cube, timestep, latlon, waves, gsd) + print( + f"Iteration {i}: Output shapes - {out[0].shape}, {out[1].shape}, {out[2].shape}, {out[3].shape}" # noqa E501 + ) + print("Time taken for inference: ", time.time() - start) + + +def run(eager=True): + """ + Run the exported model and benchmark it. + + Args: + eager (bool): Flag to decide whether to use eager mode or compiled mode. + """ + print("Running model") + model = load_exported_model(eager=eager) + benchmark_model(model) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run benchmark for the exported model." + ) + parser.add_argument( + "--eager", action="store_true", help="Use eager mode for running the model." + ) + args = parser.parse_args() + + run(args.eager) diff --git a/src/export.py b/src/export.py index 2adccbb9..70a65f1e 100644 --- a/src/export.py +++ b/src/export.py @@ -1,3 +1,4 @@ +import warnings from pathlib import Path import torch @@ -5,38 +6,47 @@ from src.model import ClayMAEModule +warnings.filterwarnings("ignore") + CHECKPOINT_PATH = "checkpoints/clay-v1-base.ckpt" -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# device = torch.device("cpu") +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +CHIP_SIZE = 256 def get_data(): - # Load data - cube = torch.randn(128, 3, 224, 224).to(device) - time = torch.randn(128, 4).to(device) - latlon = torch.randn(128, 4).to(device) - waves = torch.randn(3).to(device) - gsd = torch.randn(1).to(device) - return cube, time, latlon, waves, gsd + """ + Generate random data tensors for model input. + """ + cube = torch.randn(128, 3, CHIP_SIZE, CHIP_SIZE).to(DEVICE) + timestep = torch.randn(128, 4).to(DEVICE) + latlon = torch.randn(128, 4).to(DEVICE) + waves = torch.randn(3).to(DEVICE) + gsd = torch.randn(1).to(DEVICE) + return cube, timestep, latlon, waves, gsd def load_model(): - module = ClayMAEModule.load_from_checkpoint(CHECKPOINT_PATH) - encoder = module.model.encoder # Get the encoder - encoder = encoder.to(device) # Move to device + """ + Load the model from a checkpoint and prepare it for evaluation. + """ + module = ClayMAEModule.load_from_checkpoint( + CHECKPOINT_PATH, shuffle=False, mask_ratio=0.0 + ) + encoder = module.model.encoder.eval() # Get the encoder in eval mode + encoder = encoder.to(DEVICE) # Move to the appropriate device return encoder -def main(): - # Load data - cube, time, latlon, waves, gsd = get_data() - - # Load model +def export_model(): + """ + Export the model with dynamic shapes for deployment. + """ + cube, timestep, latlon, waves, gsd = get_data() encoder = load_model() # Define dynamic shapes for model export - batch_size = Dim("batch_size", min=2, max=128) # Define batch size range - channel_bands = Dim("channel_bands", min=1, max=12) # Define channel bands range + batch_size = Dim("batch_size", min=32, max=1200) + channel_bands = Dim("channel_bands", min=1, max=10) dynamic_shapes = { "cube": {0: batch_size, 1: channel_bands}, @@ -47,22 +57,17 @@ def main(): } # Export model - exp_compiled_encoder = torch.export.export( + ep = torch.export.export( mod=encoder, - args=(cube, time, latlon, waves, gsd), + args=(cube, timestep, latlon, waves, gsd), dynamic_shapes=dynamic_shapes, - strict=False, + strict=True, ) - # tensortrt compiled model - # trt_encoder = torch_tensorrt.dynamo.compile( - # exp_compiled_encoder, [cube, time, latlon, waves, gsd] - # ) - - # Save model + # Save the exported model Path("checkpoints/compiled").mkdir(parents=True, exist_ok=True) - torch.export.save(exp_compiled_encoder, "checkpoints/compiled/encoder.pt") + torch.export.save(ep, "checkpoints/compiled/encoder.pt") if __name__ == "__main__": - main() + export_model() diff --git a/src/model.py b/src/model.py index 1ecf508e..900b18b3 100644 --- a/src/model.py +++ b/src/model.py @@ -39,6 +39,10 @@ def __init__( # noqa: PLR0913 self.dim = dim self.cls_token = nn.Parameter(torch.randn(1, 1, dim) * 0.02) + # Required to compile & export the model + self.grid_size = 256 // 8 + self.num_patches = self.grid_size**2 + self.patch_embedding = DynamicEmbedding( wave_dim=128, num_latent_tokens=128, @@ -64,8 +68,9 @@ def add_encodings(self, patches, time, latlon, gsd): """Add position encoding to the patches""" B, L, D = patches.shape - grid_size = int(math.sqrt(L)) - self.num_patches = grid_size**2 + # grid_size = int(math.sqrt(L)) + # self.num_patches = grid_size**2 + grid_size = self.grid_size pos_encoding = ( posemb_sincos_2d_with_gsd( diff --git a/src/test_encoder.py b/src/test_encoder.py new file mode 100644 index 00000000..14839197 --- /dev/null +++ b/src/test_encoder.py @@ -0,0 +1,86 @@ +import torch + +from src.datamodule import ClayDataModule + +# Load the pre-trained Clay encoder model +clay_encoder = torch.export.load("checkpoints/compiled/encoder.pt").module() + + +def load_batch(): + # Initialize the data module with appropriate parameters + dm = ClayDataModule( + data_dir="/home/ubuntu/data", + size=256, + metadata_path="configs/metadata.yaml", + batch_size=1, + num_workers=1, + ) + + # Setup the data module for the 'fit' stage + dm.setup(stage="fit") + metadata = dm.metadata + + # Get the training data loader and create an iterator + trn_dl = dm.train_dataloader() + iter_dl = iter(trn_dl) + + return iter_dl, metadata + + +def prepare_data(sensor, metadata, device): + """ + Load data from the sensor and transfer it to the specified device. + + Args: + - sensor (dict): Sensor data containing 'pixels', 'time', 'latlon', and 'platform'. + - metadata (dict): Metadata information for different platforms. + - device (torch.device): The device to which the data should be transferred. + + Returns: + - tuple: Transferred cube, timestep, latlon, waves, and gsd tensors. + """ + cube = sensor["pixels"] + timestep = sensor["time"] + latlon = sensor["latlon"] + platform = sensor["platform"][0] + + # Get wavelengths and ground sampling distance (gsd) from metadata + waves = torch.tensor(list(metadata[platform].bands.wavelength.values())) + gsd = torch.tensor([metadata[platform].gsd]) + + # Transfer data to the specified device + cube, timestep, latlon, waves, gsd = map( + lambda x: x.to(device), (cube, timestep, latlon, waves, gsd) + ) + return cube, timestep, latlon, waves, gsd + + +def main(): + dl, metadata = load_batch() + + # Fetch samples from the data loader + l8_c2l1 = next(dl) + l8_c2l2 = next(dl) + linz = next(dl) + naip = next(dl) + s1 = next(dl) + s2 = next(dl) + + # Perform inference with the Clay encoder model + with torch.no_grad(): + for sensor in (l8_c2l1, l8_c2l2, linz, naip, s1, s2): + # Load data and transfer to GPU + batch = prepare_data(sensor, metadata, torch.device("cuda")) + + # Get patch embeddings from the encoder model + patch_embeddings, *_ = clay_encoder(*batch) + + # Extract the class (CLS) embedding + cls_embedding = patch_embeddings[:, 0, :] + + # Print the platform and the shape of the CLS embedding + print(sensor["platform"][0], cls_embedding.shape) + + +if __name__ == "__main__": + main() From eba1867f1c5d443d76dd1369764807174bea52bf Mon Sep 17 00:00:00 2001 From: srmsoumya Date: Wed, 24 Jul 2024 20:03:40 +0530 Subject: [PATCH 03/11] Revert changes to Encoder, don't change the API --- src/model.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/model.py b/src/model.py index 900b18b3..ee211b97 100644 --- a/src/model.py +++ b/src/model.py @@ -39,10 +39,6 @@ def __init__( # noqa: PLR0913 self.dim = dim self.cls_token = nn.Parameter(torch.randn(1, 1, dim) * 0.02) - # Required to compile & export the model - self.grid_size = 256 // 8 - self.num_patches = self.grid_size**2 - self.patch_embedding = DynamicEmbedding( wave_dim=128, num_latent_tokens=128, @@ -68,9 +64,8 @@ def add_encodings(self, patches, time, latlon, gsd): """Add position encoding to the patches""" B, L, D = patches.shape - # grid_size = int(math.sqrt(L)) - # self.num_patches = grid_size**2 - grid_size = self.grid_size + grid_size = int(math.sqrt(L)) + self.num_patches = grid_size**2 pos_encoding = ( posemb_sincos_2d_with_gsd( @@ -165,14 +160,14 @@ def mask_out(self, patches): masked_matrix, ) # [B L:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B L] - def forward(self, cube, time, latlon, waves, gsd): - # cube, time, latlon, gsd, waves = ( - # datacube["pixels"], # [B C H W] - # datacube["time"], # [B 2] - # datacube["latlon"], # [B 2] - # datacube["gsd"], # 1 - # datacube["waves"], # [N] - # ) # [B C H W] + def forward(self, datacube): + cube, time, latlon, gsd, waves = ( + datacube["pixels"], # [B C H W] + datacube["time"], # [B 2] + datacube["latlon"], # [B 2] + datacube["gsd"], # 1 + datacube["waves"], # [N] + ) # [B C H W] B, C, H, W = cube.shape From 73171ddeb7f10adf0f5cd51be31542aa479f6f1a Mon Sep 17 00:00:00 2001 From: srmsoumya Date: Wed, 24 Jul 2024 22:48:49 +0530 Subject: [PATCH 04/11] Add embedder to load clay encoder & save in onnx/ep format --- finetune/embedder/factory.py | 303 +++++++++++++++++++++++++++++++++++ 1 file changed, 303 insertions(+) create mode 100644 finetune/embedder/factory.py diff --git a/finetune/embedder/factory.py b/finetune/embedder/factory.py new file mode 100644 index 00000000..bf3ee6e4 --- /dev/null +++ b/finetune/embedder/factory.py @@ -0,0 +1,303 @@ +"""Export the Clay model to ONNX and pytorch ExportedProgram format. + +This script exports the Clay model to ONNX and pytorch ExportedProgram format +for deployment. The model is exported with dynamic shapes for inference. + +How to use: + +```bash +python -m finetune.embedder.factory \ + --img_size 256 \ + --ckpt_path checkpoints/clay-v1-base.ckpt \ + --device cuda \ + --name clay-v1-encoder.onnx \ + --onnx +# exports Clay encoder to ONNX format that can handle chips of size 256x256 +# for different sensors like Sentinel-2, Landsat-8, NAIP, LINZ & Sentinel 1. +``` + +```bash +python -m finetune.embedder.factory \ + --img_size 224 \ + --ckpt_path checkpoints/clay-v1-base.ckpt \ + --device cuda \ + --name clay-v1-encoder.pt2 \ + --ep +# exports Clay encoder to pytorch ExportedProgram format that can handle chips +# of size 224x224 for different sensors like Sentinel-2, Landsat-8, NAIP, LINZ +# & Sentinel 1. +``` + +""" + +import argparse +import re +import warnings +from pathlib import Path + +import torch +from einops import repeat +from torch import nn +from torch.export import Dim + +from src.model import Encoder +from src.utils import posemb_sincos_2d_with_gsd + +warnings.filterwarnings("ignore", category=UserWarning) + + +class EmbeddingEncoder(Encoder): + """Clay Encoder without mask and shuffle.""" + + def __init__( # noqa: PLR0913 + self, + img_size, + patch_size, + dim, + depth, + heads, + dim_head, + mlp_ratio, + ): + super().__init__( + mask_ratio=0.0, + shuffle=False, + patch_size=patch_size, + dim=dim, + depth=depth, + heads=heads, + dim_head=dim_head, + mlp_ratio=mlp_ratio, + ) + self.img_size = img_size + + # Using fixed grid size for inference + self.grid_size = img_size // patch_size + self.num_patches = self.grid_size**2 + + def add_encodings(self, patches, time, latlon, gsd): + """Add position encoding to the patches""" + B, L, D = patches.shape + + grid_size = self.grid_size + + pos_encoding = ( + posemb_sincos_2d_with_gsd( + h=grid_size, + w=grid_size, + dim=(self.dim - 8), + gsd=gsd, + ) + .to(patches.device) + .detach() + ) # [L (D - 8)] + + time_latlon = torch.hstack((time, latlon)).to(patches.device).detach() # [B 8] + + pos_encoding = repeat(pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)] + time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8] + pos_metadata_encoding = torch.cat( + (pos_encoding, time_latlon), dim=-1 + ) # [B L D] + + patches = patches + pos_metadata_encoding # [B L D] + [B L D] -> [B L D] + return patches # [B L D] + + # def forward(self, cube, time, latlon, waves, gsd): + def forward(self, datacube): + cube, time, latlon, gsd, waves = ( + datacube["pixels"], # [B C H W] + datacube["time"], # [B 2] + datacube["latlon"], # [B 2] + datacube["gsd"], # 1 + datacube["waves"], # [N] + ) # [B C H W] + B, C, H, W = cube.shape + + patches, _ = self.to_patch_embed( + cube, waves + ) # [B L D] - patchify & create embeddings per patch + + # Add time & latlon as encoding to patches + patches = self.add_encodings( + patches, + time, + latlon, + gsd, + ) # [B L D] - add position encoding to the embeddings + + # Add class tokens + cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] + patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D] + + # pass the patches through the transformer + patches = self.transformer(patches) # [B (1 + L) D] + + # get the cls token + embeddings = patches[:, 0, :] # [B D] + + return embeddings + + +class Embedder(nn.Module): + def __init__(self, img_size=256, ckpt_path=None, device="cpu"): + super().__init__() + self.clay_encoder = ( + EmbeddingEncoder( # Default parameters for the Clay base model + img_size=img_size, + patch_size=8, + dim=768, + depth=12, + heads=12, + dim_head=64, + mlp_ratio=4.0, + ).to(device) + ) + self.img_size = img_size + self.device = torch.device(device) + self.load_clay_weights(ckpt_path) + + def load_clay_weights(self, ckpt_path): + "Load the weights from the Clay model encoder." + ckpt = torch.load(ckpt_path, map_location=self.device) + state_dict = ckpt.get("state_dict") + state_dict = { + re.sub(r"^model\.encoder\.", "", name): param + for name, param in state_dict.items() + if name.startswith("model.encoder") + } + + with torch.no_grad(): + for name, param in self.clay_encoder.named_parameters(): + if name in state_dict and param.size() == state_dict[name].size(): + param.data.copy_(state_dict[name]) # Copy the weights + else: + print(f"No matching parameter for {name} with size {param.size()}") + + for param in self.clay_encoder.parameters(): + param.requires_grad = False + + self.clay_encoder.eval() + + def forward(self, datacube): + embeddings = self.clay_encoder(datacube) + + return embeddings + + def fake_datacube(self): + "Generate a fake datacube for model export." + dummy_datacube = { + "pixels": torch.randn(2, 3, self.img_size, self.img_size), + "time": torch.randn(2, 4), + "latlon": torch.randn(2, 4), + "waves": torch.randn(3), + "gsd": torch.randn(1), + } + dummy_datacube = {k: v.to(self.device) for k, v in dummy_datacube.items()} + return dummy_datacube + + def export_to_onnx(self, name): + "Save the model to ONNX format." + + datacube = self.fake_datacube() + export_options = torch.onnx.ExportOptions(dynamic_shapes=True) + + # Export the model to ONNX format + onnx_program = torch.onnx.dynamo_export( + self.eval(), datacube, export_options=export_options + ) + + # Save the exported model + onnx_program.save(f"checkpoints/compiled/{name}") + print(f"Model exported to ONNX format: checkpoints/compiled/{name}") + + return onnx_program + + def export_to_torchep(self, name): + "Save the model to pytorch ExportedProgram format." + + datacube = self.fake_datacube() + + # dynamic shapes for model export + batch_size = Dim("batch_size", min=2, max=1000) + channel_bands = Dim("channel_bands", min=1, max=10) + dynamic_shapes = { + "datacube": { + "pixels": {0: batch_size, 1: channel_bands}, + "time": {0: batch_size}, + "latlon": {0: batch_size}, + "waves": {0: channel_bands}, + "gsd": {0: None}, + } + } + + # Export the model to pytorch ExportedProgram format + ep = torch.export.export( + self.eval(), + (datacube,), + dynamic_shapes=dynamic_shapes, + strict=True, + ) + + # Save the exported model + torch.export.save(ep, f"checkpoints/compiled/{name}") + print( + f"Model exported to pytorch ExportedProgram format: checkpoints/compiled/{name}" # noqa: E501 + ) + + return ep + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Export the Clay model.") + parser.add_argument( + "--img_size", + type=int, + default=256, + help="Image size for the model", + ) + parser.add_argument( + "--ckpt_path", + type=str, + default="checkpoints/clay-v1-base.ckpt", + help="Path to the Clay model checkpoint", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use for the model", + ) + parser.add_argument( + "--name", + type=str, + default="clay-base.pt", + help="Name of the exported model", + ) + parser.add_argument( + "--onnx", + action="store_true", + help="Export the model to ONNX format", + ) + parser.add_argument( + "--ep", + action="store_true", + help="Export the model to pytorch ExportedProgram format", + ) + + args = parser.parse_args() + + Path("checkpoints/compiled").mkdir(parents=True, exist_ok=True) + embedder = Embedder( + img_size=args.img_size, + ckpt_path=args.ckpt_path, + device=args.device, + ) + + if args.onnx: + embedder.export_to_onnx(args.name) + elif args.ep: + embedder.export_to_torchep(args.name) + else: + print("Please specify the format to export the model.") + parser.print_help() From 1f2fcc9a0c74518907bb2a41c9cba86ca97e5f29 Mon Sep 17 00:00:00 2001 From: srmsoumya Date: Thu, 25 Jul 2024 13:16:38 +0530 Subject: [PATCH 05/11] Remove files from src, fix utils to run everything on same device --- src/benchmark_encoder.py | 80 ------------------------------------- src/export.py | 73 ---------------------------------- src/test_encoder.py | 86 ---------------------------------------- src/utils.py | 16 ++++---- 4 files changed, 7 insertions(+), 248 deletions(-) delete mode 100644 src/benchmark_encoder.py delete mode 100644 src/export.py delete mode 100644 src/test_encoder.py diff --git a/src/benchmark_encoder.py b/src/benchmark_encoder.py deleted file mode 100644 index 08c2b0ac..00000000 --- a/src/benchmark_encoder.py +++ /dev/null @@ -1,80 +0,0 @@ -import argparse -import time -import warnings - -import torch - -warnings.filterwarnings("ignore") - -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -def get_data(): - """ - Generate random data tensors for model input. - """ - cube = torch.randn(128, 3, 256, 256).to(DEVICE) - timestep = torch.randn(128, 4).to(DEVICE) - latlon = torch.randn(128, 4).to(DEVICE) - waves = torch.randn(3).to(DEVICE) - gsd = torch.randn(1).to(DEVICE) - return cube, timestep, latlon, waves, gsd - - -def load_exported_model(eager=True): - """ - Load the exported model from a file. - - Args: - eager (bool): Flag to decide whether to use eager mode or compiled mode. - """ - print("Loading exported model") - ep = torch.export.load("checkpoints/compiled/encoder.pt") - if eager: - model = ep.module() - else: - model = torch.compile(ep.module(), backend="inductor") - return model - - -def benchmark_model(model): - """ - Benchmark the model by running inference on randomly generated data. - - Args: - model: The model to benchmark. - """ - print("Benchmarking model") - start = time.time() - for i in range(20): - cube, timestep, latlon, waves, gsd = get_data() - with torch.inference_mode(): - out = model(cube, timestep, latlon, waves, gsd) - print( - f"Iteration {i}: Output shapes - {out[0].shape}, {out[1].shape}, {out[2].shape}, {out[3].shape}" # noqa E501 - ) - print("Time taken for inference: ", time.time() - start) - - -def run(eager=True): - """ - Run the exported model and benchmark it. - - Args: - eager (bool): Flag to decide whether to use eager mode or compiled mode. - """ - print("Running model") - model = load_exported_model(eager=eager) - benchmark_model(model) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Run benchmark for the exported model." - ) - parser.add_argument( - "--eager", action="store_true", help="Use eager mode for running the model." - ) - args = parser.parse_args() - - run(args.eager) diff --git a/src/export.py b/src/export.py deleted file mode 100644 index 70a65f1e..00000000 --- a/src/export.py +++ /dev/null @@ -1,73 +0,0 @@ -import warnings -from pathlib import Path - -import torch -from torch.export import Dim - -from src.model import ClayMAEModule - -warnings.filterwarnings("ignore") - -CHECKPOINT_PATH = "checkpoints/clay-v1-base.ckpt" -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -CHIP_SIZE = 256 - - -def get_data(): - """ - Generate random data tensors for model input. - """ - cube = torch.randn(128, 3, CHIP_SIZE, CHIP_SIZE).to(DEVICE) - timestep = torch.randn(128, 4).to(DEVICE) - latlon = torch.randn(128, 4).to(DEVICE) - waves = torch.randn(3).to(DEVICE) - gsd = torch.randn(1).to(DEVICE) - return cube, timestep, latlon, waves, gsd - - -def load_model(): - """ - Load the model from a checkpoint and prepare it for evaluation. - """ - module = ClayMAEModule.load_from_checkpoint( - CHECKPOINT_PATH, shuffle=False, mask_ratio=0.0 - ) - encoder = module.model.encoder.eval() # Get the encoder in eval mode - encoder = encoder.to(DEVICE) # Move to the appropriate device - return encoder - - -def export_model(): - """ - Export the model with dynamic shapes for deployment. - """ - cube, timestep, latlon, waves, gsd = get_data() - encoder = load_model() - - # Define dynamic shapes for model export - batch_size = Dim("batch_size", min=32, max=1200) - channel_bands = Dim("channel_bands", min=1, max=10) - - dynamic_shapes = { - "cube": {0: batch_size, 1: channel_bands}, - "time": {0: batch_size}, - "latlon": {0: batch_size}, - "waves": {0: channel_bands}, - "gsd": {0: None}, - } - - # Export model - ep = torch.export.export( - mod=encoder, - args=(cube, timestep, latlon, waves, gsd), - dynamic_shapes=dynamic_shapes, - strict=True, - ) - - # Save the exported model - Path("checkpoints/compiled").mkdir(parents=True, exist_ok=True) - torch.export.save(ep, "checkpoints/compiled/encoder.pt") - - -if __name__ == "__main__": - export_model() diff --git a/src/test_encoder.py b/src/test_encoder.py deleted file mode 100644 index 14839197..00000000 --- a/src/test_encoder.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch - -from src.datamodule import ClayDataModule - -# Load the pre-trained Clay encoder model -clay_encoder = torch.export.load("checkpoints/compiled/encoder.pt").module() - - -def load_batch(): - # Initialize the data module with appropriate parameters - dm = ClayDataModule( - data_dir="/home/ubuntu/data", - size=256, - metadata_path="configs/metadata.yaml", - batch_size=1, - num_workers=1, - ) - - # Setup the data module for the 'fit' stage - dm.setup(stage="fit") - metadata = dm.metadata - - # Get the training data loader and create an iterator - trn_dl = dm.train_dataloader() - iter_dl = iter(trn_dl) - - return iter_dl, metadata - - -def prepare_data(sensor, metadata, device): - """ - Load data from the sensor and transfer it to the specified device. - - Args: - - sensor (dict): Sensor data containing 'pixels', 'time', 'latlon', and 'platform'. - - metadata (dict): Metadata information for different platforms. - - device (torch.device): The device to which the data should be transferred. - - Returns: - - tuple: Transferred cube, timestep, latlon, waves, and gsd tensors. - """ - cube = sensor["pixels"] - timestep = sensor["time"] - latlon = sensor["latlon"] - platform = sensor["platform"][0] - - # Get wavelengths and ground sampling distance (gsd) from metadata - waves = torch.tensor(list(metadata[platform].bands.wavelength.values())) - gsd = torch.tensor([metadata[platform].gsd]) - - # Transfer data to the specified device - cube, timestep, latlon, waves, gsd = map( - lambda x: x.to(device), (cube, timestep, latlon, waves, gsd) - ) - return cube, timestep, latlon, waves, gsd - - -def main(): - dl, metadata = load_batch() - - # Fetch samples from the data loader - l8_c2l1 = next(dl) - l8_c2l2 = next(dl) - linz = next(dl) - naip = next(dl) - s1 = next(dl) - s2 = next(dl) - - # Perform inference with the Clay encoder model - with torch.no_grad(): - for sensor in (l8_c2l1, l8_c2l2, linz, naip, s1, s2): - # Load data and transfer to GPU - batch = prepare_data(sensor, metadata, torch.device("cuda")) - - # Get patch embeddings from the encoder model - patch_embeddings, *_ = clay_encoder(*batch) - - # Extract the class (CLS) embedding - cls_embedding = patch_embeddings[:, 0, :] - - # Print the platform and the shape of the CLS embedding - print(sensor["platform"][0], cls_embedding.shape) - - -if __name__ == "__main__": - main() diff --git a/src/utils.py b/src/utils.py index 1e35731f..b0f2bcce 100644 --- a/src/utils.py +++ b/src/utils.py @@ -11,7 +11,6 @@ def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" omega = torch.arange(dim // 4) / (dim // 4 - 1) omega = 1.0 / (temperature**omega) - omega = omega.to(y.device) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] @@ -25,9 +24,9 @@ def posemb_sincos_2d_with_gsd( y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" - omega = torch.arange(dim // 4, device=gsd.device) / (dim // 4 - 1) + gsd = gsd.to(x.device) + omega = torch.arange(dim // 4) / (dim // 4 - 1) omega = 1.0 / (temperature ** (2 * omega / dim)) * (gsd / 1.0) # Adjusted for g - omega = omega.to(y.device) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] @@ -35,17 +34,16 @@ def posemb_sincos_2d_with_gsd( return pe.type(dtype) -def posemb_sincos_1d(pos, dim, temperature: int = 10000, dtype=torch.float32): +def posemb_sincos_1d(waves, dim, temperature: int = 10000, dtype=torch.float32): assert ( dim % 2 == 0 ), "Feature dimension must be a multiple of 2 for sincos embedding" - pos = torch.arange(pos) if isinstance(pos, int) else pos + waves = torch.arange(waves) if isinstance(waves, int) else waves - omega = torch.arange(dim // 2) / (dim // 2 - 1) + omega = torch.arange(dim // 2, device=waves.device) / (dim // 2 - 1) omega = 1.0 / (temperature**omega) - omega = omega.to(pos.device) - scaled_pos = pos[:, None] * omega[None, :] - pe = torch.cat((scaled_pos.sin(), scaled_pos.cos()), dim=1) + scaled_waves = waves[:, None] * omega[None, :] + pe = torch.cat((scaled_waves.sin(), scaled_waves.cos()), dim=1) return pe.type(dtype) From 0a3ce9d90e4e9f2871b7d9db26b06d2583aef6bd Mon Sep 17 00:00:00 2001 From: srmsoumya Date: Thu, 25 Jul 2024 14:07:42 +0530 Subject: [PATCH 06/11] Bump torch==2.3.1 & torchvision==0.18.1, add onnx & onnxsxript as dependency --- environment.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/environment.yml b/environment.yml index 699df9df..2190062f 100644 --- a/environment.yml +++ b/environment.yml @@ -14,10 +14,12 @@ dependencies: - lancedb~=0.10.2 - lightning~=2.1.0 - matplotlib-base~=3.8.2 + - onnx~=1.16.1 + - onnxscript~=0.1.0.dev20240724 - planetary-computer~=1.0.0 - python-box~=7.1.0 - pytorch~=2.1.0 # [osx] - - pytorch~=2.1.0 *cuda12* # [linux] + - pytorch~=2.3.1 *cuda12* # [linux] - python~=3.11.0 - pyarrow~=16.1.0 - rioxarray~=0.15.0 @@ -29,7 +31,7 @@ dependencies: - timm~=0.9.16 - torchdata~=0.7.1 - torchgeo~=0.5.2 - - torchvision~=0.16.1 + - torchvision~=0.18.1 - transformers~=4.35.2 - typeshed-client~=2.4.0 - vit-pytorch~=1.6.4 From 37503feb487e377021ae62a0ee3386e5f976469e Mon Sep 17 00:00:00 2001 From: srmsoumya Date: Thu, 25 Jul 2024 16:40:03 +0530 Subject: [PATCH 07/11] Release few contraints on env --- environment.yml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/environment.yml b/environment.yml index 2190062f..79a56ec8 100644 --- a/environment.yml +++ b/environment.yml @@ -11,14 +11,13 @@ dependencies: - jupyter-book~=1.0.0 - jupyterlab~=4.0.7 - jsonargparse~=4.27.0 - - lancedb~=0.10.2 - lightning~=2.1.0 - matplotlib-base~=3.8.2 - onnx~=1.16.1 - - onnxscript~=0.1.0.dev20240724 + - onnxscript - planetary-computer~=1.0.0 - python-box~=7.1.0 - - pytorch~=2.1.0 # [osx] + - pytorch~=2.3.1 # [osx] - pytorch~=2.3.1 *cuda12* # [linux] - python~=3.11.0 - pyarrow~=16.1.0 @@ -29,13 +28,12 @@ dependencies: - scikit-learn~=1.4.0 - stackstac~=0.5.0 - timm~=0.9.16 - - torchdata~=0.7.1 - - torchgeo~=0.5.2 + - torchgeo - torchvision~=0.18.1 - transformers~=4.35.2 - typeshed-client~=2.4.0 - vit-pytorch~=1.6.4 - - wandb~=0.15.12 + - wandb - zarr~=2.16.1 platforms: - linux-64 From db8a3f2a033835606e1360f6efdec9130599ae4b Mon Sep 17 00:00:00 2001 From: srmsoumya Date: Thu, 25 Jul 2024 12:52:26 +0000 Subject: [PATCH 08/11] Add notebook to show how to embed using compiled embedders --- environment.yml | 18 +- finetune/embedder/how-to-embed.ipynb | 637 +++++++++++++++++++++++++++ 2 files changed, 647 insertions(+), 8 deletions(-) create mode 100644 finetune/embedder/how-to-embed.ipynb diff --git a/environment.yml b/environment.yml index 79a56ec8..2eb2d1c3 100644 --- a/environment.yml +++ b/environment.yml @@ -7,33 +7,35 @@ dependencies: - einops~=0.7.0 - fiona~=1.9.5 - geopandas-base~=0.14.1 - - h5netcdf~=1.3.0 - - jupyter-book~=1.0.0 - - jupyterlab~=4.0.7 - jsonargparse~=4.27.0 - lightning~=2.1.0 - matplotlib-base~=3.8.2 - - onnx~=1.16.1 - - onnxscript - planetary-computer~=1.0.0 - python-box~=7.1.0 - pytorch~=2.3.1 # [osx] - pytorch~=2.3.1 *cuda12* # [linux] - python~=3.11.0 - pyarrow~=16.1.0 - - rioxarray~=0.15.0 - rasterio~=1.3.10 - s3fs~=2024.3.1 - scikit-image~=0.22.0 - scikit-learn~=1.4.0 - stackstac~=0.5.0 - timm~=0.9.16 - - torchgeo - torchvision~=0.18.1 - transformers~=4.35.2 - typeshed-client~=2.4.0 - vit-pytorch~=1.6.4 - - wandb - zarr~=2.16.1 + - pip: + - geoarrow-pyarrow==0.1.2 + - jupyter-book==1.0.2 + - jupyterlab==4.2.4 + - onnx==1.16.1 + - onnxscript + - onnxruntime + - torchgeo==0.5.2 + - stacchip==0.1.35 + - wandb==0.17.5 platforms: - linux-64 diff --git a/finetune/embedder/how-to-embed.ipynb b/finetune/embedder/how-to-embed.ipynb new file mode 100644 index 00000000..06f55cc7 --- /dev/null +++ b/finetune/embedder/how-to-embed.ipynb @@ -0,0 +1,637 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "d9960547-640d-425c-8180-fc5523a80e42", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import os\n", + "import requests\n", + "import warnings\n", + "\n", + "import geoarrow.pyarrow as ga\n", + "import numpy as np\n", + "import pystac_client\n", + "import pyarrow as pa\n", + "import pyarrow.parquet as pq\n", + "import torch\n", + "import yaml\n", + "from box import Box\n", + "from torchvision.transforms import v2\n", + "\n", + "from stacchip.indexer import Sentinel2Indexer\n", + "from stacchip.chipper import Chipper\n", + "\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "markdown", + "id": "598fec81-2cc1-4c5a-9e46-7c46a5591484", + "metadata": {}, + "source": [ + "### Find data for AOI\n", + "The first step is to find STAC items of imagery that we want to use to create embeddings. In this example we are going to use Earth Genome's composite dataset which comes with a great STAC catalog.\n", + "\n", + "We are also going to create embeddings along time so that we have multiple embeddings for the same location at different moments in time." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3e1d46ee-40f6-49f5-99ad-83819339561e", + "metadata": {}, + "outputs": [], + "source": [ + "# Point over Monchique Portugal\n", + "lat, lon = 37.30939, -8.57207\n", + "\n", + "# Dates of a large forest fire\n", + "start = \"2018-07-01\"\n", + "end = \"2018-09-01\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b7825318-23f3-449f-9104-eae6562a55ab", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 12 items\n" + ] + } + ], + "source": [ + "# Optimize GDAL settings for cloud optimized reading\n", + "os.environ[\"GDAL_DISABLE_READDIR_ON_OPEN\"] = \"EMPTY_DIR\"\n", + "os.environ[\"AWS_REQUEST_PAYER\"] = \"requester\"\n", + "\n", + "STAC_API = \"https://earth-search.aws.element84.com/v1\"\n", + "COLLECTION = \"sentinel-2-l2a\"\n", + "\n", + "# Search the catalogue\n", + "catalog = pystac_client.Client.open(STAC_API)\n", + "search = catalog.search(\n", + " collections=[COLLECTION],\n", + " datetime=f\"{start}/{end}\",\n", + " bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),\n", + " max_items=100,\n", + " query={\"eo:cloud_cover\": {\"lt\": 80}},\n", + ")\n", + "\n", + "all_items = search.get_all_items()\n", + "\n", + "# Reduce to one per date (there might be some duplicates\n", + "# based on the location)\n", + "items = []\n", + "dates = []\n", + "for item in all_items:\n", + " if item.datetime.date() not in dates:\n", + " items.append(item)\n", + " dates.append(item.datetime.date())\n", + "\n", + "print(f\"Found {len(items)} items\")" + ] + }, + { + "cell_type": "markdown", + "id": "600f3cfb-ce4e-4409-ae15-20f3a7107a62", + "metadata": {}, + "source": [ + "To speed up processing in this example, we limit the number of chips to 3 per Sentinel-2 scene. Remove this limit in a real use case." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "183975c7-8afb-49ef-8e70-790265719aea", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n" + ] + } + ], + "source": [ + "chips = []\n", + "datetimes = []\n", + "bboxs = []\n", + "chip_ids = []\n", + "item_ids = []\n", + "\n", + "for item in items:\n", + " print(f\"Working on {item}\")\n", + "\n", + " # Index the chips in the item\n", + " indexer = Sentinel2Indexer(item)\n", + "\n", + " # Instanciate the chipper\n", + " chipper = Chipper(indexer, assets=[\"red\", \"green\", \"blue\", \"nir\", \"scl\"])\n", + "\n", + " # Get first chip for the \"image\" asset key\n", + " for idx, (x, y, chip) in enumerate(chipper):\n", + " if idx > 2:\n", + " break\n", + " del chip[\"scl\"]\n", + " chips.append(chip)\n", + " datetimes.append(item.datetime)\n", + " bboxs.append(indexer.get_chip_bbox(x, y))\n", + " chip_ids.append((x, y))\n", + " item_ids.append(item.id)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "71902ab7-3320-43cd-85c3-362c2500f241", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(36, 4, 256, 256)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pixels = np.array([np.array(list(chip.values())).squeeze() for chip in chips])\n", + "pixels.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6f7ce367-4e12-4648-bb79-119b4f50ead8", + "metadata": {}, + "outputs": [], + "source": [ + "# Extract mean, std, and wavelengths from metadata\n", + "platform = \"sentinel-2-l2a\"\n", + "# Retrieve the file content from the URL\n", + "\n", + "url = (\n", + " \"https://raw.githubusercontent.com/Clay-foundation/model/main/configs/metadata.yaml\"\n", + ")\n", + "response = requests.get(url, allow_redirects=True)\n", + "\n", + "# Convert bytes to string\n", + "content = response.content.decode(\"utf-8\")\n", + "\n", + "# Load the yaml\n", + "content = yaml.safe_load(content)\n", + "\n", + "metadata = Box(content)\n", + "mean = []\n", + "std = []\n", + "waves = []\n", + "# Use the band names to get the correct values in the correct order.\n", + "for band in chips[0].keys():\n", + " mean.append(metadata[platform].bands.mean[band])\n", + " std.append(metadata[platform].bands.std[band])\n", + " waves.append(metadata[platform].bands.wavelength[band])\n", + "\n", + "# Prepare the normalization transform function using the mean and std values.\n", + "transform = v2.Compose(\n", + " [\n", + " v2.Normalize(mean=mean, std=std),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a8ec8c2d-ecb9-42a2-9e8c-3f95c67ef07b", + "metadata": {}, + "outputs": [], + "source": [ + "def normalize_timestamp(date):\n", + " week = date.isocalendar().week * 2 * np.pi / 52\n", + " hour = date.hour * 2 * np.pi / 24\n", + "\n", + " return (math.sin(week), math.cos(week)), (math.sin(hour), math.cos(hour))\n", + "\n", + "\n", + "times = [normalize_timestamp(dat) for dat in datetimes]\n", + "week_norm = [dat[0] for dat in times]\n", + "hour_norm = [dat[1] for dat in times]\n", + "\n", + "\n", + "# Prep lat/lon embedding using the\n", + "def normalize_latlon(lat, lon):\n", + " lat = lat * np.pi / 180\n", + " lon = lon * np.pi / 180\n", + "\n", + " return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))\n", + "\n", + "\n", + "latlons = [normalize_latlon(lat, lon)] * len(times)\n", + "lat_norm = [dat[0] for dat in latlons]\n", + "lon_norm = [dat[1] for dat in latlons]\n", + "\n", + "# Prep gsd\n", + "gsd = [10]\n", + "\n", + "# Normalize pixels\n", + "pixels = transform(pixels)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2640eb17-a85c-4972-8d5d-e45e9ed8eba5", + "metadata": {}, + "outputs": [], + "source": [ + "datacube = {\n", + " \"pixels\": torch.tensor(pixels, dtype=torch.float32),\n", + " \"time\": torch.tensor(np.hstack((week_norm, hour_norm)), dtype=torch.float32),\n", + " \"latlon\": torch.tensor(np.hstack((lat_norm, lon_norm)), dtype=torch.float32),\n", + " \"waves\": torch.tensor(waves, dtype=torch.float32),\n", + " \"gsd\": torch.tensor(gsd, dtype=torch.float32),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7f6711a9-e7ed-44d5-add7-2c3a498cd422", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pixels torch.Size([36, 4, 256, 256])\n", + "time torch.Size([36, 4])\n", + "latlon torch.Size([36, 4])\n", + "waves torch.Size([4])\n", + "gsd torch.Size([1])\n" + ] + } + ], + "source": [ + "for k,v in datacube.items():\n", + " print(k, v.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "83243912-a2a8-4fa5-a39c-a9c3b07c7569", + "metadata": {}, + "source": [ + "### Clay Embedder\n", + "\n", + "#### Load the embedder that is stored in ExportedProgram format using **cpu**." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "4eb468af-d468-46aa-a8fb-23ff95c56288", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/compiled/v1.0/clay-v1-encoder-cpu.pt2" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "9eb797f7-5238-49e0-9950-e85f10132454", + "metadata": {}, + "outputs": [], + "source": [ + "ep_embedder_cpu = torch.export.load(\"clay-v1-encoder-cpu.pt2\").module()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "eefe4811-7290-47c3-a10e-45257e6d42e0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2min 36s, sys: 26.9 s, total: 3min 3s\n", + "Wall time: 51.3 s\n" + ] + }, + { + "data": { + "text/plain": [ + "(torch.Size([36, 4, 256, 256]), torch.Size([36, 768]))" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "with torch.no_grad():\n", + " embeddings = ep_embedder_cpu(datacube)\n", + "datacube[\"pixels\"].shape, embeddings.shape" + ] + }, + { + "cell_type": "markdown", + "id": "8e927b01-c855-4172-a4d9-2c10ba794ed4", + "metadata": {}, + "source": [ + "For each chip, we have an embedding of size `768`" + ] + }, + { + "cell_type": "markdown", + "id": "fa0810b4-34ad-490e-bbcd-c0c3288f017c", + "metadata": {}, + "source": [ + "#### Load the embedder that is stored in ExportedProgram format using **gpu**." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "9c1bbfd4-7dc6-4ad0-8a0b-b3745a9f35ca", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/compiled/v1.0/clay-v1-encoder.pt2" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "e285a543-20ab-44ba-b676-2303284dc477", + "metadata": {}, + "outputs": [], + "source": [ + "datacube = {k:v.to(\"cuda\") for k,v in datacube.items()}\n", + "ep_embedder = torch.export.load(\"clay-v1-encoder.pt2\").module()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "edefee90-e6b8-4701-bb5d-2bf7febc806c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 313 ms, sys: 41.5 ms, total: 354 ms\n", + "Wall time: 239 ms\n" + ] + }, + { + "data": { + "text/plain": [ + "(torch.Size([36, 4, 256, 256]), torch.Size([36, 768]))" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "with torch.no_grad():\n", + " embeddings = ep_embedder(datacube)\n", + "datacube[\"pixels\"].shape, embeddings.shape" + ] + }, + { + "cell_type": "markdown", + "id": "196f2121-46b5-4b02-94d3-75e648c329c3", + "metadata": {}, + "source": [ + "For each chip, we have an embedding of size `768`" + ] + }, + { + "cell_type": "markdown", + "id": "5b1cb0f9-a434-419b-a88b-4d4edd84fea6", + "metadata": {}, + "source": [ + "#### Load the embedder that is stored in ONNX format using **cpu**." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "aa10d696-740a-458e-ae10-eec9a43fb362", + "metadata": {}, + "outputs": [], + "source": [ + "import onnx\n", + "import onnxruntime as ort" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "992524e5-2c2a-4e48-ae95-bd2aa87b72a9", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/compiled/v1.0/clay-v1-encoder-cpu.onnx" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "dc3fa967-73d5-431c-88a2-84b088aff06f", + "metadata": {}, + "outputs": [], + "source": [ + "datacube = {k:v.to(\"cpu\") for k,v in datacube.items()}\n", + "onnx_embedder = ort.InferenceSession(\"clay-v1-encoder-cpu.onnx\", \n", + " providers=[\"CPUExecutionProvider\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "24591d17-d1c8-452b-9b20-676a9b6f8643", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3min 48s, sys: 1.82 s, total: 3min 50s\n", + "Wall time: 30.6 s\n" + ] + }, + { + "data": { + "text/plain": [ + "(36, 768)" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "embeddings = onnx_embedder.run([], {\n", + " \"cube\": datacube[\"pixels\"].numpy(),\n", + " \"time\": datacube[\"time\"].numpy(),\n", + " \"latlon\": datacube[\"latlon\"].numpy(),\n", + " \"waves\": datacube[\"waves\"].numpy(),\n", + " \"gsd\": datacube[\"gsd\"].numpy()\n", + "})[0]\n", + "embeddings.shape" + ] + }, + { + "cell_type": "markdown", + "id": "9c07216e-a109-4cd8-8c74-9a3fc9a37757", + "metadata": {}, + "source": [ + "For each chip, we have an embedding of size `768`" + ] + }, + { + "cell_type": "markdown", + "id": "2e8d5900-9a4b-4e2d-b992-4fb0a1e8c835", + "metadata": {}, + "source": [ + "### Store the results\n", + "\n", + "We create a table containing the embeddings, bounding box, the STAC item ID, the datetime of the image capture, and the chip x and y ids. Then we save that data to disk." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "677f04d3-db38-4d44-9b55-c103d54adcd5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "pyarrow.Table\n", + "datetimes: timestamp[us, tz=UTC]\n", + "chip_ids: list\n", + " child 0, item: int64\n", + "item_ids: string\n", + "emeddings: list\n", + " child 0, item: float\n", + "geometry: extension>\n", + "----\n", + "datetimes: [[2018-08-28 11:30:56.771000Z,2018-08-28 11:30:56.771000Z,2018-08-28 11:30:56.771000Z,2018-08-23 11:30:50.574000Z,2018-08-23 11:30:50.574000Z,...,2018-07-09 11:24:55.535000Z,2018-07-09 11:24:55.535000Z,2018-07-04 11:30:35.271000Z,2018-07-04 11:30:35.271000Z,2018-07-04 11:30:35.271000Z]]\n", + "chip_ids: [[[0,0],[1,0],...,[1,0],[2,0]]]\n", + "item_ids: [[\"S2A_29SNB_20180828_1_L2A\",\"S2A_29SNB_20180828_1_L2A\",\"S2A_29SNB_20180828_1_L2A\",\"S2B_29SNB_20180823_1_L2A\",\"S2B_29SNB_20180823_1_L2A\",...,\"S2A_29SNB_20180709_0_L2A\",\"S2A_29SNB_20180709_0_L2A\",\"S2B_29SNB_20180704_0_L2A\",\"S2B_29SNB_20180704_0_L2A\",\"S2B_29SNB_20180704_0_L2A\"]]\n", + "emeddings: [[[-0.14773342,0.08466571,0.13797832,0.11150883,0.06517959,...,0.036681578,-0.092160255,0.025934512,-0.12496276,-0.034070153],[-0.14430065,0.085857555,0.13839196,0.10963549,0.0652737,...,0.03711322,-0.09153629,0.02631686,-0.12422915,-0.03333628],...,[-0.09626354,0.062443394,0.24817112,0.012715777,0.043093704,...,0.011770063,-0.037860263,0.027813748,-0.11962952,-0.02246455],[-0.10004063,0.06320572,0.24851695,0.012129029,0.043350283,...,0.011444314,-0.03733269,0.027787287,-0.12139094,-0.021088997]]]\n", + "geometry: [[[ -- is_valid: all not null\n", + " -- child 0 type: double\n", + "[-8.825403979293151,-8.825730459265694,-9.000227209792856,-9.000227635454767,-8.825403979293151]\n", + " -- child 1 type: double\n", + "[37.947460030545635,37.809019655564406,37.809148556380286,37.947589571562965,37.947460030545635]],[ -- is_valid: all not null\n", + " -- child 0 type: double\n", + "[-8.650582567535476,-8.651235936821893,-8.825730459265694,-8.825403979293151,-8.650582567535476]\n", + " -- child 1 type: double\n", + "[37.94707073614538,37.80863228507305,37.809019655564406,37.947460030545635,37.94707073614538]],...,[ -- is_valid: all not null\n", + " -- child 0 type: double\n", + "[-8.650582567535476,-8.651235936821893,-8.825730459265694,-8.825403979293151,-8.650582567535476]\n", + " -- child 1 type: double\n", + "[37.94707073614538,37.80863228507305,37.809019655564406,37.947460030545635,37.94707073614538]],[ -- is_valid: all not null\n", + " -- child 0 type: double\n", + "[-8.475765647330832,-8.476745873271028,-8.651235936821893,-8.650582567535476,-8.475765647330832]\n", + " -- child 1 type: double\n", + "[37.94642170369997,37.80798646012822,37.80863228507305,37.94707073614538,37.94642170369997]]]]" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Write data to pyarrow table\n", + "index = {\n", + " \"datetimes\": datetimes,\n", + " \"chip_ids\": chip_ids,\n", + " \"item_ids\": item_ids,\n", + " \"emeddings\": [np.ascontiguousarray(dat) for dat in embeddings],\n", + " \"geometry\": ga.as_geoarrow([dat.wkt for dat in bboxs]),\n", + "}\n", + "table = pa.table(index)\n", + "table" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "d62a9e8a-b4f9-491c-a437-6a164a9e74fe", + "metadata": {}, + "outputs": [], + "source": [ + "pq.write_table(table, \"embeddings.parquet\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d30fb8c7-d04d-453f-93f6-dc3599f1df15", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From c0552bda2c326c84fcda2a259f5621e1fa843747 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 12:53:09 +0000 Subject: [PATCH 09/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- finetune/embedder/how-to-embed.ipynb | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/finetune/embedder/how-to-embed.ipynb b/finetune/embedder/how-to-embed.ipynb index 06f55cc7..3799c12a 100644 --- a/finetune/embedder/how-to-embed.ipynb +++ b/finetune/embedder/how-to-embed.ipynb @@ -296,7 +296,7 @@ } ], "source": [ - "for k,v in datacube.items():\n", + "for k, v in datacube.items():\n", " print(k, v.shape)" ] }, @@ -395,7 +395,7 @@ "metadata": {}, "outputs": [], "source": [ - "datacube = {k:v.to(\"cuda\") for k,v in datacube.items()}\n", + "datacube = {k: v.to(\"cuda\") for k, v in datacube.items()}\n", "ep_embedder = torch.export.load(\"clay-v1-encoder.pt2\").module()" ] }, @@ -475,9 +475,10 @@ "metadata": {}, "outputs": [], "source": [ - "datacube = {k:v.to(\"cpu\") for k,v in datacube.items()}\n", - "onnx_embedder = ort.InferenceSession(\"clay-v1-encoder-cpu.onnx\", \n", - " providers=[\"CPUExecutionProvider\"])" + "datacube = {k: v.to(\"cpu\") for k, v in datacube.items()}\n", + "onnx_embedder = ort.InferenceSession(\n", + " \"clay-v1-encoder-cpu.onnx\", providers=[\"CPUExecutionProvider\"]\n", + ")" ] }, { @@ -507,13 +508,16 @@ ], "source": [ "%%time\n", - "embeddings = onnx_embedder.run([], {\n", - " \"cube\": datacube[\"pixels\"].numpy(),\n", - " \"time\": datacube[\"time\"].numpy(),\n", - " \"latlon\": datacube[\"latlon\"].numpy(),\n", - " \"waves\": datacube[\"waves\"].numpy(),\n", - " \"gsd\": datacube[\"gsd\"].numpy()\n", - "})[0]\n", + "embeddings = onnx_embedder.run(\n", + " [],\n", + " {\n", + " \"cube\": datacube[\"pixels\"].numpy(),\n", + " \"time\": datacube[\"time\"].numpy(),\n", + " \"latlon\": datacube[\"latlon\"].numpy(),\n", + " \"waves\": datacube[\"waves\"].numpy(),\n", + " \"gsd\": datacube[\"gsd\"].numpy(),\n", + " },\n", + ")[0]\n", "embeddings.shape" ] }, From 1e9750661ff534b504600ce70c6c3677c97e0cc4 Mon Sep 17 00:00:00 2001 From: srmsoumya Date: Thu, 25 Jul 2024 12:57:50 +0000 Subject: [PATCH 10/11] Clear outputs from the notebook --- finetune/embedder/how-to-embed.ipynb | 205 ++++----------------------- 1 file changed, 29 insertions(+), 176 deletions(-) diff --git a/finetune/embedder/how-to-embed.ipynb b/finetune/embedder/how-to-embed.ipynb index 06f55cc7..77e4ed2c 100644 --- a/finetune/embedder/how-to-embed.ipynb +++ b/finetune/embedder/how-to-embed.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "d9960547-640d-425c-8180-fc5523a80e42", "metadata": {}, "outputs": [], @@ -41,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "3e1d46ee-40f6-49f5-99ad-83819339561e", "metadata": {}, "outputs": [], @@ -56,18 +56,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "b7825318-23f3-449f-9104-eae6562a55ab", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found 12 items\n" - ] - } - ], + "outputs": [], "source": [ "# Optimize GDAL settings for cloud optimized reading\n", "os.environ[\"GDAL_DISABLE_READDIR_ON_OPEN\"] = \"EMPTY_DIR\"\n", @@ -110,29 +102,10 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "183975c7-8afb-49ef-8e70-790265719aea", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Working on \n", - "Working on \n", - "Working on \n", - "Working on \n", - "Working on \n", - "Working on \n", - "Working on \n", - "Working on \n", - "Working on \n", - "Working on \n", - "Working on \n", - "Working on \n" - ] - } - ], + "outputs": [], "source": [ "chips = []\n", "datetimes = []\n", @@ -163,21 +136,10 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "71902ab7-3320-43cd-85c3-362c2500f241", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(36, 4, 256, 256)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "pixels = np.array([np.array(list(chip.values())).squeeze() for chip in chips])\n", "pixels.shape" @@ -185,7 +147,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "6f7ce367-4e12-4648-bb79-119b4f50ead8", "metadata": {}, "outputs": [], @@ -225,7 +187,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "a8ec8c2d-ecb9-42a2-9e8c-3f95c67ef07b", "metadata": {}, "outputs": [], @@ -263,7 +225,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "2640eb17-a85c-4972-8d5d-e45e9ed8eba5", "metadata": {}, "outputs": [], @@ -279,22 +241,10 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "7f6711a9-e7ed-44d5-add7-2c3a498cd422", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "pixels torch.Size([36, 4, 256, 256])\n", - "time torch.Size([36, 4])\n", - "latlon torch.Size([36, 4])\n", - "waves torch.Size([4])\n", - "gsd torch.Size([1])\n" - ] - } - ], + "outputs": [], "source": [ "for k,v in datacube.items():\n", " print(k, v.shape)" @@ -312,7 +262,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "4eb468af-d468-46aa-a8fb-23ff95c56288", "metadata": {}, "outputs": [], @@ -322,7 +272,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "id": "9eb797f7-5238-49e0-9950-e85f10132454", "metadata": {}, "outputs": [], @@ -332,29 +282,10 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "id": "eefe4811-7290-47c3-a10e-45257e6d42e0", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 2min 36s, sys: 26.9 s, total: 3min 3s\n", - "Wall time: 51.3 s\n" - ] - }, - { - "data": { - "text/plain": [ - "(torch.Size([36, 4, 256, 256]), torch.Size([36, 768]))" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "%%time\n", "with torch.no_grad():\n", @@ -380,7 +311,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "id": "9c1bbfd4-7dc6-4ad0-8a0b-b3745a9f35ca", "metadata": {}, "outputs": [], @@ -390,7 +321,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "id": "e285a543-20ab-44ba-b676-2303284dc477", "metadata": {}, "outputs": [], @@ -401,29 +332,10 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "id": "edefee90-e6b8-4701-bb5d-2bf7febc806c", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 313 ms, sys: 41.5 ms, total: 354 ms\n", - "Wall time: 239 ms\n" - ] - }, - { - "data": { - "text/plain": [ - "(torch.Size([36, 4, 256, 256]), torch.Size([36, 768]))" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "%%time\n", "with torch.no_grad():\n", @@ -449,7 +361,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "id": "aa10d696-740a-458e-ae10-eec9a43fb362", "metadata": {}, "outputs": [], @@ -460,7 +372,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "id": "992524e5-2c2a-4e48-ae95-bd2aa87b72a9", "metadata": {}, "outputs": [], @@ -470,7 +382,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "id": "dc3fa967-73d5-431c-88a2-84b088aff06f", "metadata": {}, "outputs": [], @@ -482,29 +394,10 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "id": "24591d17-d1c8-452b-9b20-676a9b6f8643", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 3min 48s, sys: 1.82 s, total: 3min 50s\n", - "Wall time: 30.6 s\n" - ] - }, - { - "data": { - "text/plain": [ - "(36, 768)" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "%%time\n", "embeddings = onnx_embedder.run([], {\n", @@ -537,50 +430,10 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "id": "677f04d3-db38-4d44-9b55-c103d54adcd5", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "pyarrow.Table\n", - "datetimes: timestamp[us, tz=UTC]\n", - "chip_ids: list\n", - " child 0, item: int64\n", - "item_ids: string\n", - "emeddings: list\n", - " child 0, item: float\n", - "geometry: extension>\n", - "----\n", - "datetimes: [[2018-08-28 11:30:56.771000Z,2018-08-28 11:30:56.771000Z,2018-08-28 11:30:56.771000Z,2018-08-23 11:30:50.574000Z,2018-08-23 11:30:50.574000Z,...,2018-07-09 11:24:55.535000Z,2018-07-09 11:24:55.535000Z,2018-07-04 11:30:35.271000Z,2018-07-04 11:30:35.271000Z,2018-07-04 11:30:35.271000Z]]\n", - "chip_ids: [[[0,0],[1,0],...,[1,0],[2,0]]]\n", - "item_ids: [[\"S2A_29SNB_20180828_1_L2A\",\"S2A_29SNB_20180828_1_L2A\",\"S2A_29SNB_20180828_1_L2A\",\"S2B_29SNB_20180823_1_L2A\",\"S2B_29SNB_20180823_1_L2A\",...,\"S2A_29SNB_20180709_0_L2A\",\"S2A_29SNB_20180709_0_L2A\",\"S2B_29SNB_20180704_0_L2A\",\"S2B_29SNB_20180704_0_L2A\",\"S2B_29SNB_20180704_0_L2A\"]]\n", - "emeddings: [[[-0.14773342,0.08466571,0.13797832,0.11150883,0.06517959,...,0.036681578,-0.092160255,0.025934512,-0.12496276,-0.034070153],[-0.14430065,0.085857555,0.13839196,0.10963549,0.0652737,...,0.03711322,-0.09153629,0.02631686,-0.12422915,-0.03333628],...,[-0.09626354,0.062443394,0.24817112,0.012715777,0.043093704,...,0.011770063,-0.037860263,0.027813748,-0.11962952,-0.02246455],[-0.10004063,0.06320572,0.24851695,0.012129029,0.043350283,...,0.011444314,-0.03733269,0.027787287,-0.12139094,-0.021088997]]]\n", - "geometry: [[[ -- is_valid: all not null\n", - " -- child 0 type: double\n", - "[-8.825403979293151,-8.825730459265694,-9.000227209792856,-9.000227635454767,-8.825403979293151]\n", - " -- child 1 type: double\n", - "[37.947460030545635,37.809019655564406,37.809148556380286,37.947589571562965,37.947460030545635]],[ -- is_valid: all not null\n", - " -- child 0 type: double\n", - "[-8.650582567535476,-8.651235936821893,-8.825730459265694,-8.825403979293151,-8.650582567535476]\n", - " -- child 1 type: double\n", - "[37.94707073614538,37.80863228507305,37.809019655564406,37.947460030545635,37.94707073614538]],...,[ -- is_valid: all not null\n", - " -- child 0 type: double\n", - "[-8.650582567535476,-8.651235936821893,-8.825730459265694,-8.825403979293151,-8.650582567535476]\n", - " -- child 1 type: double\n", - "[37.94707073614538,37.80863228507305,37.809019655564406,37.947460030545635,37.94707073614538]],[ -- is_valid: all not null\n", - " -- child 0 type: double\n", - "[-8.475765647330832,-8.476745873271028,-8.651235936821893,-8.650582567535476,-8.475765647330832]\n", - " -- child 1 type: double\n", - "[37.94642170369997,37.80798646012822,37.80863228507305,37.94707073614538,37.94642170369997]]]]" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# Write data to pyarrow table\n", "index = {\n", @@ -596,7 +449,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "id": "d62a9e8a-b4f9-491c-a437-6a164a9e74fe", "metadata": {}, "outputs": [], From 1803954f73983ede18483c009361ec32ab7f89d7 Mon Sep 17 00:00:00 2001 From: srmsoumya Date: Thu, 25 Jul 2024 15:26:17 +0000 Subject: [PATCH 11/11] Add torchdata as a pip dependency --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 2eb2d1c3..ac0ddeaf 100644 --- a/environment.yml +++ b/environment.yml @@ -34,6 +34,7 @@ dependencies: - onnx==1.16.1 - onnxscript - onnxruntime + - torchdata==0.7.1 - torchgeo==0.5.2 - stacchip==0.1.35 - wandb==0.17.5