Skip to content

Commit ac99aa8

Browse files
refactor: refactor pickle_reader using ray data
1 parent ba865c8 commit ac99aa8

File tree

2 files changed

+68
-35
lines changed

2 files changed

+68
-35
lines changed

graphgen/models/reader/pdf_reader.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -247,22 +247,3 @@ def _check_bin() -> None:
247247
"MinerU is not installed or not found in PATH. Please install it from pip: \n"
248248
"pip install -U 'mineru[core]'"
249249
) from exc
250-
251-
252-
if __name__ == "__main__":
253-
reader = PDFReader(
254-
output_dir="./output",
255-
method="auto",
256-
backend="pipeline",
257-
device="cpu",
258-
lang="en",
259-
formula=True,
260-
table=True,
261-
)
262-
dataset = reader.read(
263-
"/home/PJLAB/chenzihong/Project/graphgen/resources/input_examples/pdf_demo.pdf",
264-
parallelism=2,
265-
)
266-
267-
for item in dataset.take_all():
268-
print(item)
Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,82 @@
11
import pickle
2-
from typing import Any, Dict, List
2+
from typing import List, Union
3+
4+
import pandas as pd
5+
import ray
6+
from ray.data import Dataset
37

48
from graphgen.bases.base_reader import BaseReader
9+
from graphgen.utils import logger
510

611

712
class PickleReader(BaseReader):
813
"""
9-
Read pickle files, requiring the top-level object to be List[Dict[str, Any]].
10-
11-
Columns:
14+
Read pickle files, requiring the schema to be restored to List[Dict[str, Any]].
15+
Each pickle file should contain a list of dictionaries with at least:
1216
- type: The type of the document (e.g., "text", "image", etc.)
1317
- if type is "text", "content" column must be present.
18+
19+
Note: Uses ray.data.read_binary_files as ray.data.read_pickle is not available.
20+
For Ray >= 2.5, consider using read_pickle if available in your version.
1421
"""
1522

16-
def read(self, file_path: str) -> List[Dict[str, Any]]:
17-
with open(file_path, "rb") as f:
18-
data = pickle.load(f)
23+
def read(
24+
self,
25+
input_path: Union[str, List[str]],
26+
override_num_blocks: int = None,
27+
) -> Dataset:
28+
"""
29+
Read Pickle files using Ray Data.
30+
31+
:param input_path: Path to pickle file or list of pickle files.
32+
:param override_num_blocks: Number of blocks for Ray Dataset reading.
33+
:return: Ray Dataset containing validated documents.
34+
"""
35+
if not ray.is_initialized():
36+
ray.init()
37+
38+
# Use read_binary_files as a reliable alternative to read_pickle
39+
ds = ray.data.read_binary_files(
40+
input_path, override_num_blocks=override_num_blocks, include_paths=True
41+
)
42+
43+
# Deserialize pickle files and flatten into individual records
44+
def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame:
45+
all_records = []
46+
for _, row in batch.iterrows():
47+
try:
48+
# Load pickle data from bytes
49+
data = pickle.loads(row["bytes"])
50+
51+
# Validate structure
52+
if not isinstance(data, list):
53+
logger.error(
54+
"Pickle file {row['path']} must contain a list, got {type(data)}"
55+
)
56+
continue
57+
58+
if not all(isinstance(item, dict) for item in data):
59+
logger.error(
60+
"Pickle file {row['path']} must contain a list of dictionaries"
61+
)
62+
continue
63+
64+
# Flatten: each dict in the list becomes a separate row
65+
all_records.extend(data)
66+
except Exception as e:
67+
logger.error(
68+
"Failed to deserialize pickle file %s: %s", row["path"], str(e)
69+
)
70+
continue
71+
72+
return pd.DataFrame(all_records)
1973

20-
if not isinstance(data, list):
21-
raise ValueError("Pickle file must contain a list of documents.")
74+
# Apply deserialization and flattening
75+
ds = ds.map_batches(deserialize_batch, batch_format="pandas")
2276

23-
for doc in data:
24-
if not isinstance(doc, dict):
25-
raise ValueError("Every item in the list must be a dict.")
26-
assert "type" in doc, f"Missing 'type' in document: {doc}"
27-
if doc.get("type") == "text" and self.text_column not in doc:
28-
raise ValueError(f"Missing '{self.text_column}' in document: {doc}")
77+
# Validate the schema
78+
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
2979

30-
return self.filter(data)
80+
# Filter valid items
81+
ds = ds.filter(self._should_keep_item)
82+
return ds

0 commit comments

Comments
 (0)