Skip to content

Commit 36e80ef

Browse files
refactor: refactor json_reader using ray data
1 parent 0422bd0 commit 36e80ef

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

graphgen/models/reader/csv_reader.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ def read(
2828
"""
2929

3030
ds = ray.data.read_csv(input_path, override_num_blocks=override_num_blocks)
31-
3231
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
33-
3432
ds = ds.filter(self._should_keep_item)
35-
3633
return ds
Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
import json
2-
from typing import Any, Dict, List
1+
from typing import List, Union
2+
3+
import ray
4+
from ray.data import Dataset
35

46
from graphgen.bases.base_reader import BaseReader
57

@@ -12,15 +14,19 @@ class JSONReader(BaseReader):
1214
- if type is "text", "content" column must be present.
1315
"""
1416

15-
def read(self, file_path: str) -> List[Dict[str, Any]]:
16-
with open(file_path, "r", encoding="utf-8") as f:
17-
data = json.load(f)
18-
if isinstance(data, list):
19-
for doc in data:
20-
assert "type" in doc, f"Missing 'type' in document: {doc}"
21-
if doc.get("type") == "text" and self.text_column not in doc:
22-
raise ValueError(
23-
f"Missing '{self.text_column}' in document: {doc}"
24-
)
25-
return self.filter(data)
26-
raise ValueError("JSON file must contain a list of documents.")
17+
def read(
18+
self,
19+
input_path: Union[str, List[str]],
20+
parallelism: int = 4,
21+
) -> Dataset:
22+
"""
23+
Read JSON file and return Ray Dataset.
24+
:param input_path: Path to JSON file or list of JSON files.
25+
:param parallelism: Number of parallel workers for reading files.
26+
:return: Ray Dataset containing validated and filtered data.
27+
"""
28+
29+
ds = ray.data.read_json(input_path, parallelism=parallelism)
30+
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
31+
ds = ds.filter(self._should_keep_item)
32+
return ds

0 commit comments

Comments
 (0)