From 87a99d207fe8bacacb9c9362ce7d2b152293dfcb Mon Sep 17 00:00:00 2001 From: Even Rouault Date: Tue, 10 Sep 2024 17:12:05 +0200 Subject: [PATCH] Overview: fix nearest resampling to be exact with all data types Fixes #10758 Also make sure that for other resampling methods, the working data type is large enough (e.g using Float64 for Int32/UInt32/Int64/UInt64). While doing so, fix an underlying bug in the convolution based code, where a hard-coded pixel stride value of 4 was used instead of sizeof(TWork), which I believe wasn't visible before. --- autotest/gcore/rasterio.py | 136 ++++++++++++++++++++++++++++++----- gcore/gdalnodatamaskband.cpp | 3 +- gcore/overview.cpp | 96 ++++++++++++++++++------- 3 files changed, 190 insertions(+), 45 deletions(-) diff --git a/autotest/gcore/rasterio.py b/autotest/gcore/rasterio.py index cee967b15724..ed44b28ffdf2 100755 --- a/autotest/gcore/rasterio.py +++ b/autotest/gcore/rasterio.py @@ -884,29 +884,129 @@ def test_rasterio_12(): # Test cubic resampling with masking -def test_rasterio_13(): +@pytest.mark.parametrize( + "dt", + [ + "Byte", + "Int8", + "Int16", + "UInt16", + "Int32", + "UInt32", + "Int64", + "UInt64", + "Float32", + "Float64", + ], +) +def test_rasterio_13(dt): numpy = pytest.importorskip("numpy") - for dt in [gdal.GDT_Byte, gdal.GDT_UInt16, gdal.GDT_UInt32]: + dt = gdal.GetDataTypeByName(dt) + mem_ds = gdal.GetDriverByName("MEM").Create("", 4, 3, 1, dt) + mem_ds.GetRasterBand(1).SetNoDataValue(0) + if dt == gdal.GDT_Int8: + x = (1 << 7) - 1 + elif dt == gdal.GDT_Byte: + x = (1 << 8) - 1 + elif dt == gdal.GDT_Int16: + x = (1 << 15) - 1 + elif dt == gdal.GDT_UInt16: + x = (1 << 16) - 1 + elif dt == gdal.GDT_Int32: + x = (1 << 31) - 1 + elif dt == gdal.GDT_UInt32: + x = (1 << 32) - 1 + elif dt == gdal.GDT_Int64: + x = (1 << 63) - 1 + elif dt == gdal.GDT_UInt64: + x = (1 << 64) - 2048 + elif dt == gdal.GDT_Float32: + x = 1.5 + else: + x = 1.23456 + mem_ds.GetRasterBand(1).WriteArray( + numpy.array([[0, 0, 0, 0], [0, x, 0, 0], [0, 0, 0, 0]]) + ) - mem_ds = gdal.GetDriverByName("MEM").Create("", 4, 3, 1, dt) - mem_ds.GetRasterBand(1).SetNoDataValue(0) - mem_ds.GetRasterBand(1).WriteArray( - numpy.array([[0, 0, 0, 0], [0, 255, 0, 0], [0, 0, 0, 0]]) - ) + ar_ds = mem_ds.ReadAsArray( + 0, 0, 4, 3, buf_xsize=8, buf_ysize=3, resample_alg=gdal.GRIORA_Cubic + ) - ar_ds = mem_ds.ReadAsArray( - 0, 0, 4, 3, buf_xsize=8, buf_ysize=3, resample_alg=gdal.GRIORA_Cubic - ) + expected_ar = numpy.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0], + [0, x, x, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ] + ) + assert numpy.array_equal(ar_ds, expected_ar) - expected_ar = numpy.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 255, 255, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - ] - ) - assert numpy.array_equal(ar_ds, expected_ar), (ar_ds, dt) + +############################################################################### +# Test cubic resampling with masking + + +@pytest.mark.parametrize( + "dt", + [ + "Byte", + "Int8", + "Int16", + "UInt16", + "Int32", + "UInt32", + "Int64", + "UInt64", + "Float32", + "Float64", + "CInt16", + "CInt32", + "CFloat32", + "CFloat64", + ], +) +def test_rasterio_nearest(dt): + numpy = pytest.importorskip("numpy") + gdal_array = pytest.importorskip("osgeo.gdal_array") + + dt = gdal.GetDataTypeByName(dt) + mem_ds = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, dt) + if dt == gdal.GDT_Int8: + x = (1 << 7) - 1 + elif dt == gdal.GDT_Byte: + x = (1 << 8) - 1 + elif dt == gdal.GDT_Int16 or dt == gdal.GDT_CInt16: + x = (1 << 15) - 1 + elif dt == gdal.GDT_UInt16: + x = (1 << 16) - 1 + elif dt == gdal.GDT_Int32 or dt == gdal.GDT_CInt32: + x = (1 << 31) - 1 + elif dt == gdal.GDT_UInt32: + x = (1 << 32) - 1 + elif dt == gdal.GDT_Int64: + x = (1 << 63) - 1 + elif dt == gdal.GDT_UInt64: + x = (1 << 64) - 1 + elif dt == gdal.GDT_Float32 or dt == gdal.GDT_CFloat32: + x = 1.5 + else: + x = 1.234567890123 + + if gdal.DataTypeIsComplex(dt): + x = complex(x, x) + + dtype = gdal_array.flip_code(dt) + mem_ds.GetRasterBand(1).WriteArray(numpy.full((4, 4), x, dtype=dtype)) + + ar_ds = mem_ds.ReadAsArray(0, 0, 4, 4, buf_xsize=1, buf_ysize=1) + + expected_ar = numpy.array([[x]]).astype(dtype) + assert numpy.array_equal(ar_ds, expected_ar) + + mem_ds.BuildOverviews("NEAR", [4]) + ar_ds = mem_ds.GetRasterBand(1).GetOverview(0).ReadAsArray() + assert numpy.array_equal(ar_ds, expected_ar) ############################################################################### diff --git a/gcore/gdalnodatamaskband.cpp b/gcore/gdalnodatamaskband.cpp index 135679647b25..273a801560b1 100644 --- a/gcore/gdalnodatamaskband.cpp +++ b/gcore/gdalnodatamaskband.cpp @@ -140,7 +140,8 @@ static GDALDataType GetWorkDataType(GDALDataType eDataType) eWrkDT = eDataType; break; - default: + case GDT_Unknown: + case GDT_TypeCount: CPLAssert(false); eWrkDT = GDT_Float64; break; diff --git a/gcore/overview.cpp b/gcore/overview.cpp index 78febc2ae656..856caae70af4 100644 --- a/gcore/overview.cpp +++ b/gcore/overview.cpp @@ -170,35 +170,58 @@ static CPLErr GDALResampleChunk_Near(const GDALOverviewResampleArgs &args, *peDstBufferDataType = args.eWrkDataType; switch (args.eWrkDataType) { + // For nearest resampling, as no computation is done, only the + // size of the data type matters. case GDT_Byte: + case GDT_Int8: { + CPLAssert(GDALGetDataTypeSizeBytes(args.eWrkDataType) == 1); return GDALResampleChunk_NearT( - args, static_cast(pChunk), - reinterpret_cast(ppDstBuffer)); + args, static_cast(pChunk), + reinterpret_cast(ppDstBuffer)); } + case GDT_Int16: case GDT_UInt16: { + CPLAssert(GDALGetDataTypeSizeBytes(args.eWrkDataType) == 2); return GDALResampleChunk_NearT( - args, static_cast(pChunk), - reinterpret_cast(ppDstBuffer)); + args, static_cast(pChunk), + reinterpret_cast(ppDstBuffer)); } + case GDT_CInt16: + case GDT_Int32: + case GDT_UInt32: case GDT_Float32: { + CPLAssert(GDALGetDataTypeSizeBytes(args.eWrkDataType) == 4); return GDALResampleChunk_NearT( - args, static_cast(pChunk), - reinterpret_cast(ppDstBuffer)); + args, static_cast(pChunk), + reinterpret_cast(ppDstBuffer)); } + case GDT_CInt32: + case GDT_CFloat32: + case GDT_Int64: + case GDT_UInt64: case GDT_Float64: { + CPLAssert(GDALGetDataTypeSizeBytes(args.eWrkDataType) == 8); return GDALResampleChunk_NearT( - args, static_cast(pChunk), - reinterpret_cast(ppDstBuffer)); + args, static_cast(pChunk), + reinterpret_cast(ppDstBuffer)); } - default: + case GDT_CFloat64: + { + return GDALResampleChunk_NearT( + args, static_cast *>(pChunk), + reinterpret_cast **>(ppDstBuffer)); + } + + case GDT_Unknown: + case GDT_TypeCount: break; } CPLAssert(false); @@ -3032,6 +3055,7 @@ static CPLErr GDALResampleChunk_ConvolutionT( // cppcheck-suppress unreadVariable const int isIntegerDT = GDALDataTypeIsInteger(dstDataType); const auto nNodataValueInt64 = static_cast(dfNoDataValue); + constexpr int nWrkDataTypeSize = static_cast(sizeof(Twork)); // TODO: we should have some generic function to do this. Twork fDstMin = -std::numeric_limits::max(); @@ -3068,6 +3092,20 @@ static CPLErr GDALResampleChunk_ConvolutionT( // cppcheck-suppress unreadVariable fDstMax = static_cast(std::numeric_limits::max()); } + else if (dstDataType == GDT_UInt64) + { + // cppcheck-suppress unreadVariable + fDstMin = static_cast(std::numeric_limits::min()); + // cppcheck-suppress unreadVariable + fDstMax = static_cast(std::numeric_limits::max()); + } + else if (dstDataType == GDT_Int64) + { + // cppcheck-suppress unreadVariable + fDstMin = static_cast(std::numeric_limits::min()); + // cppcheck-suppress unreadVariable + fDstMax = static_cast(std::numeric_limits::max()); + } auto replaceValIfNodata = [bHasNoData, isIntegerDT, fDstMin, fDstMax, nNodataValueInt64, dfNoDataValue, @@ -3585,7 +3623,7 @@ static CPLErr GDALResampleChunk_ConvolutionT( if (pafWrkScanline) { - GDALCopyWords(pafWrkScanline, eWrkDataType, 4, + GDALCopyWords(pafWrkScanline, eWrkDataType, nWrkDataTypeSize, static_cast(pDstBuffer) + static_cast(iDstLine - nDstYOff) * nDstXSize * nDstDataTypeSize, @@ -4101,33 +4139,38 @@ GDALResampleFunction GDALGetResampleFunction(const char *pszResampling, GDALDataType GDALGetOvrWorkDataType(const char *pszResampling, GDALDataType eSrcDataType) { - if ((STARTS_WITH_CI(pszResampling, "NEAR") || - STARTS_WITH_CI(pszResampling, "AVER") || EQUAL(pszResampling, "RMS") || - EQUAL(pszResampling, "CUBIC") || EQUAL(pszResampling, "CUBICSPLINE") || - EQUAL(pszResampling, "LANCZOS") || EQUAL(pszResampling, "BILINEAR") || - EQUAL(pszResampling, "MODE")) && - eSrcDataType == GDT_Byte) + if (STARTS_WITH_CI(pszResampling, "NEAR")) + { + return eSrcDataType; + } + else if (eSrcDataType == GDT_Byte && + (STARTS_WITH_CI(pszResampling, "AVER") || + EQUAL(pszResampling, "RMS") || EQUAL(pszResampling, "CUBIC") || + EQUAL(pszResampling, "CUBICSPLINE") || + EQUAL(pszResampling, "LANCZOS") || + EQUAL(pszResampling, "BILINEAR") || EQUAL(pszResampling, "MODE"))) { return GDT_Byte; } - else if ((STARTS_WITH_CI(pszResampling, "NEAR") || - STARTS_WITH_CI(pszResampling, "AVER") || + else if (eSrcDataType == GDT_UInt16 && + (STARTS_WITH_CI(pszResampling, "AVER") || EQUAL(pszResampling, "RMS") || EQUAL(pszResampling, "CUBIC") || EQUAL(pszResampling, "CUBICSPLINE") || EQUAL(pszResampling, "LANCZOS") || - EQUAL(pszResampling, "BILINEAR") || - EQUAL(pszResampling, "MODE")) && - eSrcDataType == GDT_UInt16) + EQUAL(pszResampling, "BILINEAR") || EQUAL(pszResampling, "MODE"))) { return GDT_UInt16; } else if (EQUAL(pszResampling, "GAUSS")) return GDT_Float64; - if (eSrcDataType == GDT_Float64) - return GDT_Float64; - - return GDT_Float32; + if (eSrcDataType == GDT_Byte || eSrcDataType == GDT_Int8 || + eSrcDataType == GDT_UInt16 || eSrcDataType == GDT_Int16 || + eSrcDataType == GDT_Float32) + { + return GDT_Float32; + } + return GDT_Float64; } namespace @@ -4361,7 +4404,8 @@ CPLErr GDALRegenerateOverviewsEx(GDALRasterBandH hSrcBand, int nOverviewCount, const GDALDataType eSrcDataType = poSrcBand->GetRasterDataType(); const GDALDataType eWrkDataType = - GDALDataTypeIsComplex(eSrcDataType) + (GDALDataTypeIsComplex(eSrcDataType) && + !STARTS_WITH_CI(pszResampling, "NEAR")) ? GDT_CFloat32 : GDALGetOvrWorkDataType(pszResampling, eSrcDataType);