Skip to content

Commit

Permalink
Implement Sentinel-2 data source for AWS bucket.
Browse files Browse the repository at this point in the history
Resolves #1.

Also changes item to have STGeometry instead of shp+time which resolves #5.

Partially deals with #2 by having data sources now provide config loading
function.
  • Loading branch information
favyen2 committed Feb 9, 2024
1 parent b67b60f commit 284bac5
Show file tree
Hide file tree
Showing 12 changed files with 473 additions and 93 deletions.
3 changes: 1 addition & 2 deletions bin/rslearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import argparse
import sys

from rslearn.dataset import (Dataset, ingest_dataset_windows,
prepare_dataset_windows)
from rslearn.dataset import Dataset, ingest_dataset_windows, prepare_dataset_windows

handler_registry = {}

Expand Down
2 changes: 2 additions & 0 deletions extra_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mgrs
rtree
31 changes: 7 additions & 24 deletions rslearn/data_sources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,22 @@
import importlib
from typing import Any

from .data_source import DataSource, Item, QueryConfig, SpaceMode, TimeMode
from .raster_source import DType, RasterFormat, RasterOptions


def load_raster_options(kwargs: dict) -> RasterOptions:
"""Loads a RasterOptions instance by arguments.
Args:
kwargs: the arguments to pass to RasterOptions.
"""
for k, v in kwargs.items():
if k == "format":
kwargs[k] = RasterFormat(v)
if k == "dtype":
kwargs[k] = DType(v)
return RasterOptions(**kwargs)


def load_data_source(name: str, kwargs: dict) -> DataSource:
"""Loads a data source by name and arguments.
def data_source_from_config(config: dict[str, Any]) -> DataSource:
"""Loads a data source from config dict.
Args:
name: the class name of the data source
kwargs: the arguments to pass to the data
"""
module_name = ".".join(name.split(".")[:-1])
class_name = name.split(".")[-1]
module_name = ".".join(config["name"].split(".")[:-1])
class_name = config["name"].split(".")[-1]
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
for k, v in kwargs.items():
if k == "raster_options":
kwargs[k] = load_raster_options(v)
return class_(**kwargs)
return class_.from_config(config)


__all__ = (
Expand All @@ -44,6 +28,5 @@ def load_data_source(name: str, kwargs: dict) -> DataSource:
"DType",
"RasterFormat",
"RasterOptions",
"load_data_source",
"load_raster_options",
"data_source_from_config",
)
Loading

0 comments on commit 284bac5

Please sign in to comment.