diff --git a/data/sentinel2_vessel_attribute/config.json b/data/sentinel2_vessel_attribute/config.json new file mode 100644 index 00000000..04835c1c --- /dev/null +++ b/data/sentinel2_vessel_attribute/config.json @@ -0,0 +1,93 @@ +{ + "layers": { + "info": { + "format": { + "name": "geojson" + }, + "type": "vector" + }, + "output": { + "type": "vector" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + } + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + }, + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2/", + "max_time_delta": "0d", + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "query_config": { + "max_matches": 1 + }, + "use_rtree_index": false + }, + "type": "raster" + } + }, + "tile_store": { + "class_path": "rslearn.tile_stores.default.DefaultTileStore", + "init_args": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "path_suffix": "gs://rslearn-eai/datasets/sentinel2_vessel_attribute/dataset_v1/20241212/tiles" + } + } +} diff --git a/data/sentinel2_vessel_attribute/config.yaml b/data/sentinel2_vessel_attribute/config.yaml new file mode 100644 index 00000000..81ad756f --- /dev/null +++ b/data/sentinel2_vessel_attribute/config.yaml @@ -0,0 +1,187 @@ +model: + class_path: rslp.sentinel2_vessel_attribute.train.VesselAttributeLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.satlaspretrain.SatlasPretrain + init_args: + model_identifier: "Sentinel2_SwinB_SI_MS" + fpn: true + decoders: + length: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 1 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + width: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 1 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + speed: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 1 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + heading_x: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 1 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + heading_y: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 1 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + ship_type: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 9 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + lr: 0.0001 + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/sentinel2_vessel_attribute/dataset_v1/20250205/ + inputs: + image: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + info: + data_type: "vector" + layers: ["info"] + is_target: true + task: + class_path: rslp.sentinel2_vessel_attribute.train.VesselAttributeMultiTask + init_args: + length_buckets: [10, 20, 30, 50, 75, 100, 150, 200] + width_buckets: [5, 10, 20] + speed_buckets: [2, 4, 8] + tasks: + length: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + property_name: "length" + allow_invalid: true + scale_factor: 0.01 + metric_mode: l1 + width: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + property_name: "width" + allow_invalid: true + scale_factor: 0.01 + metric_mode: l1 + speed: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + property_name: "sog" + allow_invalid: true + scale_factor: 0.01 + metric_mode: l1 + heading_x: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + property_name: "cog_x" + allow_invalid: true + metric_mode: l1 + heading_y: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + property_name: "cog_y" + allow_invalid: true + metric_mode: l1 + ship_type: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "type" + allow_invalid: true + classes: ["cargo", "tanker", "passenger", "service", "tug", "pleasure", "fishing", "enforcement", "sar"] + metric_kwargs: + average: "micro" + input_mapping: + length: + info: "targets" + width: + info: "targets" + speed: + info: "targets" + heading_x: + info: "targets" + heading_y: + info: "targets" + ship_type: + info: "targets" + batch_size: 32 + num_workers: 64 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + train_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + - class_path: rslp.sentinel2_vessel_attribute.train.VesselAttributeFlip + tags: + split: "train" + val_config: + tags: + split: "val" + test_config: + tags: + split: "val" + predict_config: + groups: ["attribute_predict"] + skip_targets: true +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: placeholder + output_layer: output +rslp_project: sentinel2_vessel_attribute +rslp_experiment: data_20250205_regress_00 diff --git a/data/sentinel2_vessel_attribute/config_bucket.yaml b/data/sentinel2_vessel_attribute/config_bucket.yaml new file mode 100644 index 00000000..1dea9d43 --- /dev/null +++ b/data/sentinel2_vessel_attribute/config_bucket.yaml @@ -0,0 +1,186 @@ +model: + class_path: rslp.sentinel2_vessel_attribute.train.VesselAttributeLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.satlaspretrain.SatlasPretrain + init_args: + model_identifier: "Sentinel2_SwinB_SI_MS" + fpn: true + decoders: + length: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 9 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + width: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 4 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + speed: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 4 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + heading_x: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 1 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + heading_y: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 1 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + ship_type: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 9 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + lr: 0.0001 + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/sentinel2_vessel_attribute/dataset_v1/20250205/ + inputs: + image: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + info: + data_type: "vector" + layers: ["info"] + is_target: true + task: + class_path: rslp.sentinel2_vessel_attribute.train.VesselAttributeMultiTask + init_args: + length_buckets: [10, 20, 30, 50, 75, 100, 150, 200] + width_buckets: [5, 10, 20] + speed_buckets: [2, 4, 8] + tasks: + length: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "length_bucket" + read_class_id: true + allow_invalid: true + classes: ["0-10", "10-20", "20-30", "30-50", "50-75", "75-100", "100-150", "150-200", "200+"] + metric_kwargs: + average: "micro" + width: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "width_bucket" + read_class_id: true + allow_invalid: true + classes: ["0-5", "5-10", "10-20", "20+"] + metric_kwargs: + average: "micro" + speed: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "sog_bucket" + read_class_id: true + allow_invalid: true + classes: ["0-2", "2-4", "4-8", "8+"] + metric_kwargs: + average: "micro" + heading_x: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + property_name: "cog_x" + allow_invalid: true + metric_mode: l1 + heading_y: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + property_name: "cog_y" + allow_invalid: true + metric_mode: l1 + ship_type: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "type" + allow_invalid: true + classes: ["cargo", "tanker", "passenger", "service", "tug", "pleasure", "fishing", "enforcement", "sar"] + metric_kwargs: + average: "micro" + input_mapping: + length: + info: "targets" + width: + info: "targets" + speed: + info: "targets" + heading_x: + info: "targets" + heading_y: + info: "targets" + ship_type: + info: "targets" + batch_size: 32 + num_workers: 64 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + train_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + - class_path: rslp.sentinel2_vessel_attribute.train.VesselAttributeFlip + tags: + split: "train" + val_config: + tags: + split: "val" + test_config: + tags: + split: "val" +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min +rslp_project: sentinel2_vessel_attribute +rslp_experiment: data_20250205_bucket_00 diff --git a/data/sentinel2_vessel_attribute/config_heading2.yaml b/data/sentinel2_vessel_attribute/config_heading2.yaml new file mode 100644 index 00000000..9063ce86 --- /dev/null +++ b/data/sentinel2_vessel_attribute/config_heading2.yaml @@ -0,0 +1,210 @@ +# Use heading2 task which treats forward/backward orientation, and predicts the vessel +# direction separately. +model: + class_path: rslp.sentinel2_vessel_attribute.train.VesselAttributeLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.satlaspretrain.SatlasPretrain + init_args: + model_identifier: "Sentinel2_SwinB_SI_MS" + fpn: true + decoders: + length: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 1 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + width: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 1 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + speed: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 1 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + heading2_x: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 1 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + heading2_y: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 1 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.regression.RegressionHead + heading2_direction: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 2 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + ship_type: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 128 + out_channels: 9 + num_conv_layers: 1 + num_fc_layers: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + lr: 0.0001 + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/sentinel2_vessel_attribute/dataset_v1/20250205/ + inputs: + image: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + info: + data_type: "vector" + layers: ["info"] + is_target: true + task: + class_path: rslp.sentinel2_vessel_attribute.train.VesselAttributeMultiTask + init_args: + length_buckets: [10, 20, 30, 50, 75, 100, 150, 200] + width_buckets: [5, 10, 20] + speed_buckets: [2, 4, 8] + heading_mode: "XYD" + tasks: + length: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + property_name: "length" + allow_invalid: true + scale_factor: 0.01 + metric_mode: l1 + width: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + property_name: "width" + allow_invalid: true + scale_factor: 0.01 + metric_mode: l1 + speed: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + property_name: "sog" + allow_invalid: true + scale_factor: 0.01 + metric_mode: l1 + heading2_x: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + property_name: "cog2_x" + allow_invalid: true + metric_mode: l1 + heading2_y: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + property_name: "cog2_y" + allow_invalid: true + metric_mode: l1 + heading2_direction: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "cog2_direction" + allow_invalid: true + classes: ["<180", ">180"] + read_class_id: true + metric_kwargs: + average: "micro" + ship_type: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "type" + allow_invalid: true + classes: ["cargo", "tanker", "passenger", "service", "tug", "pleasure", "fishing", "enforcement", "sar"] + metric_kwargs: + average: "micro" + input_mapping: + length: + info: "targets" + width: + info: "targets" + speed: + info: "targets" + heading2_x: + info: "targets" + heading2_y: + info: "targets" + heading2_direction: + info: "targets" + ship_type: + info: "targets" + batch_size: 32 + num_workers: 64 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + train_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] +# disable flip for now since it doesn't support xyd heading mode +# - class_path: rslp.sentinel2_vessel_attribute.train.VesselAttributeFlip + tags: + split: "train" + val_config: + tags: + split: "val" + test_config: + tags: + split: "val" + predict_config: + groups: ["attribute_predict"] + skip_targets: true +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: placeholder + output_layer: output +rslp_project: sentinel2_vessel_attribute +rslp_experiment: data_20250205_heading2_00 diff --git a/data/sentinel2_vessels/config.json b/data/sentinel2_vessels/config.json index 80eca270..3deabcf6 100644 --- a/data/sentinel2_vessels/config.json +++ b/data/sentinel2_vessels/config.json @@ -25,19 +25,62 @@ "band_sets": [ { "bands": [ - "R", - "G", - "B" + "B02", + "B03", + "B04", + "B08" ], - "dtype": "uint8", + "dtype": "uint16", "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, "name": "geotiff" } + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + }, + "zoom_offset": -2 } ], "data_source": { "harmonize": true, - "index_cache_dir": "cache/sentinel2", + "index_cache_dir": "file:///dfive-default/rslearn-eai/datasets/sentinel2_vessels/data_source_cache/sentinel2", "max_time_delta": "1d", "modality": "L1C", "name": "rslearn.data_sources.gcp_public_data.Sentinel2", @@ -47,7 +90,14 @@ } }, "tile_store": { - "name": "file", - "root_dir": "tiles" + "class_path": "rslearn.tile_stores.default.DefaultTileStore", + "init_args": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "path_suffix": "file:///dfive-default/rslearn-eai/datasets/sentinel2_vessels/tile_store/" + } } } diff --git a/data/sentinel2_vessels/config.yaml b/data/sentinel2_vessels/config.yaml index 6ceb028c..62ff781f 100644 --- a/data/sentinel2_vessels/config.yaml +++ b/data/sentinel2_vessels/config.yaml @@ -8,7 +8,7 @@ model: - class_path: rslearn.models.swin.Swin init_args: pretrained: true - input_channels: 3 + input_channels: 9 output_layers: [1, 3, 5, 7] - class_path: rslearn.models.fpn.Fpn init_args: @@ -29,9 +29,9 @@ model: plateau_min_lr: 1e-6 plateau_cooldown: 10 restore_config: - restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres.pth + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth remap_prefixes: - - ["backbone.backbone.", "encoder.0.model."] + - ["backbone.backbone.backbone.", "encoder.0.model."] data: class_path: rslearn.train.data_module.RslearnDataModule init_args: @@ -40,9 +40,9 @@ data: image: data_type: "raster" layers: ["sentinel2"] - bands: ["R", "G", "B"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true - dtype: INT32 + dtype: FLOAT32 mask: data_type: "raster" layers: ["mask"] @@ -66,7 +66,7 @@ data: box_size: 15 remap_values: [[0, 1], [0, 255]] exclude_by_center: true - score_threshold: 0.7 + score_threshold: 0.8 enable_map_metric: true enable_f1_metric: true f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] @@ -84,7 +84,15 @@ data: - class_path: rslearn.train.transforms.normalize.Normalize init_args: mean: 0 - std: 255 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] - class_path: rslp.transforms.mask.Mask train_config: patch_size: 512 @@ -92,7 +100,15 @@ data: - class_path: rslearn.train.transforms.normalize.Normalize init_args: mean: 0 - std: 255 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] - class_path: rslp.transforms.mask.Mask - class_path: rslearn.train.transforms.flip.Flip init_args: @@ -110,7 +126,15 @@ data: - class_path: rslearn.train.transforms.normalize.Normalize init_args: mean: 0 - std: 255 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] groups: ["detector_predict"] load_all_patches: true skip_targets: true @@ -145,4 +169,4 @@ trainer: module_selector: ["model", "encoder", 0, "model"] unfreeze_at_epoch: 4 rslp_project: sentinel2_vessels -rslp_experiment: data_20240213_01_add_freezing_and_fix_fpn_restore +rslp_experiment: data_20250213_02_all_bands diff --git a/data/sentinel2_vessels/config_local_files.json b/data/sentinel2_vessels/config_local_files.json deleted file mode 100644 index ce93a2e1..00000000 --- a/data/sentinel2_vessels/config_local_files.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "layers": { - "label": { - "type": "vector" - }, - "mask": { - "band_sets": [ - { - "bands": [ - "mask" - ], - "dtype": "uint8", - "format": { - "format": "png", - "name": "single_image" - } - } - ], - "type": "raster" - }, - "output": { - "type": "vector" - }, - "sentinel2": { - "band_sets": [ - { - "bands": [ - "R", - "G", - "B" - ], - "dtype": "uint8", - "format": { - "name": "geotiff" - } - } - ], - "data_source": { - "item_specs": "PLACEHOLDER", - "name": "rslearn.data_sources.local_files.LocalFiles", - "src_dir": "PLACEHOLDER" - }, - "type": "raster" - } - }, - "tile_store": { - "name": "file", - "root_dir": "tiles" - } -} diff --git a/data/sentinel2_vessels/config_predict_gcp.json b/data/sentinel2_vessels/config_predict_gcp.json new file mode 100644 index 00000000..4ccd5c2b --- /dev/null +++ b/data/sentinel2_vessels/config_predict_gcp.json @@ -0,0 +1,79 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "mask": { + "band_sets": [ + { + "bands": [ + "mask" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "output": { + "type": "vector" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + } + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_time_delta": "1d", + "modality": "L1C", + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "use_rtree_index": false + }, + "type": "raster" + } + }, + "tile_store": { + "class_path": "rslearn.tile_stores.default.DefaultTileStore" + } +} diff --git a/data/sentinel2_vessels/config_predict_local_files.json b/data/sentinel2_vessels/config_predict_local_files.json new file mode 100644 index 00000000..88d74c14 --- /dev/null +++ b/data/sentinel2_vessels/config_predict_local_files.json @@ -0,0 +1,102 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "mask": { + "band_sets": [ + { + "bands": [ + "mask" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "output": { + "type": "vector" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16", + "remap": { + "dst": [ + 0, + 10000 + ], + "name": "linear", + "src": [ + 1000, + 11000 + ] + } + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "remap": { + "dst": [ + 0, + 10000 + ], + "name": "linear", + "src": [ + 1000, + 11000 + ] + }, + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "remap": { + "dst": [ + 0, + 10000 + ], + "name": "linear", + "src": [ + 1000, + 11000 + ] + }, + "zoom_offset": -2 + } + ], + "data_source": { + "item_specs": "PLACEHOLDER", + "name": "rslearn.data_sources.local_files.LocalFiles", + "src_dir": "PLACEHOLDER" + }, + "type": "raster" + } + }, + "tile_store": { + "name": "file", + "root_dir": "tiles" + } +} diff --git a/docs/sentinel2_vessels.md b/docs/sentinel2_vessels.md index 163b5307..45792703 100644 --- a/docs/sentinel2_vessels.md +++ b/docs/sentinel2_vessels.md @@ -20,7 +20,7 @@ First, download the model checkpoint to the `RSLP_PREFIX` directory. cd rslearn_projects mkdir -p project_data/projects/sentinel2_vessels/data_20240927_satlaspretrain_patch512_00/checkpoints/ - wget https://storage.googleapis.com/ai2-rslearn-projects-data/projects/sentinel2_vessels/data_20240213_01_add_freezing_and_fix_fpn_restore/checkpoints/best.ckpt -O project_data/projects/sentinel2_vessels/data_20240213_01_add_freezing_and_fix_fpn_restore/checkpoints/best.ckpt + wget https://storage.googleapis.com/ai2-rslearn-projects-data/projects/sentinel2_vessels/data_20250213_02_all_bands/checkpoints/best.ckpt -O project_data/projects/data_20250213_02_all_bands/data_20240213_01_add_freezing_and_fix_fpn_restore/checkpoints/best.ckpt The easiest way to apply the model is using the prediction pipeline in `rslp/sentinel2_vessels/predict_pipeline.py`. It accepts a Sentinel-2 scene ID and @@ -73,6 +73,9 @@ Model Version History The version names correspond to the `rslp_experiment` field in the model configuration file (`data/sentinel2_vessels/config.yaml`). +- `data_20250213_02_all_bands`: Train on all bands instead of just RGB. Note that it + uses B01-B12 instead of TCI so it needs "harmonization" (subtracting 1000 from new + Sentinel-2 products). - `data_20240213_01_add_freezing_and_fix_fpn_restore`: Freeze the pre-trained model for the first few epochs before unfreezing. - `data_20240213_00`: Some of the windows contained blank images. I re-ingested the @@ -83,6 +86,13 @@ file (`data/sentinel2_vessels/config.yaml`). Model Performance ----------------- +### data_20250213_02_all_bands + +- Selected threshold: 0.8 +- Results on validation set (split1, split7, sargassum_val) + - Precision: 77.2% + - Recall: 78.6% + ### data_20240213_01_add_freezing_and_fix_fpn_restore - Selected threshold: 0.8 @@ -103,8 +113,8 @@ The Docker container does not contain the model weights. Instead, it expects the weights to be present in a directory based on the `RSLP_PREFIX` environment variable. So download the model checkpoint: - mkdir -p project_data/projects/sentinel2_vessels/data_20240213_01_add_freezing_and_fix_fpn_restore/checkpoints/ - wget https://storage.googleapis.com/ai2-rslearn-projects-data/projects/sentinel2_vessels/data_20240213_01_add_freezing_and_fix_fpn_restore/checkpoints/best.ckpt -O project_data/projects/sentinel2_vessels/data_20240213_01_add_freezing_and_fix_fpn_restore/checkpoints/best.ckpt + mkdir -p project_data/projects/sentinel2_vessels/data_20250213_02_all_bands/checkpoints/ + wget https://storage.googleapis.com/ai2-rslearn-projects-data/projects/sentinel2_vessels/data_20250213_02_all_bands/checkpoints/best.ckpt -O project_data/projects/sentinel2_vessels/data_20250213_02_all_bands/checkpoints/best.ckpt Run the container: @@ -138,9 +148,26 @@ Alternatively, process the scene by providing the paths to the image assets. The can be URIs but must be accessible from the Docker container. ```bash -curl -X POST http://localhost:${SENTINEL2_PORT}/detections -H "Content-Type: application/json" -d '{"image_files": [{"bands": ["R", "G", "B"], "fname": "gs://gcp-public-data-sentinel-2/tiles/30/U/YD/S2A_MSIL1C_20180904T110621_N0206_R137_T30UYD_20180904T133425.SAFE/GRANULE/L1C_T30UYD_A016722_20180904T110820/IMG_DATA/T30UYD_20180904T110621_TCI.jp2"}, {"bands": ["B08"], "fname": "gs://gcp-public-data-sentinel-2/tiles/30/U/YD/S2A_MSIL1C_20180904T110621_N0206_R137_T30UYD_20180904T133425.SAFE/GRANULE/L1C_T30UYD_A016722_20180904T110820/IMG_DATA/T30UYD_20180904T110621_B08.jp2"}]}' +curl -X POST http://localhost:${SENTINEL2_PORT}/detections -H "Content-Type: application/json" -d '{"image_files": [{"bands": ["B08"], "fname": "gs://gcp-public-data-sentinel-2/tiles/30/U/YD/S2A_MSIL1C_20180904T110621_N0206_R137_T30UYD_20180904T133425.SAFE/GRANULE/L1C_T30UYD_A016722_20180904T110820/IMG_DATA/T30UYD_20180904T110621_B08.jp2"}, ...]}' ``` +These bands must be provided. In this case the scene must be processed with processing baseline 04.00 or later (i.e. has N0400 or higher in the scene ID) since it is assumed to be the newer type where the same intensity has 1000 higher pixel value (we subtract 1000 in `data/sentinel2_vessels/config_predict_local_files.json`, see [GEE Harmonized Sentinel-2](https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2_SR_HARMONIZED) for details). + +- B01 +- B02 +- B03 +- B04 +- B05 +- B06 +- B07 +- B08 +- B09 +- B10 +- B11 +- B12 +- B8A + ### Docker Container Version History +- v0.0.2: add attribute prediction and use model `data_20250213_02_all_bands`. - v0.0.1: initial version. It uses model `data_20240213_01_add_freezing_and_fix_fpn_restore`. diff --git a/rslp/landsat_vessels/predict_pipeline.py b/rslp/landsat_vessels/predict_pipeline.py index 5369a2b5..36943281 100644 --- a/rslp/landsat_vessels/predict_pipeline.py +++ b/rslp/landsat_vessels/predict_pipeline.py @@ -137,7 +137,9 @@ def get_vessel_detections( # Read the detections. layer_dir = window.get_layer_dir(OUTPUT_LAYER_NAME) - features = GeojsonVectorFormat().decode_vector(layer_dir, window.bounds) + features = GeojsonVectorFormat().decode_vector( + layer_dir, window.projection, window.bounds + ) detections: list[VesselDetection] = [] for feature in features: geometry = feature.geometry @@ -234,7 +236,9 @@ def run_classifier( good_detections = [] for detection, window in zip(detections, windows): layer_dir = window.get_layer_dir(OUTPUT_LAYER_NAME) - features = GeojsonVectorFormat().decode_vector(layer_dir, window.bounds) + features = GeojsonVectorFormat().decode_vector( + layer_dir, window.projection, window.bounds + ) category = features[0].properties["label"] if category == "correct": good_detections.append(detection) diff --git a/rslp/sentinel2_vessel_attribute/README.md b/rslp/sentinel2_vessel_attribute/README.md new file mode 100644 index 00000000..ba5b72f7 --- /dev/null +++ b/rslp/sentinel2_vessel_attribute/README.md @@ -0,0 +1,17 @@ +Dataset +------- + +The dataset is created from CSV files of AIS-correlated vessel detections provided by +Skylight. + +To populate the dataset: + +``` +mkdir /path/to/dataset/ +cp data/sentinel2_vessel_attribute/config.json /path/to/dataset/config.json +python -m rslp.main sentinel2_vessel_attribute create_windows detections_bigtable gs://rslearn-eai/datasets/sentinel2_vessel_attribute/artifacts/sentinel2_correlated_detections_bigtable/ gs://rslearn-eai/datasets/sentinel2_vessel_attribute/dataset_v1/20250205/ --workers 64 +python -m rslp.main sentinel2_vessel_attribute create_windows detections_jan_470k gs://rslearn-eai/datasets/sentinel2_vessel_attribute/artifacts/sentinel2_correlated_detections_jan_470k/ gs://rslearn-eai/datasets/sentinel2_vessel_attribute/dataset_v1/20250205/ --workers 64 +``` + +This puts the first batch of CSVs in one group ("detections_bigtable") and the second +batch in another group. diff --git a/rslp/sentinel2_vessel_attribute/__init__.py b/rslp/sentinel2_vessel_attribute/__init__.py new file mode 100644 index 00000000..6dcf7a3c --- /dev/null +++ b/rslp/sentinel2_vessel_attribute/__init__.py @@ -0,0 +1,7 @@ +"""Sentinel-2 vessel attribute prediction model.""" + +from .create_windows import create_windows + +workflows = { + "create_windows": create_windows, +} diff --git a/rslp/sentinel2_vessel_attribute/create_windows.py b/rslp/sentinel2_vessel_attribute/create_windows.py new file mode 100644 index 00000000..12ace97d --- /dev/null +++ b/rslp/sentinel2_vessel_attribute/create_windows.py @@ -0,0 +1,175 @@ +"""Populate rslearn dataset with windows from the source CSVs.""" + +import csv +import hashlib +import json +import multiprocessing +import shutil +from datetime import datetime, timedelta +from typing import Any + +import shapely +import tqdm +from rslearn.const import WGS84_PROJECTION +from rslearn.dataset import Window +from rslearn.utils.feature import Feature +from rslearn.utils.geometry import STGeometry +from rslearn.utils.get_utm_ups_crs import get_utm_ups_projection +from rslearn.utils.mp import star_imap_unordered +from rslearn.utils.vector_format import GeojsonVectorFormat +from upath import UPath + +from .ship_types import VESSEL_CATEGORIES + +PIXEL_SIZE = 10 +WINDOW_SIZE = 128 +DATASET_CONFIG_FNAME = "data/sentinel2_vessel_attribute/config.json" + + +def process_row(group: str, ds_upath: UPath, csv_row: dict[str, str]) -> None: + """Create a window from one row in the vessel CSV. + + Args: + group: the rslearn group to add the window to. + ds_upath: the path of the output rslearn dataset. + csv_row: the row from vessel CSV. + """ + + def get_optional_float(k: str) -> float | None: + if csv_row[k]: + return float(csv_row[k]) + else: + return None + + event_id = csv_row["event_id"] + ts = datetime.fromisoformat(csv_row["event_time"]) + lat = float(csv_row["lat"]) + lon = float(csv_row["lon"]) + if csv_row["vessel_category"]: + ship_type = csv_row["vessel_category"] + else: + ship_type = "unknown" + vessel_length = get_optional_float("vessel_length") + vessel_width = get_optional_float("vessel_width") + vessel_cog = get_optional_float("ais_course") + vessel_cog_avg = get_optional_float("course") + vessel_sog = get_optional_float("ais_speed") + vessel_sog_variance = get_optional_float("ais_speed_variance") + if "time_to_closest_position" in csv_row: + time_to_closest_position = get_optional_float("time_to_closest_position") + else: + time_to_closest_position = None + + src_point = shapely.Point(lon, lat) + src_geometry = STGeometry(WGS84_PROJECTION, src_point, None) + dst_projection = get_utm_ups_projection(lon, lat, PIXEL_SIZE, -PIXEL_SIZE) + dst_geometry = src_geometry.to_projection(dst_projection) + + bounds = ( + int(dst_geometry.shp.x) - WINDOW_SIZE // 2, + int(dst_geometry.shp.y) - WINDOW_SIZE // 2, + int(dst_geometry.shp.x) + WINDOW_SIZE // 2, + int(dst_geometry.shp.y) + WINDOW_SIZE // 2, + ) + time_range = (ts - timedelta(hours=1), ts + timedelta(hours=1)) + + # Check if train or val. + is_val = hashlib.sha256(event_id.encode()).hexdigest()[0] in ["0", "1"] + split = "val" if is_val else "train" + + window_name = event_id + window_root = Window.get_window_root(ds_upath, group, window_name) + window = Window( + path=window_root, + group=group, + name=window_name, + projection=dst_projection, + bounds=bounds, + time_range=time_range, + options=dict( + split=split, + ), + ) + window.save() + + # Save metadata. + with (window_root / "info.json").open("w") as f: + json.dump( + { + "event_id": event_id, + "length": vessel_length, + "width": vessel_width, + "cog": vessel_cog, + "cog_avg": vessel_cog_avg, + "sog": vessel_sog, + "type": ship_type, + "sog_variance": vessel_sog_variance, + "time_to_closest_position": time_to_closest_position, + }, + f, + ) + + info_dir = window.get_layer_dir("info") + info_dir.mkdir(parents=True, exist_ok=True) + properties: dict[str, Any] = { + "event_id": event_id, + } + if vessel_length and vessel_length >= 5 and vessel_length < 460: + properties["length"] = vessel_length + if vessel_width and vessel_width >= 2 and vessel_width < 120: + properties["width"] = vessel_width + if ( + vessel_cog + and vessel_sog + and vessel_sog > 5 + and vessel_sog < 50 + and vessel_cog >= 0 + and vessel_cog < 360 + ): + properties["cog"] = vessel_cog + if vessel_sog and vessel_sog > 0 and vessel_sog < 60: + properties["sog"] = vessel_sog + if ship_type and ship_type in VESSEL_CATEGORIES: + properties["type"] = VESSEL_CATEGORIES[ship_type] + feat = Feature(dst_geometry, properties) + GeojsonVectorFormat().encode_vector(info_dir, [feat]) + window.mark_layer_completed("info") + + +def create_windows(group: str, csv_dir: str, ds_path: str, workers: int = 32) -> None: + """Initialize an rslearn dataset at the specified path. + + Args: + group: which group to use for these windows. + csv_dir: path containing CSVs with AIS-correlated vessel detections, e.g. + gs://rslearn-eai/datasets/sentinel2_vessel_attribute/artifacts/sentinel2_correlated_detections_bigtable/. + ds_path: path to write the dataset, e.g. + gs://rslearn-eai/datasets/sentinel2_vessel_attribute/dataset_v1/20241212/ + workers: number of worker processes to use + """ + csv_upath = UPath(csv_dir) + ds_upath = UPath(ds_path) + + # Copy dataset configuration first. + with open(DATASET_CONFIG_FNAME, "rb") as src: + with (ds_upath / "config.json").open("wb") as dst: + shutil.copyfileobj(src, dst) + + jobs = [] + for fname in csv_upath.iterdir(): + with fname.open() as f: + reader = csv.DictReader(f) + for csv_row in reader: + jobs.append( + dict( + group=group, + ds_upath=ds_upath, + csv_row=csv_row, + ) + ) + + p = multiprocessing.Pool(workers) + outputs = star_imap_unordered(p, process_row, jobs) + for _ in tqdm.tqdm(outputs, total=len(jobs)): + pass + p.close() diff --git a/rslp/sentinel2_vessel_attribute/scripts/check_accuracy_from_length_only.py b/rslp/sentinel2_vessel_attribute/scripts/check_accuracy_from_length_only.py new file mode 100644 index 00000000..e68876d2 --- /dev/null +++ b/rslp/sentinel2_vessel_attribute/scripts/check_accuracy_from_length_only.py @@ -0,0 +1,108 @@ +"""Ship type accuracy using only length attribute. + +It splits up vessels into buckets based on length, and then checks the accuracy when +mapping each bucket to the most common ship type category in the bucket. +""" + +import argparse +import json +import multiprocessing +from typing import Any + +import tqdm +from upath import UPath + +import rslp.utils.mp + + +def get_json(fname: UPath) -> dict[str, Any]: + """Read a JSON file. This is for multiprocessing. + + Args: + fname: the filename to read. + + Returns: + the decoded JSON from the file. + """ + with fname.open() as f: + return json.load(f) + + +if __name__ == "__main__": + rslp.utils.mp.init_mp() + + parser = argparse.ArgumentParser( + description="Determine potential ship type classification accuracy using only length attribute", + ) + parser.add_argument( + "--ds_path", + type=str, + help="Dataset path", + required=True, + ) + parser.add_argument( + "--bucket_size", + type=int, + help="Size of buckets to group up vessel length", + required=True, + ) + args = parser.parse_args() + + ds_path = UPath(args.ds_path) + fnames = list(ds_path.glob("windows/default/*/layers/info/data.geojson")) + + p = multiprocessing.Pool(32) + outputs = p.imap_unordered(get_json, fnames) + vessel_datas = list(tqdm.tqdm(outputs, total=len(fnames))) + p.close() + + # Length bucket -> {vessel type -> count} + # So this gives the count of each vessel type within each bucket. + buckets: dict[int, dict[str, int]] = {} + for vessel_data in vessel_datas: + properties = vessel_data["features"][0]["properties"] + + # Some vessels don't have all the labels we need. + if "length" not in properties: + continue + if "type" not in properties: + continue + + length_bucket_idx = int(properties["length"]) // args.bucket_size + if length_bucket_idx not in buckets: + buckets[length_bucket_idx] = {} + bucket_dict = buckets[length_bucket_idx] + + vessel_type = properties["type"] + bucket_dict[vessel_type] = bucket_dict.get(vessel_type, 0) + 1 + + correct = 0 + incorrect = 0 + for bucket_idx, bucket_dict in buckets.items(): + # Get most common type in this bucket and add the numbers that would be + # correctly and incorrectly classified. + most_common_type = None + for vessel_type, count in bucket_dict.items(): + if most_common_type is not None and count <= bucket_dict[most_common_type]: + continue + most_common_type = vessel_type + assert most_common_type is not None + + cur_correct = bucket_dict[most_common_type] + cur_incorrect = 0 + for vessel_type, count in bucket_dict.items(): + if vessel_type == most_common_type: + continue + cur_incorrect += count + + lo = bucket_idx * args.bucket_size + hi = (bucket_idx + 1) * args.bucket_size + print( + f"bucket {lo} to {hi}: correct={cur_correct}, incorrect={cur_incorrect}, most_common={most_common_type}" + ) + + correct += cur_correct + incorrect += cur_incorrect + + accuracy = correct / (correct + incorrect) + print(f"correct={correct}, incorrect={incorrect}, accuracy={accuracy}") diff --git a/rslp/sentinel2_vessel_attribute/scripts/dataset_statistics.py b/rslp/sentinel2_vessel_attribute/scripts/dataset_statistics.py new file mode 100644 index 00000000..542474d1 --- /dev/null +++ b/rslp/sentinel2_vessel_attribute/scripts/dataset_statistics.py @@ -0,0 +1,73 @@ +"""Get the number of vessels that have values for each attribute in the dataset.""" + +import argparse +import json +import multiprocessing +from typing import Any + +import tqdm +from upath import UPath + +import rslp.utils.mp + + +def get_attributes(info_fname: UPath) -> dict[str, Any]: + """Get the attributes that the GeoJSON for info layer includes. + + Args: + info_fname: the filename for a layers/info/data.geojson file. + + Returns: + the attributes that are present in the JSON. + """ + with info_fname.open() as f: + fc = json.load(f) + + if len(fc["features"]) != 1: + raise ValueError( + f"expected info JSON {info_fname} to contain exactly one GeoJSON Feature" + ) + + properties = fc["features"][0]["properties"] + attributes = { + "length": 0, + "width": 0, + "cog": 0, + "sog": 0, + "type": 0, + } + for attr in attributes.keys(): + if attr in properties: + attributes[attr] = 1 + return attributes + + +if __name__ == "__main__": + rslp.utils.mp.init_mp() + + parser = argparse.ArgumentParser( + description="Get dataset statistics", + ) + parser.add_argument( + "--ds_path", + type=str, + help="Dataset path", + required=True, + ) + args = parser.parse_args() + + ds_path = UPath(args.ds_path) + fnames = list(ds_path.glob("windows/*/*/layers/info/data.geojson")) + + # Identify attributes present in each GeoJSON thread in parallel (via the + # get_attributes function). Then here we go through and add them up. + p = multiprocessing.Pool(32) + outputs = p.imap_unordered(get_attributes, fnames) + total_attributes = {} + for attributes in tqdm.tqdm(outputs, total=len(fnames)): + for attr, count in attributes.items(): + if attr not in total_attributes: + total_attributes[attr] = 0 + total_attributes[attr] += count + + print(total_attributes) diff --git a/rslp/sentinel2_vessel_attribute/ship_types.py b/rslp/sentinel2_vessel_attribute/ship_types.py new file mode 100644 index 00000000..4db0a1af --- /dev/null +++ b/rslp/sentinel2_vessel_attribute/ship_types.py @@ -0,0 +1,118 @@ +"""Constants relating to vessel types and categories.""" + +SHIP_TYPES = { + 0: "Not available (default)", + 1: "Reserved for future use", + 2: "Reserved for future use", + 3: "Reserved for future use", + 4: "Reserved for future use", + 5: "Reserved for future use", + 6: "Reserved for future use", + 7: "Reserved for future use", + 8: "Reserved for future use", + 9: "Reserved for future use", + 10: "Fishing", + 11: "Towing", + 12: "Towing: length exceeds 200m or breadth exceeds 25m", + 13: "Dredging or underwater ops", + 14: "Diving ops", + 15: "Military ops", + 16: "Sailing", + 17: "Pleasure Craft", + 18: "Reserved", + 19: "Reserved", + 20: "Wing in ground (WIG), all ships of this type", + 21: "Wing in ground (WIG), Hazardous category A", + 22: "Wing in ground (WIG), Hazardous category B", + 23: "Wing in ground (WIG), Hazardous category C", + 24: "Wing in ground (WIG), Hazardous category D", + 25: "Wing in ground (WIG), Reserved for future use", + 26: "Wing in ground (WIG), Reserved for future use", + 27: "Wing in ground (WIG), Reserved for future use", + 28: "Wing in ground (WIG), Reserved for future use", + 29: "Wing in ground (WIG), Reserved for future use", + 30: "Fishing", + 31: "Towing", + 32: "Towing: length exceeds 200m or breadth exceeds 25m", + 33: "Dredging or underwater ops", + 34: "Diving ops", + 35: "Military ops", + 36: "Sailing", + 37: "Pleasure Craft", + 38: "Reserved", + 39: "Reserved", + 40: "High speed craft (HSC), all ships of this type", + 41: "High speed craft (HSC), Hazardous category A", + 42: "High speed craft (HSC), Hazardous category B", + 43: "High speed craft (HSC), Hazardous category C", + 44: "High speed craft (HSC), Hazardous category D", + 45: "High speed craft (HSC), Reserved for future use", + 46: "High speed craft (HSC), Reserved for future use", + 47: "High speed craft (HSC), Reserved for future use", + 48: "High speed craft (HSC), Reserved for future use", + 49: "High speed craft (HSC), No additional information", + 50: "Pilot Vessel", + 51: "Search and Rescue vessel", + 52: "Tug", + 53: "Port Tender", + 54: "Anti-pollution equipment", + 55: "Law Enforcement", + 56: "Spare - Local Vessel", + 57: "Spare - Local Vessel", + 58: "Medical Transport", + 59: "Noncombatant ship according to RR Resolution No. 18", + 60: "Passenger, all ships of this type", + 61: "Passenger, Hazardous category A", + 62: "Passenger, Hazardous category B", + 63: "Passenger, Hazardous category C", + 64: "Passenger, Hazardous category D", + 65: "Passenger, Reserved for future use", + 66: "Passenger, Reserved for future use", + 67: "Passenger, Reserved for future use", + 68: "Passenger, Reserved for future use", + 69: "Passenger, No additional information", + 70: "Cargo, all ships of this type", + 71: "Cargo, Hazardous category A", + 72: "Cargo, Hazardous category B", + 73: "Cargo, Hazardous category C", + 74: "Cargo, Hazardous category D", + 75: "Cargo, Reserved for future use", + 76: "Cargo, Reserved for future use", + 77: "Cargo, Reserved for future use", + 78: "Cargo, Reserved for future use", + 79: "Cargo, No additional information", + 80: "Tanker, all ships of this type", + 81: "Tanker, Hazardous category A", + 82: "Tanker, Hazardous category B", + 83: "Tanker, Hazardous category C", + 84: "Tanker, Hazardous category D", + 85: "Tanker, Reserved for future use", + 86: "Tanker, Reserved for future use", + 87: "Tanker, Reserved for future use", + 88: "Tanker, Reserved for future use", + 89: "Tanker, No additional information", + 90: "Other Type, all ships of this type", + 91: "Other Type, Hazardous category A", + 92: "Other Type, Hazardous category B", + 93: "Other Type, Hazardous category C", + 94: "Other Type, Hazardous category D", + 95: "Other Type, Reserved for future use", + 96: "Other Type, Reserved for future use", + 97: "Other Type, Reserved for future use", + 98: "Other Type, Reserved for future use", + 99: "Other Type, no additional information", +} + +VESSEL_CATEGORIES = { + "cargo": "cargo", + "tanker": "tanker", + "passenger": "passenger", + "service": "service", + "tug": "tug", + "pleasure": "pleasure", + "other pleasure": "pleasure", + "fishing": "fishing", + "other fishing": "fishing", + "enforcement": "enforcement", + "sar": "sar", +} diff --git a/rslp/sentinel2_vessel_attribute/train.py b/rslp/sentinel2_vessel_attribute/train.py new file mode 100644 index 00000000..34c0974e --- /dev/null +++ b/rslp/sentinel2_vessel_attribute/train.py @@ -0,0 +1,676 @@ +"""Custom task and augmentation for vessel attribute trianing.""" + +import math +import os +from enum import Enum +from typing import Any + +import numpy as np +import numpy.typing as npt +import rslearn.main +import torch +import wandb +from PIL import Image, ImageDraw +from rslearn.train.lightning_module import RslearnLightningModule +from rslearn.train.tasks.multi_task import MultiTask +from rslearn.train.tasks.regression import RegressionTask +from rslearn.train.tasks.task import BasicTask, Task +from rslearn.utils.feature import Feature +from torchmetrics import Metric, MetricCollection + +SHIP_TYPE_CATEGORIES = [ + "cargo", + "tanker", + "passenger", + "service", + "tug", + "pleasure", + "fishing", + "enforcement", + "sar", +] + + +class HeadingXYMetric(Metric): + """Metric for heading which comes from heading_x and heading_y combination.""" + + def __init__(self, degrees_tolerance: float = 10): + """Create a new HeadingXYMetric. + + Args: + degrees_tolerance: consider prediction correct as long as it is within this + many degrees of the ground truth. + """ + super().__init__() + self.degrees_tolerance = degrees_tolerance + self.correct = 0 + self.total = 0 + + def update( + self, preds: list[dict[str, Any]], targets: list[dict[str, Any]] + ) -> None: + """Update metric. + + Args: + preds: the predictions + targets: the targets + """ + for output, target_dict in zip(preds, targets): + if not target_dict["heading_x"]["valid"]: + continue + + pred_cog = ( + math.atan2(output["heading_y"], output["heading_x"]) * 180 / math.pi + ) + gt_cog = ( + math.atan2( + target_dict["heading_y"]["value"], + target_dict["heading_x"]["value"], + ) + * 180 + / math.pi + ) + + angle_difference = abs(pred_cog - gt_cog) % 360 + if angle_difference > 180: + angle_difference = 360 - angle_difference + + if angle_difference <= self.degrees_tolerance: + self.correct += 1 + self.total += 1 + + def compute(self) -> Any: + """Returns the computed metric.""" + return torch.tensor(self.correct / self.total) + + def reset(self) -> None: + """Reset metric.""" + super().reset() + self.correct = 0 + self.total = 0 + + def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any: + """Returns a plot of the metric.""" + return None + + +class HeadingXYDMetric(Metric): + """A metric for predicting direction separately from the angle. + + JoeR's suggestion: + Try to predict sin(2*theta), cos(2*theta), p(theta > pi). + + VesselAttributeMultiTask will populate cog2_x, cog2_y, and cog2_direction. + User should set up regression task for the first two and classification task for + the direction. It should be called heading2_x, heading2_y, and heading2_direction + respectively. + """ + + def __init__(self, degrees_tolerance: float = 10): + """Create a new HeadingXYDMetric. + + Args: + degrees_tolerance: consider prediction correct as long as it is within this + many degrees of the ground truth. + """ + super().__init__() + self.degrees_tolerance = degrees_tolerance + self.correct = 0 + self.total = 0 + + def _get_original_angle( + self, heading2_x: float, heading2_y: float, heading2_direction: int + ) -> float: + """Get the original angle from the x/y/direction. + + Args: + heading2_x: the predicted or ground truth x component. + heading2_y: the predicted or ground truth y component. + heading2_direction: the predicted or ground truth direction. + + Returns: + angle that corresponds to these values. + """ + # Compute the angle that went into the cos/sin. + angle = math.atan2(heading2_y, heading2_x) * 180 / math.pi + # The original angle was doubled, so we need to halve the angle. + # However the actual angle could be this one or the opposite. + angle = angle / 2 + # Normalize it to be the smaller angle. + angle = angle % 360 + if angle > 180: + angle = angle - 180 + # Now if the direction is 1 then we need to flip it back. + if heading2_direction == 1: + angle = angle + 180 + return angle + + def update( + self, preds: list[dict[str, Any]], targets: list[dict[str, Any]] + ) -> None: + """Update metric. + + Args: + preds: the predictions + targets: the targets + """ + for output, target_dict in zip(preds, targets): + if not target_dict["heading2_x"]["valid"]: + continue + + pred_cog = self._get_original_angle( + output["heading2_x"].item(), + output["heading2_y"].item(), + output["heading2_direction"].argmax().item(), + ) + gt_cog = self._get_original_angle( + target_dict["heading2_x"]["value"].item(), + target_dict["heading2_y"]["value"].item(), + target_dict["heading2_direction"]["class"].item(), + ) + + angle_difference = abs(pred_cog - gt_cog) % 360 + if angle_difference > 180: + angle_difference = 360 - angle_difference + + if angle_difference <= self.degrees_tolerance: + self.correct += 1 + self.total += 1 + + def compute(self) -> Any: + """Returns the computed metric.""" + return torch.tensor(self.correct / self.total) + + def reset(self) -> None: + """Reset metric.""" + super().reset() + self.correct = 0 + self.total = 0 + + def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any: + """Returns a plot of the metric.""" + return None + + +class HeadingMode(str, Enum): + """The method by which we are representing the heading to the model.""" + + XY = "xy" + XYD = "xyd" + + +class VesselAttributeMultiTask(MultiTask): + """Extension of MultiTask with custom input pre-processing and visualization.""" + + def __init__( + self, + tasks: dict[str, Task], + input_mapping: dict[str, dict[str, str]], + length_buckets: list[float] = [], + width_buckets: list[float] = [], + speed_buckets: list[float] = [], + heading_mode: HeadingMode = HeadingMode.XY, + ): + """Create a new VesselAttributeMultiTask. + + Args: + tasks: see MultiTask. + input_mapping: see MultiTask. + length_buckets: which buckets to use for length attribute. + width_buckets: which buckets to use for width attribute. + speed_buckets: which attributes to use for speed attribute. + heading_mode: how heading should be predicted + """ + super().__init__(tasks, input_mapping) + self.heading_mode = heading_mode + self.buckets = dict( + length=length_buckets, + width=width_buckets, + speed=speed_buckets, + ) + + def _get_bucket(self, buckets: list[float], value: float) -> int: + """Get bucket that the value belongs to. + + Args: + buckets: a list of the values that separate buckets. For example, [1, 5] + means there are three buckets, #0 covering 0-1, #1 for 1-5, and #2 for + 5+. + value: the value to bucketize. + + Returns: + the bucket index that the value belongs to. + """ + for bucket_idx, threshold in enumerate(buckets): + if value <= threshold: + return bucket_idx + return len(buckets) + + def _get_bucket_range(self, buckets: list[float], bucket_idx: int) -> str: + """Get string representation of the range of a bucket. + + Args: + buckets: a list of the values that separate buckets. + bucket_idx: the bucket index to get range of. + + Returns: + string representation of the range of the bucket like "5-10" or "10+". + """ + if bucket_idx == len(buckets): + return f"{buckets[bucket_idx-1]}+" + + if bucket_idx == 0: + lo = 0.0 + hi = buckets[bucket_idx] + else: + lo = buckets[bucket_idx - 1] + hi = buckets[bucket_idx] + return f"{lo}-{hi}" + + def process_inputs( + self, + raw_inputs: dict[str, torch.Tensor | list[Feature]], + metadata: dict[str, Any], + load_targets: bool = True, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Processes the data into targets. + + Args: + raw_inputs: raster or vector data to process + metadata: metadata about the patch being read + load_targets: whether to load the targets or only inputs + + Returns: + tuple (input_dict, target_dict) containing the processed inputs and targets + that are compatible with both metrics and loss functions + """ + # Add various derived properties to support different versions of tasks, like + # predicting cog x/y components with regression (instead of directly predicting + # the angle) and classifying the length/width/speed by bucket. + # Then we pass to superclass to handle what each sub-task needs. + if load_targets: + for feat in raw_inputs["info"]: + if "cog" in feat.properties: + angle = 90 - feat.properties["cog"] + if self.heading_mode == HeadingMode.XY: + # Compute x/y components of the angle. + feat.properties["cog_x"] = math.cos(angle * math.pi / 180) + feat.properties["cog_y"] = math.sin(angle * math.pi / 180) + + if self.heading_mode == HeadingMode.XYD: + # For HeadingXYDMetric (see that class for details). + feat.properties["cog2_x"] = math.cos(angle * math.pi / 180 * 2) + feat.properties["cog2_y"] = math.sin(angle * math.pi / 180 * 2) + if angle % 360 > 180: + feat.properties["cog2_direction"] = 1 + else: + feat.properties["cog2_direction"] = 0 + + for task in ["length", "width", "speed"]: + if task == "speed": + prop_name = "sog" + else: + prop_name = task + + if prop_name not in feat.properties: + continue + feat.properties[f"{prop_name}_bucket"] = self._get_bucket( + self.buckets[task], feat.properties[prop_name] + ) + + return super().process_inputs(raw_inputs, metadata, load_targets) + + def process_output(self, raw_output: Any, metadata: dict[str, Any]) -> Feature: + """Processes an output into raster or vector data. + + Args: + raw_output: the output from prediction head. + metadata: metadata about the patch being read + + Returns: + either raster or vector data. + """ + # Merge the Features from the regression and classification tasks into a single + # feature that has all of those properties. + feature = None + for task_name, task in self.tasks.items(): + task_output = task.process_output(raw_output[task_name], metadata) + task_feature = task_output[0] + if not isinstance(task_feature, Feature): + raise ValueError( + f"expected task {task_name} to output a Feature but got {task_feature}" + ) + if feature is None: + feature = task_feature + else: + feature.properties.update(task_feature.properties) + return [feature] + + def visualize( + self, + input_dict: dict[str, Any], + target_dict: dict[str, Any] | None, + output: dict[str, Any], + ) -> dict[str, npt.NDArray[Any]]: + """Visualize the outputs and targets. + + Args: + input_dict: the input dict from process_inputs + target_dict: the target dict from process_inputs + output: the prediction + + Returns: + a dictionary mapping image name to visualization image + """ + # Create combined visualization showing all the attributes. + basic_task = BasicTask(remap_values=[[0.0, 0.3], [0, 255]]) + scale_factor = 0.01 + + image = basic_task.visualize(input_dict, target_dict, output)["image"] + image = image.repeat(axis=0, repeats=8).repeat(axis=1, repeats=8) + image = Image.fromarray(image) + draw = ImageDraw.Draw(image) + + assert target_dict + + # Focus on specific mis-predictions. + # if not target_dict["ship_type"]["valid"]: + # return {} + # if abs(target_dict["speed"]["class"] - output["speed"].argmax()) <= 2: + # return {} + + lines = [] + for task in ["length", "width", "speed"]: + if output[task].shape == (): + # regression + s = f"{task}: {output[task]/scale_factor:.1f}" + if target_dict[task]["valid"]: + s += f" ({target_dict[task]['value']/scale_factor:.1f})" + + else: + # classification + output_bucket_idx = output[task].argmax().item() + output_bucket_range = self._get_bucket_range( + self.buckets[task], output_bucket_idx + ) + s = f"{task}: {output_bucket_range}" + if target_dict[task]["valid"]: + target_bucket_range = self._get_bucket_range( + self.buckets[task], target_dict[task]["class"] + ) + s += f" ({target_bucket_range})" + + lines.append(s) + + for task in ["heading"]: + pred_cog = ( + math.atan2(output[task + "_y"], output[task + "_x"]) * 180 / math.pi + ) + s = f"{task}: {pred_cog:.1f}" + if target_dict[task + "_x"]["valid"]: + gt_cog = ( + math.atan2( + target_dict[task + "_y"]["value"], + target_dict[task + "_x"]["value"], + ) + * 180 + / math.pi + ) + s += f" ({gt_cog:.1f})" + lines.append(s) + + # only visualize heading mis-predictions + # angle_difference = abs(pred_cog - gt_cog) % 360 + # if angle_difference > 180: + # angle_difference = 360 - angle_difference + # if angle_difference < 20: + # return {} + + for task in ["ship_type"]: + pred_category = SHIP_TYPE_CATEGORIES[output[task].argmax()] + s = f"{task}: {pred_category}" + if target_dict[task]["valid"]: + gt_category = SHIP_TYPE_CATEGORIES[target_dict[task]["class"]] + s += f" ({gt_category})" + lines.append(s) + + # only visualize cargo/tanker <-> fishing mis-predictions + # okay1 = pred_category == "fishing" and gt_category in ["cargo", "tanker"] + # okay2 = pred_category in ["cargo", "tanker"] and gt_category == "fishing" + # if not (okay1 or okay2): + # return {} + + text = "\n".join(lines) + box = draw.textbbox(xy=(0, 0), text=text, font_size=12) + draw.rectangle(xy=box, fill=(0, 0, 0)) + draw.text(xy=(0, 0), text=text, font_size=12, fill=(255, 255, 255)) + return { + "image": np.array(image), + } + + def get_metrics(self) -> MetricCollection: + """Get metrics for this task.""" + metrics = super().get_metrics() + if self.heading_mode == HeadingMode.XY: + metrics.add_metrics({"heading_accuracy": HeadingXYMetric()}) + elif self.heading_mode == HeadingMode.XYD: + metrics.add_metrics({"heading_accuracy": HeadingXYDMetric()}) + return metrics + + +class VesselAttributeLightningModule(RslearnLightningModule): + """Extend LM to produce confusion matrices for each attribute.""" + + def on_test_epoch_start(self) -> None: + """Called when at beginning of test epoch. + + Here we initialize the confusion matrices. + """ + self.test_cm_probs: dict[str, list[npt.NDArray[npt.float32]]] = { + "ship_type": [], + "length": [], + "width": [], + "speed": [], + } + self.test_cm_gt: dict[str, list[npt.NDArray[np.int32]]] = { + "ship_type": [], + "length": [], + "width": [], + "speed": [], + } + + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + """Test step extended with confusion matrix.""" + # Code below is copied from RslearnLightningModule.test_step. + inputs, targets, metadatas = batch + batch_size = len(inputs) + outputs, loss_dict = self(inputs, targets) + test_loss = sum(loss_dict.values()) + self.log_dict( + {"test_" + k: v for k, v in loss_dict.items()}, + batch_size=batch_size, + on_step=False, + on_epoch=True, + ) + self.log( + "test_loss", test_loss, batch_size=batch_size, on_step=False, on_epoch=True + ) + self.test_metrics.update(outputs, targets) + self.log_dict(self.test_metrics, batch_size=batch_size, on_epoch=True) + + if self.visualize_dir: + for idx, (inp, target, output, metadata) in enumerate( + zip(inputs, targets, outputs, metadatas) + ): + images = self.task.visualize(inp, target, output) + for image_suffix, image in images.items(): + out_fname = os.path.join( + self.visualize_dir, + f'{metadata["window_name"]}_{metadata["bounds"][0]}_{metadata["bounds"][1]}_{image_suffix}.png', + ) + Image.fromarray(image).save(out_fname) + + # Now we hook in part to compute confusion matrices. + # For length/width/speed, they could be either classification task or + # regression. If it is regression, we need to convert the predicted and gt + # values to a class (bucket). The buckets are specified in the Task. + vessel_attribute_multi_task = self.task + assert isinstance(vessel_attribute_multi_task, VesselAttributeMultiTask) + + for output, target in zip(outputs, targets): + if target["ship_type"]["valid"]: + self.test_cm_probs["ship_type"].append( + output["ship_type"].cpu().numpy() + ) + self.test_cm_gt["ship_type"].append( + target["ship_type"]["class"].cpu().numpy() + ) + + for task_name in ["length", "width", "speed"]: + if not target[task_name]["valid"]: + continue + + if output[task_name].shape == (): + # This means it is using regression (output is a scalar). + # So we need to convert to bucket. + buckets = vessel_attribute_multi_task.buckets[task_name] + sub_task = vessel_attribute_multi_task.tasks[task_name] + assert isinstance(sub_task, RegressionTask) + scale_factor = sub_task.scale_factor + output_bucket = vessel_attribute_multi_task._get_bucket( + buckets, + output[task_name].cpu().numpy() / scale_factor, + ) + # Make fake probabilities for it. + output_probs = np.zeros((len(buckets) + 1,), dtype=np.float32) + output_probs[output_bucket] = 1 + + gt_bucket = vessel_attribute_multi_task._get_bucket( + buckets, + target[task_name]["value"].cpu().numpy() / scale_factor, + ) + self.test_cm_probs[task_name].append(output_probs) + self.test_cm_gt[task_name].append(gt_bucket) + + else: + # It is classification so it is already in buckets. + self.test_cm_probs[task_name].append( + output[task_name].cpu().numpy() + ) + self.test_cm_gt[task_name].append( + target[task_name]["class"].cpu().numpy() + ) + + def on_test_epoch_end(self) -> None: + """Push the confusion matrices to W&B.""" + vessel_attribute_multi_task = self.task + assert isinstance(vessel_attribute_multi_task, VesselAttributeMultiTask) + + for task_name, probs_list in self.test_cm_probs.items(): + if len(probs_list) == 0: + continue + gt_list = self.test_cm_gt[task_name] + + if task_name == "ship_type": + class_names = SHIP_TYPE_CATEGORIES + else: + buckets = vessel_attribute_multi_task.buckets[task_name] + num_buckets = len(buckets) + 1 + class_names = [f"bucket{idx}" for idx in range(num_buckets)] + + self.logger.experiment.log( + { + f"test_{task_name}_cm": wandb.plot.confusion_matrix( + probs=np.stack(probs_list), + y_true=np.stack(gt_list), + class_names=class_names, + ) + } + ) + + +class VesselAttributeFlip(torch.nn.Module): + """Flip inputs horizontally and/or vertically. + + Also extracts x/y component from the heading. + """ + + def __init__( + self, + horizontal: bool = True, + vertical: bool = True, + ): + """Initialize a new VesselAttributeFlip. + + Args: + horizontal: whether to randomly flip horizontally + vertical: whether to randomly flip vertically + """ + super().__init__() + self.horizontal = horizontal + self.vertical = vertical + + def sample_state(self) -> dict[str, bool]: + """Randomly decide how to transform the input. + + Returns: + dict of sampled choices + """ + horizontal = False + if self.horizontal: + horizontal = torch.randint(low=0, high=2, size=()) == 0 + vertical = False + if self.vertical: + vertical = torch.randint(low=0, high=2, size=()) == 0 + return { + "horizontal": horizontal, + "vertical": vertical, + } + + def apply_state( + self, + state: dict[str, bool], + d: dict[str, Any], + image_keys: list[str], + heading_keys: list[str], + ) -> None: + """Apply the flipping. + + Args: + state: the sampled state from sample_state. + d: the input or target dict. + image_keys: image keys to flip. + heading_keys: heading keys to flip. + """ + for k in image_keys: + if state["horizontal"]: + d[k] = torch.flip(d[k], dims=[-1]) + if state["vertical"]: + d[k] = torch.flip(d[k], dims=[-2]) + + for k in heading_keys: + if state["horizontal"]: + d[k + "_x"]["value"] *= -1 + if state["vertical"]: + d[k + "_y"]["value"] *= -1 + + def forward( + self, input_dict: dict[str, Any], target_dict: dict[str, Any] + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Apply transform over the inputs and targets. + + Args: + input_dict: the input + target_dict: the target + + Returns: + transformed (input_dicts, target_dicts) tuple + """ + state = self.sample_state() + self.apply_state(state, input_dict, ["image"], []) + self.apply_state(state, target_dict, [], ["heading"]) + return input_dict, target_dict + + +if __name__ == "__main__": + rslearn.main.main() diff --git a/rslp/sentinel2_vessels/predict_pipeline.py b/rslp/sentinel2_vessels/predict_pipeline.py index 1ce3711e..f23caec6 100644 --- a/rslp/sentinel2_vessels/predict_pipeline.py +++ b/rslp/sentinel2_vessels/predict_pipeline.py @@ -6,6 +6,7 @@ from datetime import datetime from typing import Any +import numpy as np from PIL import Image from rslearn.const import WGS84_PROJECTION from rslearn.data_sources import Item, data_source_from_config @@ -39,9 +40,10 @@ # Name of layer containing the output. OUTPUT_LAYER_NAME = "output" -SCENE_ID_DATASET_CONFIG = "data/sentinel2_vessels/config.json" -IMAGE_FILES_DATASET_CONFIG = "data/sentinel2_vessels/config_local_files.json" +SCENE_ID_DATASET_CONFIG = "data/sentinel2_vessels/config_predict_gcp.json" +IMAGE_FILES_DATASET_CONFIG = "data/sentinel2_vessels/config_predict_local_files.json" DETECT_MODEL_CONFIG = "data/sentinel2_vessels/config.yaml" +ATTRIBUTE_MODEL_CONFIG = "data/sentinel2_vessel_attribute/config.yaml" SENTINEL2_RESOLUTION = 10 CROP_WINDOW_SIZE = 128 @@ -50,8 +52,14 @@ # 0.05 km = 50 m INFRA_DISTANCE_THRESHOLD = 0.05 -# The bands of the TCI image. It must be provided when using ImageFiles mode. -TCI_BANDS = ["R", "G", "B"] +# The bands in the 10 m/pixel band set. +# The first band here is also used to determine projection/bounds in ImageFiles mode. +# B02/B03/B04 are also used to create the RGB image. +HIGH_RES_BAND_SET = ["B02", "B03", "B04", "B08"] +RGB_BAND_INDICES = (2, 1, 0) + +# How much to divide B04/B03/B02 by to get 8-bit image. +RGB_NORM_FACTOR = 10 @dataclass @@ -200,17 +208,19 @@ def setup_dataset_with_image_files( # Get the projection and scene bounds for each task from the TCI image. scene_datas: list[SceneData] = [] for image_files in image_files_list: - # Look for TCI image. It is required. - tci_fname: UPath | None = None + # Look for an image at the highest resolution. It is required. + hr_fname: UPath | None = None for image_file in image_files: - if image_file.bands != TCI_BANDS: + if image_file.bands != [HIGH_RES_BAND_SET[0]]: continue - tci_fname = UPath(image_file.fname) + hr_fname = UPath(image_file.fname) - if tci_fname is None: - raise ValueError("provided list of image files does not have TCI image") + if hr_fname is None: + raise ValueError( + f"provided list of image files does not have band {HIGH_RES_BAND_SET[0]}" + ) - with open_rasterio_upath_reader(tci_fname) as raster: + with open_rasterio_upath_reader(hr_fname) as raster: projection = Projection(raster.crs, raster.transform.a, raster.transform.e) left = int(raster.transform.c / projection.x_resolution) top = int(raster.transform.f / projection.y_resolution) @@ -295,7 +305,9 @@ def get_vessel_detections( detections: list[VesselDetection] = [] for task_idx, (window, scene_data) in enumerate(zip(windows, scene_datas)): layer_dir = window.get_layer_dir(OUTPUT_LAYER_NAME) - features = GeojsonVectorFormat().decode_vector(layer_dir, window.bounds) + features = GeojsonVectorFormat().decode_vector( + layer_dir, window.projection, window.bounds + ) for feature in features: geometry = feature.geometry score = feature.properties["score"] @@ -320,35 +332,37 @@ def get_vessel_detections( return detections -def get_vessel_crop_windows( - ds_path: UPath, detections: list[VesselDetection], scene_datas: list[SceneData] +def run_attribute_model( + ds_path: UPath, + detections: list[VesselDetection], + scene_datas: list[SceneData], ) -> list[Window]: - """Create a window for each vessel to obtain a cropped image for it. + """Run the attribute prediction model. Args: - ds_path: the rslearn dataset path (same one used for object detector -- we will - put the crop windows in a different group). - detections: list of vessel detections. - scene_datas: list of SceneDatas that we are processing. + ds_path: the dataset path that will be populated with new windows to apply the + attribute model. + detections: the detections from the detector. + scene_datas: the list of SceneDatas. Returns: - list of windows corresponding to the detection list, where cropped images have - been materialized. + the new windows. The detections will also be updated with the predicted + attributes. """ - # Create the windows. - group = "crops" - crop_windows: list[UPath] = [] + # Create windows for applying attribute prediction model. + group = "attribute_predict" + windows: list[Window] = [] for detection in detections: window_name = ( f"{detection.metadata['task_idx']}_{detection.col}_{detection.row}" ) window_path = Window.get_window_root(ds_path, group, window_name) - bounds = ( + bounds = [ detection.col - CROP_WINDOW_SIZE // 2, detection.row - CROP_WINDOW_SIZE // 2, detection.col + CROP_WINDOW_SIZE // 2, detection.row + CROP_WINDOW_SIZE // 2, - ) + ] # task_idx metadata is always set in sentinel2_vessels. task_idx = detection.metadata["task_idx"] @@ -362,6 +376,8 @@ def get_vessel_crop_windows( time_range=scene_data.time_range, ) window.save() + windows.append(window) + detection.metadata["crop_window"] = window if scene_data.item: layer_data = WindowLayerData( @@ -369,9 +385,8 @@ def get_vessel_crop_windows( ) window.save_layer_datas({SENTINEL2_LAYER_NAME: layer_data}) - crop_windows.append(window) - - # Materialize the windows. + # Materialize the dataset. + logger.info("materialize dataset") apply_windows_args = ApplyWindowsArgs(group=group, workers=32) materialize_pipeline_args = MaterializePipelineArgs( disabled_layers=[], @@ -387,7 +402,26 @@ def get_vessel_crop_windows( if len(detections) > 0: materialize_dataset(ds_path, materialize_pipeline_args) - return crop_windows + # Verify that no window is unmaterialized. + for window in windows: + if not window.is_layer_completed(SENTINEL2_LAYER_NAME): + raise ValueError(f"window {window.name} does not have materialized Landsat") + + # Run classification model. + run_model_predict(ATTRIBUTE_MODEL_CONFIG, ds_path, groups=[group]) + + # Read the results. + for detection, window in zip(detections, windows): + layer_dir = window.get_layer_dir(OUTPUT_LAYER_NAME) + features = GeojsonVectorFormat().decode_vector( + layer_dir, window.projection, window.bounds + ) + properties = features[0].properties + detection.length = properties["length"] + detection.width = properties["width"] + detection.speed = properties["sog"] + + return windows def predict_pipeline( @@ -437,8 +471,9 @@ def predict_pipeline( # Apply the vessel detection model. detections = get_vessel_detections(ds_path, scene_datas) - # Create and materialize windows that correspond to a crop of each detection. - crop_windows = get_vessel_crop_windows(ds_path, detections, scene_datas) + # Apply the attribute prediction model. + # This also collects vessel crop windows. + crop_windows = run_attribute_model(ds_path, detections, scene_datas) # Write crops and prepare the JSON data. json_vessels_by_task: list[list[dict[str, Any]]] = [[] for _ in tasks] @@ -448,7 +483,6 @@ def predict_pipeline( near_infra_filter = NearInfraFilter( infra_distance_threshold=INFRA_DISTANCE_THRESHOLD ) - raster_format = GeotiffRasterFormat() for detection, crop_window in zip(detections, crop_windows): # Apply near infra filter (True -> filter out, False -> keep) lon, lat = detection.get_lon_lat() @@ -465,15 +499,19 @@ def predict_pipeline( # Get RGB crop. raster_dir = crop_window.get_raster_dir( - SENTINEL2_LAYER_NAME, ["R", "G", "B"] + SENTINEL2_LAYER_NAME, + HIGH_RES_BAND_SET, + ) + image = GeotiffRasterFormat().decode_raster( + raster_dir, crop_window.projection, crop_window.bounds ) - raster_bounds = raster_format.get_raster_bounds(raster_dir) - image = GeotiffRasterFormat().decode_raster(raster_dir, raster_bounds) + rgb_image = image[RGB_BAND_INDICES, :, :] + rgb_image = np.clip(rgb_image // RGB_NORM_FACTOR, 0, 255).astype(np.uint8) # And save it under the specified crop path. detection.crop_fname = crop_upath / f"{detection.col}_{detection.row}.png" with detection.crop_fname.open("wb") as f: - Image.fromarray(image.transpose(1, 2, 0)).save(f, format="PNG") + Image.fromarray(rgb_image.transpose(1, 2, 0)).save(f, format="PNG") json_vessels_by_task[task_idx].append(detection.to_dict()) geojson_vessels_by_task[task_idx].append(detection.to_feature()) diff --git a/rslp/vessels/__init__.py b/rslp/vessels/__init__.py index 3b77010f..33308b01 100644 --- a/rslp/vessels/__init__.py +++ b/rslp/vessels/__init__.py @@ -32,6 +32,9 @@ class VesselDetectionDict(TypedDict): crop_fname: filename where crop image for this vessel is stored. longitude: the longitude position of the vessel detection. latitude: the latitude position of the vessel detection. + length: the predicted length (if attribute model is available). + width: the predicted width (if attribute model is available). + speed: the predicted speed (if attribute model is available). """ source: VesselDetectionSource @@ -44,6 +47,9 @@ class VesselDetectionDict(TypedDict): crop_fname: str | None longitude: float latitude: float + length: float | None + width: float | None + speed: float | None class VesselDetection: @@ -60,6 +66,9 @@ def __init__( scene_id: str | None = None, crop_fname: UPath | None = None, metadata: dict[str, Any] | None = None, + length: float | None = None, + width: float | None = None, + speed: float | None = None, ) -> None: """Create a new VesselDetection. @@ -73,6 +82,9 @@ def __init__( scene_id: the scene ID that the vessel was detected in (if known). crop_fname: filename where crop image for this vessel is stored. metadata: additional metadata that caller wants to store with this detection. + length: the predicted length (if attribute model is available). + width: the predicted width (if attribute model is available). + speed: the predicted speed (if attribute model is available). """ self.source = source self.col = col @@ -82,6 +94,9 @@ def __init__( self.ts = ts self.scene_id = scene_id self.crop_fname = crop_fname + self.length = length + self.width = width + self.speed = speed if metadata is None: self.metadata = {} @@ -112,6 +127,9 @@ def to_dict(self) -> VesselDetectionDict: crop_fname=str(self.crop_fname) if self.crop_fname else None, longitude=lon, latitude=lat, + length=self.length, + width=self.width, + speed=self.speed, ) def to_feature(self) -> dict[str, Any]: