diff --git a/src/async_geotiff/_array.py b/src/async_geotiff/_array.py index 26b0107..4324d61 100644 --- a/src/async_geotiff/_array.py +++ b/src/async_geotiff/_array.py @@ -15,6 +15,8 @@ from numpy.typing import NDArray from pyproj.crs import CRS + from async_geotiff import GeoTIFF + @dataclass(frozen=True, kw_only=True, eq=False) class Array(TransformMixin): @@ -41,11 +43,14 @@ class Array(TransformMixin): transform: Affine """The affine transform mapping pixel coordinates to geographic coordinates.""" - crs: CRS - """The coordinate reference system of the array.""" + _alpha_band_idx: int | None + """The index of the alpha band, if any. + + The alpha band lives in the `data` array but is checked in `as_masked`. + """ - nodata: float | None = None - """The nodata value for the array, if any.""" + _geotiff: GeoTIFF + """A reference to the parent GeoTIFF.""" @classmethod def _create( # noqa: PLR0913 @@ -55,8 +60,8 @@ def _create( # noqa: PLR0913 mask: AsyncTiffArray | None, planar_configuration: PlanarConfiguration, transform: Affine, - crs: CRS, - nodata: float | None, + geotiff: GeoTIFF, + alpha_band_idx: int | None, ) -> Self: """Create an Array from async_tiff data. @@ -92,8 +97,8 @@ def _create( # noqa: PLR0913 height=height, count=count, transform=transform, - crs=crs, - nodata=nodata, + _geotiff=geotiff, + _alpha_band_idx=alpha_band_idx, ) def as_masked(self) -> MaskedArray: @@ -121,4 +126,26 @@ def as_masked(self) -> MaskedArray: if self.nodata is not None: return np.ma.masked_equal(self.data, self.nodata) + if self._alpha_band_idx is not None: + alpha_band = self.data[self._alpha_band_idx] + single_band_mask = alpha_band == np.iinfo(alpha_band.dtype).min + mask = np.broadcast_to( + single_band_mask, + (self.count - 1, self.height, self.width), + ) + # Rasterio semantics say that the alpha band itself in `data` should be + # considered valid data. + mask = np.insert(mask, self._alpha_band_idx, False, axis=0) + return MaskedArray(self.data, mask=mask) + return MaskedArray(self.data) + + @property + def crs(self) -> CRS: + """The coordinate reference system of the array.""" + return self._geotiff.crs + + @property + def nodata(self) -> float | None: + """The nodata value for the array, if any.""" + return self._geotiff.nodata diff --git a/src/async_geotiff/_fetch.py b/src/async_geotiff/_fetch.py index 2e08535..f30466d 100644 --- a/src/async_geotiff/_fetch.py +++ b/src/async_geotiff/_fetch.py @@ -14,7 +14,8 @@ from async_tiff import Array as AsyncTiffArray from async_tiff import ImageFileDirectory - from pyproj import CRS + + from async_geotiff import GeoTIFF class HasTiffReference(HasTransform, Protocol): @@ -31,8 +32,8 @@ def _mask_ifd(self) -> ImageFileDirectory | None: ... @property - def crs(self) -> CRS: - """The coordinate reference system.""" + def _geotiff(self) -> GeoTIFF: + """The parent GeoTIFF object.""" ... @property @@ -45,11 +46,6 @@ def tile_width(self) -> int: """The width of tiles in pixels.""" ... - @property - def nodata(self) -> int | float | None: - """The nodata value for the image, if any.""" - ... - @property def width(self) -> int: """The width of the image in pixels.""" @@ -107,9 +103,12 @@ async def fetch_tile( data=tile_data, mask=mask_data, planar_configuration=self._ifd.planar_configuration, - crs=self.crs, transform=tile_transform, - nodata=self.nodata, + geotiff=self._geotiff, + # TODO: when we support fetching partial bands, we need to check if the + # alpha band is included in the bands we've fetched. + # https://github.com/developmentseed/async-geotiff/issues/113 + alpha_band_idx=self._geotiff._alpha_band_idx, # noqa: SLF001 ) if not boundless: @@ -166,9 +165,12 @@ async def fetch_tiles( data=tile_data, mask=mask_data, planar_configuration=self._ifd.planar_configuration, - crs=self.crs, transform=tile_transform, - nodata=self.nodata, + geotiff=self._geotiff, + # TODO: when we support fetching partial bands, we need to check if the + # alpha band is included in the bands we've fetched. + # https://github.com/developmentseed/async-geotiff/issues/113 + alpha_band_idx=self._geotiff._alpha_band_idx, # noqa: SLF001 ) if not boundless: @@ -219,6 +221,6 @@ def _clip_to_image_bounds( height=clipped_height, count=array.count, transform=array.transform, - crs=array.crs, - nodata=array.nodata, + _geotiff=array._geotiff, # noqa: SLF001 + _alpha_band_idx=array._alpha_band_idx, # noqa: SLF001 ) diff --git a/src/async_geotiff/_geotiff.py b/src/async_geotiff/_geotiff.py index 8d2110e..ebdc8b5 100644 --- a/src/async_geotiff/_geotiff.py +++ b/src/async_geotiff/_geotiff.py @@ -79,11 +79,32 @@ class GeoTIFF(ReadMixin, FetchTileMixin, TiledMixin, TransformMixin): _gdal_metadata: GDALMetadata | None = None """The metadata extracted from the GDALMetadata TIFF tag, if any.""" + @property + def _alpha_band_idx(self) -> int | None: + """The index of the alpha band, if any.""" + # TODO: when we support fetching partial bands, we need to check if the alpha + # band is included in the bands we've fetched. + # https://github.com/developmentseed/async-geotiff/issues/113 + alpha_band_idxs = [ + i + for i, colorinterp in enumerate(self._geotiff.colorinterp) + if colorinterp == ColorInterp.ALPHA + ] + if len(alpha_band_idxs) > 1: + raise ValueError("Multiple alpha bands are not supported") + + return alpha_band_idxs[0] if alpha_band_idxs else None + @property def _ifd(self) -> ImageFileDirectory: """An alias for the primary IFD to satisfy _fetch protocol.""" return self._primary_ifd + @property + def _geotiff(self) -> GeoTIFF: + """An alias for self to satisfy _fetch protocol.""" + return self + def __init__(self, tiff: TIFF) -> None: """Create a GeoTIFF from an existing TIFF instance.""" first_ifd = tiff.ifds[0] diff --git a/src/async_geotiff/_read.py b/src/async_geotiff/_read.py index 608c1e6..8fc49e1 100644 --- a/src/async_geotiff/_read.py +++ b/src/async_geotiff/_read.py @@ -121,8 +121,11 @@ async def read( height=win.height, count=num_bands, transform=window_transform, - crs=self.crs, - nodata=self.nodata, + _geotiff=self._geotiff, + # TODO: when we support fetching partial bands, we need to check if the + # alpha band is included in the bands we've fetched. + # https://github.com/developmentseed/async-geotiff/issues/113 + _alpha_band_idx=self._geotiff._alpha_band_idx, # noqa: SLF001 ) diff --git a/tests/image_list.py b/tests/image_list.py index 53d378a..e03b09d 100644 --- a/tests/image_list.py +++ b/tests/image_list.py @@ -4,6 +4,7 @@ ("nlcd", "nlcd_landcover"), ("rasterio", "cog_uint8_rgb_mask"), ("rasterio", "cog_uint8_rgb_nodata"), + ("rasterio", "cog_uint8_rgba"), ("rasterio", "float32_1band_lerc_block32"), ("rasterio", "float32_1band_lerc_deflate_block32"), ("rasterio", "float32_1band_lerc_zstd_block32"), @@ -20,6 +21,7 @@ ALL_DATA_IMAGES: list[tuple[str, str]] = [ *ALL_COG_IMAGES, ("eox", "eox_cloudless"), + ("rasterio", "antimeridian"), ("rasterio", "pixel_as_point"), ] """All fixtures where the data can be compared with rasterio. diff --git a/tests/test_read.py b/tests/test_read.py index c4fe011..2645a95 100644 --- a/tests/test_read.py +++ b/tests/test_read.py @@ -167,3 +167,27 @@ async def test_read_full( rasterio_data = rasterio_ds.read() np.testing.assert_array_equal(array.data, rasterio_data) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("variant", "file_name"), + ALL_DATA_IMAGES, +) +async def test_read_full_as_masked( + load_geotiff: LoadGeoTIFF, + load_rasterio: LoadRasterio, + variant: str, + file_name: str, +) -> None: + geotiff = await load_geotiff(file_name, variant=variant) + array = await geotiff.read() + masked_array = array.as_masked() + + with load_rasterio(file_name, variant=variant) as rasterio_ds: + rasterio_data = rasterio_ds.read(masked=True) + + np.testing.assert_array_equal(masked_array.mask, rasterio_data.mask) + np.testing.assert_array_equal(masked_array.data, rasterio_data.data) + assert masked_array.shape == rasterio_data.shape + assert masked_array.dtype == rasterio_data.dtype