diff --git a/rslearn/tile_stores/__init__.py b/rslearn/tile_stores/__init__.py index 670e56d..4cce29e 100644 --- a/rslearn/tile_stores/__init__.py +++ b/rslearn/tile_stores/__init__.py @@ -6,6 +6,7 @@ from upath import UPath from rslearn.config import LayerConfig +from rslearn.utils.jsonargparse import init_jsonargparse from .default import DefaultTileStore from .tile_store import TileStore, TileStoreWithLayer @@ -32,6 +33,7 @@ def load_tile_store(config: dict[str, Any], ds_path: UPath) -> TileStore: tile_store.set_dataset_path(ds_path) return tile_store + init_jsonargparse() parser = jsonargparse.ArgumentParser() parser.add_argument("--tile_store", type=TileStore) cfg = parser.parse_object({"tile_store": config}) diff --git a/rslearn/utils/jsonargparse.py b/rslearn/utils/jsonargparse.py new file mode 100644 index 0000000..e20d40d --- /dev/null +++ b/rslearn/utils/jsonargparse.py @@ -0,0 +1,33 @@ +"""Custom serialization for jsonargparse.""" + +import jsonargparse +from rasterio.crs import CRS + + +def crs_serializer(v: CRS) -> str: + """Serialize CRS for jsonargparse. + + Args: + v: the CRS object. + + Returns: + the CRS encoded to string + """ + return v.to_string() + + +def crs_deserializer(v: str) -> CRS: + """Deserialize CRS for jsonargparse. + + Args: + v: the encoded CRS. + + Returns: + the decoded CRS object + """ + return CRS.from_string(v) + + +def init_jsonargparse() -> None: + """Initialize custom jsonargparse serializers.""" + jsonargparse.typing.register_type(CRS, crs_serializer, crs_deserializer)