Skip to content

Commit

Permalink
Overview: fix nearest resampling to be exact with all data types
Browse files Browse the repository at this point in the history
Fixes #10758

Also make sure that for other resampling methods, the working data type
is large enough (e.g using Float64 for Int32/UInt32/Int64/UInt64). While
doing so, fix an underlying bug in the convolution based code, where a
hard-coded pixel stride value of 4 was used instead of sizeof(TWork),
which I believe wasn't visible before.
  • Loading branch information
rouault authored and github-actions[bot] committed Sep 18, 2024
1 parent 67ac368 commit 87a99d2
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 45 deletions.
136 changes: 118 additions & 18 deletions autotest/gcore/rasterio.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,29 +884,129 @@ def test_rasterio_12():
# Test cubic resampling with masking


def test_rasterio_13():
@pytest.mark.parametrize(
"dt",
[
"Byte",
"Int8",
"Int16",
"UInt16",
"Int32",
"UInt32",
"Int64",
"UInt64",
"Float32",
"Float64",
],
)
def test_rasterio_13(dt):
numpy = pytest.importorskip("numpy")

for dt in [gdal.GDT_Byte, gdal.GDT_UInt16, gdal.GDT_UInt32]:
dt = gdal.GetDataTypeByName(dt)
mem_ds = gdal.GetDriverByName("MEM").Create("", 4, 3, 1, dt)
mem_ds.GetRasterBand(1).SetNoDataValue(0)
if dt == gdal.GDT_Int8:
x = (1 << 7) - 1
elif dt == gdal.GDT_Byte:
x = (1 << 8) - 1
elif dt == gdal.GDT_Int16:
x = (1 << 15) - 1
elif dt == gdal.GDT_UInt16:
x = (1 << 16) - 1
elif dt == gdal.GDT_Int32:
x = (1 << 31) - 1
elif dt == gdal.GDT_UInt32:
x = (1 << 32) - 1
elif dt == gdal.GDT_Int64:
x = (1 << 63) - 1
elif dt == gdal.GDT_UInt64:
x = (1 << 64) - 2048
elif dt == gdal.GDT_Float32:
x = 1.5
else:
x = 1.23456
mem_ds.GetRasterBand(1).WriteArray(
numpy.array([[0, 0, 0, 0], [0, x, 0, 0], [0, 0, 0, 0]])
)

mem_ds = gdal.GetDriverByName("MEM").Create("", 4, 3, 1, dt)
mem_ds.GetRasterBand(1).SetNoDataValue(0)
mem_ds.GetRasterBand(1).WriteArray(
numpy.array([[0, 0, 0, 0], [0, 255, 0, 0], [0, 0, 0, 0]])
)
ar_ds = mem_ds.ReadAsArray(
0, 0, 4, 3, buf_xsize=8, buf_ysize=3, resample_alg=gdal.GRIORA_Cubic
)

ar_ds = mem_ds.ReadAsArray(
0, 0, 4, 3, buf_xsize=8, buf_ysize=3, resample_alg=gdal.GRIORA_Cubic
)
expected_ar = numpy.array(
[
[0, 0, 0, 0, 0, 0, 0, 0],
[0, x, x, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
]
)
assert numpy.array_equal(ar_ds, expected_ar)

expected_ar = numpy.array(
[
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 255, 255, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
]
)
assert numpy.array_equal(ar_ds, expected_ar), (ar_ds, dt)

###############################################################################
# Test cubic resampling with masking


@pytest.mark.parametrize(
"dt",
[
"Byte",
"Int8",
"Int16",
"UInt16",
"Int32",
"UInt32",
"Int64",
"UInt64",
"Float32",
"Float64",
"CInt16",
"CInt32",
"CFloat32",
"CFloat64",
],
)
def test_rasterio_nearest(dt):
numpy = pytest.importorskip("numpy")
gdal_array = pytest.importorskip("osgeo.gdal_array")

dt = gdal.GetDataTypeByName(dt)
mem_ds = gdal.GetDriverByName("MEM").Create("", 4, 4, 1, dt)
if dt == gdal.GDT_Int8:
x = (1 << 7) - 1
elif dt == gdal.GDT_Byte:
x = (1 << 8) - 1
elif dt == gdal.GDT_Int16 or dt == gdal.GDT_CInt16:
x = (1 << 15) - 1
elif dt == gdal.GDT_UInt16:
x = (1 << 16) - 1
elif dt == gdal.GDT_Int32 or dt == gdal.GDT_CInt32:
x = (1 << 31) - 1
elif dt == gdal.GDT_UInt32:
x = (1 << 32) - 1
elif dt == gdal.GDT_Int64:
x = (1 << 63) - 1
elif dt == gdal.GDT_UInt64:
x = (1 << 64) - 1
elif dt == gdal.GDT_Float32 or dt == gdal.GDT_CFloat32:
x = 1.5
else:
x = 1.234567890123

if gdal.DataTypeIsComplex(dt):
x = complex(x, x)

dtype = gdal_array.flip_code(dt)
mem_ds.GetRasterBand(1).WriteArray(numpy.full((4, 4), x, dtype=dtype))

ar_ds = mem_ds.ReadAsArray(0, 0, 4, 4, buf_xsize=1, buf_ysize=1)

expected_ar = numpy.array([[x]]).astype(dtype)
assert numpy.array_equal(ar_ds, expected_ar)

mem_ds.BuildOverviews("NEAR", [4])
ar_ds = mem_ds.GetRasterBand(1).GetOverview(0).ReadAsArray()
assert numpy.array_equal(ar_ds, expected_ar)


###############################################################################
Expand Down
3 changes: 2 additions & 1 deletion gcore/gdalnodatamaskband.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ static GDALDataType GetWorkDataType(GDALDataType eDataType)
eWrkDT = eDataType;
break;

default:
case GDT_Unknown:
case GDT_TypeCount:
CPLAssert(false);
eWrkDT = GDT_Float64;
break;
Expand Down
96 changes: 70 additions & 26 deletions gcore/overview.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,35 +170,58 @@ static CPLErr GDALResampleChunk_Near(const GDALOverviewResampleArgs &args,
*peDstBufferDataType = args.eWrkDataType;
switch (args.eWrkDataType)
{
// For nearest resampling, as no computation is done, only the
// size of the data type matters.
case GDT_Byte:
case GDT_Int8:
{
CPLAssert(GDALGetDataTypeSizeBytes(args.eWrkDataType) == 1);
return GDALResampleChunk_NearT(
args, static_cast<const GByte *>(pChunk),
reinterpret_cast<GByte **>(ppDstBuffer));
args, static_cast<const uint8_t *>(pChunk),
reinterpret_cast<uint8_t **>(ppDstBuffer));
}

case GDT_Int16:
case GDT_UInt16:
{
CPLAssert(GDALGetDataTypeSizeBytes(args.eWrkDataType) == 2);
return GDALResampleChunk_NearT(
args, static_cast<const GUInt16 *>(pChunk),
reinterpret_cast<GUInt16 **>(ppDstBuffer));
args, static_cast<const uint16_t *>(pChunk),
reinterpret_cast<uint16_t **>(ppDstBuffer));
}

case GDT_CInt16:
case GDT_Int32:
case GDT_UInt32:
case GDT_Float32:
{
CPLAssert(GDALGetDataTypeSizeBytes(args.eWrkDataType) == 4);
return GDALResampleChunk_NearT(
args, static_cast<const float *>(pChunk),
reinterpret_cast<float **>(ppDstBuffer));
args, static_cast<const uint32_t *>(pChunk),
reinterpret_cast<uint32_t **>(ppDstBuffer));
}

case GDT_CInt32:
case GDT_CFloat32:
case GDT_Int64:
case GDT_UInt64:
case GDT_Float64:
{
CPLAssert(GDALGetDataTypeSizeBytes(args.eWrkDataType) == 8);
return GDALResampleChunk_NearT(
args, static_cast<const double *>(pChunk),
reinterpret_cast<double **>(ppDstBuffer));
args, static_cast<const uint64_t *>(pChunk),
reinterpret_cast<uint64_t **>(ppDstBuffer));
}

default:
case GDT_CFloat64:
{
return GDALResampleChunk_NearT(
args, static_cast<const std::complex<double> *>(pChunk),
reinterpret_cast<std::complex<double> **>(ppDstBuffer));
}

case GDT_Unknown:
case GDT_TypeCount:
break;
}
CPLAssert(false);
Expand Down Expand Up @@ -3032,6 +3055,7 @@ static CPLErr GDALResampleChunk_ConvolutionT(
// cppcheck-suppress unreadVariable
const int isIntegerDT = GDALDataTypeIsInteger(dstDataType);
const auto nNodataValueInt64 = static_cast<GInt64>(dfNoDataValue);
constexpr int nWrkDataTypeSize = static_cast<int>(sizeof(Twork));

// TODO: we should have some generic function to do this.
Twork fDstMin = -std::numeric_limits<Twork>::max();
Expand Down Expand Up @@ -3068,6 +3092,20 @@ static CPLErr GDALResampleChunk_ConvolutionT(
// cppcheck-suppress unreadVariable
fDstMax = static_cast<Twork>(std::numeric_limits<GInt32>::max());
}
else if (dstDataType == GDT_UInt64)
{
// cppcheck-suppress unreadVariable
fDstMin = static_cast<Twork>(std::numeric_limits<uint64_t>::min());
// cppcheck-suppress unreadVariable
fDstMax = static_cast<Twork>(std::numeric_limits<uint64_t>::max());
}
else if (dstDataType == GDT_Int64)
{
// cppcheck-suppress unreadVariable
fDstMin = static_cast<Twork>(std::numeric_limits<int64_t>::min());
// cppcheck-suppress unreadVariable
fDstMax = static_cast<Twork>(std::numeric_limits<int64_t>::max());
}

auto replaceValIfNodata = [bHasNoData, isIntegerDT, fDstMin, fDstMax,
nNodataValueInt64, dfNoDataValue,
Expand Down Expand Up @@ -3585,7 +3623,7 @@ static CPLErr GDALResampleChunk_ConvolutionT(

if (pafWrkScanline)
{
GDALCopyWords(pafWrkScanline, eWrkDataType, 4,
GDALCopyWords(pafWrkScanline, eWrkDataType, nWrkDataTypeSize,
static_cast<GByte *>(pDstBuffer) +
static_cast<size_t>(iDstLine - nDstYOff) *
nDstXSize * nDstDataTypeSize,
Expand Down Expand Up @@ -4101,33 +4139,38 @@ GDALResampleFunction GDALGetResampleFunction(const char *pszResampling,
GDALDataType GDALGetOvrWorkDataType(const char *pszResampling,
GDALDataType eSrcDataType)
{
if ((STARTS_WITH_CI(pszResampling, "NEAR") ||
STARTS_WITH_CI(pszResampling, "AVER") || EQUAL(pszResampling, "RMS") ||
EQUAL(pszResampling, "CUBIC") || EQUAL(pszResampling, "CUBICSPLINE") ||
EQUAL(pszResampling, "LANCZOS") || EQUAL(pszResampling, "BILINEAR") ||
EQUAL(pszResampling, "MODE")) &&
eSrcDataType == GDT_Byte)
if (STARTS_WITH_CI(pszResampling, "NEAR"))
{
return eSrcDataType;
}
else if (eSrcDataType == GDT_Byte &&
(STARTS_WITH_CI(pszResampling, "AVER") ||
EQUAL(pszResampling, "RMS") || EQUAL(pszResampling, "CUBIC") ||
EQUAL(pszResampling, "CUBICSPLINE") ||
EQUAL(pszResampling, "LANCZOS") ||
EQUAL(pszResampling, "BILINEAR") || EQUAL(pszResampling, "MODE")))
{
return GDT_Byte;
}
else if ((STARTS_WITH_CI(pszResampling, "NEAR") ||
STARTS_WITH_CI(pszResampling, "AVER") ||
else if (eSrcDataType == GDT_UInt16 &&
(STARTS_WITH_CI(pszResampling, "AVER") ||
EQUAL(pszResampling, "RMS") || EQUAL(pszResampling, "CUBIC") ||
EQUAL(pszResampling, "CUBICSPLINE") ||
EQUAL(pszResampling, "LANCZOS") ||
EQUAL(pszResampling, "BILINEAR") ||
EQUAL(pszResampling, "MODE")) &&
eSrcDataType == GDT_UInt16)
EQUAL(pszResampling, "BILINEAR") || EQUAL(pszResampling, "MODE")))
{
return GDT_UInt16;
}
else if (EQUAL(pszResampling, "GAUSS"))
return GDT_Float64;

if (eSrcDataType == GDT_Float64)
return GDT_Float64;

return GDT_Float32;
if (eSrcDataType == GDT_Byte || eSrcDataType == GDT_Int8 ||
eSrcDataType == GDT_UInt16 || eSrcDataType == GDT_Int16 ||
eSrcDataType == GDT_Float32)
{
return GDT_Float32;
}
return GDT_Float64;
}

namespace
Expand Down Expand Up @@ -4361,7 +4404,8 @@ CPLErr GDALRegenerateOverviewsEx(GDALRasterBandH hSrcBand, int nOverviewCount,

const GDALDataType eSrcDataType = poSrcBand->GetRasterDataType();
const GDALDataType eWrkDataType =
GDALDataTypeIsComplex(eSrcDataType)
(GDALDataTypeIsComplex(eSrcDataType) &&
!STARTS_WITH_CI(pszResampling, "NEAR"))
? GDT_CFloat32
: GDALGetOvrWorkDataType(pszResampling, eSrcDataType);

Expand Down

0 comments on commit 87a99d2

Please sign in to comment.