|
1 | 1 | import pickle |
2 | | -from typing import Any, Dict, List |
| 2 | +from typing import List, Union |
| 3 | + |
| 4 | +import pandas as pd |
| 5 | +import ray |
| 6 | +from ray.data import Dataset |
3 | 7 |
|
4 | 8 | from graphgen.bases.base_reader import BaseReader |
| 9 | +from graphgen.utils import logger |
5 | 10 |
|
6 | 11 |
|
7 | 12 | class PickleReader(BaseReader): |
8 | 13 | """ |
9 | | - Read pickle files, requiring the top-level object to be List[Dict[str, Any]]. |
10 | | -
|
11 | | - Columns: |
| 14 | + Read pickle files, requiring the schema to be restored to List[Dict[str, Any]]. |
| 15 | + Each pickle file should contain a list of dictionaries with at least: |
12 | 16 | - type: The type of the document (e.g., "text", "image", etc.) |
13 | 17 | - if type is "text", "content" column must be present. |
| 18 | +
|
| 19 | + Note: Uses ray.data.read_binary_files as ray.data.read_pickle is not available. |
| 20 | + For Ray >= 2.5, consider using read_pickle if available in your version. |
14 | 21 | """ |
15 | 22 |
|
16 | | - def read(self, file_path: str) -> List[Dict[str, Any]]: |
17 | | - with open(file_path, "rb") as f: |
18 | | - data = pickle.load(f) |
| 23 | + def read( |
| 24 | + self, |
| 25 | + input_path: Union[str, List[str]], |
| 26 | + override_num_blocks: int = None, |
| 27 | + ) -> Dataset: |
| 28 | + """ |
| 29 | + Read Pickle files using Ray Data. |
| 30 | +
|
| 31 | + :param input_path: Path to pickle file or list of pickle files. |
| 32 | + :param override_num_blocks: Number of blocks for Ray Dataset reading. |
| 33 | + :return: Ray Dataset containing validated documents. |
| 34 | + """ |
| 35 | + if not ray.is_initialized(): |
| 36 | + ray.init() |
| 37 | + |
| 38 | + # Use read_binary_files as a reliable alternative to read_pickle |
| 39 | + ds = ray.data.read_binary_files( |
| 40 | + input_path, override_num_blocks=override_num_blocks, include_paths=True |
| 41 | + ) |
| 42 | + |
| 43 | + # Deserialize pickle files and flatten into individual records |
| 44 | + def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame: |
| 45 | + all_records = [] |
| 46 | + for _, row in batch.iterrows(): |
| 47 | + try: |
| 48 | + # Load pickle data from bytes |
| 49 | + data = pickle.loads(row["bytes"]) |
| 50 | + |
| 51 | + # Validate structure |
| 52 | + if not isinstance(data, list): |
| 53 | + logger.error( |
| 54 | + "Pickle file {row['path']} must contain a list, got {type(data)}" |
| 55 | + ) |
| 56 | + continue |
| 57 | + |
| 58 | + if not all(isinstance(item, dict) for item in data): |
| 59 | + logger.error( |
| 60 | + "Pickle file {row['path']} must contain a list of dictionaries" |
| 61 | + ) |
| 62 | + continue |
| 63 | + |
| 64 | + # Flatten: each dict in the list becomes a separate row |
| 65 | + all_records.extend(data) |
| 66 | + except Exception as e: |
| 67 | + logger.error( |
| 68 | + "Failed to deserialize pickle file %s: %s", row["path"], str(e) |
| 69 | + ) |
| 70 | + continue |
| 71 | + |
| 72 | + return pd.DataFrame(all_records) |
19 | 73 |
|
20 | | - if not isinstance(data, list): |
21 | | - raise ValueError("Pickle file must contain a list of documents.") |
| 74 | + # Apply deserialization and flattening |
| 75 | + ds = ds.map_batches(deserialize_batch, batch_format="pandas") |
22 | 76 |
|
23 | | - for doc in data: |
24 | | - if not isinstance(doc, dict): |
25 | | - raise ValueError("Every item in the list must be a dict.") |
26 | | - assert "type" in doc, f"Missing 'type' in document: {doc}" |
27 | | - if doc.get("type") == "text" and self.text_column not in doc: |
28 | | - raise ValueError(f"Missing '{self.text_column}' in document: {doc}") |
| 77 | + # Validate the schema |
| 78 | + ds = ds.map_batches(self._validate_batch, batch_format="pandas") |
29 | 79 |
|
30 | | - return self.filter(data) |
| 80 | + # Filter valid items |
| 81 | + ds = ds.filter(self._should_keep_item) |
| 82 | + return ds |
0 commit comments