Skip to content

Commit f391c24

Browse files
perf: optimize read_files.py by deleting implementation of ray.data.DataSource
1 parent 00551e3 commit f391c24

1 file changed

Lines changed: 42 additions & 198 deletions

File tree

Lines changed: 42 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
from pathlib import Path
2-
from typing import Any, Iterable, List, Optional, Union
2+
from typing import Any, List, Optional, Union
33

4-
import pyarrow as pa
54
import ray
6-
from ray.data.block import Block, BlockMetadata
7-
from ray.data.datasource import Datasource, ReadTask
85

96
from graphgen.models import (
107
CSVReader,
@@ -50,230 +47,77 @@ def _build_reader(suffix: str, cache_dir: str | None, **reader_kwargs):
5047
return reader_cls(**reader_kwargs)
5148

5249

53-
# pylint: disable=abstract-method
54-
class UnifiedFileDatasource(Datasource):
55-
"""
56-
A unified Ray DataSource that can read multiple file types
57-
and automatically route to the appropriate reader.
50+
def read_files(
51+
input_path: Union[str, List[str]],
52+
allowed_suffix: Optional[List[str]] = None,
53+
cache_dir: Optional[str] = None,
54+
parallelism: int = 4,
55+
recursive: bool = True,
56+
**reader_kwargs: Any,
57+
) -> ray.data.Dataset:
5858
"""
59+
Unified entry point to read files of multiple types using Ray Data.
5960
60-
def __init__(
61-
self,
62-
paths: Union[str, List[str]],
63-
allowed_suffix: Optional[List[str]] = None,
64-
cache_dir: Optional[str] = None,
65-
recursive: bool = True,
66-
**reader_kwargs,
67-
):
68-
"""
69-
Initialize the datasource.
70-
71-
:param paths: File or directory paths to read
72-
:param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt'])
73-
:param cache_dir: Directory to cache intermediate files (used for PDF processing)
74-
:param recursive: Whether to scan directories recursively
75-
:param reader_kwargs: Additional kwargs passed to readers
76-
"""
77-
self.paths = [paths] if isinstance(paths, str) else paths
78-
self.allowed_suffix = (
79-
[s.lower().lstrip(".") for s in allowed_suffix]
80-
if allowed_suffix
81-
else list(_MAPPING.keys())
82-
)
83-
self.cache_dir = cache_dir
84-
self.recursive = recursive
85-
self.reader_kwargs = reader_kwargs
86-
87-
# Validate allowed suffixes
88-
unsupported = set(self.allowed_suffix) - set(_MAPPING.keys())
89-
if unsupported:
90-
raise ValueError(f"Unsupported file suffixes: {unsupported}")
91-
92-
def get_read_tasks(
93-
self, parallelism: int, per_task_row_limit: Optional[int] = None
94-
) -> List[ReadTask]:
95-
"""
96-
Create read tasks for all discovered files.
97-
98-
:param parallelism: Number of parallel workers
99-
:param per_task_row_limit: Optional limit on rows per task
100-
:return: List of ReadTask objects
101-
"""
61+
:param input_path: File or directory path(s) to read from
62+
:param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt'])
63+
:param cache_dir: Directory to cache intermediate files (PDF processing)
64+
:param parallelism: Number of parallel workers
65+
:param recursive: Whether to scan directories recursively
66+
:param reader_kwargs: Additional kwargs passed to readers
67+
:return: Ray Dataset containing all documents
68+
"""
69+
try:
10270
# 1. Scan all paths to discover files
103-
logger.info("[READ] Scanning paths: %s", self.paths)
71+
logger.info("[READ] Scanning paths: %s", input_path)
10472
scanner = ParallelFileScanner(
105-
cache_dir=self.cache_dir,
106-
allowed_suffix=self.allowed_suffix,
73+
cache_dir=cache_dir,
74+
allowed_suffix=allowed_suffix,
10775
rescan=False,
10876
max_workers=parallelism if parallelism > 0 else 1,
10977
)
11078

11179
all_files = []
112-
scan_results = scanner.scan(self.paths, recursive=self.recursive)
80+
scan_results = scanner.scan(input_path, recursive=recursive)
11381

11482
for result in scan_results.values():
11583
all_files.extend(result.get("files", []))
11684

11785
logger.info("[READ] Found %d files to process", len(all_files))
86+
11887
if not all_files:
119-
return []
88+
return ray.data.from_items([])
12089

12190
# 2. Group files by suffix to use appropriate reader
12291
files_by_suffix = {}
12392
for file_info in all_files:
12493
suffix = Path(file_info["path"]).suffix.lower().lstrip(".")
125-
if suffix not in self.allowed_suffix:
94+
if allowed_suffix and suffix not in [
95+
s.lower().lstrip(".") for s in allowed_suffix
96+
]:
12697
continue
12798
files_by_suffix.setdefault(suffix, []).append(file_info["path"])
12899

129100
# 3. Create read tasks
130101
read_tasks = []
131-
132102
for suffix, file_paths in files_by_suffix.items():
133-
# Split files into chunks for parallel processing
134-
num_chunks = min(parallelism, len(file_paths))
135-
if num_chunks == 0:
136-
continue
137-
138-
chunks = [[] for _ in range(num_chunks)]
139-
for i, path in enumerate(file_paths):
140-
chunks[i % num_chunks].append(path)
141-
142-
# Create a task for each chunk
143-
for chunk in chunks:
144-
if not chunk:
145-
continue
146-
147-
# Use factory function to avoid mutable default argument issue
148-
def make_read_fn(
149-
file_paths_chunk, suffix_val, reader_kwargs_val, cache_dir_val
150-
):
151-
def _read_fn() -> Iterable[Block]:
152-
"""
153-
Read a chunk of files and return blocks.
154-
This function runs in a Ray worker.
155-
"""
156-
all_records = []
157-
158-
for file_path in file_paths_chunk:
159-
try:
160-
# Build reader for this file
161-
reader = _build_reader(
162-
suffix_val, cache_dir_val, **reader_kwargs_val
163-
)
164-
165-
# Read the file - readers return Dataset
166-
ds = reader.read(file_path, parallelism=parallelism)
167-
168-
# Convert Dataset to list of dicts
169-
records = ds.take_all()
170-
all_records.extend(records)
103+
reader = _build_reader(suffix, cache_dir, **reader_kwargs)
104+
ds = reader.read(file_paths, parallelism=parallelism)
105+
read_tasks.append(ds)
171106

172-
except Exception as e:
173-
logger.error(
174-
"[READ] Error reading file %s: %s", file_path, e
175-
)
176-
continue
107+
# 4. Combine all datasets
108+
if not read_tasks:
109+
logger.warning("[READ] No datasets created")
110+
return ray.data.from_items([])
177111

178-
# Convert list of dicts to PyArrow Table (Block)
179-
if all_records:
180-
# Create PyArrow Table from records
181-
# pylint: disable=no-value-for-parameter
182-
table = pa.Table.from_pylist(mapping=all_records)
183-
yield table
112+
if len(read_tasks) == 1:
113+
logger.info("[READ] Successfully read files from %s", input_path)
114+
return read_tasks[0]
115+
# len(read_tasks) > 1
116+
combined_ds = read_tasks[0].union(*read_tasks[1:])
184117

185-
return _read_fn
118+
logger.info("[READ] Successfully read files from %s", input_path)
119+
return combined_ds
186120

187-
# Create closure with current loop variables
188-
read_fn = make_read_fn(
189-
chunk, suffix, self.reader_kwargs, self.cache_dir
190-
)
191-
192-
# Calculate metadata for this task
193-
total_bytes = sum(
194-
Path(fp).stat().st_size for fp in chunk if Path(fp).exists()
195-
)
196-
197-
# input_files must be Optional[str], not List[str]
198-
# Use first file as representative or None if empty
199-
first_file = chunk[0] if chunk else None
200-
201-
metadata = BlockMetadata(
202-
num_rows=None, # Unknown until read
203-
size_bytes=total_bytes,
204-
input_files=first_file,
205-
exec_stats=None,
206-
)
207-
208-
read_tasks.append(
209-
ReadTask(
210-
read_fn=read_fn,
211-
metadata=metadata,
212-
schema=None, # Will be inferred
213-
per_task_row_limit=per_task_row_limit,
214-
)
215-
)
216-
217-
logger.info("[READ] Created %d read tasks", len(read_tasks))
218-
return read_tasks
219-
220-
def estimate_inmemory_data_size(self) -> Optional[int]:
221-
"""
222-
Estimate the total size of data in memory.
223-
This helps Ray optimize task scheduling.
224-
"""
225-
try:
226-
total_size = 0
227-
for path in self.paths:
228-
scan_results = ParallelFileScanner(
229-
cache_dir=self.cache_dir,
230-
allowed_suffix=self.allowed_suffix,
231-
rescan=False,
232-
max_workers=1,
233-
).scan(path, recursive=self.recursive)
234-
235-
for result in scan_results.values():
236-
total_size += result.get("stats", {}).get("total_size", 0)
237-
return total_size
238-
except Exception:
239-
# Return None if estimation fails
240-
return None
241-
242-
243-
def read_files(
244-
input_path: Union[str, List[str]],
245-
allowed_suffix: Optional[List[str]] = None,
246-
cache_dir: Optional[str] = None,
247-
parallelism: int = 4,
248-
recursive: bool = True,
249-
**reader_kwargs: Any,
250-
) -> ray.data.Dataset:
251-
"""
252-
Unified entry point to read files of multiple types using Ray Data.
253-
254-
:param input_path: File or directory path(s) to read from
255-
:param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt'])
256-
:param cache_dir: Directory to cache intermediate files (PDF processing)
257-
:param parallelism: Number of parallel workers
258-
:param recursive: Whether to scan directories recursively
259-
:param reader_kwargs: Additional kwargs passed to readers
260-
:return: Ray Dataset containing all documents
261-
"""
262-
263-
if not ray.is_initialized():
264-
ray.init()
265-
266-
try:
267-
return ray.data.read_datasource(
268-
UnifiedFileDatasource(
269-
paths=input_path,
270-
allowed_suffix=allowed_suffix,
271-
cache_dir=cache_dir,
272-
recursive=recursive,
273-
**reader_kwargs,
274-
),
275-
parallelism=parallelism,
276-
)
277121
except Exception as e:
278122
logger.error("[READ] Failed to read files from %s: %s", input_path, e)
279123
raise

0 commit comments

Comments
 (0)