Skip to content

Commit

Permalink
Python bindings: Avoid linear scan in gdal_array.NumericTypeCodeToGDA…
Browse files Browse the repository at this point in the history
…LTypeCode (#10694)

Avoids a surprisingly expensive lookup in gdal_array.NumericTypeCodeToGDALTypeCode. This was showing up in profiling of a script that calls ReadAsArray many times; these changes improve total script runtime by ~10%.

I re-implemented the function flip_code because I saw it used in the wild: https://github.com/search?q=gdal_array.flip_code&type=code
  • Loading branch information
dbaston authored Aug 31, 2024
1 parent 9b1ab02 commit 526b28c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 20 deletions.
26 changes: 26 additions & 0 deletions autotest/gcore/numpy_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,3 +1015,29 @@ def test_numpy_rw_masked_array_2():
assert numpy.all(masked_arr.mask[mask != 255])

assert masked_arr.sum() == arr[mask == 255].sum()


###############################################################################
# Test type code mapping


def test_gdal_type_code_to_numeric_type_code():

assert gdal_array.GDALTypeCodeToNumericTypeCode(gdal.GDT_Float32) == numpy.float32

# invalid type code
assert gdal_array.GDALTypeCodeToNumericTypeCode(802) is None


def test_numeric_type_code_to_gdal_type_code():

assert gdal_array.NumericTypeCodeToGDALTypeCode(numpy.float32) == gdal.GDT_Float32
assert (
gdal_array.NumericTypeCodeToGDALTypeCode(numpy.dtype("int16")) == gdal.GDT_Int16
)


def test_flip_code():

assert gdal_array.flip_code(numpy.float32) == gdal.GDT_Float32
assert gdal_array.flip_code(gdal.GDT_Int16) == numpy.int16
35 changes: 15 additions & 20 deletions swig/include/gdal_array.i
Original file line number Diff line number Diff line change
Expand Up @@ -2389,6 +2389,11 @@ codes = {gdalconst.GDT_Byte: numpy.uint8,
gdalconst.GDT_CFloat32: numpy.complex64,
gdalconst.GDT_CFloat64: numpy.complex128}

np_class_to_gdal_code = { v : k for k, v in codes.items() }
# since several things map to complex64 we must carefully select
# the opposite that is an exact match (ticket 1518)
np_class_to_gdal_code[numpy.complex64] = gdalconst.GDT_CFloat32
np_dtype_to_gdal_code = { numpy.dtype(k) : v for k, v in np_class_to_gdal_code.items() }

def OpenArray(array, prototype_ds=None, interleave='band'):

Expand All @@ -2410,31 +2415,21 @@ def OpenArray(array, prototype_ds=None, interleave='band'):

return ds


def flip_code(code):
if isinstance(code, (numpy.dtype, type)):
# since several things map to complex64 we must carefully select
# the opposite that is an exact match (ticket 1518)
if code == numpy.complex64:
return gdalconst.GDT_CFloat32

for key, value in codes.items():
if value == code:
return key
return None
else:
try:
return codes[code]
except KeyError:
return None
try:
return NumericTypeCodeToGDALTypeCode(code)
except TypeError:
return GDALTypeCodeToNumericTypeCode(code)

def NumericTypeCodeToGDALTypeCode(numeric_type):
if not isinstance(numeric_type, (numpy.dtype, type)):
raise TypeError("Input must be a type")
return flip_code(numeric_type)
if isinstance(numeric_type, type):
return np_class_to_gdal_code.get(numeric_type, None)
elif isinstance(numeric_type, numpy.dtype):
return np_dtype_to_gdal_code.get(numeric_type, None)
raise TypeError("Input must be a type")

def GDALTypeCodeToNumericTypeCode(gdal_code):
return flip_code(gdal_code)
return codes.get(gdal_code, None)

def _RaiseException():
if gdal.GetUseExceptions():
Expand Down

0 comments on commit 526b28c

Please sign in to comment.