Skip to content

Commit e8803a4

Browse files
Merge pull request #3 from developmentseed/ab/implement-multi-base-custom-reader
Ab/implement multi base custom reader
2 parents 3b80678 + af38c92 commit e8803a4

File tree

8 files changed

+301
-320
lines changed

8 files changed

+301
-320
lines changed

tests/test_assets_reader.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Test titiler.stacapi.stac_reader functions."""
2+
3+
import json
4+
import os
5+
from unittest.mock import patch
6+
7+
import pytest
8+
from rio_tiler.io import Reader
9+
from rio_tiler.models import ImageData
10+
11+
from titiler.stacapi.assets_reader import AssetsReader
12+
from titiler.stacapi.models import AssetInfo
13+
14+
from .conftest import mock_rasterio_open
15+
16+
item_file = os.path.join(
17+
os.path.dirname(__file__), "fixtures", "20200307aC0853900w361030.json"
18+
)
19+
item_json = json.loads(open(item_file).read())
20+
21+
22+
def test_get_asset_info():
23+
"""Test get_asset_info function"""
24+
assets_reader = AssetsReader(item_json)
25+
expected_asset_info = AssetInfo(
26+
url=item_json["assets"]["cog"]["href"],
27+
type=item_json["assets"]["cog"]["type"],
28+
env={},
29+
)
30+
assert assets_reader._get_asset_info("cog") == expected_asset_info
31+
32+
33+
def test_get_reader_any():
34+
"""Test reader is rio_tiler.io.Reader"""
35+
asset_info = AssetInfo(url="https://file.tif")
36+
empty_stac_reader = AssetsReader({"bbox": [], "assets": []})
37+
assert empty_stac_reader._get_reader(asset_info) == Reader
38+
39+
40+
@pytest.mark.xfail(reason="To be implemented.")
41+
def test_get_reader_netcdf():
42+
"""Test reader attribute is titiler.stacapi.XarrayReader"""
43+
asset_info = AssetInfo(url="https://file.nc", type="application/netcdf")
44+
empty_stac_reader = AssetsReader({"bbox": [], "assets": []})
45+
empty_stac_reader._get_reader(asset_info)
46+
47+
48+
@pytest.mark.skip(reason="Too slow.")
49+
@patch("rio_tiler.io.rasterio.rasterio")
50+
def test_tile_cog(rio):
51+
"""Test tile function with COG asset."""
52+
rio.open = mock_rasterio_open
53+
54+
with AssetsReader(item_json) as reader:
55+
img = reader.tile(0, 0, 0, assets=["cog"])
56+
assert isinstance(img, ImageData)
57+
58+
59+
@pytest.mark.skip(reason="To be implemented.")
60+
def test_tile_netcdf():
61+
"""Test tile function with netcdf asset."""
62+
pass

tests/test_items.py

-44
This file was deleted.

titiler/stacapi/assets_reader.py

+221
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
"""titiler-stacapi Asset Reader."""
2+
3+
import warnings
4+
from typing import Any, Dict, Optional, Sequence, Set, Type, Union
5+
6+
import attr
7+
import rasterio
8+
from morecantile import TileMatrixSet
9+
from rio_tiler.constants import WEB_MERCATOR_TMS, WGS84_CRS
10+
from rio_tiler.errors import (
11+
AssetAsBandError,
12+
ExpressionMixingWarning,
13+
InvalidAssetName,
14+
MissingAssets,
15+
TileOutsideBounds,
16+
)
17+
from rio_tiler.io import Reader
18+
from rio_tiler.io.base import BaseReader, MultiBaseReader
19+
from rio_tiler.models import ImageData
20+
from rio_tiler.tasks import multi_arrays
21+
from rio_tiler.types import Indexes
22+
23+
from titiler.stacapi.models import AssetInfo
24+
from titiler.stacapi.settings import STACSettings
25+
26+
stac_config = STACSettings()
27+
28+
valid_types = {
29+
"image/tiff; application=geotiff",
30+
"image/tiff; application=geotiff; profile=cloud-optimized",
31+
"image/tiff; profile=cloud-optimized; application=geotiff",
32+
"image/vnd.stac.geotiff; cloud-optimized=true",
33+
"image/tiff",
34+
"image/x.geotiff",
35+
"image/jp2",
36+
"application/x-hdf5",
37+
"application/x-hdf",
38+
"application/vnd+zarr",
39+
"application/x-netcdf",
40+
}
41+
42+
43+
@attr.s
44+
class AssetsReader(MultiBaseReader):
45+
"""
46+
Asset reader for STAC items.
47+
"""
48+
49+
# bounds and assets are required
50+
input: Any = attr.ib()
51+
tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS)
52+
minzoom: int = attr.ib()
53+
maxzoom: int = attr.ib()
54+
55+
reader: Type[BaseReader] = attr.ib(default=Reader)
56+
reader_options: Dict = attr.ib(factory=dict)
57+
58+
ctx: Any = attr.ib(default=rasterio.Env)
59+
60+
include_asset_types: Set[str] = attr.ib(default=valid_types)
61+
62+
@minzoom.default
63+
def _minzoom(self):
64+
return self.tms.minzoom
65+
66+
@maxzoom.default
67+
def _maxzoom(self):
68+
return self.tms.maxzoom
69+
70+
def __attrs_post_init__(self):
71+
"""
72+
Post Init.
73+
"""
74+
# MultibaseReader includes the spatial mixin so these attributes are required to assert that the tile exists inside the bounds of the item
75+
self.crs = WGS84_CRS # Per specification STAC items are in WGS84
76+
self.bounds = self.input["bbox"]
77+
self.assets = list(self.input["assets"])
78+
79+
def _get_reader(self, asset_info: AssetInfo) -> Type[BaseReader]:
80+
"""Get Asset Reader."""
81+
asset_type = asset_info.get("type", None)
82+
83+
if asset_type and asset_type in [
84+
"application/x-hdf5",
85+
"application/x-hdf",
86+
"application/vnd.zarr",
87+
"application/x-netcdf",
88+
"application/netcdf",
89+
]:
90+
raise NotImplementedError("XarrayReader not yet implemented")
91+
92+
return Reader
93+
94+
def _get_asset_info(self, asset: str) -> AssetInfo:
95+
"""
96+
Validate asset names and return asset's info.
97+
98+
Args:
99+
asset (str): asset name.
100+
101+
Returns:
102+
AssetInfo: Asset info
103+
104+
"""
105+
if asset not in self.assets:
106+
raise InvalidAssetName(
107+
f"{asset} is not valid. Should be one of {self.assets}"
108+
)
109+
110+
asset_info = self.input["assets"][asset]
111+
112+
url = asset_info["href"]
113+
if alternate := stac_config.alternate_url:
114+
url = asset_info["alternate"][alternate]["href"]
115+
116+
info = AssetInfo(url=url, env={})
117+
118+
if asset_info.get("type"):
119+
info["type"] = asset_info["type"]
120+
121+
# there is a file STAC extension for which `header_size` is the size of the header in the file
122+
# if this value is present, we want to use the GDAL_INGESTED_BYTES_AT_OPEN env variable to read that many bytes at file open.
123+
if header_size := asset_info.get("file:header_size"):
124+
info["env"]["GDAL_INGESTED_BYTES_AT_OPEN"] = header_size # type: ignore
125+
126+
if bands := asset_info.get("raster:bands"):
127+
stats = [
128+
(b["statistics"]["minimum"], b["statistics"]["maximum"])
129+
for b in bands
130+
if {"minimum", "maximum"}.issubset(b.get("statistics", {}))
131+
]
132+
if len(stats) == len(bands):
133+
info["dataset_statistics"] = stats
134+
135+
return info
136+
137+
def tile( # noqa: C901
138+
self,
139+
tile_x: int,
140+
tile_y: int,
141+
tile_z: int,
142+
assets: Union[Sequence[str], str] = (),
143+
expression: Optional[str] = None,
144+
asset_indexes: Optional[Dict[str, Indexes]] = None, # Indexes for each asset
145+
asset_as_band: bool = False,
146+
**kwargs: Any,
147+
) -> ImageData:
148+
"""Read and merge Wep Map tiles from multiple assets.
149+
150+
Args:
151+
tile_x (int): Tile's horizontal index.
152+
tile_y (int): Tile's vertical index.
153+
tile_z (int): Tile's zoom level index.
154+
assets (sequence of str or str, optional): assets to fetch info from.
155+
expression (str, optional): rio-tiler expression for the asset list (e.g. asset1/asset2+asset3).
156+
asset_indexes (dict, optional): Band indexes for each asset (e.g {"asset1": 1, "asset2": (1, 2,)}).
157+
kwargs (optional): Options to forward to the `self.reader.tile` method.
158+
159+
Returns:
160+
rio_tiler.models.ImageData: ImageData instance with data, mask and tile spatial info.
161+
162+
"""
163+
if not self.tile_exists(tile_x, tile_y, tile_z):
164+
raise TileOutsideBounds(
165+
f"Tile {tile_z}/{tile_x}/{tile_y} is outside image bounds"
166+
)
167+
168+
if isinstance(assets, str):
169+
assets = (assets,)
170+
171+
if assets and expression:
172+
warnings.warn(
173+
"Both expression and assets passed; expression will overwrite assets parameter.",
174+
ExpressionMixingWarning,
175+
stacklevel=2,
176+
)
177+
178+
if expression:
179+
assets = self.parse_expression(expression, asset_as_band=asset_as_band)
180+
181+
if not assets:
182+
raise MissingAssets(
183+
"assets must be passed either via `expression` or `assets` options."
184+
)
185+
186+
# indexes comes from the bidx query-parameter.
187+
# but for asset based backend we usually use asset_bidx option.
188+
asset_indexes = asset_indexes or {}
189+
190+
# We fall back to `indexes` if provided
191+
indexes = kwargs.pop("indexes", None)
192+
193+
def _reader(asset: str, *args: Any, **kwargs: Any) -> ImageData:
194+
idx = asset_indexes.get(asset) or indexes # type: ignore
195+
asset_info = self._get_asset_info(asset)
196+
reader = self._get_reader(asset_info)
197+
198+
with self.ctx(**asset_info.get("env", {})):
199+
with reader(
200+
asset_info["url"], tms=self.tms, **self.reader_options
201+
) as src:
202+
if idx is not None:
203+
kwargs.update({"indexes": idx})
204+
data = src.tile(*args, **kwargs)
205+
206+
if asset_as_band:
207+
if len(data.band_names) > 1:
208+
raise AssetAsBandError(
209+
"Can't use `asset_as_band` for multibands asset"
210+
)
211+
data.band_names = [asset]
212+
else:
213+
data.band_names = [f"{asset}_{n}" for n in data.band_names]
214+
215+
return data
216+
217+
img = multi_arrays(assets, _reader, tile_x, tile_y, tile_z, **kwargs)
218+
if expression:
219+
return img.apply_expression(expression)
220+
221+
return img

0 commit comments

Comments
 (0)