diff --git a/cluster/ml-cluster.yaml.template b/cluster/ml-cluster.yaml.template new file mode 100644 index 00000000..03a12836 --- /dev/null +++ b/cluster/ml-cluster.yaml.template @@ -0,0 +1,61 @@ +Region: us-east-2 + +# DL AMI +Image: + Os: ubuntu2004 + CustomAmi: + +# FSx LUSTRE SHARED STORAGE +SharedStorage: + - MountDir: /fsx + Name: fsx + StorageType: FsxLustre + FsxLustreSettings: + FileSystemId: + +# HEAD NODE +HeadNode: + InstanceType: c5.12xlarge + Networking: + SubnetId: + SecurityGroups: + - # EFA enabled SG + Ssh: + KeyName: + LocalStorage: + RootVolume: + Size: 200 + Iam: + S3Access: + - BucketName: + EnableWriteAccess: false + - BucketName: + EnableWriteAccess: true + + +# SCHEDULER +Scheduling: + Scheduler: slurm + SlurmQueues: + - Name: gpu-queue + ComputeResources: + - Name: + Instances: + - InstanceType: + MinCount: 0 + MaxCount: 8 + Efa: + Enabled: true + Networking: + SubnetIds: + - + SecurityGroups: + - # EFA enabled SG + PlacementGroup: + Enabled: true + Iam: + S3Access: + - BucketName: + EnableWriteAccess: false + - BucketName: + EnableWriteAccess: true diff --git a/configs/classify_eurosat.yaml b/configs/classify_eurosat.yaml index 38e72eba..946a7659 100644 --- a/configs/classify_eurosat.yaml +++ b/configs/classify_eurosat.yaml @@ -2,12 +2,12 @@ seed_everything: 42 data: metadata_path: configs/metadata.yaml - batch_size: 256 + batch_size: 128 num_workers: 8 model: num_classes: 10 - ckpt_path: checkpoints/clay-v1-base.ckpt - lr: 1e-4 + ckpt_path: checkpoints/v1.5.0-no-mrl-dinov2/mae_v1.5.0_epoch-07_val-loss-0.1718.ckpt + lr: 5e-5 wd: 0.05 b1: 0.9 b2: 0.95 @@ -28,6 +28,7 @@ trainer: init_args: entity: developmentseed project: clay-classify + group: v1.5-test log_model: false callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint diff --git a/configs/config.yaml b/configs/config.yaml index 217bf7fe..aad3353f 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -1,56 +1,63 @@ -# lightning.pytorch==2.1.2 -seed_everything: 42 +seed_everything: 108 data: - data_dir: data - size: 224 + data_dir: /fsx + size: 256 metadata_path: configs/metadata.yaml platforms: - landsat-c2l1 - landsat-c2l2-sr - linz + - modis - naip - sentinel-1-rtc - sentinel-2-l2a - batch_size: 8 - num_workers: 8 + batch_size: 1 + num_workers: 12 model: - model_size: base + model_size: large mask_ratio: 0.75 - norm_pix_loss: True + norm_pix_loss: False patch_size: 8 shuffle: True metadata_path: configs/metadata.yaml - teacher: vit_base_patch16_224.dino - lr: 1e-5 + teacher: vit_large_patch14_reg4_dinov2.lvd142m + dolls: [16, 32, 64, 128, 256, 768, 1024] + doll_weights: [1, 1, 1, 1, 1, 1, 1] + lr: 5e-6 wd: 0.05 b1: 0.9 b2: 0.95 embeddings_level: mean trainer: - accelerator: auto + accelerator: gpu strategy: ddp - devices: auto - num_nodes: 1 + devices: 8 + num_nodes: 48 precision: bf16-mixed - log_every_n_steps: 10 - max_epochs: 200 + log_every_n_steps: 1 + max_epochs: 1000 accumulate_grad_batches: 1 - default_root_dir: s3://clay-model-ckpt/v1.0.0/ + default_root_dir: checkpoints/v1.5.0/ fast_dev_run: False num_sanity_val_steps: 0 use_distributed_sampler: False + limit_train_batches: 0.99 + limit_val_batches: 0.99 logger: - class_path: lightning.pytorch.loggers.WandbLogger init_args: entity: developmentseed project: clay + group: v1.5-nomrl-dinov2 + id: 0uy3in7l + resume: must log_model: false callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - dirpath: s3://clay-model-ckpt/v1.0.0/ + dirpath: checkpoints/v1.5.0/ auto_insert_metric_name: False - filename: mae_v1.0.0_epoch-{epoch:02d}_val-loss-{val/loss:.4f} + filename: mae_v1.5.0_epoch-{epoch:02d}_val-loss-{val/loss:.4f} monitor: val/loss mode: min save_last: True @@ -63,4 +70,4 @@ trainer: - class_path: src.callbacks_wandb.LogIntermediatePredictions plugins: - class_path: lightning.pytorch.plugins.io.AsyncCheckpointIO -ckpt_path: null +ckpt_path: checkpoints/v1.5.0/last.ckpt diff --git a/configs/metadata.yaml b/configs/metadata.yaml index 7467b120..193bc438 100644 --- a/configs/metadata.yaml +++ b/configs/metadata.yaml @@ -176,11 +176,50 @@ sentinel-1-rtc: gsd: 10 bands: mean: - vv: 0.123273 - vh: 0.027337 + vv: -12.113 + vh: -18.673 std: - vv: 1.492154 - vh: 0.122182 + vv: 8.314 + vh: 8.017 wavelength: vv: 3.5 vh: 4.0 +modis: + band_order: + - sur_refl_b01 + - sur_refl_b02 + - sur_refl_b03 + - sur_refl_b04 + - sur_refl_b05 + - sur_refl_b06 + - sur_refl_b07 + rgb_indices: + - 0 + - 3 + - 2 + gsd: 500 + bands: + mean: + sur_refl_b01: 1072. + sur_refl_b02: 1624. + sur_refl_b03: 931. + sur_refl_b04: 1023. + sur_refl_b05: 1599. + sur_refl_b06: 1404. + sur_refl_b07: 1051. + std: + sur_refl_b01: 1643. + sur_refl_b02: 1878. + sur_refl_b03: 1449. + sur_refl_b04: 1538. + sur_refl_b05: 1763. + sur_refl_b06: 1618. + sur_refl_b07: 1396. + wavelength: + sur_refl_b01: .645 + sur_refl_b02: .858 + sur_refl_b03: .469 + sur_refl_b04: .555 + sur_refl_b05: 1.240 + sur_refl_b06: 1.640 + sur_refl_b07: 2.130 diff --git a/configs/segment_chesapeake.yaml b/configs/segment_chesapeake.yaml index d1d89dff..57e3858c 100644 --- a/configs/segment_chesapeake.yaml +++ b/configs/segment_chesapeake.yaml @@ -6,17 +6,17 @@ data: val_chip_dir: data/cvpr/ny/val/chips/ val_label_dir: data/cvpr/ny/val/labels/ metadata_path: configs/metadata.yaml - batch_size: 40 + batch_size: 16 num_workers: 8 platform: naip model: num_classes: 7 feature_maps: - - 3 - 5 - - 7 - 11 - ckpt_path: checkpoints/clay-v1-base.ckpt + - 15 + - 23 + ckpt_path: checkpoints/v1.5.0-no-mrl-dinov2/mae_v1.5.0_epoch-05_val-loss-0.1734.ckpt lr: 1e-5 wd: 0.05 b1: 0.9 @@ -38,6 +38,7 @@ trainer: init_args: entity: developmentseed project: clay-segment + group: v1.5-test log_model: false callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint diff --git a/copy_data.sh b/copy_data.sh new file mode 100644 index 00000000..43c63516 --- /dev/null +++ b/copy_data.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# Define source and destination directories +src="/fsx" +dest="data/pretrain" + +# Create the destination directory if it doesn't exist +mkdir -p "$dest" + +# Find all directories in the source directory +find "$src" -type d -print0 | while IFS= read -r -d '' dir; do + # Create corresponding directory in the destination + newdir="$dest${dir#$src}" + mkdir -p "$newdir" + + # Copy the first 100 files from the source directory to the new directory + find "$dir" -maxdepth 1 -type f -print0 | head -z -n 100 | xargs -0 -I{} cp {} "$newdir" +done diff --git a/docs/release-notes/data_sampling.md b/docs/release-notes/data_sampling.md index 4741064a..c3650964 100644 --- a/docs/release-notes/data_sampling.md +++ b/docs/release-notes/data_sampling.md @@ -112,6 +112,67 @@ and a maximum of 2000 scenes for each catalog that was included. We selected the latest imagery for each of the available regions of new zealand. The list of catalogs is in the linz processor file. +### MODIS sampling strategy + +For MODIS we used the [Surface Reflectance 8-Day (500m)](https://planetarycomputer.microsoft.com/dataset/modis-09A1-061) +product. The data is distributed in SIN grid tiles. We included all SIN grid +tiles that do not have any nodata inside. The selected SIN grid tiles are then +transform to EPSG:3857 for all tiles. This results in some variation between the +nominal resolution, although the original resolution from the SIN projection is +500 meters. For input to the model, we assumed the 500m resolution as a fixed +resolution size for all tiles. + +Algorithm to determine which tiles do not have nodata is shown in the code block +below. This resulted in 233 SIN grid tiles to be selected. For each of these +we sampled the first STAC search result for each month in each year from 2018 +until 2023. This therefore resulted in 72 (`6 years * 12 months`) separate scenes +for each of the 233 SIN grid tiles. + +Script for selection of SIN grid tiles included in the sampling: + +```python +from multiprocessing import Pool +import rasterio +import planetary_computer as pc +import pystac_client +import numpy as np + +SIN_GRID_TILES = [] +for i in SIN_VERTICAL_RANGE: + for j in SIN_HORIZONTAL_RANGE: + SIN_GRID_TILES.append((i, j)) + +def evaluate_nodata(i, j): + catalog = pystac_client.Client.open(STAC_API, modifier=pc.sign_inplace) + items = catalog.search( + collections=[COLLECTION], + query={ + "modis:vertical-tile": { + "eq": i, + }, + "modis:horizontal-tile": { + "eq": j, + }, + }, + max_items=1, + ) + item = list(items.item_collection())[0] + + with rasterio.open(item.assets["sur_refl_b01"].href) as src: + data = src.read() + + nodata = np.sum(data == -28672) + + if nodata == 0: + print(i, j) + return i, j + +if __name__ == '__main__': + with Pool(16) as p: + indexes = p.starmap(evaluate_nodata, SIN_GRID_TILES) + print("done") + print(indexes) +``` ## Data preparation @@ -136,6 +197,7 @@ Using stacchip, we created a dataset with a size of 33.8 TB of imagery, with abo | Landsat-c2l1 | 5827333 | | Landsat-c2l2-sr | 5790651 | | Sentinel-1-rtc | 16133394 | +| MODIS | 1350864 | # Older versions diff --git a/embeddings/Dockerfile b/embeddings/Dockerfile new file mode 100644 index 00000000..eb532ab7 --- /dev/null +++ b/embeddings/Dockerfile @@ -0,0 +1,45 @@ +FROM 763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu20.04-ec2 + +WORKDIR /model + +RUN git clone -b all-of-naip https://github.com/Clay-foundation/model.git . + +RUN aws s3 cp --no-sign-request s3://clay-model-ckpt/v1.5.0-no-mrl-dinov2/mae_v1.5.0_epoch-07_val-loss-0.1718.ckpt data/mae_v1.5.0_epoch-07_val-loss-0.1718.ckpt +RUN aws s3 cp --no-sign-request s3://clay-mgrs-samples/naip-manifest.txt.zip data/naip-manifest.txt.zip +RUN aws s3 cp --no-sign-request s3://clay-mgrs-samples/element84-tiles-2023.gz data/element84-tiles-2023.gz + +RUN pip install \ + einops~=0.7.0 \ + fiona~=1.9.5 \ + geopandas~=0.14.1 \ + jsonargparse~=4.27.0 \ + lightning~=2.1.0 \ + matplotlib~=3.9.0 \ + planetary-computer~=1.0.0 \ + python-box~=7.1.0 \ + pyarrow~=15.0.2 \ + rasterio~=1.3.10 \ + s3fs~=2024.6.0 \ + boto3~=1.34.122 \ + botocore~=1.34.122 \ + scikit-image~=0.22.0 \ + scikit-learn~=1.4.0 \ + stackstac~=0.5.0 \ + timm~=0.9.16 \ + transformers~=4.35.2 \ + typeshed-client~=2.4.0 \ + vit-pytorch~=1.6.4 \ + zarr~=2.16.1 \ + geoarrow-pyarrow==0.1.2 \ + torchdata==0.7.1 \ + stacchip==0.1.35 \ + wandb==0.17.5 \ + rio_stac~=0.10.0 + +RUN git pull && git checkout ceecb6138705cb28a5f4d3f61f22b19a2f625edb + +# Move file to home directory so that relative imports work +RUN cp embeddings/all-naip.py . +RUN cp embeddings/all-sentinel.py . + +ENTRYPOINT ["python"] diff --git a/embeddings/README.md b/embeddings/README.md new file mode 100644 index 00000000..f8a835f3 --- /dev/null +++ b/embeddings/README.md @@ -0,0 +1,55 @@ +## Large scale embedding runs + +The code in this section has been used to create embedding runs over large +archives. Currently this covers NAIP and Sentinel-2. + +The algorithms are dockerized to be ran in a batch setup. AWS Batch is what +was used to execute the algorithms but it is not a strict requirement. + +The scripts rely on the `AWS_BATCH_JOB_ARRAY_INDEX` environment variable +to choose which files from the archives to process. This is set automatically +by AWS Batch when using array jobs. Outside of array jobs, this index variable +needs to be specified manually. + +The script also requires the `EMBEDDINGS_BUCKET` environment variable, +specifying the name of the output bucket. + +To specify a custom bucket location (for source coop for instance), use the +`ENDPOINT_URL`, `ENDPOINT_KEY_ID`, and `ENDPOINT_ACCESS_KEY` environment +variables. + + +### Build docker image + +Embedding runs are dockerized for parallel computing. To build the docker image +use the Dockerfile in the embeddings directory. Then push the image to ECR or +another docker repository of your choice. + +```bash +aws ecr get-login-password --region us-east-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-east-2.amazonaws.com +docker pull 763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu20.04-ec2 + +docker build -t clay-embeddings -f embeddings/Dockerfile . + +docker tag clay-embeddings:latest 875815656045.dkr.ecr.us-east-2.amazonaws.com/clay-embeddings:latest +aws ecr get-login-password --region us-east-2 | docker login --username AWS --password-stdin 875815656045.dkr.ecr.us-east-2.amazonaws.com +docker push 875815656045.dkr.ecr.us-east-2.amazonaws.com/clay-embeddings:latest +``` + +### NAIP + +For NAIP, we use the `naip-analytic` bucket. We leverage the manifest file that +lists all files in the bucket. This list is parsed in the beginning and each +job processes a section of the naip scenes. + +At the moment of processing there were 1'231'441 NAIP scenes. + +### Sentinel-2 + +For Sentinel-2 we use the `sentinel-cogs` bucket. Also here we use the manifest +file, but parse it beforehand because it contains references to each single +asset for each product. + +The parser is essentially copied from [this gist](https://github.com/alexgleith/sinergise-element84-sentinel-2-qa/blob/main/0-parse-inventory-element84.py) +by @alexgleith. +The resulting zip file contains a list of static STAC json files for 2023 and 2024. diff --git a/embeddings/all-naip.py b/embeddings/all-naip.py new file mode 100644 index 00000000..69594251 --- /dev/null +++ b/embeddings/all-naip.py @@ -0,0 +1,154 @@ +import datetime +import io +import logging +import os +import tempfile +import zipfile +from pathlib import Path + +import boto3 +from rasterio.errors import RasterioIOError +from rio_stac import create_stac_item +from stacchip.chipper import Chipper +from stacchip.indexer import NoStatsChipIndexer + +from embeddings.utils import ( + check_exists, + get_embeddings, + get_pixels, + load_clay, + load_metadata, + prepare_datacube, + write_to_table, +) + +logging.basicConfig() +logger = logging.getLogger("clay") +logger.setLevel(logging.DEBUG) + + +MANIFEST = "data/naip-manifest.txt.zip" +EMBEDDINGS_BUCKET = os.environ["EMBEDDINGS_BUCKET"] +HOUR_OF_DAY = 12 + + +def open_scene_list(limit_to_state=None): + """ + Read the naip-analytic manifest file and extract a list of NAIP + scenes as tif files to process. + + The file used here is the zipped version of the original manifest file. + """ + with zipfile.ZipFile(MANIFEST) as zf: + with io.TextIOWrapper(zf.open("naip-manifest.txt"), encoding="utf-8") as f: + data = f.readlines() + data = [Path(dat.rstrip()) for dat in data if "rgbir_cog"] + data = [dat for dat in data if dat.suffix == ".tif"] + + logger.debug(f"Found {len(data)} NAIP scenes in manifest") + + if limit_to_state is not None: + data = [dat for dat in data if str(dat).startswith(limit_to_state)] + logger.debug(f"Found {len(data)} NAIP scenes for state {limit_to_state}") + + return data + + +def process_scene(clay, path, batchsize): + """ + Embeds a slingle NAIP scene. + """ + state = path.parts[0] + datestr = path.stem.split("_")[-1] + date = datetime.datetime( + int(datestr[:4]), int(datestr[4:6]), int(datestr[6:8]), HOUR_OF_DAY + ) + gsd = float(path.parts[2].replace("cm", "")) / 100 + bands, waves, mean, std = load_metadata("naip") + + logger.debug(f"Processing {path} in state {state} and date {date}") + + with tempfile.NamedTemporaryFile(mode="w+b", suffix=".tif") as fl: + s3 = boto3.client("s3") + s3.download_fileobj( + "naip-analytic", str(path), fl, ExtraArgs={"RequestPayer": "requester"} + ) + + # Prepare properties, some NAIP imagery contains date stamps that + # raise an error in create_stac_item. + props = {"start_datetime": date, "end_datetime": date} + + item = create_stac_item( + fl.name, + with_proj=True, + input_datetime=date, + id=f"{state}_{path.stem}", + properties=props, + ) + + try: + indexer = NoStatsChipIndexer(item) + chipper = Chipper(indexer) + bboxs, datetimes, pixels = get_pixels( + item=item, + indexer=indexer, + chipper=chipper, + ) + except RasterioIOError: + logger.warning("Skipping scene due to rasterio io error") + return + + time_norm, latlon_norm, gsd, pixels_norm = prepare_datacube( + mean=mean, std=std, datetimes=datetimes, bboxs=bboxs, pixels=pixels, gsd=gsd + ) + # Embed data + cls_embeddings = get_embeddings( + clay=clay, + pixels_norm=pixels_norm, + time_norm=time_norm, + latlon_norm=latlon_norm, + waves=waves, + gsd=gsd, + batchsize=batchsize, + ) + # Write class embeddings + kwargs = dict( + bboxs=bboxs, + datestr=datestr, + gsd=gsd, + destination_bucket=EMBEDDINGS_BUCKET, + path=path, + source_bucket="naip-analytic", + ) + logger.debug("Writing class embeddings") + write_to_table(embeddings=cls_embeddings, **kwargs) + + +def process(): + if "AWS_BATCH_JOB_ARRAY_INDEX" not in os.environ: + raise ValueError("AWS_BATCH_JOB_ARRAY_INDEX env var not set") + index = int(os.environ.get("AWS_BATCH_JOB_ARRAY_INDEX", 0)) + items_per_job = int(os.environ.get("ITEMS_PER_JOB", 100)) + batchsize = int(os.environ.get("EMBEDDING_BATCH_SIZE", 50)) + limit_to_state = os.environ.get("LIMIT_TO_STATE", None) + + scenes = open_scene_list(limit_to_state) + clay = load_clay() + + for i in range(index * items_per_job, (index + 1) * items_per_job): + scene = scenes[i] + if check_exists(scene): + logger.debug(f"Skipping scene because exists: {scene}") + continue + + process_scene( + clay=clay, + path=scene, + batchsize=batchsize, + ) + + +if __name__ == "__main__": + logger.debug("Starting") + process() + logger.debug("Done!") diff --git a/embeddings/all-sentinel.py b/embeddings/all-sentinel.py new file mode 100644 index 00000000..505d273e --- /dev/null +++ b/embeddings/all-sentinel.py @@ -0,0 +1,161 @@ +import gzip +import json +import logging +import os +import tempfile +from pathlib import Path + +import boto3 +import numpy as np +from pystac import Item +from stacchip.chipper import Chipper +from stacchip.indexer import Sentinel2Indexer + +from embeddings.utils import ( + check_exists, + get_embeddings, + get_pixels, + load_clay, + load_metadata, + prepare_datacube, + write_to_table, +) + +logging.basicConfig() +logger = logging.getLogger("clay") +logger.setLevel(logging.DEBUG) + +SCENES_LIST = "data/element84-tiles-2023-brazil.gz" +EMBEDDINGS_BUCKET = "clay-embeddings-sentinel-2" +GSD = 10 +S2_BUCKET = "sentinel-2-cogs" + + +def open_scenes_list(): + with gzip.open(SCENES_LIST) as fl: + data = fl.readlines() + data = [dat.decode().rstrip() for dat in data] + data = [dat for dat in data if dat.split("/")[7] == "2024"] + # Process the X, C, and D regions last + data = sorted(data, key=lambda dat: dat.split("/")[5] in ["X", "C", "D"]) + data = [Path(dat.replace("s3://sentinel-cogs/", "")) for dat in data] + logger.debug(f"Found {len(data)} scenes to process") + return data + + +def download_scenes_local(tmp, item, bands): + s3 = boto3.client("s3") + for band in bands: + local_asset_path = f"{tmp}/{band}.tif" + remote_asset_key = item.assets[band].href.replace( + "https://sentinel-cogs.s3.us-west-2.amazonaws.com/", "" + ) + print(f"Downloading band {band} to {local_asset_path}") + with open(local_asset_path, mode="w+b") as fl: + s3.download_fileobj("sentinel-cogs", remote_asset_key, fl) + item.assets[band].href = local_asset_path + + return item + + +def process_scene(clay, path, batchsize): + bands, waves, mean, std = load_metadata("sentinel-2-l2a") + + s3 = boto3.resource("s3") + + stac_json = json.load(s3.Object("sentinel-cogs", str(path)).get()["Body"]) + + item = Item.from_dict(stac_json) + + # Sanity checks + if "red" not in item.assets: + logger.debug(f"No red band for {path}") + return + elif not item.ext.has("proj"): + logger.debug(f"No proj for {path}") + return + + all_bboxs = [] + all_cls_embeddings = None + + with tempfile.TemporaryDirectory() as tmp: + item = download_scenes_local(tmp, item, bands) + indexer = Sentinel2Indexer(item, chip_max_nodata=0.1) + chipper = Chipper(indexer, assets=bands) + logger.debug(f"Creating chips for {item.id}") + STEP = 50 + for index in range(0, len(chipper), STEP): + bboxs, datetimes, pixels = get_pixels( + item=item, + indexer=indexer, + chipper=chipper, + start=index, + end=index + STEP, + ) + + if not len(pixels): + continue + + time_norm, latlon_norm, gsd, pixels_norm = prepare_datacube( + mean=mean, + std=std, + datetimes=datetimes, + bboxs=bboxs, + pixels=pixels, + gsd=GSD, + ) + + # Embed data + cls_embeddings = get_embeddings( + clay=clay, + pixels_norm=pixels_norm, + time_norm=time_norm, + latlon_norm=latlon_norm, + waves=waves, + gsd=GSD, + batchsize=batchsize, + ) + all_bboxs += bboxs + if all_cls_embeddings is None: + all_cls_embeddings = cls_embeddings + else: + all_cls_embeddings = np.vstack((all_cls_embeddings, cls_embeddings)) + + kwargs = dict( + bboxs=all_bboxs, + datestr=str(item.datetime.date()), + gsd=GSD, + destination_bucket=EMBEDDINGS_BUCKET, + path=path, + source_bucket="sentinel-cogs", + ) + + write_to_table(embeddings=all_cls_embeddings, **kwargs) + + +def process(): + if "AWS_BATCH_JOB_ARRAY_INDEX" not in os.environ: + raise ValueError("AWS_BATCH_JOB_ARRAY_INDEX env var not set") + index = int(os.environ.get("AWS_BATCH_JOB_ARRAY_INDEX", 0)) + items_per_job = int(os.environ.get("ITEMS_PER_JOB", 100)) + batchsize = int(os.environ.get("EMBEDDING_BATCH_SIZE", 50)) + + scenes = open_scenes_list() + clay = load_clay() + + for i in range(index * items_per_job, (index + 1) * items_per_job): + if check_exists(scenes[i]): + logger.debug(f"Skipping scene because exists: {scenes[i]}") + continue + + process_scene( + clay=clay, + path=scenes[i], + batchsize=batchsize, + ) + + +if __name__ == "__main__": + logger.debug("Starting") + process() + logger.debug("Done!") diff --git a/embeddings/environment.yml b/embeddings/environment.yml new file mode 100644 index 00000000..2c047768 --- /dev/null +++ b/embeddings/environment.yml @@ -0,0 +1,40 @@ +name: claymodel +channels: + - conda-forge + - nodefaults +dependencies: + - conda-lock~=2.5.6 + - einops~=0.7.0 + - fiona~=1.9.5 + - geopandas-base~=0.14.1 + - jsonargparse~=4.27.0 + - lightning~=2.1.0 + - matplotlib-base~=3.8.2 + - planetary-computer~=1.0.0 + - python-box~=7.1.0 + - python~=3.11.0 + - pyarrow~=16.1.0 + - rasterio~=1.3.10 + - s3fs~=2024.3.1 + - scikit-image~=0.22.0 + - scikit-learn~=1.4.0 + - stackstac~=0.5.0 + - timm~=0.9.16 + - torchvision~=0.18.1 + - transformers~=4.35.2 + - typeshed-client~=2.4.0 + - vit-pytorch~=1.6.4 + - zarr~=2.16.1 + - pip: + - geoarrow-pyarrow==0.1.2 + - jupyter-book==1.0.2 + - jupyterlab==4.2.4 + - onnx==1.16.1 + - onnxscript + - onnxruntime + - torchdata==0.7.1 + - torchgeo==0.5.2 + - wandb==0.17.5 + - stacchip==0.1.38 +platforms: + - linux-64 diff --git a/embeddings/parse-sentinel-2-inventory.py b/embeddings/parse-sentinel-2-inventory.py new file mode 100644 index 00000000..397674d8 --- /dev/null +++ b/embeddings/parse-sentinel-2-inventory.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +# From https://github.com/alexgleith/sinergise-element84-sentinel-2-qa/blob/main/0-parse-inventory-element84.py +import csv +import gzip +import json +import sys + +import boto3 + +SPECIAL_YEAR = "2019" +CUTOFF_YEAR = 2023 + +s3 = boto3.resource("s3") + +bucket = "sentinel-cogs-inventory" +manifest_key = "sentinel-cogs/sentinel-cogs/2024-10-03T01-00Z/manifest.json" + +print("Starting up...") + + +def log(comment): + sys.stdout.write(f"\r{comment}") + + +# Stolen from https://alukach.com/posts/parsing-s3-inventory-output +def list_keys(bucket, manifest_key): + manifest = json.load(s3.Object(bucket, manifest_key).get()["Body"]) + for item in manifest["files"]: + gzip_obj = s3.Object(bucket_name=bucket, key=item["key"]) + buffer = gzip.open(gzip_obj.get()["Body"], mode="rt") + reader = csv.reader(buffer) + yield from reader + + +limit = 2 +count = 0 +valid = 0 +log_every = 10000 +cutoff_year = 2023 + +if __name__ == "__main__": + # Parse zip file for all scenes + with gzip.open("data/element84-tiles.list.gz", "wt") as text_file: + for tiles_bucket, key, *rest in list_keys(bucket, manifest_key): + if ".json" in key: + c = key.split("/") + # Counting scenes + count += 1 + if count % log_every == 0: + log(f"Found {count} scenes...") + tile = f"{c[1]}{c[2]}{c[3]}" + text_file.write(f"s3://{tiles_bucket}/{key}\n") + + print(f"Found {count} scenes") + + # Reduce to 2023 and 20204 + with gzip.open("data/element84-tiles-2023.gz", "wt") as dst: + with gzip.open("data/element84-tiles.list.gz") as fl: + line = fl.readline() + while line: + line = line.decode().rstrip() + c = line.split("/") + # Skip data befor 2023. Some scenes from 2019 have the year + # in a different part of the prefix. + if c[4] == SPECIAL_YEAR: + line = fl.readline() + continue + elif int(c[7]) < CUTOFF_YEAR: + line = fl.readline() + continue + elif not line.endswith("L2A.json"): + line = fl.readline() + continue + + count += 1 + if count % log_every == 0: + log(f"Found {count} scenes... {line}") + + dst.write(line + "\n") + line = fl.readline() diff --git a/embeddings/utils.py b/embeddings/utils.py new file mode 100644 index 00000000..61cc3919 --- /dev/null +++ b/embeddings/utils.py @@ -0,0 +1,226 @@ +import logging +import math +import os + +import boto3 +import botocore +import geoarrow.pyarrow as ga +import numpy as np +import pyarrow as pa +import torch +import yaml +from box import Box +from geoarrow.pyarrow import io as gaio +from torchvision.transforms import v2 + +from src.module import ClayMAEModule + +CHECKPOINT = "data/mae_v1.5.0_epoch-07_val-loss-0.1718.ckpt" +EMBEDDING_SHAPE_CLASS = 2 +EMBEDDING_SHAPE_PATCH = 3 +EMBEDDINGS_BUCKET = os.environ["EMBEDDINGS_BUCKET"] + +CLOUD_LIMIT = 0.1 +NODATA_LIMIT = 0.01 + +logger = logging.getLogger("clay") + + +def check_exists(path): + if "ENDPOINT_URL" in os.environ: + s3 = boto3.client( + "s3", + endpoint_url=os.environ.get("ENDPOINT_URL"), + aws_access_key_id=os.environ.get("ENDPOINT_KEY_ID"), + aws_secret_access_key=os.environ.get("ENDPOINT_ACCESS_KEY"), + ) + else: + s3 = boto3.client("s3") + try: + s3.head_object( + Bucket=EMBEDDINGS_BUCKET, + Key=f"{path.parent}/{path.stem}.parquet", + ) + return True + except botocore.exceptions.ClientError: + return False + + +def load_metadata(platform): + metadata = Box(yaml.safe_load(open("configs/metadata.yaml"))) + platform_meta = getattr(metadata, platform) + + bands = list(platform_meta.bands.wavelength.keys()) + waves = list(platform_meta.bands.wavelength.values()) + mean = list(platform_meta.bands.mean.values()) + std = list(platform_meta.bands.std.values()) + + return bands, waves, mean, std + + +def normalize_timestamp(date): + week = date.isocalendar().week * 2 * np.pi / 52 + hour = date.hour * 2 * np.pi / 24 + + return (math.sin(week), math.cos(week)), (math.sin(hour), math.cos(hour)) + + +def normalize_latlon(lat, lon): + lat = lat * np.pi / 180 + lon = lon * np.pi / 180 + + return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon)) + + +def prepare_datacube(mean, std, datetimes, bboxs, pixels, gsd): + transform = v2.Compose( + [ + v2.Normalize(mean=mean, std=std), + ] + ) + + times = [normalize_timestamp(dat) for dat in datetimes] + week_norm = [dat[0] for dat in times] + hour_norm = [dat[1] for dat in times] + time_norm = np.hstack((week_norm, hour_norm)) + + latlons = [normalize_latlon(*bbox.centroid.coords[0]) for bbox in bboxs] + lat_norm = [dat[0] for dat in latlons] + lon_norm = [dat[1] for dat in latlons] + latlon_norm = np.hstack((lat_norm, lon_norm)) + + gsd = [gsd] + + pixels_norm = transform(torch.tensor(pixels, dtype=torch.float32)).numpy() + + return time_norm, latlon_norm, gsd, pixels_norm + + +def get_pixels(item, indexer, chipper, start=None, end=None): + chips = [] + datetimes = [] + bboxs = [] + chip_ids = [] + item_ids = [] + if start: + index_range = range(start, min(end, len(chipper))) + else: + index_range = range(len(chipper)) + for index in index_range: + y = index // chipper.indexer.x_size + x = index % chipper.indexer.x_size + + cloud_percentage, nodata_percentage = chipper.indexer.get_stats(x, y) + + if cloud_percentage > CLOUD_LIMIT: + continue + elif nodata_percentage > NODATA_LIMIT: + continue + + chip = chipper.chip(x, y) + + chips.append(chip) + datetimes.append(item.datetime) + bboxs.append(indexer.get_chip_bbox(x, y)) + chip_ids.append((x, y)) + item_ids.append(item.id) + + pixels = np.array([np.array(list(chip.values())).squeeze() for chip in chips]) + + return bboxs, datetimes, pixels + + +def get_embeddings(clay, pixels_norm, time_norm, latlon_norm, waves, gsd, batchsize): # noqa: PLR0913 + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + logger.debug(f"Using device {device} to create {len(pixels_norm)} embeddings") + # Run the clay encoder + cls_embeddings = None + for i in range(0, len(pixels_norm), batchsize): + if i / batchsize % 5 == 0: + logger.debug(f"Iteration {i}") + + datacube = { + "pixels": torch.tensor( + pixels_norm[i : (i + batchsize)], dtype=torch.float32, device=device + ), + "time": torch.tensor( + time_norm[i : (i + batchsize)], dtype=torch.float32, device=device + ), + "latlon": torch.tensor( + latlon_norm[i : (i + batchsize)], dtype=torch.float32, device=device + ), + "waves": torch.tensor(waves, dtype=torch.float32, device=device), + "gsd": torch.tensor(gsd, dtype=torch.float32, device=device), + "platform": ["naip"], + } + with torch.no_grad(): + unmsk_patch, unmsk_idx, msk_idx, msk_matrix = clay.model.encoder(datacube) + # The first embedding is the class token, which is the + # overall single embedding we want to keep. + batch_cls_embeddings = unmsk_patch[:, 0, :].cpu().numpy() + if cls_embeddings is None: + cls_embeddings = batch_cls_embeddings + else: + cls_embeddings = np.vstack((cls_embeddings, batch_cls_embeddings)) + + return cls_embeddings + + +def load_clay(): + device = "cuda" if torch.cuda.is_available() else "cpu" + logger.debug(f"Loading model on device {device}") + model = ClayMAEModule.load_from_checkpoint( + checkpoint_path=CHECKPOINT, + metadata_path="configs/metadata.yaml", + model_size="large", + dolls=[16, 32, 64, 128, 256, 768, 1024], + doll_weights=[1, 1, 1, 1, 1, 1, 1], + mask_ratio=0.0, + shuffle=False, + ) + model.eval() + + return model.to(device) + + +def write_to_table( # noqa: PLR0913 + embeddings, bboxs, datestr, gsd, destination_bucket, path, source_bucket +): + index = {"geometry": ga.as_geoarrow([dat.wkt for dat in bboxs])} + if len(embeddings.shape) == EMBEDDING_SHAPE_CLASS: + # Handle class embeddings + index["embeddings"] = [np.ascontiguousarray(dat) for dat in embeddings] + elif len(embeddings.shape) == EMBEDDING_SHAPE_PATCH: + # Handle patch embeddings + for i in range(embeddings.shape[1]): + index[f"patch_embeddings_{i}"] = [ + np.ascontiguousarray(dat) for dat in embeddings[:, i, :] + ] + + table = pa.table( + index, + metadata={ + "date": datestr, + "gsd": str(gsd), + "uri": f"s3://{source_bucket}/{path}", + }, + ) + + writer = pa.BufferOutputStream() + gaio.write_geoparquet_table(table, writer) + body = bytes(writer.getvalue()) + if "ENDPOINT_URL" in os.environ: + s3_resource = boto3.resource( + "s3", + endpoint_url=os.environ.get("ENDPOINT_URL"), + aws_access_key_id=os.environ.get("ENDPOINT_KEY_ID"), + aws_secret_access_key=os.environ.get("ENDPOINT_ACCESS_KEY"), + ) + else: + s3_resource = boto3.resource("s3") + + s3_bucket = s3_resource.Bucket(name=destination_bucket) + s3_bucket.put_object( + Body=body, + Key=f"{path.parent}/{path.stem}.parquet", + ) diff --git a/environment.yml b/environment.yml index 699df9df..fed9665d 100644 --- a/environment.yml +++ b/environment.yml @@ -7,33 +7,36 @@ dependencies: - einops~=0.7.0 - fiona~=1.9.5 - geopandas-base~=0.14.1 - - h5netcdf~=1.3.0 - - jupyter-book~=1.0.0 - - jupyterlab~=4.0.7 - jsonargparse~=4.27.0 - - lancedb~=0.10.2 - lightning~=2.1.0 - matplotlib-base~=3.8.2 - planetary-computer~=1.0.0 - python-box~=7.1.0 - - pytorch~=2.1.0 # [osx] - - pytorch~=2.1.0 *cuda12* # [linux] + - pytorch~=2.3.1 # [osx] + - pytorch~=2.3.1 *cuda12* # [linux] - python~=3.11.0 - pyarrow~=16.1.0 - - rioxarray~=0.15.0 - rasterio~=1.3.10 - s3fs~=2024.3.1 - scikit-image~=0.22.0 - scikit-learn~=1.4.0 - stackstac~=0.5.0 - timm~=0.9.16 - - torchdata~=0.7.1 - - torchgeo~=0.5.2 - - torchvision~=0.16.1 + - torchvision~=0.18.1 - transformers~=4.35.2 - typeshed-client~=2.4.0 - vit-pytorch~=1.6.4 - - wandb~=0.15.12 - zarr~=2.16.1 + - pip: + - geoarrow-pyarrow==0.1.2 + - jupyter-book==1.0.2 + - jupyterlab==4.2.4 + - onnx==1.16.1 + - onnxscript + - onnxruntime + - torchdata==0.7.1 + - torchgeo==0.5.2 + - stacchip==0.1.38 + - wandb==0.17.5 platforms: - linux-64 diff --git a/finetune/classify/classify.py b/finetune/classify/classify.py index 352afd15..7b87f2e4 100644 --- a/finetune/classify/classify.py +++ b/finetune/classify/classify.py @@ -21,7 +21,9 @@ def cli_main(): """ Command-line inteface to run Clasifier model with EuroSATDataModule. """ - cli = LightningCLI(EuroSATClassifier, EuroSATDataModule) + cli = LightningCLI( + EuroSATClassifier, EuroSATDataModule, save_config_kwargs={"overwrite": True} + ) return cli diff --git a/finetune/classify/factory.py b/finetune/classify/factory.py index cebdeca2..079d3f39 100644 --- a/finetune/classify/factory.py +++ b/finetune/classify/factory.py @@ -31,20 +31,32 @@ def __init__(self, num_classes=10, ckpt_path=None): # Initialize Clay Encoder with parameters from base model. Set # mask_ratio to 0.0 & shuffle to False for downstream tasks. + # self.clay_encoder = Encoder( + # mask_ratio=0.0, + # patch_size=8, + # shuffle=False, + # dim=768, + # depth=12, + # heads=12, + # dim_head=64, + # mlp_ratio=4.0, + # ) self.clay_encoder = Encoder( mask_ratio=0.0, patch_size=8, shuffle=False, - dim=768, - depth=12, - heads=12, + dim=1024, + depth=24, + heads=16, dim_head=64, mlp_ratio=4.0, + # feature_maps=feature_maps, + # ckpt_path=ckpt_path, ) # Simple 2 layer MLP head for classification self.head = nn.Sequential( - nn.Linear(768, 512), + nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.25), nn.Linear(512, num_classes), diff --git a/finetune/embedder/factory.py b/finetune/embedder/factory.py new file mode 100644 index 00000000..bf3ee6e4 --- /dev/null +++ b/finetune/embedder/factory.py @@ -0,0 +1,303 @@ +"""Export the Clay model to ONNX and pytorch ExportedProgram format. + +This script exports the Clay model to ONNX and pytorch ExportedProgram format +for deployment. The model is exported with dynamic shapes for inference. + +How to use: + +```bash +python -m finetune.embedder.factory \ + --img_size 256 \ + --ckpt_path checkpoints/clay-v1-base.ckpt \ + --device cuda \ + --name clay-v1-encoder.onnx \ + --onnx +# exports Clay encoder to ONNX format that can handle chips of size 256x256 +# for different sensors like Sentinel-2, Landsat-8, NAIP, LINZ & Sentinel 1. +``` + +```bash +python -m finetune.embedder.factory \ + --img_size 224 \ + --ckpt_path checkpoints/clay-v1-base.ckpt \ + --device cuda \ + --name clay-v1-encoder.pt2 \ + --ep +# exports Clay encoder to pytorch ExportedProgram format that can handle chips +# of size 224x224 for different sensors like Sentinel-2, Landsat-8, NAIP, LINZ +# & Sentinel 1. +``` + +""" + +import argparse +import re +import warnings +from pathlib import Path + +import torch +from einops import repeat +from torch import nn +from torch.export import Dim + +from src.model import Encoder +from src.utils import posemb_sincos_2d_with_gsd + +warnings.filterwarnings("ignore", category=UserWarning) + + +class EmbeddingEncoder(Encoder): + """Clay Encoder without mask and shuffle.""" + + def __init__( # noqa: PLR0913 + self, + img_size, + patch_size, + dim, + depth, + heads, + dim_head, + mlp_ratio, + ): + super().__init__( + mask_ratio=0.0, + shuffle=False, + patch_size=patch_size, + dim=dim, + depth=depth, + heads=heads, + dim_head=dim_head, + mlp_ratio=mlp_ratio, + ) + self.img_size = img_size + + # Using fixed grid size for inference + self.grid_size = img_size // patch_size + self.num_patches = self.grid_size**2 + + def add_encodings(self, patches, time, latlon, gsd): + """Add position encoding to the patches""" + B, L, D = patches.shape + + grid_size = self.grid_size + + pos_encoding = ( + posemb_sincos_2d_with_gsd( + h=grid_size, + w=grid_size, + dim=(self.dim - 8), + gsd=gsd, + ) + .to(patches.device) + .detach() + ) # [L (D - 8)] + + time_latlon = torch.hstack((time, latlon)).to(patches.device).detach() # [B 8] + + pos_encoding = repeat(pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)] + time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8] + pos_metadata_encoding = torch.cat( + (pos_encoding, time_latlon), dim=-1 + ) # [B L D] + + patches = patches + pos_metadata_encoding # [B L D] + [B L D] -> [B L D] + return patches # [B L D] + + # def forward(self, cube, time, latlon, waves, gsd): + def forward(self, datacube): + cube, time, latlon, gsd, waves = ( + datacube["pixels"], # [B C H W] + datacube["time"], # [B 2] + datacube["latlon"], # [B 2] + datacube["gsd"], # 1 + datacube["waves"], # [N] + ) # [B C H W] + B, C, H, W = cube.shape + + patches, _ = self.to_patch_embed( + cube, waves + ) # [B L D] - patchify & create embeddings per patch + + # Add time & latlon as encoding to patches + patches = self.add_encodings( + patches, + time, + latlon, + gsd, + ) # [B L D] - add position encoding to the embeddings + + # Add class tokens + cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] + patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D] + + # pass the patches through the transformer + patches = self.transformer(patches) # [B (1 + L) D] + + # get the cls token + embeddings = patches[:, 0, :] # [B D] + + return embeddings + + +class Embedder(nn.Module): + def __init__(self, img_size=256, ckpt_path=None, device="cpu"): + super().__init__() + self.clay_encoder = ( + EmbeddingEncoder( # Default parameters for the Clay base model + img_size=img_size, + patch_size=8, + dim=768, + depth=12, + heads=12, + dim_head=64, + mlp_ratio=4.0, + ).to(device) + ) + self.img_size = img_size + self.device = torch.device(device) + self.load_clay_weights(ckpt_path) + + def load_clay_weights(self, ckpt_path): + "Load the weights from the Clay model encoder." + ckpt = torch.load(ckpt_path, map_location=self.device) + state_dict = ckpt.get("state_dict") + state_dict = { + re.sub(r"^model\.encoder\.", "", name): param + for name, param in state_dict.items() + if name.startswith("model.encoder") + } + + with torch.no_grad(): + for name, param in self.clay_encoder.named_parameters(): + if name in state_dict and param.size() == state_dict[name].size(): + param.data.copy_(state_dict[name]) # Copy the weights + else: + print(f"No matching parameter for {name} with size {param.size()}") + + for param in self.clay_encoder.parameters(): + param.requires_grad = False + + self.clay_encoder.eval() + + def forward(self, datacube): + embeddings = self.clay_encoder(datacube) + + return embeddings + + def fake_datacube(self): + "Generate a fake datacube for model export." + dummy_datacube = { + "pixels": torch.randn(2, 3, self.img_size, self.img_size), + "time": torch.randn(2, 4), + "latlon": torch.randn(2, 4), + "waves": torch.randn(3), + "gsd": torch.randn(1), + } + dummy_datacube = {k: v.to(self.device) for k, v in dummy_datacube.items()} + return dummy_datacube + + def export_to_onnx(self, name): + "Save the model to ONNX format." + + datacube = self.fake_datacube() + export_options = torch.onnx.ExportOptions(dynamic_shapes=True) + + # Export the model to ONNX format + onnx_program = torch.onnx.dynamo_export( + self.eval(), datacube, export_options=export_options + ) + + # Save the exported model + onnx_program.save(f"checkpoints/compiled/{name}") + print(f"Model exported to ONNX format: checkpoints/compiled/{name}") + + return onnx_program + + def export_to_torchep(self, name): + "Save the model to pytorch ExportedProgram format." + + datacube = self.fake_datacube() + + # dynamic shapes for model export + batch_size = Dim("batch_size", min=2, max=1000) + channel_bands = Dim("channel_bands", min=1, max=10) + dynamic_shapes = { + "datacube": { + "pixels": {0: batch_size, 1: channel_bands}, + "time": {0: batch_size}, + "latlon": {0: batch_size}, + "waves": {0: channel_bands}, + "gsd": {0: None}, + } + } + + # Export the model to pytorch ExportedProgram format + ep = torch.export.export( + self.eval(), + (datacube,), + dynamic_shapes=dynamic_shapes, + strict=True, + ) + + # Save the exported model + torch.export.save(ep, f"checkpoints/compiled/{name}") + print( + f"Model exported to pytorch ExportedProgram format: checkpoints/compiled/{name}" # noqa: E501 + ) + + return ep + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Export the Clay model.") + parser.add_argument( + "--img_size", + type=int, + default=256, + help="Image size for the model", + ) + parser.add_argument( + "--ckpt_path", + type=str, + default="checkpoints/clay-v1-base.ckpt", + help="Path to the Clay model checkpoint", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use for the model", + ) + parser.add_argument( + "--name", + type=str, + default="clay-base.pt", + help="Name of the exported model", + ) + parser.add_argument( + "--onnx", + action="store_true", + help="Export the model to ONNX format", + ) + parser.add_argument( + "--ep", + action="store_true", + help="Export the model to pytorch ExportedProgram format", + ) + + args = parser.parse_args() + + Path("checkpoints/compiled").mkdir(parents=True, exist_ok=True) + embedder = Embedder( + img_size=args.img_size, + ckpt_path=args.ckpt_path, + device=args.device, + ) + + if args.onnx: + embedder.export_to_onnx(args.name) + elif args.ep: + embedder.export_to_torchep(args.name) + else: + print("Please specify the format to export the model.") + parser.print_help() diff --git a/finetune/embedder/how-to-embed.ipynb b/finetune/embedder/how-to-embed.ipynb new file mode 100644 index 00000000..5f482846 --- /dev/null +++ b/finetune/embedder/how-to-embed.ipynb @@ -0,0 +1,494 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "d9960547-640d-425c-8180-fc5523a80e42", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import os\n", + "import requests\n", + "import warnings\n", + "\n", + "import geoarrow.pyarrow as ga\n", + "import numpy as np\n", + "import pystac_client\n", + "import pyarrow as pa\n", + "import pyarrow.parquet as pq\n", + "import torch\n", + "import yaml\n", + "from box import Box\n", + "from torchvision.transforms import v2\n", + "\n", + "from stacchip.indexer import Sentinel2Indexer\n", + "from stacchip.chipper import Chipper\n", + "\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "markdown", + "id": "598fec81-2cc1-4c5a-9e46-7c46a5591484", + "metadata": {}, + "source": [ + "### Find data for AOI\n", + "The first step is to find STAC items of imagery that we want to use to create embeddings. In this example we are going to use Earth Genome's composite dataset which comes with a great STAC catalog.\n", + "\n", + "We are also going to create embeddings along time so that we have multiple embeddings for the same location at different moments in time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e1d46ee-40f6-49f5-99ad-83819339561e", + "metadata": {}, + "outputs": [], + "source": [ + "# Point over Monchique Portugal\n", + "lat, lon = 37.30939, -8.57207\n", + "\n", + "# Dates of a large forest fire\n", + "start = \"2018-07-01\"\n", + "end = \"2018-09-01\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7825318-23f3-449f-9104-eae6562a55ab", + "metadata": {}, + "outputs": [], + "source": [ + "# Optimize GDAL settings for cloud optimized reading\n", + "os.environ[\"GDAL_DISABLE_READDIR_ON_OPEN\"] = \"EMPTY_DIR\"\n", + "os.environ[\"AWS_REQUEST_PAYER\"] = \"requester\"\n", + "\n", + "STAC_API = \"https://earth-search.aws.element84.com/v1\"\n", + "COLLECTION = \"sentinel-2-l2a\"\n", + "\n", + "# Search the catalogue\n", + "catalog = pystac_client.Client.open(STAC_API)\n", + "search = catalog.search(\n", + " collections=[COLLECTION],\n", + " datetime=f\"{start}/{end}\",\n", + " bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),\n", + " max_items=100,\n", + " query={\"eo:cloud_cover\": {\"lt\": 80}},\n", + ")\n", + "\n", + "all_items = search.get_all_items()\n", + "\n", + "# Reduce to one per date (there might be some duplicates\n", + "# based on the location)\n", + "items = []\n", + "dates = []\n", + "for item in all_items:\n", + " if item.datetime.date() not in dates:\n", + " items.append(item)\n", + " dates.append(item.datetime.date())\n", + "\n", + "print(f\"Found {len(items)} items\")" + ] + }, + { + "cell_type": "markdown", + "id": "600f3cfb-ce4e-4409-ae15-20f3a7107a62", + "metadata": {}, + "source": [ + "To speed up processing in this example, we limit the number of chips to 3 per Sentinel-2 scene. Remove this limit in a real use case." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "183975c7-8afb-49ef-8e70-790265719aea", + "metadata": {}, + "outputs": [], + "source": [ + "chips = []\n", + "datetimes = []\n", + "bboxs = []\n", + "chip_ids = []\n", + "item_ids = []\n", + "\n", + "for item in items:\n", + " print(f\"Working on {item}\")\n", + "\n", + " # Index the chips in the item\n", + " indexer = Sentinel2Indexer(item)\n", + "\n", + " # Instanciate the chipper\n", + " chipper = Chipper(indexer, assets=[\"red\", \"green\", \"blue\", \"nir\", \"scl\"])\n", + "\n", + " # Get first chip for the \"image\" asset key\n", + " for idx, (x, y, chip) in enumerate(chipper):\n", + " if idx > 2:\n", + " break\n", + " del chip[\"scl\"]\n", + " chips.append(chip)\n", + " datetimes.append(item.datetime)\n", + " bboxs.append(indexer.get_chip_bbox(x, y))\n", + " chip_ids.append((x, y))\n", + " item_ids.append(item.id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71902ab7-3320-43cd-85c3-362c2500f241", + "metadata": {}, + "outputs": [], + "source": [ + "pixels = np.array([np.array(list(chip.values())).squeeze() for chip in chips])\n", + "pixels.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f7ce367-4e12-4648-bb79-119b4f50ead8", + "metadata": {}, + "outputs": [], + "source": [ + "# Extract mean, std, and wavelengths from metadata\n", + "platform = \"sentinel-2-l2a\"\n", + "# Retrieve the file content from the URL\n", + "\n", + "url = (\n", + " \"https://raw.githubusercontent.com/Clay-foundation/model/main/configs/metadata.yaml\"\n", + ")\n", + "response = requests.get(url, allow_redirects=True)\n", + "\n", + "# Convert bytes to string\n", + "content = response.content.decode(\"utf-8\")\n", + "\n", + "# Load the yaml\n", + "content = yaml.safe_load(content)\n", + "\n", + "metadata = Box(content)\n", + "mean = []\n", + "std = []\n", + "waves = []\n", + "# Use the band names to get the correct values in the correct order.\n", + "for band in chips[0].keys():\n", + " mean.append(metadata[platform].bands.mean[band])\n", + " std.append(metadata[platform].bands.std[band])\n", + " waves.append(metadata[platform].bands.wavelength[band])\n", + "\n", + "# Prepare the normalization transform function using the mean and std values.\n", + "transform = v2.Compose(\n", + " [\n", + " v2.Normalize(mean=mean, std=std),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8ec8c2d-ecb9-42a2-9e8c-3f95c67ef07b", + "metadata": {}, + "outputs": [], + "source": [ + "def normalize_timestamp(date):\n", + " week = date.isocalendar().week * 2 * np.pi / 52\n", + " hour = date.hour * 2 * np.pi / 24\n", + "\n", + " return (math.sin(week), math.cos(week)), (math.sin(hour), math.cos(hour))\n", + "\n", + "\n", + "times = [normalize_timestamp(dat) for dat in datetimes]\n", + "week_norm = [dat[0] for dat in times]\n", + "hour_norm = [dat[1] for dat in times]\n", + "\n", + "\n", + "# Prep lat/lon embedding using the\n", + "def normalize_latlon(lat, lon):\n", + " lat = lat * np.pi / 180\n", + " lon = lon * np.pi / 180\n", + "\n", + " return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))\n", + "\n", + "\n", + "latlons = [normalize_latlon(lat, lon)] * len(times)\n", + "lat_norm = [dat[0] for dat in latlons]\n", + "lon_norm = [dat[1] for dat in latlons]\n", + "\n", + "# Prep gsd\n", + "gsd = [10]\n", + "\n", + "# Normalize pixels\n", + "pixels = transform(pixels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2640eb17-a85c-4972-8d5d-e45e9ed8eba5", + "metadata": {}, + "outputs": [], + "source": [ + "datacube = {\n", + " \"pixels\": torch.tensor(pixels, dtype=torch.float32),\n", + " \"time\": torch.tensor(np.hstack((week_norm, hour_norm)), dtype=torch.float32),\n", + " \"latlon\": torch.tensor(np.hstack((lat_norm, lon_norm)), dtype=torch.float32),\n", + " \"waves\": torch.tensor(waves, dtype=torch.float32),\n", + " \"gsd\": torch.tensor(gsd, dtype=torch.float32),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f6711a9-e7ed-44d5-add7-2c3a498cd422", + "metadata": {}, + "outputs": [], + "source": [ + "for k, v in datacube.items():\n", + " print(k, v.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "83243912-a2a8-4fa5-a39c-a9c3b07c7569", + "metadata": {}, + "source": [ + "### Clay Embedder\n", + "\n", + "#### Load the embedder that is stored in ExportedProgram format using **cpu**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4eb468af-d468-46aa-a8fb-23ff95c56288", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/compiled/v1.0/clay-v1-encoder-cpu.pt2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9eb797f7-5238-49e0-9950-e85f10132454", + "metadata": {}, + "outputs": [], + "source": [ + "ep_embedder_cpu = torch.export.load(\"clay-v1-encoder-cpu.pt2\").module()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eefe4811-7290-47c3-a10e-45257e6d42e0", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "with torch.no_grad():\n", + " embeddings = ep_embedder_cpu(datacube)\n", + "datacube[\"pixels\"].shape, embeddings.shape" + ] + }, + { + "cell_type": "markdown", + "id": "8e927b01-c855-4172-a4d9-2c10ba794ed4", + "metadata": {}, + "source": [ + "For each chip, we have an embedding of size `768`" + ] + }, + { + "cell_type": "markdown", + "id": "fa0810b4-34ad-490e-bbcd-c0c3288f017c", + "metadata": {}, + "source": [ + "#### Load the embedder that is stored in ExportedProgram format using **gpu**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c1bbfd4-7dc6-4ad0-8a0b-b3745a9f35ca", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/compiled/v1.0/clay-v1-encoder.pt2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e285a543-20ab-44ba-b676-2303284dc477", + "metadata": {}, + "outputs": [], + "source": [ + "datacube = {k: v.to(\"cuda\") for k, v in datacube.items()}\n", + "ep_embedder = torch.export.load(\"clay-v1-encoder.pt2\").module()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edefee90-e6b8-4701-bb5d-2bf7febc806c", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "with torch.no_grad():\n", + " embeddings = ep_embedder(datacube)\n", + "datacube[\"pixels\"].shape, embeddings.shape" + ] + }, + { + "cell_type": "markdown", + "id": "196f2121-46b5-4b02-94d3-75e648c329c3", + "metadata": {}, + "source": [ + "For each chip, we have an embedding of size `768`" + ] + }, + { + "cell_type": "markdown", + "id": "5b1cb0f9-a434-419b-a88b-4d4edd84fea6", + "metadata": {}, + "source": [ + "#### Load the embedder that is stored in ONNX format using **cpu**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa10d696-740a-458e-ae10-eec9a43fb362", + "metadata": {}, + "outputs": [], + "source": [ + "import onnx\n", + "import onnxruntime as ort" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "992524e5-2c2a-4e48-ae95-bd2aa87b72a9", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/compiled/v1.0/clay-v1-encoder-cpu.onnx" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc3fa967-73d5-431c-88a2-84b088aff06f", + "metadata": {}, + "outputs": [], + "source": [ + "datacube = {k: v.to(\"cpu\") for k, v in datacube.items()}\n", + "onnx_embedder = ort.InferenceSession(\n", + " \"clay-v1-encoder-cpu.onnx\", providers=[\"CPUExecutionProvider\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24591d17-d1c8-452b-9b20-676a9b6f8643", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "embeddings = onnx_embedder.run(\n", + " [],\n", + " {\n", + " \"cube\": datacube[\"pixels\"].numpy(),\n", + " \"time\": datacube[\"time\"].numpy(),\n", + " \"latlon\": datacube[\"latlon\"].numpy(),\n", + " \"waves\": datacube[\"waves\"].numpy(),\n", + " \"gsd\": datacube[\"gsd\"].numpy(),\n", + " },\n", + ")[0]\n", + "embeddings.shape" + ] + }, + { + "cell_type": "markdown", + "id": "9c07216e-a109-4cd8-8c74-9a3fc9a37757", + "metadata": {}, + "source": [ + "For each chip, we have an embedding of size `768`" + ] + }, + { + "cell_type": "markdown", + "id": "2e8d5900-9a4b-4e2d-b992-4fb0a1e8c835", + "metadata": {}, + "source": [ + "### Store the results\n", + "\n", + "We create a table containing the embeddings, bounding box, the STAC item ID, the datetime of the image capture, and the chip x and y ids. Then we save that data to disk." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "677f04d3-db38-4d44-9b55-c103d54adcd5", + "metadata": {}, + "outputs": [], + "source": [ + "# Write data to pyarrow table\n", + "index = {\n", + " \"datetimes\": datetimes,\n", + " \"chip_ids\": chip_ids,\n", + " \"item_ids\": item_ids,\n", + " \"emeddings\": [np.ascontiguousarray(dat) for dat in embeddings],\n", + " \"geometry\": ga.as_geoarrow([dat.wkt for dat in bboxs]),\n", + "}\n", + "table = pa.table(index)\n", + "table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d62a9e8a-b4f9-491c-a437-6a164a9e74fe", + "metadata": {}, + "outputs": [], + "source": [ + "pq.write_table(table, \"embeddings.parquet\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d30fb8c7-d04d-453f-93f6-dc3599f1df15", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/finetune/segment/chesapeake_datamodule.py b/finetune/segment/chesapeake_datamodule.py index ec7e16d0..310f2099 100644 --- a/finetune/segment/chesapeake_datamodule.py +++ b/finetune/segment/chesapeake_datamodule.py @@ -46,7 +46,9 @@ def __init__(self, chip_dir, label_dir, metadata, platform): ) # Load chip and label file names - self.chips = [chip_path.name for chip_path in self.chip_dir.glob("*.npy")] + self.chips = [chip_path.name for chip_path in self.chip_dir.glob("*.npy")][ + :1000 + ] self.labels = [re.sub("_naip-new_", "_lc_", chip) for chip in self.chips] def create_transforms(self, mean, std): diff --git a/finetune/segment/chesapeake_model.py b/finetune/segment/chesapeake_model.py index b5964ab3..949e5223 100644 --- a/finetune/segment/chesapeake_model.py +++ b/finetune/segment/chesapeake_model.py @@ -99,9 +99,9 @@ def configure_optimizers(self): ) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, - T_0=1000, + T_0=100, T_mult=1, - eta_min=self.hparams.lr * 100, + eta_min=self.hparams.lr * 10, last_epoch=-1, ) return { diff --git a/finetune/segment/factory.py b/finetune/segment/factory.py index 0ee95db8..de439b90 100644 --- a/finetune/segment/factory.py +++ b/finetune/segment/factory.py @@ -182,9 +182,9 @@ def __init__(self, num_classes, feature_maps, ckpt_path): mask_ratio=0.0, patch_size=8, shuffle=False, - dim=768, - depth=12, - heads=12, + dim=1024, + depth=24, + heads=16, dim_head=64, mlp_ratio=4.0, feature_maps=feature_maps, diff --git a/finetune/segment/segment.py b/finetune/segment/segment.py index 7531b4d8..50b61d26 100644 --- a/finetune/segment/segment.py +++ b/finetune/segment/segment.py @@ -21,7 +21,11 @@ def cli_main(): """ Command-line inteface to run Segmentation Model with ChesapeakeDataModule. """ - cli = LightningCLI(ChesapeakeSegmentor, ChesapeakeDataModule) + cli = LightningCLI( + ChesapeakeSegmentor, + ChesapeakeDataModule, + save_config_kwargs={"overwrite": True}, + ) return cli diff --git a/src/README.md b/src/README.md deleted file mode 100644 index 7326daf7..00000000 --- a/src/README.md +++ /dev/null @@ -1,23 +0,0 @@ -# Clay Foundation Model Modules - -This folder contains several LightningDataModule, LightningModule and callback -classes. - -## DataModules (data pipeline) - -- datamodule.py - Data pipeline to read in Earth Observation chips from GeoTIFF files - -## LightningModule (model architecture) - -- model_clay.py - Clay Foundation Model architecture with spatiotemporal encoders -- model_vit.py - Vanilla Vision Transformer neural network model architecture - -## Callbacks (custom plugins) - -- callbacks_wandb.py - Log metrics and predictions to Weights and Biases while training. - -## References - -- https://lightning.ai/docs/pytorch/2.1.0/data/datamodule.html -- https://lightning.ai/docs/pytorch/2.1.0/common/lightning_module.html -- https://lightning.ai/docs/pytorch/2.1.0/extensions/callbacks.html diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/backbone.py b/src/backbone.py new file mode 100644 index 00000000..a6e2ebb2 --- /dev/null +++ b/src/backbone.py @@ -0,0 +1,83 @@ +"""Code for Transformer from Phil Wangs vit-pytorch library. +Repository: https://github.com/lucidrains/vit-pytorch +""" + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, dim), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64, fused_attn=True): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim_head**-0.5 + self.norm = nn.LayerNorm(dim) + self.fused_attn = fused_attn + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x): + x = self.norm(x) + + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) + else: + attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = attn.softmax(dim=-1) + x = torch.matmul(attn, v) + + x = rearrange(x, "b h n d -> b n (h d)") + return self.to_out(x) + + +class Transformer(nn.Module): + def __init__( # noqa: PLR0913 + self, + dim, + depth, + heads, + dim_head, + mlp_dim, + fused_attn, + ): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + Attention( + dim, heads=heads, dim_head=dim_head, fused_attn=fused_attn + ), + FeedForward(dim, mlp_dim), + ] + ) + ) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return self.norm(x) diff --git a/src/callbacks_wandb.py b/src/callbacks_wandb.py index 0f4d4a10..374867fc 100644 --- a/src/callbacks_wandb.py +++ b/src/callbacks_wandb.py @@ -247,6 +247,8 @@ def on_validation_end( ) assert pixels.shape == batch["pixels"].shape + batch["pixels"] = batch["pixels"].detach().cpu().numpy() + pixels = pixels.detach().cpu().numpy() n_rows = 4 # 2 for actual and 2 for predicted n_cols = 8 @@ -255,29 +257,23 @@ def on_validation_end( for j in range(n_cols): # Plot actual images in rows 0 and 2 - axs[0, j].imshow( - batch["pixels"][j][0].detach().cpu().numpy(), cmap="viridis" - ) + axs[0, j].imshow(batch["pixels"][j][0], cmap="viridis") axs[0, j].set_title(f"Actual {j}") axs[0, j].axis("off") axs[2, j].imshow( - batch["pixels"][j + n_cols][0].detach().cpu().numpy(), + batch["pixels"][j + n_cols][0], cmap="viridis", ) axs[2, j].set_title(f"Actual {j+n_cols}") axs[2, j].axis("off") # Plot predicted images in rows 1 and 3 - axs[1, j].imshow( - pixels[j][0].detach().cpu().numpy(), cmap="viridis" - ) + axs[1, j].imshow(pixels[j][0], cmap="viridis") axs[1, j].set_title(f"Pred {j}") axs[1, j].axis("off") - axs[3, j].imshow( - pixels[j + n_cols][0].detach().cpu().numpy(), cmap="viridis" - ) + axs[3, j].imshow(pixels[j + n_cols][0], cmap="viridis") axs[3, j].set_title(f"Pred {j+n_cols}") axs[3, j].axis("off") diff --git a/src/datamodule.py b/src/datamodule.py index 7ba22e98..dc6c0901 100644 --- a/src/datamodule.py +++ b/src/datamodule.py @@ -3,6 +3,8 @@ rasterio. """ +import math +import random from collections import defaultdict from pathlib import Path from typing import List, Literal @@ -10,7 +12,8 @@ import lightning as L import numpy as np import torch -import torchdata + +# import torchdata import yaml from box import Box from einops import rearrange @@ -42,7 +45,7 @@ def create_transforms(self, mean, std): [ v2.RandomHorizontalFlip(p=0.5), v2.RandomVerticalFlip(p=0.5), - v2.RandomCrop(size=(self.size, self.size)), + # v2.RandomCrop(size=(self.size, self.size)), v2.Normalize(mean=mean, std=std), ] ) @@ -53,20 +56,38 @@ def __len__(self): def __getitem__(self, idx): chip_path = self.chips_path[idx] with np.load(chip_path, allow_pickle=False) as chip: - pixels = torch.from_numpy(chip["pixels"].astype(np.float32)) platform = chip_path.parent.name + if platform == "sentinel-1-rtc": + pixels = chip["pixels"].astype(np.float32) + pixels[pixels <= 0] = ( + 1e-10 # replace corrupted pixels in sentinel-1-rtc with small value + ) + pixels = 10 * np.log10( + pixels + ) # convert to dB scale, more interpretable pixels + else: + pixels = chip["pixels"].astype(np.float32) + + pixels = torch.from_numpy(pixels) pixels = self.transforms[platform](pixels) + time_tensor = torch.tensor( + np.hstack((chip["week_norm"], chip["hour_norm"]), dtype=np.float32) + ) + latlon_tensor = torch.tensor( + np.hstack((chip["lat_norm"], chip["lon_norm"]), dtype=np.float32) + ) + + # Randomly set time & latlon to zero for 20% of the chips + if random.random() < 0.2: # noqa: PLR2004 + time_tensor.zero_() + latlon_tensor.zero_() + # Prepare additional information additional_info = { "platform": platform, - "time": torch.tensor( - np.hstack((chip["week_norm"], chip["hour_norm"])), - dtype=torch.float32, - ), - "latlon": torch.tensor( - np.hstack((chip["lat_norm"], chip["lon_norm"])), dtype=torch.float32 - ), + "time": time_tensor, + "latlon": latlon_tensor, } return {"pixels": pixels, **additional_info} @@ -106,6 +127,77 @@ def __len__(self): return len(self.dataset.chips_path) // self.batch_size +class ClayDistributedSampler(Sampler): + def __init__( # noqa: PLR0913 + self, + dataset, + platforms, + batch_size, + num_replicas=None, + rank=None, + shuffle=True, + ): + self.dataset = dataset + self.platforms = platforms + self.batch_size = batch_size + self.num_replicas = ( + num_replicas + if num_replicas is not None + else torch.distributed.get_world_size() + ) + self.rank = rank if rank is not None else torch.distributed.get_rank() + self.shuffle = shuffle + self.epoch = 0 + + self.platform_indices = {platform: [] for platform in platforms} + for idx, chip_path in enumerate(self.dataset.chips_path): + platform = chip_path.parent.name + self.platform_indices[platform].append(idx) + + self.max_len = max(len(indices) for indices in self.platform_indices.values()) + self.adjusted_indices = {} + # Normalize the length of indices for each platform by replicating the indices + # to match the max_len + for platform, indices in self.platform_indices.items(): + if len(indices) < self.max_len: + extended_indices = np.tile(indices, (self.max_len // len(indices) + 1))[ + : self.max_len + ] + self.adjusted_indices[platform] = extended_indices + else: + self.adjusted_indices[platform] = indices + + self.num_samples = math.ceil( + ((self.max_len * len(self.platforms)) - self.num_replicas) + / self.num_replicas + ) + self.total_size = self.num_samples * self.num_replicas + self.num_samples_per_platform = self.max_len // self.num_replicas + + def __iter__(self): + rng = np.random.default_rng(self.epoch) + platform_batches = {} + for platform, indices in self.adjusted_indices.items(): + if self.shuffle: + rng.shuffle(indices) + # Distribute the indices to each process + start_idx = self.rank * self.num_samples_per_platform + end_idx = start_idx + self.num_samples_per_platform + platform_batches[platform] = indices[start_idx:end_idx] + + for i in range(0, self.num_samples_per_platform, self.batch_size): + for platform in self.platforms: + batch = platform_batches[platform][i : i + self.batch_size] + if len(batch) == self.batch_size: + yield batch + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch + + def batch_collate(batch): """Collate function for DataLoader. @@ -135,12 +227,14 @@ def __init__( # noqa: PLR0913 "landsat-c2l1", "landsat-c2l2-sr", "linz", + "modis", "naip", "sentinel-1-rtc", "sentinel-2-l2a", ], batch_size: int = 10, num_workers: int = 8, + prefetch_factor: int = 2, ): super().__init__() self.data_dir = data_dir @@ -149,17 +243,18 @@ def __init__( # noqa: PLR0913 self.metadata = Box(yaml.safe_load(open(metadata_path))) self.batch_size = batch_size self.num_workers = num_workers + self.prefetch_factor = prefetch_factor self.split_ratio = 0.8 def setup(self, stage: Literal["fit", "predict"] | None = None) -> None: # Get list of GeoTIFF filepaths from s3 bucket or data/ folder - if self.data_dir.startswith("s3://"): - dp = torchdata.datapipes.iter.IterableWrapper(iterable=[self.data_dir]) - chips_path = list(dp.list_files_by_s3(masks="*.npz")) - else: # if self.data_dir is a local data path - chips_path = sorted(list(Path(self.data_dir).glob("**/*.npz"))) - chips_platform = [chip.parent.parent.name for chip in chips_path] - # chips_platform = [chip.parent.parent.name for chip in chips_path] + # if self.data_dir.startswith("s3://"): + # dp = torchdata.datapipes.iter.IterableWrapper(iterable=[self.data_dir]) + # chips_path = list(dp.list_files_by_s3(masks="*.npz")) + # else: # if self.data_dir is a local data path + chips_path = sorted(list(Path(self.data_dir).glob("**/*.npz"))) + chips_platform = [chip.parent.name for chip in chips_path] + # chips_platform = [chip.parent.parent.name for chip in chips_path] print(f"Total number of chips: {len(chips_path)}") if stage == "fit": @@ -176,7 +271,7 @@ def setup(self, stage: Literal["fit", "predict"] | None = None) -> None: platforms=self.platforms, metadata=self.metadata, ) - self.trn_sampler = ClaySampler( + self.trn_sampler = ClayDistributedSampler( dataset=self.trn_ds, platforms=self.platforms, batch_size=self.batch_size, @@ -187,7 +282,7 @@ def setup(self, stage: Literal["fit", "predict"] | None = None) -> None: platforms=self.platforms, metadata=self.metadata, ) - self.val_sampler = ClaySampler( + self.val_sampler = ClayDistributedSampler( dataset=self.val_ds, platforms=self.platforms, batch_size=self.batch_size, @@ -207,7 +302,7 @@ def train_dataloader(self): batch_sampler=self.trn_sampler, collate_fn=batch_collate, pin_memory=True, - prefetch_factor=4, + prefetch_factor=self.prefetch_factor, ) def val_dataloader(self): @@ -217,7 +312,7 @@ def val_dataloader(self): batch_sampler=self.val_sampler, collate_fn=batch_collate, pin_memory=True, - prefetch_factor=4, + prefetch_factor=self.prefetch_factor, ) def predict_dataloader(self): diff --git a/src/factory.py b/src/factory.py index 9f10fab9..84a0d5b5 100644 --- a/src/factory.py +++ b/src/factory.py @@ -44,7 +44,7 @@ def __init__( # noqa: PLR0913 activation="gelu", dropout=0, norm_first=False, - batch_first=False, + batch_first=True, ) self.encoder = nn.TransformerEncoder(layer, num_layers) diff --git a/src/model.py b/src/model.py index ee211b97..9648002f 100644 --- a/src/model.py +++ b/src/model.py @@ -1,18 +1,15 @@ import math import os -from typing import Literal +import random -import lightning as L import timm import torch import torch.nn.functional as F -import yaml -from box import Box from einops import rearrange, reduce, repeat from torch import nn from torchvision.transforms import v2 -from vit_pytorch.simple_vit import Transformer +from src.backbone import Transformer from src.factory import DynamicEmbedding from src.utils import posemb_sincos_2d_with_gsd @@ -53,6 +50,7 @@ def __init__( # noqa: PLR0913 heads=heads, dim_head=dim_head, mlp_dim=int(dim * mlp_ratio), + fused_attn=True, ) def to_patch_embed(self, cube, waves): @@ -239,6 +237,7 @@ def __init__( # noqa: PLR0913 heads=heads, dim_head=dim_head, mlp_dim=int(dim * mlp_ratio), + fused_attn=True, ) self.embed_to_pixels = DynamicEmbedding( wave_dim=128, @@ -364,6 +363,8 @@ def __init__( # noqa: PLR0913 shuffle, metadata, teacher, + dolls, + doll_weights, # ENCODER dim, depth, @@ -385,10 +386,12 @@ def __init__( # noqa: PLR0913 self.shuffle = shuffle self.metadata = metadata self.teacher = timm.create_model(teacher, pretrained=True, num_classes=0) - self.teacher_chip_size = 224 + self.teacher_chip_size = 518 self.teacher_resize = v2.Resize( size=(self.teacher_chip_size, self.teacher_chip_size) ) + # self.mrl = MRL(features=self.teacher.num_features, dolls=dolls) + # self.mrl_loss = MRLLoss(weights=doll_weights) self.proj = nn.Linear(dim, self.teacher.num_features) self.encoder = Encoder( @@ -418,6 +421,7 @@ def __init__( # noqa: PLR0913 def freeze_teacher(self): for param in self.teacher.parameters(): param.requires_grad = False + self.teacher.eval() def per_pixel_loss(self, cube, pixels, masked_matrix): """ @@ -459,6 +463,27 @@ def forward(self, datacube): waves = torch.tensor(list(self.metadata[platform].bands.wavelength.values())) gsd = torch.tensor(self.metadata[platform].gsd) + # Drop channels randomly + _pixels = datacube["pixels"].clone() + batch_size, channels, _, _ = _pixels.size() + + # Define probabilities for dropping channels + prob_drop_all = 0.10 # 10% probability to drop all channels + prob_drop_half = 0.20 # 20% probability to drop half the channels + + for i in range(batch_size): + if torch.any( + datacube["latlon"][i] != 0 + ): # Check if latlon is not all zeros + rand_val = random.random() + if rand_val < prob_drop_all: + _pixels[i, :, :, :] = 0 # Drop all channels + elif rand_val < prob_drop_all + prob_drop_half: + channel_indices = torch.randperm(channels)[ + : channels // 2 + ] # Get 50% of channel indices + _pixels[i, channel_indices, :, :] = 0 # Drop 50% of channels + # ENCODER ( encoded_unmasked_patches, # [B (1 + L):(1 - mask_ratio) D] @@ -467,7 +492,7 @@ def forward(self, datacube): masked_matrix, # [B L] ) = self.encoder( { - "pixels": datacube["pixels"], + "pixels": _pixels, "time": datacube["time"], "latlon": datacube["latlon"], "gsd": gsd, @@ -487,32 +512,39 @@ def forward(self, datacube): waves, ) # [B L (C P P)] - # LOSS + # MAE reconstruction_loss = self.per_pixel_loss( datacube["pixels"], pixels, masked_matrix ) + # MODIS has a 10x reconstruction loss compared to all the other sensors, + # so we need to scale it down to improve the learning capability. + if platform == "modis": + reconstruction_loss /= 10 + + # # MRL + # representations = self.mrl(encoded_unmasked_patches[:, 0, :]) # [(B D') ...] + + # PROJ + representations = self.proj(encoded_unmasked_patches[:, 0, :]) # [B D'] - # TEACHER - encoder_output = self.proj(encoded_unmasked_patches[:, 0, :]) # [B D'] with torch.no_grad(): if platform == "sentinel-1-rtc": r = datacube["pixels"][:, 0, :, :] g = datacube["pixels"][:, 1, :, :] - b = r - g + b = (r + g) / 2 rgb = torch.stack((r, g, b), dim=1) else: # Read RGB bands from the sensor to feed the teacher model indices = self.metadata[platform].rgb_indices rgb = datacube["pixels"][:, indices, :, :] rgb = self.teacher_resize(rgb) - teacher_output = self.teacher(rgb) + target = self.teacher(rgb) + # target = self.teacher(rgb) - representation_loss = -( - F.cosine_similarity(encoder_output, teacher_output).mean() - - 1.0 # change range from [-1, 1] to [-2, 0] - ) # negative cosine similarity, [0, 2] -> 0 is similar & 2 is opposite + # representation_loss = self.mrl_loss(representations, target) + representation_loss = 1.0 - F.cosine_similarity(representations, target).mean() - loss = 0.90 * reconstruction_loss + 0.10 * representation_loss + loss = 0.9 * reconstruction_loss + 0.1 * representation_loss return (loss, reconstruction_loss, representation_loss) @@ -564,8 +596,8 @@ def clay_mae_base(**kwargs): "mlp_ratio": 4, # DECODER "decoder_dim": 512, - "decoder_depth": 6, - "decoder_heads": 6, + "decoder_depth": 4, + "decoder_heads": 4, "decoder_dim_head": 64, "decoder_mlp_ratio": 4, } @@ -583,114 +615,10 @@ def clay_mae_large(**kwargs): "mlp_ratio": 4, # DECODER "decoder_dim": 512, - "decoder_depth": 8, - "decoder_heads": 8, + "decoder_depth": 4, + "decoder_heads": 4, "decoder_dim_head": 64, "decoder_mlp_ratio": 4, } args.update(kwargs) return ClayMAE(**args) - - -class ClayMAEModule(L.LightningModule): - def __init__( # noqa: PLR0913 - self, - model_size="base", - mask_ratio=0.75, - norm_pix_loss=False, - patch_size=16, - shuffle=False, - metadata_path="configs/metadata.yaml", - teacher="vit_base_patch16_224.dino", - lr=1e-4, - wd=0.05, - b1=0.9, - b2=0.95, - embeddings_level: Literal["mean", "patch", "group"] = "mean", - ): - super().__init__() - self.save_hyperparameters(logger=True) - self.metadata = Box(yaml.safe_load(open(metadata_path))) - model_map = { - "tiny": clay_mae_tiny, - "small": clay_mae_small, - "base": clay_mae_base, - "large": clay_mae_large, - } - if model_size in model_map: - model_args = { - "mask_ratio": mask_ratio, - "patch_size": patch_size, - "norm_pix_loss": norm_pix_loss, - "shuffle": shuffle, - "metadata": self.metadata, - "teacher": teacher, - } - self.model = model_map[model_size](**model_args) - else: - raise ValueError( - f"Invalid model size {model_size}. Expected one of {model_map.keys()}" - ) - - def on_train_epoch_start(self): - self.model.teacher.eval() - - def forward(self, datacube: dict[str, torch.Tensor]): - return self.model(datacube) - - def configure_optimizers(self): - optimizer = torch.optim.AdamW( - self.parameters(), - lr=self.hparams.lr, - weight_decay=self.hparams.wd, - betas=(self.hparams.b1, self.hparams.b2), - ) - scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( - optimizer, T_0=1000, T_mult=2, eta_min=self.hparams.lr * 100, last_epoch=-1 - ) - - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": scheduler, - "interval": "step", - }, - } - - def shared_step(self, batch: dict[str, torch.Tensor], batch_idx: int, phase: str): - datacube = batch - loss, reconstruction_loss, representation_loss = self(datacube) - self.log( - name=f"{phase}/loss", - value=loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - self.log( - name=f"{phase}/rec_loss", - value=reconstruction_loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - self.log( - name=f"{phase}/rep_loss", - value=representation_loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - return loss - - def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int): - return self.shared_step(batch, batch_idx, phase="train") - - def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int): - return self.shared_step(batch, batch_idx, phase="val") diff --git a/src/module.py b/src/module.py new file mode 100644 index 00000000..eafcf815 --- /dev/null +++ b/src/module.py @@ -0,0 +1,138 @@ +from typing import Literal + +import lightning as L +import torch +import yaml +from box import Box + +from src.model import clay_mae_base, clay_mae_large, clay_mae_small, clay_mae_tiny + + +class ClayMAEModule(L.LightningModule): + def __init__( # noqa: PLR0913 + self, + model_size="base", + mask_ratio=0.75, + norm_pix_loss=False, + patch_size=8, + shuffle=False, + metadata_path="configs/metadata.yaml", + teacher="samvit_base_patch16.sa1b", + dolls=[16, 32, 64, 128, 256, 768], + doll_weights=[1, 1, 1, 1, 1, 1], + lr=1e-5, + wd=0.05, + b1=0.9, + b2=0.95, + embeddings_level: Literal["mean", "patch", "group"] = "mean", + ): + super().__init__() + # self.strict_loading = False # Allow partial loading to check if MRL was the bug + self.save_hyperparameters(logger=True) + self.metadata = Box(yaml.safe_load(open(metadata_path))) + model_map = { + "tiny": clay_mae_tiny, + "small": clay_mae_small, + "base": clay_mae_base, + "large": clay_mae_large, + } + if model_size in model_map: + model_args = { + "mask_ratio": mask_ratio, + "patch_size": patch_size, + "norm_pix_loss": norm_pix_loss, + "shuffle": shuffle, + "metadata": self.metadata, + "teacher": teacher, + "dolls": dolls, + "doll_weights": doll_weights, + } + self.model = model_map[model_size](**model_args) + # checkpoint_path = 'mae_v1.5.0_epoch-76_val-loss-0.1612.ckpt' + # checkpoint = torch.load(checkpoint_path, map_location="cpu") + # # Extract the state dictionary + # state_dict = checkpoint['state_dict'] + + # # Modify the state dictionary + # new_state_dict = OrderedDict() + # for k, v in state_dict.items(): + # # Remove 'model.' prefix if it exists + # if k.startswith('model.'): + # k = k[len('model.'):] + # # Exclude keys related to the 'teacher' + # if not (k.startswith('teacher') or k.startswith('mrl')): + # new_state_dict[k] = v + # with torch.no_grad(): + # # Load the modified state dictionary into your model + # missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False) + # # Optionally, print missing and unexpected keys + # print(f"Missing keys: {missing_keys}") + # print(f"Unexpected keys: {unexpected_keys}") + else: + raise ValueError( + f"Invalid model size {model_size}. Expected one of {model_map.keys()}" + ) + + def on_train_epoch_start(self): + self.model.teacher.eval() + + def forward(self, datacube: dict[str, torch.Tensor]): + return self.model(datacube) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.hparams.lr, + weight_decay=self.hparams.wd, + betas=(self.hparams.b1, self.hparams.b2), + fused=True, + ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=5000, T_mult=1, eta_min=self.hparams.lr * 100, last_epoch=-1 + ) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", + }, + } + + def shared_step(self, batch: dict[str, torch.Tensor], batch_idx: int, phase: str): + platform = batch["platform"][0] + loss, reconstruction_loss, representation_loss = self(batch) + + losses = { + "loss": loss, + "rec_loss": reconstruction_loss, + "rep_loss": representation_loss, + } + + for loss_name, loss_value in losses.items(): + self.log( + name=f"{phase}/{loss_name}", + value=loss_value, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + self.log( + name=f"{phase}_{platform}/{loss_name}", + value=loss_value, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + return loss + + def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int): + return self.shared_step(batch, batch_idx, phase="train") + + def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int): + return self.shared_step(batch, batch_idx, phase="val") diff --git a/src/mrl.py b/src/mrl.py new file mode 100644 index 00000000..12ee22e7 --- /dev/null +++ b/src/mrl.py @@ -0,0 +1,37 @@ +from torch import nn + + +class MRL(nn.Module): + """ + Matryoshka Representation Learning from the paper: https://arxiv.org/abs/2205.13147 + """ + + def __init__(self, features, dolls: list = [16, 32, 64, 128, 256, 768]) -> None: + super().__init__() + self.dolls = dolls + self.layers = nn.ModuleDict() + for doll in dolls: + self.layers[f"mrl_{doll}"] = nn.Linear(doll, features) + + def forward(self, x): + "x: (batch, features)" + logits = [self.layers[f"mrl_{doll}"](x[:, :doll]) for doll in self.dolls] + return logits + + +class MRLLoss(nn.Module): + def __init__(self, weights) -> None: + super().__init__() + self.weights = weights + self.criterion = nn.CosineSimilarity(dim=1, eps=1e-6) + + def forward(self, representations, targets): + """ + representations: [(batch, features), ...] + targets: (batch, features) + """ + losses = [ + self.weights[i] * (1 - self.criterion(rep, targets)).mean() + for i, rep in enumerate(representations) + ] + return sum(losses) / len(losses) diff --git a/src/utils.py b/src/utils.py index 539a2acd..b0f2bcce 100644 --- a/src/utils.py +++ b/src/utils.py @@ -24,6 +24,7 @@ def posemb_sincos_2d_with_gsd( y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" + gsd = gsd.to(x.device) omega = torch.arange(dim // 4) / (dim // 4 - 1) omega = 1.0 / (temperature ** (2 * omega / dim)) * (gsd / 1.0) # Adjusted for g @@ -33,16 +34,16 @@ def posemb_sincos_2d_with_gsd( return pe.type(dtype) -def posemb_sincos_1d(pos, dim, temperature: int = 10000, dtype=torch.float32): +def posemb_sincos_1d(waves, dim, temperature: int = 10000, dtype=torch.float32): assert ( dim % 2 == 0 ), "Feature dimension must be a multiple of 2 for sincos embedding" - pos = torch.arange(pos) if isinstance(pos, int) else pos + waves = torch.arange(waves) if isinstance(waves, int) else waves - omega = torch.arange(dim // 2) / (dim // 2 - 1) + omega = torch.arange(dim // 2, device=waves.device) / (dim // 2 - 1) omega = 1.0 / (temperature**omega) - scaled_pos = pos[:, None] * omega[None, :] - pe = torch.cat((scaled_pos.sin(), scaled_pos.cos()), dim=1) + scaled_waves = waves[:, None] * omega[None, :] + pe = torch.cat((scaled_waves.sin(), scaled_waves.cos()), dim=1) return pe.type(dtype) diff --git a/train_clay_v2.sh b/train_clay_v2.sh new file mode 100644 index 00000000..f680d7bc --- /dev/null +++ b/train_clay_v2.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +#SBATCH --job-name=clay-laucher +#SBATCH --nodes=24 +#SBATCH --ntasks-per-node=8 # EDIT if it's not 8-gpus per node +#SBATCH --cpus-per-task=12 # EDIT this to how many cpu cores the node has divided by num of gpus +#SBATCH --gres=gpu:8 # EDIT this if it's not 8-gpus per node +#SBATCH --time=0-00:00:00 # EDIT the desired runtime +#SBATCH --exclusive +#SBATCH --partition=gpu # EDIT to the desired partition name +#SBATCH --nodelist=gpu-dy-g6-[1-12],gpu-dy-g5-[1-12] +#SBATCH --output=%x-%j-%N.out + +echo "START TIME: $(date)" + +# auto-fail on any errors in this script +set -eo pipefail + +# logging script's variables/commands for future debug needs +set -x + +# EDIT the conda evn and any startup scripts +# source /path/to/start-xxx-user # if you have something to preload before the job +# Load any required modules (environments, libraries etc.) +eval "$(conda 'shell.bash' 'hook' 2> /dev/null)" + +# initialize conda +conda activate /home/ubuntu/claymodel # if you have conda env to activate + +LOG_PATH="main_log.txt" + +# PTL doesn't need a special launcher +LAUNCHER="python -u" + +# Capture the number of nodes allocated by Slurm +NUM_NODES=$SLURM_JOB_NUM_NODES + +# EDIT the path+name of the python script and whatever args it needs +PROGRAM="trainer.py fit --config configs/config.yaml --trainer.num_nodes=$NUM_NODES" + +export CMD="$LAUNCHER $PROGRAM" + +echo $CMD + +# EDIT if you want to redirect /tmp to /scratch (some local SSD path) since /tmp is tiny on compute nodes +# export TMPDIR=/scratch + +# EDIT: useful for debug if needed +# +# to debug NCCL issues +# export NCCL_DEBUG=INFO +# +# to unravel async errors w/o the correct traceback - potentially makes everything very slower +# export CUDA_LAUNCH_BLOCKING=1 +# +# to force crashing on nccl issues like hanging broadcast +# export NCCL_ASYNC_ERROR_HANDLING=1 + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + --jobid $SLURM_JOB_ID \ + " + +# bash -c is needed for the delayed interpolation of env vars to work +srun $SRUN_ARGS bash -c "$CMD" 2>&1 | tee -a $LOG_PATH + +echo "END TIME: $(date)" diff --git a/train_environment.yml b/train_environment.yml new file mode 100644 index 00000000..78923977 --- /dev/null +++ b/train_environment.yml @@ -0,0 +1,23 @@ +name: claymodel +channels: + - conda-forge + - nvidia + - pytorch +dependencies: + - python=3.11 + - pip + - pip: + - --extra-index-url https://download.pytorch.org/whl/cu121 + - torch==2.4.0+cu121 + - torchvision==0.19.0+cu121 + - einops~=0.7.0 + - geopandas + - jsonargparse[signatures]>=4.27.7 + - lightning + - matplotlib + - python-box + - scikit-image + - scikit-learn + - timm + - vit-pytorch + - wandb diff --git a/trainer.py b/trainer.py index 986574e8..509925fc 100644 --- a/trainer.py +++ b/trainer.py @@ -13,7 +13,7 @@ from lightning.pytorch.cli import LightningCLI from src.datamodule import ClayDataModule # noqa: F401 -from src.model import ClayMAEModule # noqa: F401 +from src.module import ClayMAEModule # noqa: F401 # %% @@ -21,7 +21,9 @@ def cli_main(): """ Command-line inteface to run ClayMAE with ClayDataModule. """ - cli = LightningCLI(save_config_kwargs={"overwrite": True}) + cli = LightningCLI( + ClayMAEModule, ClayDataModule, save_config_kwargs={"overwrite": True} + ) return cli diff --git a/utils/check_data_sanity.py b/utils/check_data_sanity.py new file mode 100644 index 00000000..a5276239 --- /dev/null +++ b/utils/check_data_sanity.py @@ -0,0 +1,60 @@ +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np + + +def check_and_delete_npz(file_path): + try: + # Attempt to load the .npz file using numpy + data = np.load(file_path) + + # Check if the 'pixel' key exists and has shape 128 in the 0th dimension + if "pixels" in data: + if data["pixels"].shape[0] != 128: # noqa: PLR2004 + os.remove(file_path) + return ( + None, + f"Invalid shape (not 128 in 0th dim): {file_path} - Deleted", + ) + else: + return f"Valid: {file_path}", None + else: + os.remove(file_path) + return None, f"'pixels' key missing: {file_path} - Deleted" + + except Exception as e: + os.remove(file_path) + return None, f"Invalid (Exception): {file_path} - {str(e)} - Deleted" + + +def process_directory_in_parallel(directory, max_workers=4): + invalid_files = [] + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for root, dirs, files in os.walk(directory): + for file in files: + if file.endswith(".npz"): + file_path = os.path.join(root, file) + futures.append(executor.submit(check_and_delete_npz, file_path)) + + for future in as_completed(futures): + valid_msg, invalid_msg = future.result() + if valid_msg: + print(valid_msg) + if invalid_msg: + print(invalid_msg) + invalid_files.append(invalid_msg) + + return invalid_files + + +# Replace 'your_directory_path' with the path to the directory you want to check +invalid_files = process_directory_in_parallel("/fsx", max_workers=24) + +if invalid_files: + print("\nInvalid or corrupted .npz files found and deleted:") + for file in invalid_files: + print(file) +else: + print("\nAll .npz files are valid and meet the shape criteria for 'pixel' key.") diff --git a/utils/split_npz.py b/utils/split_npz.py new file mode 100644 index 00000000..cb2e87e5 --- /dev/null +++ b/utils/split_npz.py @@ -0,0 +1,53 @@ +import os +from concurrent.futures import ProcessPoolExecutor + +import numpy as np + + +def split_npz_file(file_path): + # Load the .npz file + with np.load(file_path) as data: + # Check if the file has the required batch size of 128 + if "pixels" in data and data["pixels"].shape[0] == 128: # noqa: PLR2004 + # Extract all arrays + keys = data.files + arrays = {key: data[key] for key in keys} + + # Determine the batch size and the number of splits + batch_size = 32 + num_splits = 4 # Since we want to split into 4 files, each with 32 samples + + # Split and save the smaller .npz files + for i in range(num_splits): + split_data = { + key: value[i * batch_size : (i + 1) * batch_size] + for key, value in arrays.items() + } + split_file_path = file_path.replace(".npz", f"_{i}.npz") + np.savez(split_file_path, **split_data) + print(f"Saved {split_file_path}") + + # Delete the original file + os.remove(file_path) + print(f"Deleted original file: {file_path}") + else: + print(f"Skipped {file_path}: Does not have a batch size of 128") + + +def process_directory(root_dir): + # Collect all .npz files + npz_files = [] + for dirpath, _, filenames in os.walk(root_dir): + for filename in filenames: + if filename.endswith(".npz"): + file_path = os.path.join(dirpath, filename) + npz_files.append(file_path) + + # Process files in parallel + with ProcessPoolExecutor() as executor: + executor.map(split_npz_file, npz_files) + + +# Example usage +root_dir = "/fsx" +process_directory(root_dir)