Skip to content

Commit 8bcbe51

Browse files
refactor: refactor readers with ray data
1 parent 31c5a64 commit 8bcbe51

File tree

11 files changed

+429
-223
lines changed

11 files changed

+429
-223
lines changed

graphgen/bases/base_reader.py

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
from abc import ABC, abstractmethod
3-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Union
44

5+
import pandas as pd
56
import requests
7+
from ray.data import Dataset
68

79

810
class BaseReader(ABC):
@@ -14,52 +16,65 @@ def __init__(self, text_column: str = "content"):
1416
self.text_column = text_column
1517

1618
@abstractmethod
17-
def read(self, file_path: str) -> List[Dict[str, Any]]:
19+
def read(self, input_path: Union[str, List[str]]) -> Dataset:
1820
"""
1921
Read data from the specified file path.
2022
21-
:param file_path: Path to the input file.
22-
:return: List of dictionaries containing the data.
23+
:param input_path: Path to the input file or list of file paths.
24+
:return: Ray Dataset containing the read data.
2325
"""
2426

25-
@staticmethod
26-
def filter(data: List[dict]) -> List[dict]:
27+
def _should_keep_item(self, item: Dict[str, Any]) -> bool:
28+
"""
29+
Determine whether to keep the given item based on the text column.
30+
31+
:param item: Dictionary representing a data entry.
32+
:return: True if the item should be kept, False otherwise.
2733
"""
28-
Filter out entries with empty or missing text in the specified column.
34+
item_type = item.get("type")
35+
assert item_type in [
36+
"text",
37+
"image",
38+
"table",
39+
"equation",
40+
"protein",
41+
], f"Unsupported item type: {item_type}"
42+
if item_type == "text":
43+
content = item.get(self.text_column, "").strip()
44+
return bool(content)
45+
return True
2946

30-
:param data: List of dictionaries containing the data.
31-
:return: Filtered list of dictionaries.
47+
def _validate_batch(self, batch: pd.DataFrame) -> pd.DataFrame:
48+
"""
49+
Validate data format.
3250
"""
51+
if "type" not in batch.columns:
52+
raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}")
3353

34-
def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
35-
"""
36-
Check if an image exists at the given local path or URL.
37-
:param path_or_url: Local file path or remote URL of the image.
38-
:param timeout: Timeout for remote URL requests in seconds.
39-
:return: True if the image exists, False otherwise.
40-
"""
41-
if not path_or_url:
42-
return False
43-
if not path_or_url.startswith(("http://", "https://", "ftp://")):
44-
path = path_or_url.replace("file://", "", 1)
45-
path = os.path.abspath(path)
46-
return os.path.isfile(path)
47-
try:
48-
resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
49-
return resp.status_code == 200
50-
except requests.RequestException:
51-
return False
54+
if "text" in batch["type"].values:
55+
if self.text_column not in batch.columns:
56+
raise ValueError(
57+
f"Missing '{self.text_column}' column for text documents"
58+
)
5259

53-
filtered_data = []
54-
for item in data:
55-
if item.get("type") == "text":
56-
content = item.get("content", "").strip()
57-
if content:
58-
filtered_data.append(item)
59-
elif item.get("type") in ("image", "table", "equation"):
60-
img_path = item.get("img_path")
61-
if _image_exists(img_path):
62-
filtered_data.append(item)
63-
else:
64-
filtered_data.append(item)
65-
return filtered_data
60+
return batch
61+
62+
@staticmethod
63+
def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
64+
"""
65+
Check if an image exists at the given local path or URL.
66+
:param path_or_url: Local file path or remote URL of the image.
67+
:param timeout: Timeout for remote URL requests in seconds.
68+
:return: True if the image exists, False otherwise.
69+
"""
70+
if not path_or_url:
71+
return False
72+
if not path_or_url.startswith(("http://", "https://", "ftp://")):
73+
path = path_or_url.replace("file://", "", 1)
74+
path = os.path.abspath(path)
75+
return os.path.isfile(path)
76+
try:
77+
resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
78+
return resp.status_code == 200
79+
except requests.RequestException:
80+
return False

graphgen/models/reader/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .csv_reader import CSVReader
22
from .json_reader import JSONReader
3-
from .jsonl_reader import JSONLReader
43
from .parquet_reader import ParquetReader
54
from .pdf_reader import PDFReader
65
from .pickle_reader import PickleReader
Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Any, Dict, List
1+
from typing import List, Union
22

3-
import pandas as pd
3+
import ray
4+
from ray.data import Dataset
45

56
from graphgen.bases.base_reader import BaseReader
67

@@ -13,13 +14,20 @@ class CSVReader(BaseReader):
1314
- if type is "text", "content" column must be present.
1415
"""
1516

16-
def read(self, file_path: str) -> List[Dict[str, Any]]:
17+
def read(
18+
self,
19+
input_path: Union[str, List[str]],
20+
parallelism: int = None,
21+
) -> Dataset:
22+
"""
23+
Read CSV files and return Ray Dataset.
1724
18-
df = pd.read_csv(file_path)
19-
for _, row in df.iterrows():
20-
assert "type" in row, f"Missing 'type' column in document: {row.to_dict()}"
21-
if row["type"] == "text" and self.text_column not in row:
22-
raise ValueError(
23-
f"Missing '{self.text_column}' in document: {row.to_dict()}"
24-
)
25-
return self.filter(df.to_dict(orient="records"))
25+
:param input_path: Path to CSV file or list of CSV files.
26+
:param parallelism: Number of blocks for Ray Dataset reading.
27+
:return: Ray Dataset containing validated and filtered data.
28+
"""
29+
30+
ds = ray.data.read_csv(input_path, override_num_blocks=parallelism)
31+
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
32+
ds = ds.filter(self._should_keep_item)
33+
return ds
Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,32 @@
1-
import json
2-
from typing import Any, Dict, List
1+
from typing import List, Union
2+
3+
import ray
4+
from ray.data import Dataset
35

46
from graphgen.bases.base_reader import BaseReader
57

68

79
class JSONReader(BaseReader):
810
"""
9-
Reader for JSON files.
11+
Reader for JSON and JSONL files.
1012
Columns:
1113
- type: The type of the document (e.g., "text", "image", etc.)
1214
- if type is "text", "content" column must be present.
1315
"""
1416

15-
def read(self, file_path: str) -> List[Dict[str, Any]]:
16-
with open(file_path, "r", encoding="utf-8") as f:
17-
data = json.load(f)
18-
if isinstance(data, list):
19-
for doc in data:
20-
assert "type" in doc, f"Missing 'type' in document: {doc}"
21-
if doc.get("type") == "text" and self.text_column not in doc:
22-
raise ValueError(
23-
f"Missing '{self.text_column}' in document: {doc}"
24-
)
25-
return self.filter(data)
26-
raise ValueError("JSON file must contain a list of documents.")
17+
def read(
18+
self,
19+
input_path: Union[str, List[str]],
20+
parallelism: int = 4,
21+
) -> Dataset:
22+
"""
23+
Read JSON file and return Ray Dataset.
24+
:param input_path: Path to JSON/JSONL file or list of JSON/JSONL files.
25+
:param parallelism: Number of parallel workers for reading files.
26+
:return: Ray Dataset containing validated and filtered data.
27+
"""
28+
29+
ds = ray.data.read_json(input_path, override_num_blocks=parallelism)
30+
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
31+
ds = ds.filter(self._should_keep_item)
32+
return ds

graphgen/models/reader/jsonl_reader.py

Lines changed: 0 additions & 30 deletions
This file was deleted.
Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Any, Dict, List
1+
from typing import List, Union
22

3-
import pandas as pd
3+
import ray
4+
from ray.data import Dataset
45

56
from graphgen.bases.base_reader import BaseReader
67

@@ -13,12 +14,22 @@ class ParquetReader(BaseReader):
1314
- if type is "text", "content" column must be present.
1415
"""
1516

16-
def read(self, file_path: str) -> List[Dict[str, Any]]:
17-
df = pd.read_parquet(file_path)
18-
data: List[Dict[str, Any]] = df.to_dict(orient="records")
17+
def read(
18+
self,
19+
input_path: Union[str, List[str]],
20+
parallelism: int = None,
21+
) -> Dataset:
22+
"""
23+
Read Parquet files using Ray Data.
1924
20-
for doc in data:
21-
assert "type" in doc, f"Missing 'type' in document: {doc}"
22-
if doc.get("type") == "text" and self.text_column not in doc:
23-
raise ValueError(f"Missing '{self.text_column}' in document: {doc}")
24-
return self.filter(data)
25+
:param input_path: Path to Parquet file or list of Parquet files.
26+
:param parallelism: Number of blocks for Ray Dataset reading.
27+
:return: Ray Dataset containing validated documents.
28+
"""
29+
if not ray.is_initialized():
30+
ray.init()
31+
32+
ds = ray.data.read_parquet(input_path, override_num_blocks=parallelism)
33+
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
34+
ds = ds.filter(self._should_keep_item)
35+
return ds

graphgen/models/reader/pdf_reader.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from pathlib import Path
66
from typing import Any, Dict, List, Optional, Union
77

8+
import ray
9+
from ray.data import Dataset
10+
811
from graphgen.bases.base_reader import BaseReader
912
from graphgen.models.reader.txt_reader import TXTReader
1013
from graphgen.utils import logger, pick_device
@@ -62,19 +65,32 @@ def __init__(
6265
self.parser = MinerUParser()
6366
self.txt_reader = TXTReader()
6467

65-
def read(self, file_path: str, **override) -> List[Dict[str, Any]]:
66-
"""
67-
file_path
68-
**override: override MinerU parameters
69-
"""
70-
pdf_path = Path(file_path).expanduser().resolve()
71-
if not pdf_path.is_file():
72-
raise FileNotFoundError(pdf_path)
68+
def read(
69+
self,
70+
input_path: Union[str, List[str]],
71+
parallelism: int = 4,
72+
**override,
73+
) -> Dataset:
74+
75+
# Ensure input_path is a list
76+
if isinstance(input_path, str):
77+
input_path = [input_path]
78+
79+
paths_ds = ray.data.from_items(input_path)
80+
81+
def process_pdf(row: Dict[str, Any]) -> List[Dict[str, Any]]:
82+
try:
83+
pdf_path = row["item"]
84+
kwargs = {**self._default_kwargs, **override}
85+
return self._call_mineru(Path(pdf_path), kwargs)
86+
except Exception as e:
87+
logger.error("Failed to process %s: %s", row, e)
88+
return []
7389

74-
kwargs = {**self._default_kwargs, **override}
90+
docs_ds = paths_ds.flat_map(process_pdf)
91+
docs_ds = docs_ds.filter(self._should_keep_item)
7592

76-
mineru_result = self._call_mineru(pdf_path, kwargs)
77-
return self.filter(mineru_result)
93+
return docs_ds
7894

7995
def _call_mineru(
8096
self, pdf_path: Path, kwargs: Dict[str, Any]
@@ -161,18 +177,18 @@ def _try_load_cached_result(
161177

162178
base = os.path.dirname(json_file)
163179
results = []
164-
for item in data:
180+
for it in data:
165181
for key in ("img_path", "table_img_path", "equation_img_path"):
166-
rel_path = item.get(key)
182+
rel_path = it.get(key)
167183
if rel_path:
168-
item[key] = str(Path(base).joinpath(rel_path).resolve())
169-
if item["type"] == "text":
170-
item["content"] = item["text"]
171-
del item["text"]
184+
it[key] = str(Path(base).joinpath(rel_path).resolve())
185+
if it["type"] == "text":
186+
it["content"] = it["text"]
187+
del it["text"]
172188
for key in ("page_idx", "bbox", "text_level"):
173-
if item.get(key) is not None:
174-
del item[key]
175-
results.append(item)
189+
if it.get(key) is not None:
190+
del it[key]
191+
results.append(it)
176192
return results
177193

178194
@staticmethod

0 commit comments

Comments
 (0)