Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from collections.abc import Sequence as Sequence_
from copy import deepcopy
from functools import partial, wraps
from io import BytesIO
from math import ceil, floor
from pathlib import Path
from random import sample
Expand Down Expand Up @@ -5550,21 +5549,30 @@ def _push_parquet_shards_to_hub_single(
)
shard = shard.with_format(**format)
shard_path_in_repo = f"{data_dir}/{split}-{index:05d}-of-{num_shards:05d}.parquet"
buffer = BytesIO()
shard.to_parquet(buffer, batch_size=writer_batch_size)
parquet_content = buffer.getvalue()
uploaded_size += len(parquet_content)
del buffer
shard_addition = CommitOperationAdd(path_in_repo=shard_path_in_repo, path_or_fileobj=parquet_content)
api.preupload_lfs_files(
repo_id=repo_id,
additions=[shard_addition],
repo_type="dataset",
revision=revision,
create_pr=create_pr,
)
additions.append(shard_addition)
yield job_id, False, 1
# Write to temp file instead of BytesIO to avoid holding all shard bytes in memory.
# This fixes OOM when uploading large datasets with many shards.
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as f:
temp_path = f.name
try:
shard.to_parquet(temp_path, batch_size=writer_batch_size)
uploaded_size += os.path.getsize(temp_path)
shard_addition = CommitOperationAdd(path_in_repo=shard_path_in_repo, path_or_fileobj=temp_path)
api.preupload_lfs_files(
repo_id=repo_id,
additions=[shard_addition],
repo_type="dataset",
revision=revision,
create_pr=create_pr,
)
additions.append(shard_addition)
yield job_id, False, 1
finally:
# For LFS uploads, content now lives on the Hub; the local temp file can be
# safely removed. For regular uploads, create_commit still needs to read
# from disk, so we must keep the file until after the commit completes.
if getattr(shard_addition, "_upload_mode", None) == "lfs":
if os.path.exists(temp_path):
os.unlink(temp_path)

yield job_id, True, additions

Expand Down
231 changes: 231 additions & 0 deletions tests/test_push_to_hub_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""Tests for memory-safe push_to_hub with large datasets.

Regression tests for OOM when uploading large datasets due to memory
accumulation in the additions list.
"""

import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch

from datasets import Dataset


class TestPushToHubMemorySafe:
"""Tests for memory-safe push_to_hub implementation."""

def test_push_to_hub_uses_file_path_not_bytes_in_commit_operation(self):
"""CommitOperationAdd should use file path, not bytes, to enable streaming.

This is the core fix - by using file paths instead of bytes, the upload
can stream from disk instead of holding all shard bytes in memory.
"""
ds = Dataset.from_dict({"x": list(range(100))})

commit_operations = []

with patch("datasets.arrow_dataset.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api

def capture_preupload(repo_id, additions, **kwargs):
for add in additions:
# Simulate LFS upload - set _upload_mode like huggingface_hub does
add._upload_mode = "lfs"
commit_operations.append(add)

mock_api.preupload_lfs_files = capture_preupload

# Consume the generator
list(
ds._push_parquet_shards_to_hub_single(
job_id=0,
num_jobs=1,
repo_id="test/repo",
data_dir="data",
split="train",
token="fake",
revision=None,
create_pr=False,
num_shards=2,
embed_external_files=False,
writer_batch_size=1000,
)
)

# Should have captured at least one operation
assert len(commit_operations) > 0, "No commit operations captured"

# Each CommitOperationAdd should have a Path or str, not bytes
for op in commit_operations:
assert isinstance(op.path_or_fileobj, (str, Path)), (
f"Expected file path (str or Path), got {type(op.path_or_fileobj).__name__}. "
"This indicates bytes are being held in memory instead of streamed from disk."
)

def test_push_to_hub_cleans_up_temp_files_for_lfs_uploads(self):
"""Temp files should be deleted after LFS upload completes.

For LFS uploads, content is uploaded to the Hub during preupload_lfs_files,
so the local temp file can be safely deleted to avoid disk exhaustion.
"""
ds = Dataset.from_dict({"x": list(range(100))})

created_temp_files = []

# Patch at the module level where it's used
with patch("datasets.arrow_dataset.tempfile") as mock_tempfile:
# Create real temp files but track them
real_tempfile = tempfile

def track_named_temp(*args, **kwargs):
kwargs["delete"] = False # We'll delete manually to track
f = real_tempfile.NamedTemporaryFile(*args, **kwargs)
created_temp_files.append(Path(f.name))
return f

mock_tempfile.NamedTemporaryFile = track_named_temp

with patch("datasets.arrow_dataset.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api

def simulate_lfs_preupload(repo_id, additions, **kwargs):
# Simulate huggingface_hub behavior: set _upload_mode to "lfs"
for add in additions:
add._upload_mode = "lfs"

mock_api.preupload_lfs_files = simulate_lfs_preupload

# Consume the generator
list(
ds._push_parquet_shards_to_hub_single(
job_id=0,
num_jobs=1,
repo_id="test/repo",
data_dir="data",
split="train",
token="fake",
revision=None,
create_pr=False,
num_shards=3,
embed_external_files=False,
writer_batch_size=1000,
)
)

# All temp files should be cleaned up after LFS upload completes
for temp_file in created_temp_files:
assert not temp_file.exists(), (
f"Temp file not cleaned up: {temp_file}. This will cause disk exhaustion on large datasets."
)

def test_push_to_hub_keeps_temp_files_for_regular_uploads(self):
"""Temp files should be kept for regular (non-LFS) uploads.

For regular uploads, create_commit needs to read the file content from disk,
so we must not delete the temp file until after the commit completes.
"""
ds = Dataset.from_dict({"x": list(range(100))})

created_temp_files = []

# Patch at the module level where it's used
with patch("datasets.arrow_dataset.tempfile") as mock_tempfile:
# Create real temp files but track them
real_tempfile = tempfile

def track_named_temp(*args, **kwargs):
kwargs["delete"] = False # We'll delete manually to track
f = real_tempfile.NamedTemporaryFile(*args, **kwargs)
created_temp_files.append(Path(f.name))
return f

mock_tempfile.NamedTemporaryFile = track_named_temp

with patch("datasets.arrow_dataset.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api

def simulate_regular_preupload(repo_id, additions, **kwargs):
# Simulate huggingface_hub behavior: set _upload_mode to "regular"
# This happens for small files that don't need LFS
for add in additions:
add._upload_mode = "regular"

mock_api.preupload_lfs_files = simulate_regular_preupload

# Consume the generator
list(
ds._push_parquet_shards_to_hub_single(
job_id=0,
num_jobs=1,
repo_id="test/repo",
data_dir="data",
split="train",
token="fake",
revision=None,
create_pr=False,
num_shards=3,
embed_external_files=False,
writer_batch_size=1000,
)
)

# Temp files should still exist for regular uploads (create_commit needs them)
for temp_file in created_temp_files:
assert temp_file.exists(), (
f"Temp file was deleted too early: {temp_file}. "
"Regular uploads need the file to exist until create_commit completes."
)
# Clean up manually
temp_file.unlink()

def test_push_to_hub_uploaded_size_still_calculated(self):
"""uploaded_size should still be calculated correctly with file-based approach."""
ds = Dataset.from_dict({"x": list(range(100))})

with patch("datasets.arrow_dataset.HfApi") as mock_api_class:
mock_api = MagicMock()
mock_api_class.return_value = mock_api

def simulate_preupload_with_upload_info(repo_id, additions, **kwargs):
# Simulate huggingface_hub behavior: set _upload_mode and upload_info
for add in additions:
add._upload_mode = "lfs"
# Create a mock upload_info with size
add.upload_info = MagicMock()
add.upload_info.size = 1024 # Simulate 1KB upload

mock_api.preupload_lfs_files = simulate_preupload_with_upload_info

# Collect all yields to get the final result
results = list(
ds._push_parquet_shards_to_hub_single(
job_id=0,
num_jobs=1,
repo_id="test/repo",
data_dir="data",
split="train",
token="fake",
revision=None,
create_pr=False,
num_shards=1,
embed_external_files=False,
writer_batch_size=1000,
)
)

# The function yields (job_id, done, content) tuples
# Final yield has done=True and content=additions list
final_result = results[-1]
assert final_result[1] is True, "Expected final yield to have done=True"

additions = final_result[2]
assert len(additions) > 0, "Expected at least one addition"

# Each addition should have upload_info with size > 0
for add in additions:
assert hasattr(add, "upload_info"), "CommitOperationAdd missing upload_info"
assert add.upload_info is not None, "upload_info should not be None after preupload"
assert add.upload_info.size > 0, "upload_info.size should be > 0"