|
1 | 1 | from pathlib import Path |
2 | | -from typing import Any, Iterable, List, Optional, Union |
| 2 | +from typing import Any, List, Optional, Union |
3 | 3 |
|
4 | | -import pyarrow as pa |
5 | 4 | import ray |
6 | | -from ray.data.block import Block, BlockMetadata |
7 | | -from ray.data.datasource import Datasource, ReadTask |
8 | 5 |
|
9 | 6 | from graphgen.models import ( |
10 | 7 | CSVReader, |
@@ -50,230 +47,77 @@ def _build_reader(suffix: str, cache_dir: str | None, **reader_kwargs): |
50 | 47 | return reader_cls(**reader_kwargs) |
51 | 48 |
|
52 | 49 |
|
53 | | -# pylint: disable=abstract-method |
54 | | -class UnifiedFileDatasource(Datasource): |
55 | | - """ |
56 | | - A unified Ray DataSource that can read multiple file types |
57 | | - and automatically route to the appropriate reader. |
| 50 | +def read_files( |
| 51 | + input_path: Union[str, List[str]], |
| 52 | + allowed_suffix: Optional[List[str]] = None, |
| 53 | + cache_dir: Optional[str] = None, |
| 54 | + parallelism: int = 4, |
| 55 | + recursive: bool = True, |
| 56 | + **reader_kwargs: Any, |
| 57 | +) -> ray.data.Dataset: |
58 | 58 | """ |
| 59 | + Unified entry point to read files of multiple types using Ray Data. |
59 | 60 |
|
60 | | - def __init__( |
61 | | - self, |
62 | | - paths: Union[str, List[str]], |
63 | | - allowed_suffix: Optional[List[str]] = None, |
64 | | - cache_dir: Optional[str] = None, |
65 | | - recursive: bool = True, |
66 | | - **reader_kwargs, |
67 | | - ): |
68 | | - """ |
69 | | - Initialize the datasource. |
70 | | -
|
71 | | - :param paths: File or directory paths to read |
72 | | - :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) |
73 | | - :param cache_dir: Directory to cache intermediate files (used for PDF processing) |
74 | | - :param recursive: Whether to scan directories recursively |
75 | | - :param reader_kwargs: Additional kwargs passed to readers |
76 | | - """ |
77 | | - self.paths = [paths] if isinstance(paths, str) else paths |
78 | | - self.allowed_suffix = ( |
79 | | - [s.lower().lstrip(".") for s in allowed_suffix] |
80 | | - if allowed_suffix |
81 | | - else list(_MAPPING.keys()) |
82 | | - ) |
83 | | - self.cache_dir = cache_dir |
84 | | - self.recursive = recursive |
85 | | - self.reader_kwargs = reader_kwargs |
86 | | - |
87 | | - # Validate allowed suffixes |
88 | | - unsupported = set(self.allowed_suffix) - set(_MAPPING.keys()) |
89 | | - if unsupported: |
90 | | - raise ValueError(f"Unsupported file suffixes: {unsupported}") |
91 | | - |
92 | | - def get_read_tasks( |
93 | | - self, parallelism: int, per_task_row_limit: Optional[int] = None |
94 | | - ) -> List[ReadTask]: |
95 | | - """ |
96 | | - Create read tasks for all discovered files. |
97 | | -
|
98 | | - :param parallelism: Number of parallel workers |
99 | | - :param per_task_row_limit: Optional limit on rows per task |
100 | | - :return: List of ReadTask objects |
101 | | - """ |
| 61 | + :param input_path: File or directory path(s) to read from |
| 62 | + :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) |
| 63 | + :param cache_dir: Directory to cache intermediate files (PDF processing) |
| 64 | + :param parallelism: Number of parallel workers |
| 65 | + :param recursive: Whether to scan directories recursively |
| 66 | + :param reader_kwargs: Additional kwargs passed to readers |
| 67 | + :return: Ray Dataset containing all documents |
| 68 | + """ |
| 69 | + try: |
102 | 70 | # 1. Scan all paths to discover files |
103 | | - logger.info("[READ] Scanning paths: %s", self.paths) |
| 71 | + logger.info("[READ] Scanning paths: %s", input_path) |
104 | 72 | scanner = ParallelFileScanner( |
105 | | - cache_dir=self.cache_dir, |
106 | | - allowed_suffix=self.allowed_suffix, |
| 73 | + cache_dir=cache_dir, |
| 74 | + allowed_suffix=allowed_suffix, |
107 | 75 | rescan=False, |
108 | 76 | max_workers=parallelism if parallelism > 0 else 1, |
109 | 77 | ) |
110 | 78 |
|
111 | 79 | all_files = [] |
112 | | - scan_results = scanner.scan(self.paths, recursive=self.recursive) |
| 80 | + scan_results = scanner.scan(input_path, recursive=recursive) |
113 | 81 |
|
114 | 82 | for result in scan_results.values(): |
115 | 83 | all_files.extend(result.get("files", [])) |
116 | 84 |
|
117 | 85 | logger.info("[READ] Found %d files to process", len(all_files)) |
| 86 | + |
118 | 87 | if not all_files: |
119 | | - return [] |
| 88 | + return ray.data.from_items([]) |
120 | 89 |
|
121 | 90 | # 2. Group files by suffix to use appropriate reader |
122 | 91 | files_by_suffix = {} |
123 | 92 | for file_info in all_files: |
124 | 93 | suffix = Path(file_info["path"]).suffix.lower().lstrip(".") |
125 | | - if suffix not in self.allowed_suffix: |
| 94 | + if allowed_suffix and suffix not in [ |
| 95 | + s.lower().lstrip(".") for s in allowed_suffix |
| 96 | + ]: |
126 | 97 | continue |
127 | 98 | files_by_suffix.setdefault(suffix, []).append(file_info["path"]) |
128 | 99 |
|
129 | 100 | # 3. Create read tasks |
130 | 101 | read_tasks = [] |
131 | | - |
132 | 102 | for suffix, file_paths in files_by_suffix.items(): |
133 | | - # Split files into chunks for parallel processing |
134 | | - num_chunks = min(parallelism, len(file_paths)) |
135 | | - if num_chunks == 0: |
136 | | - continue |
137 | | - |
138 | | - chunks = [[] for _ in range(num_chunks)] |
139 | | - for i, path in enumerate(file_paths): |
140 | | - chunks[i % num_chunks].append(path) |
141 | | - |
142 | | - # Create a task for each chunk |
143 | | - for chunk in chunks: |
144 | | - if not chunk: |
145 | | - continue |
146 | | - |
147 | | - # Use factory function to avoid mutable default argument issue |
148 | | - def make_read_fn( |
149 | | - file_paths_chunk, suffix_val, reader_kwargs_val, cache_dir_val |
150 | | - ): |
151 | | - def _read_fn() -> Iterable[Block]: |
152 | | - """ |
153 | | - Read a chunk of files and return blocks. |
154 | | - This function runs in a Ray worker. |
155 | | - """ |
156 | | - all_records = [] |
157 | | - |
158 | | - for file_path in file_paths_chunk: |
159 | | - try: |
160 | | - # Build reader for this file |
161 | | - reader = _build_reader( |
162 | | - suffix_val, cache_dir_val, **reader_kwargs_val |
163 | | - ) |
164 | | - |
165 | | - # Read the file - readers return Dataset |
166 | | - ds = reader.read(file_path, parallelism=parallelism) |
167 | | - |
168 | | - # Convert Dataset to list of dicts |
169 | | - records = ds.take_all() |
170 | | - all_records.extend(records) |
| 103 | + reader = _build_reader(suffix, cache_dir, **reader_kwargs) |
| 104 | + ds = reader.read(file_paths, parallelism=parallelism) |
| 105 | + read_tasks.append(ds) |
171 | 106 |
|
172 | | - except Exception as e: |
173 | | - logger.error( |
174 | | - "[READ] Error reading file %s: %s", file_path, e |
175 | | - ) |
176 | | - continue |
| 107 | + # 4. Combine all datasets |
| 108 | + if not read_tasks: |
| 109 | + logger.warning("[READ] No datasets created") |
| 110 | + return ray.data.from_items([]) |
177 | 111 |
|
178 | | - # Convert list of dicts to PyArrow Table (Block) |
179 | | - if all_records: |
180 | | - # Create PyArrow Table from records |
181 | | - # pylint: disable=no-value-for-parameter |
182 | | - table = pa.Table.from_pylist(mapping=all_records) |
183 | | - yield table |
| 112 | + if len(read_tasks) == 1: |
| 113 | + logger.info("[READ] Successfully read files from %s", input_path) |
| 114 | + return read_tasks[0] |
| 115 | + # len(read_tasks) > 1 |
| 116 | + combined_ds = read_tasks[0].union(*read_tasks[1:]) |
184 | 117 |
|
185 | | - return _read_fn |
| 118 | + logger.info("[READ] Successfully read files from %s", input_path) |
| 119 | + return combined_ds |
186 | 120 |
|
187 | | - # Create closure with current loop variables |
188 | | - read_fn = make_read_fn( |
189 | | - chunk, suffix, self.reader_kwargs, self.cache_dir |
190 | | - ) |
191 | | - |
192 | | - # Calculate metadata for this task |
193 | | - total_bytes = sum( |
194 | | - Path(fp).stat().st_size for fp in chunk if Path(fp).exists() |
195 | | - ) |
196 | | - |
197 | | - # input_files must be Optional[str], not List[str] |
198 | | - # Use first file as representative or None if empty |
199 | | - first_file = chunk[0] if chunk else None |
200 | | - |
201 | | - metadata = BlockMetadata( |
202 | | - num_rows=None, # Unknown until read |
203 | | - size_bytes=total_bytes, |
204 | | - input_files=first_file, |
205 | | - exec_stats=None, |
206 | | - ) |
207 | | - |
208 | | - read_tasks.append( |
209 | | - ReadTask( |
210 | | - read_fn=read_fn, |
211 | | - metadata=metadata, |
212 | | - schema=None, # Will be inferred |
213 | | - per_task_row_limit=per_task_row_limit, |
214 | | - ) |
215 | | - ) |
216 | | - |
217 | | - logger.info("[READ] Created %d read tasks", len(read_tasks)) |
218 | | - return read_tasks |
219 | | - |
220 | | - def estimate_inmemory_data_size(self) -> Optional[int]: |
221 | | - """ |
222 | | - Estimate the total size of data in memory. |
223 | | - This helps Ray optimize task scheduling. |
224 | | - """ |
225 | | - try: |
226 | | - total_size = 0 |
227 | | - for path in self.paths: |
228 | | - scan_results = ParallelFileScanner( |
229 | | - cache_dir=self.cache_dir, |
230 | | - allowed_suffix=self.allowed_suffix, |
231 | | - rescan=False, |
232 | | - max_workers=1, |
233 | | - ).scan(path, recursive=self.recursive) |
234 | | - |
235 | | - for result in scan_results.values(): |
236 | | - total_size += result.get("stats", {}).get("total_size", 0) |
237 | | - return total_size |
238 | | - except Exception: |
239 | | - # Return None if estimation fails |
240 | | - return None |
241 | | - |
242 | | - |
243 | | -def read_files( |
244 | | - input_path: Union[str, List[str]], |
245 | | - allowed_suffix: Optional[List[str]] = None, |
246 | | - cache_dir: Optional[str] = None, |
247 | | - parallelism: int = 4, |
248 | | - recursive: bool = True, |
249 | | - **reader_kwargs: Any, |
250 | | -) -> ray.data.Dataset: |
251 | | - """ |
252 | | - Unified entry point to read files of multiple types using Ray Data. |
253 | | -
|
254 | | - :param input_path: File or directory path(s) to read from |
255 | | - :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) |
256 | | - :param cache_dir: Directory to cache intermediate files (PDF processing) |
257 | | - :param parallelism: Number of parallel workers |
258 | | - :param recursive: Whether to scan directories recursively |
259 | | - :param reader_kwargs: Additional kwargs passed to readers |
260 | | - :return: Ray Dataset containing all documents |
261 | | - """ |
262 | | - |
263 | | - if not ray.is_initialized(): |
264 | | - ray.init() |
265 | | - |
266 | | - try: |
267 | | - return ray.data.read_datasource( |
268 | | - UnifiedFileDatasource( |
269 | | - paths=input_path, |
270 | | - allowed_suffix=allowed_suffix, |
271 | | - cache_dir=cache_dir, |
272 | | - recursive=recursive, |
273 | | - **reader_kwargs, |
274 | | - ), |
275 | | - parallelism=parallelism, |
276 | | - ) |
277 | 121 | except Exception as e: |
278 | 122 | logger.error("[READ] Failed to read files from %s: %s", input_path, e) |
279 | 123 | raise |
0 commit comments