Skip to content

Commit

Permalink
GDALThreadSafeDataset: handle the case of methods that return non pri…
Browse files Browse the repository at this point in the history
…mitive types
  • Loading branch information
rouault committed Sep 9, 2024
1 parent d441f96 commit bf2411b
Show file tree
Hide file tree
Showing 2 changed files with 306 additions and 2 deletions.
189 changes: 188 additions & 1 deletion autotest/gcore/thread_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def verify_checksum():
res[0] = False
assert False, (got_cs, expected_cs)

threads = [threading.Thread(target=verify_checksum)]
threads = [threading.Thread(target=verify_checksum) for i in range(2)]
for t in threads:
t.start()
for t in threads:
Expand Down Expand Up @@ -426,3 +426,190 @@ def test_thread_safe_unsupported_rat():
match="not supporting a non-GDALDefaultRasterAttributeTable implementation",
):
ds.GetRasterBand(1).GetDefaultRAT()


def test_thread_safe_many_datasets():

tab_ds = [
gdal.OpenEx(
"data/byte.tif" if (i % 3) < 2 else "data/utmsmall.tif",
gdal.OF_RASTER | gdal.OF_THREAD_SAFE,
)
for i in range(100)
]

res = [True]

def check():
for _ in range(10):
for i, ds in enumerate(tab_ds):
if ds.GetRasterBand(1).Checksum() != (4672 if (i % 3) < 2 else 50054):
res[0] = False

threads = [threading.Thread(target=check) for i in range(2)]
for t in threads:
t.start()
for t in threads:
t.join()
assert res[0]


def test_thread_safe_BeginAsyncReader():

pytest.importorskip("numpy")
pytest.importorskip("osgeo.gdal_array")

with gdal.OpenEx("data/byte.tif", gdal.OF_RASTER | gdal.OF_THREAD_SAFE) as ds:
with pytest.raises(Exception, match="not supported"):
ds.BeginAsyncReader(0, 0, ds.RasterXSize, ds.RasterYSize)


def test_thread_safe_GetVirtualMem():

with gdal.OpenEx("data/byte.tif", gdal.OF_RASTER | gdal.OF_THREAD_SAFE) as ds:
with pytest.raises(Exception, match="not supported"):
ds.GetRasterBand(1).GetVirtualMemAutoArray(gdal.GF_Read)


def test_thread_safe_GetMetadadata(tmp_vsimem):

filename = str(tmp_vsimem / "test.tif")
with gdal.GetDriverByName("GTiff").Create(filename, 1, 1) as ds:
ds.SetMetadataItem("foo", "bar")
ds.GetRasterBand(1).SetMetadataItem("bar", "baz")

with gdal.OpenEx(filename, gdal.OF_RASTER | gdal.OF_THREAD_SAFE) as ds:
assert ds.GetMetadataItem("foo") == "bar"
assert ds.GetMetadataItem("not existing") is None
assert ds.GetMetadata() == {"foo": "bar"}
assert ds.GetMetadata("not existing") == {}
assert ds.GetRasterBand(1).GetMetadataItem("bar") == "baz"
assert ds.GetRasterBand(1).GetMetadataItem("not existing") is None
assert ds.GetRasterBand(1).GetMetadata() == {"bar": "baz"}
assert ds.GetRasterBand(1).GetMetadata("not existing") == {}


def test_thread_safe_GetUnitType(tmp_vsimem):

filename = str(tmp_vsimem / "test.tif")
with gdal.GetDriverByName("GTiff").Create(filename, 1, 1) as ds:
ds.GetRasterBand(1).SetUnitType("foo")

with gdal.OpenEx(filename, gdal.OF_RASTER | gdal.OF_THREAD_SAFE) as ds:
assert ds.GetRasterBand(1).GetUnitType() == "foo"


def test_thread_safe_GetColorTable(tmp_vsimem):

filename = str(tmp_vsimem / "test.tif")
with gdal.GetDriverByName("GTiff").Create(filename, 1, 1) as ds:
ct = gdal.ColorTable()
ct.SetColorEntry(0, (1, 2, 3, 255))
ds.GetRasterBand(1).SetColorTable(ct)

with gdal.OpenEx(filename, gdal.OF_RASTER | gdal.OF_THREAD_SAFE) as ds:
res = [None]

def thread_job():
res[0] = ds.GetRasterBand(1).GetColorTable()

t = threading.Thread(target=thread_job)
t.start()
t.join()
assert res[0]
assert res[0].GetColorEntry(0) == (1, 2, 3, 255)
ct = ds.GetRasterBand(1).GetColorTable()
assert ct.GetColorEntry(0) == (1, 2, 3, 255)


def test_thread_safe_GetSpatialRef():

with gdal.OpenEx("data/byte.tif", gdal.OF_RASTER | gdal.OF_THREAD_SAFE) as ds:

res = [True]

def check():
for i in range(100):

if len(ds.GetGCPs()) != 0:
res[0] = False
assert False

if ds.GetGCPSpatialRef():
res[0] = False
assert False

if ds.GetGCPProjection():
res[0] = False
assert False

srs = ds.GetSpatialRef()
if not srs:
res[0] = False
assert False
if not srs.IsProjected():
res[0] = False
assert False
if "NAD27 / UTM zone 11N" not in srs.ExportToWkt():
res[0] = False
assert False

if "NAD27 / UTM zone 11N" not in ds.GetProjectionRef():
res[0] = False
assert False

threads = [threading.Thread(target=check) for i in range(2)]
for t in threads:
t.start()
for t in threads:
t.join()
assert res[0]


def test_thread_safe_GetGCPs():

with gdal.OpenEx(
"data/byte_gcp_pixelispoint.tif", gdal.OF_RASTER | gdal.OF_THREAD_SAFE
) as ds:

res = [True]

def check():
for i in range(100):

if len(ds.GetGCPs()) != 4:
res[0] = False
assert False

gcp_srs = ds.GetGCPSpatialRef()
if gcp_srs is None:
res[0] = False
assert False
if not gcp_srs.IsGeographic():
res[0] = False
assert False
if "unretrievable - using WGS84" not in gcp_srs.ExportToWkt():
res[0] = False
assert False

gcp_wkt = ds.GetGCPProjection()
if not gcp_wkt:
res[0] = False
assert False
if "unretrievable - using WGS84" not in gcp_wkt:
res[0] = False
assert False

if ds.GetSpatialRef():
res[0] = False
assert False
if ds.GetProjectionRef() != "":
res[0] = False
assert False

threads = [threading.Thread(target=check) for i in range(2)]
for t in threads:
t.start()
for t in threads:
t.join()
assert res[0]
119 changes: 118 additions & 1 deletion gcore/gdalthreadsafedataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,75 @@ class GDALThreadSafeDataset final : public GDALProxyDataset

static GDALDataset *Create(GDALDataset *poPrototypeDS, int nScopeFlags);

/* All below public methods override GDALDataset methods, and instead of
* forwarding to a thread-local dataset, they act on the prototype dataset,
* because they return a non-trivial type, that could be invalidated
* otherwise if the thread-local dataset is evicted from the LRU cache.
*/
const OGRSpatialReference *GetSpatialRef() const override
{
std::lock_guard oGuard(m_oPrototypeDSMutex);
if (m_oSRS.IsEmpty())
{
auto poSRS = m_poPrototypeDS->GetSpatialRef();
if (poSRS)
{
m_oSRS = *poSRS;
// Make sure the cached object is thread-safe.
m_oSRS.SetThreadSafe();
}
}
return m_oSRS.IsEmpty() ? nullptr : &m_oSRS;
}

const OGRSpatialReference *GetGCPSpatialRef() const override
{
std::lock_guard oGuard(m_oPrototypeDSMutex);
if (m_oGCPSRS.IsEmpty())
{
auto poSRS = m_poPrototypeDS->GetGCPSpatialRef();
if (poSRS)
{
m_oGCPSRS = *poSRS;
// Make sure the cached object is thread-safe.
m_oGCPSRS.SetThreadSafe();
}
}
return m_oGCPSRS.IsEmpty() ? nullptr : &m_oGCPSRS;
}

const GDAL_GCP *GetGCPs() override
{
std::lock_guard oGuard(m_oPrototypeDSMutex);
return const_cast<GDALDataset *>(m_poPrototypeDS)->GetGCPs();
}

const char *GetMetadataItem(const char *pszName,
const char *pszDomain = "") override
{
std::lock_guard oGuard(m_oPrototypeDSMutex);
return const_cast<GDALDataset *>(m_poPrototypeDS)
->GetMetadataItem(pszName, pszDomain);
}

char **GetMetadata(const char *pszDomain = "") override
{
std::lock_guard oGuard(m_oPrototypeDSMutex);
return const_cast<GDALDataset *>(m_poPrototypeDS)
->GetMetadata(pszDomain);
}

/* End of methods that forward on the prototype dataset */

GDALAsyncReader *BeginAsyncReader(int, int, int, int, void *, int, int,
GDALDataType, int, int *, int, int, int,
char **) override
{
CPLError(CE_Failure, CPLE_AppDefined,
"GDALThreadSafeDataset::BeginAsyncReader() not supported");
return nullptr;
}

protected:
GDALDataset *RefUnderlyingDataset() const override;

Expand All @@ -190,7 +259,7 @@ class GDALThreadSafeDataset final : public GDALProxyDataset
friend class GDALThreadLocalDatasetCache;

/** Mutex that protects accesses to m_poPrototypeDS */
std::mutex m_oPrototypeDSMutex{};
mutable std::mutex m_oPrototypeDSMutex{};

/** "Prototype" dataset, that is the dataset that was passed to the
* GDALThreadSafeDataset constructor. All calls on to it should be on
Expand All @@ -208,6 +277,12 @@ class GDALThreadSafeDataset final : public GDALProxyDataset
*/
const CPLStringList m_aosThreadLocalConfigOptions{};

/** Cached value returned by GetSpatialRef() */
mutable OGRSpatialReference m_oSRS{};

/** Cached value returned by GetGCPSpatialRef() */
mutable OGRSpatialReference m_oGCPSRS{};

/** Structure that references all GDALThreadLocalDatasetCache* instances.
*/
struct GlobalCache
Expand Down Expand Up @@ -275,6 +350,48 @@ class GDALThreadSafeRasterBand final : public GDALProxyRasterBand

GDALRasterAttributeTable *GetDefaultRAT() override;

/* All below public methods override GDALRasterBand methods, and instead of
* forwarding to a thread-local dataset, they act on the prototype band,
* because they return a non-trivial type, that could be invalidated
* otherwise if the thread-local dataset is evicted from the LRU cache.
*/
const char *GetMetadataItem(const char *pszName,
const char *pszDomain = "") override
{
std::lock_guard oGuard(m_poTSDS->m_oPrototypeDSMutex);
return const_cast<GDALRasterBand *>(m_poPrototypeBand)
->GetMetadataItem(pszName, pszDomain);
}

char **GetMetadata(const char *pszDomain = "") override
{
std::lock_guard oGuard(m_poTSDS->m_oPrototypeDSMutex);
return const_cast<GDALRasterBand *>(m_poPrototypeBand)
->GetMetadata(pszDomain);
}

const char *GetUnitType() override
{
std::lock_guard oGuard(m_poTSDS->m_oPrototypeDSMutex);
return const_cast<GDALRasterBand *>(m_poPrototypeBand)->GetUnitType();
}

GDALColorTable *GetColorTable() override
{
std::lock_guard oGuard(m_poTSDS->m_oPrototypeDSMutex);
return const_cast<GDALRasterBand *>(m_poPrototypeBand)->GetColorTable();
}

/* End of methods that forward on the prototype band */

CPLVirtualMem *GetVirtualMemAuto(GDALRWFlag, int *, GIntBig *,
char **) override
{
CPLError(CE_Failure, CPLE_AppDefined,
"GDALThreadSafeRasterBand::GetVirtualMemAuto() not supported");
return nullptr;
}

protected:
GDALRasterBand *RefUnderlyingRasterBand(bool bForceOpen) const override;
void UnrefUnderlyingRasterBand(
Expand Down

0 comments on commit bf2411b

Please sign in to comment.