diff --git a/src/async_geotiff/_colorinterp.py b/src/async_geotiff/_colorinterp.py index 02a58e5..6cbf10d 100644 --- a/src/async_geotiff/_colorinterp.py +++ b/src/async_geotiff/_colorinterp.py @@ -6,7 +6,7 @@ from async_tiff.enums import ExtraSamples -from .enums import ColorInterp, PhotometricInterpretation +from async_geotiff.enums import ColorInterp, PhotometricInterpretation if TYPE_CHECKING: from collections.abc import Sequence @@ -23,7 +23,7 @@ def infer_color_interpretation( # noqa: PLR0911 case None: return (ColorInterp.UNDEFINED,) * count case PhotometricInterpretation.BLACK_IS_ZERO: - return (ColorInterp.GRAY,) * count + return (ColorInterp.GRAY,) + (ColorInterp.UNDEFINED,) * (count - 1) case PhotometricInterpretation.RGB: if count <= 2: raise NotImplementedError( diff --git a/src/async_geotiff/_gdal_metadata.py b/src/async_geotiff/_gdal_metadata.py index 6d95a3a..b9e1c47 100644 --- a/src/async_geotiff/_gdal_metadata.py +++ b/src/async_geotiff/_gdal_metadata.py @@ -1,10 +1,12 @@ from __future__ import annotations from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field import defusedxml.ElementTree as ET # noqa: N817 +from async_geotiff.enums import ColorInterp + @dataclass class BandStatistics: @@ -36,6 +38,18 @@ class GDALMetadata: scales: tuple[float, ...] + colorinterp: dict[int, ColorInterp] = field(default_factory=dict) + """A mapping of 1-based band index to overridden ColorInterp. + + When present, these values override what would otherwise be inferred + from the photometric interpretation tag. + """ + + +LOWER_CASE_COLORINTERP_MAPPING: dict[str, ColorInterp] = {} +for color in ColorInterp: + LOWER_CASE_COLORINTERP_MAPPING[color.name.lower()] = color + def parse_gdal_metadata( # noqa: C901 gdal_metadata: str | None, @@ -53,37 +67,37 @@ def parse_gdal_metadata( # noqa: C901 band_statistics: defaultdict[int, BandStatistics] = defaultdict(BandStatistics) offsets: list[float] = [0.0] * count scales: list[float] = [1.0] * count + colorinterp: dict[int, ColorInterp] = {} for elem in root.findall("Item"): name = elem.attrib.get("name") sample = elem.attrib.get("sample") + role = elem.attrib.get("role") text = elem.text or "" match name: # Add 1 to get a 1-based band index to match GDAL. - case "STATISTICS_MAXIMUM": - assert sample is not None # noqa: S101 + case "STATISTICS_MAXIMUM" if sample is not None: band_statistics[int(sample) + 1].max = float(text) - case "STATISTICS_MEAN": - assert sample is not None # noqa: S101 + case "STATISTICS_MEAN" if sample is not None: band_statistics[int(sample) + 1].mean = float(text) - case "STATISTICS_MINIMUM": - assert sample is not None # noqa: S101 + case "STATISTICS_MINIMUM" if sample is not None: band_statistics[int(sample) + 1].min = float(text) - case "STATISTICS_STDDEV": - assert sample is not None # noqa: S101 + case "STATISTICS_STDDEV" if sample is not None: band_statistics[int(sample) + 1].std = float(text) - case "STATISTICS_VALID_PERCENT": - assert sample is not None # noqa: S101 + case "STATISTICS_VALID_PERCENT" if sample is not None: band_statistics[int(sample) + 1].valid_percent = float(text) - case "OFFSET": - assert sample is not None # noqa: S101 + case "OFFSET" if sample is not None: offsets[int(sample)] = float(text) - case "SCALE": - assert sample is not None # noqa: S101 + case "SCALE" if sample is not None: scales[int(sample)] = float(text) + case "COLORINTERP" if role == "colorinterp" and sample is not None: + colorinterp[int(sample) + 1] = LOWER_CASE_COLORINTERP_MAPPING[ + text.lower() + ] return GDALMetadata( band_statistics=dict(band_statistics), offsets=tuple(offsets), scales=tuple(scales), + colorinterp=colorinterp, ) diff --git a/src/async_geotiff/_geotiff.py b/src/async_geotiff/_geotiff.py index 0999dfa..f420a26 100644 --- a/src/async_geotiff/_geotiff.py +++ b/src/async_geotiff/_geotiff.py @@ -215,12 +215,20 @@ async def open( @property def colorinterp(self) -> tuple[ColorInterp, ...]: """The color interpretation of each band in index order.""" - return infer_color_interpretation( - count=self.count, - photometric=self.photometric, - extra_samples=self._primary_ifd.extra_samples or [], + interps = list( + infer_color_interpretation( + count=self.count, + photometric=self.photometric, + extra_samples=self._primary_ifd.extra_samples or [], + ), ) + if gdal_metadata := self._gdal_metadata: + for band_index, color_interp in gdal_metadata.colorinterp.items(): + interps[band_index - 1] = color_interp + + return tuple(interps) + @property def colormap(self) -> Colormap | None: """Return the Colormap stored in the file, if any. diff --git a/tests/image_list.py b/tests/image_list.py index 34f6b73..b17ada6 100644 --- a/tests/image_list.py +++ b/tests/image_list.py @@ -10,9 +10,14 @@ ("rasterio", "float32_1band_lerc_zstd_block32"), ("rasterio", "uint16_1band_lzw_block128_predictor2"), ("rasterio", "uint16_1band_scale_offset"), + ("rasterio", "uint8_1band_and_alpha_deflate_block64_cog"), + ("rasterio", "uint8_1band_deflate_block128_unaligned_mask"), + ("rasterio", "uint8_1band_deflate_block128_unaligned_predictor2"), ("rasterio", "uint8_1band_deflate_block128_unaligned"), ("rasterio", "uint8_1band_lzma_block64"), ("rasterio", "uint8_1band_lzw_block64_predictor2"), + ("rasterio", "uint8_1band_zstd_level1_block64"), + ("rasterio", "uint8_nonrgb_deflate_block64_cog"), ("rasterio", "uint8_rgb_deflate_block64_cog"), ("rasterio", "uint8_rgb_webp_block64_cog"), ("rasterio", "uint8_rgba_webp_block64_cog"), @@ -35,6 +40,7 @@ ALL_TEST_IMAGES: list[tuple[str, str]] = [ *ALL_DATA_IMAGES, + ("rio-tiler", "cog_rgb_with_stats"), # YCbCr is auto-decompressed by rasterio ("vantor", "maxar_opendata_yellowstone_visual"), ("source-coop-alpha-earth", "xjejfvrbm1fbu1ecw-0000000000-0000008192"), diff --git a/tests/test_colorinterp.py b/tests/test_colorinterp.py index ef59603..524c6d5 100644 --- a/tests/test_colorinterp.py +++ b/tests/test_colorinterp.py @@ -6,7 +6,7 @@ import pytest -from .image_list import ALL_DATA_IMAGES +from .image_list import ALL_TEST_IMAGES if TYPE_CHECKING: from .conftest import LoadGeoTIFF, LoadRasterio @@ -15,7 +15,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize( ("variant", "file_name"), - ALL_DATA_IMAGES, + ALL_TEST_IMAGES, ) async def test_colorinterp( load_geotiff: LoadGeoTIFF, @@ -23,6 +23,11 @@ async def test_colorinterp( variant: str, file_name: str, ) -> None: + if (variant == "vantor" and file_name == "maxar_opendata_yellowstone_visual") or ( + variant == "rio-tiler" and file_name == "cog_rgb_with_stats" + ): + pytest.skip("Should our colorinterp map YCbCr to RGB?") + geotiff = await load_geotiff(file_name, variant=variant) colorinterp = geotiff.colorinterp assert colorinterp is not None, "Expected color interpretation to be present."