Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue where tmp files created by open_atomic cause error in DefaultTileStore #142

Merged
merged 2 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions rslearn/tile_stores/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from rslearn.utils import Feature, PixelBounds, Projection, STGeometry
from rslearn.utils.fsspec import (
join_upath,
open_atomic,
open_rasterio_upath_reader,
open_rasterio_upath_writer,
)
Expand Down Expand Up @@ -75,11 +76,48 @@ def set_dataset_path(self, ds_path: UPath) -> None:
def _get_raster_dir(
self, layer_name: str, item_name: str, bands: list[str]
) -> UPath:
"""Get the directory where the specified raster is stored.

Args:
layer_name: the name of the dataset layer.
item_name: the name of the item from the data source.
bands: list of band names that are expected to be stored together.

Returns:
the UPath directory where the raster should be stored.
"""
assert self.path is not None
if any(["_" in band for band in bands]):
raise ValueError("band names must not contain '_'")
return self.path / layer_name / item_name / "_".join(bands)

def _get_raster_fname(
self, layer_name: str, item_name: str, bands: list[str]
) -> UPath:
"""Get the filename of the specified raster.

Args:
layer_name: the name of the dataset layer.
item_name: the name of the item from the data source.
bands: list of band names that are expected to be stored together.

Returns:
the UPath filename of the raster, which should be readable by rasterio.

Raises:
ValueError: if no file is found.
"""
raster_dir = self._get_raster_dir(layer_name, item_name, bands)
for fname in raster_dir.iterdir():
# Ignore completed sentinel files as well as temporary files created by
# open_atomic (in case this tile store is on local filesystem).
if fname.name == COMPLETED_FNAME:
continue
if ".tmp." in fname.name:
continue
return fname
raise ValueError(f"no raster found in {raster_dir}")

def is_raster_ready(
self, layer_name: str, item_name: str, bands: list[str]
) -> bool:
Expand Down Expand Up @@ -134,12 +172,7 @@ def get_raster_bounds(
Returns:
the bounds of the raster in the projection.
"""
raster_dir = self._get_raster_dir(layer_name, item_name, bands)
fnames = [
fname for fname in raster_dir.iterdir() if fname.name != COMPLETED_FNAME
]
assert len(fnames) == 1
raster_fname = fnames[0]
raster_fname = self._get_raster_fname(layer_name, item_name, bands)

with open_rasterio_upath_reader(raster_fname) as src:
with rasterio.vrt.WarpedVRT(src, crs=projection.crs) as vrt:
Expand Down Expand Up @@ -179,12 +212,7 @@ def read_raster(
Returns:
the raster data
"""
raster_dir = self._get_raster_dir(layer_name, item_name, bands)
fnames = [
fname for fname in raster_dir.iterdir() if fname.name != COMPLETED_FNAME
]
assert len(fnames) == 1
raster_fname = fnames[0]
raster_fname = self._get_raster_fname(layer_name, item_name, bands)

# Construct the transform to use for the warped dataset.
wanted_transform = affine.Affine(
Expand Down Expand Up @@ -275,7 +303,7 @@ def write_raster_file(
# Just copy the file directly.
dst_fname = raster_dir / fname.name
with fname.open("rb") as src:
with dst_fname.open("wb") as dst:
with open_atomic(dst_fname, "wb") as dst:
shutil.copyfileobj(src, dst)

(raster_dir / COMPLETED_FNAME).touch()
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/tile_stores/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from upath import UPath

from rslearn.tile_stores.default import DefaultTileStore
from rslearn.utils.fsspec import open_atomic
from rslearn.utils.geometry import Projection

LAYER_NAME = "layer"
Expand Down Expand Up @@ -80,3 +81,50 @@ def test_zstd_compression(tmp_path: pathlib.Path) -> None:
fname = tile_store.path / LAYER_NAME / ITEM_NAME / "_".join(BANDS) / "geotiff.tif"
with rasterio.open(fname) as raster:
assert raster.profile["compress"] == "zstd"


def test_leftover_tmp_file(tmp_path: pathlib.Path) -> None:
"""Ensure that leftover files from open_atomic do not cause issues.

Previously DefaultTileStore would raise error if there was one of these leftover
files along with an actual raster written. Now the tmp files are ignored.
"""

tile_store = DefaultTileStore()
tile_store.set_dataset_path(UPath(tmp_path))
raster_size = 4
bounds = (0, 0, raster_size, raster_size)
raster_dir = tile_store._get_raster_dir(LAYER_NAME, ITEM_NAME, BANDS)
raster_dir.mkdir(parents=True)

# Create the tmp file by writing halfway with open_atomic.
class TestException(Exception):
pass

with pytest.raises(TestException):
with open_atomic(raster_dir / "geotiff.tif", "wb") as f:
f.write(b"123")
raise TestException()

# Double check that there is a tmp file.
fnames = list(raster_dir.iterdir())
assert len(fnames) == 1
assert ".tmp." in fnames[0].name

# Read should throw ValueError because there's no raster.
with pytest.raises(ValueError):
tile_store.read_raster(LAYER_NAME, ITEM_NAME, BANDS, PROJECTION, bounds)

# Now write actual raster.
tile_store.write_raster(
LAYER_NAME,
ITEM_NAME,
BANDS,
PROJECTION,
bounds,
np.ones((len(BANDS), raster_size, raster_size), dtype=np.uint8),
)

# And make sure this time the read succeeds.
array = tile_store.read_raster(LAYER_NAME, ITEM_NAME, BANDS, PROJECTION, bounds)
assert array.min() == 1 and array.max() == 1
Loading