Skip to content

Commit 42fcb09

Browse files
refactor read and chunk operators with no side effects
1 parent 319e1e7 commit 42fcb09

File tree

7 files changed

+100
-129
lines changed

7 files changed

+100
-129
lines changed
File renamed without changes.

graphgen/models/storage/json_storage.py renamed to graphgen/models/storage/kv/json_storage.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from dataclasses import dataclass
33

4-
from graphgen.bases.base_storage import BaseKVStorage, BaseListStorage
4+
from graphgen.bases.base_storage import BaseKVStorage
55
from graphgen.utils import load_json, logger, write_json
66

77

@@ -54,41 +54,3 @@ def upsert(self, data: dict):
5454
def drop(self):
5555
if self._data:
5656
self._data.clear()
57-
58-
59-
@dataclass
60-
class JsonListStorage(BaseListStorage):
61-
working_dir: str = None
62-
namespace: str = None
63-
_data: list = None
64-
65-
def __post_init__(self):
66-
self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
67-
self._data = load_json(self._file_name) or []
68-
logger.info("Load List %s with %d data", self.namespace, len(self._data))
69-
70-
@property
71-
def data(self):
72-
return self._data
73-
74-
def all_items(self) -> list:
75-
return self._data
76-
77-
def index_done_callback(self):
78-
write_json(self._data, self._file_name)
79-
80-
def get_by_index(self, index: int):
81-
if index < 0 or index >= len(self._data):
82-
return None
83-
return self._data[index]
84-
85-
def append(self, data):
86-
self._data.append(data)
87-
88-
def upsert(self, data: list):
89-
left_data = [d for d in data if d not in self._data]
90-
self._data.extend(left_data)
91-
return left_data
92-
93-
def drop(self):
94-
self._data = []
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .chunk_service import ChunkService
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import asyncio
2+
import os
3+
from functools import lru_cache
4+
from typing import Union
5+
6+
import pandas as pd
7+
from tqdm.asyncio import tqdm as tqdm_async
8+
9+
from graphgen.models import (
10+
ChineseRecursiveTextSplitter,
11+
RecursiveCharacterSplitter,
12+
Tokenizer,
13+
)
14+
from graphgen.utils import compute_content_hash, detect_main_language
15+
16+
_MAPPING = {
17+
"en": RecursiveCharacterSplitter,
18+
"zh": ChineseRecursiveTextSplitter,
19+
}
20+
21+
SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter]
22+
23+
24+
@lru_cache(maxsize=None)
25+
def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
26+
cls = _MAPPING[language]
27+
kwargs = dict(frozen_kwargs)
28+
return cls(**kwargs)
29+
30+
31+
def split_chunks(text: str, language: str = "en", **kwargs) -> list:
32+
if language not in _MAPPING:
33+
raise ValueError(
34+
f"Unsupported language: {language}. "
35+
f"Supported languages are: {list(_MAPPING.keys())}"
36+
)
37+
frozen_kwargs = frozenset(
38+
(k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items()
39+
)
40+
splitter = _get_splitter(language, frozen_kwargs)
41+
return splitter.split_text(text)
42+
43+
44+
class ChunkService:
45+
def __init__(self, **chunk_kwargs):
46+
tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
47+
self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model)
48+
self.chunk_kwargs = chunk_kwargs
49+
50+
def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
51+
docs = batch.to_dict(orient="records")
52+
return pd.DataFrame(self.chunk_documents(docs))
53+
54+
def chunk_documents(self, new_docs: list) -> list:
55+
for doc in new_docs:
56+
doc_id = doc.get("_doc_id")
57+
doc_type = doc.get("type")
58+
59+
if doc_type == "text":
60+
doc_language = detect_main_language(doc["content"])
61+
text_chunks = split_chunks(
62+
doc["content"],
63+
language=doc_language,
64+
**self.chunk_kwargs,
65+
)
66+
67+
return [
68+
{
69+
"_chunk_id": compute_content_hash(chunk_text, prefix="chunk-"),
70+
"content": chunk_text,
71+
"type": "text",
72+
"_doc_id": doc_id,
73+
"length": len(self.tokenizer_instance.encode(chunk_text))
74+
if self.tokenizer_instance
75+
else len(chunk_text),
76+
"language": doc_language,
77+
}
78+
for chunk_text in text_chunks
79+
]
80+
81+
# other types of documents(images, sequences) are not chunked
82+
return [
83+
{
84+
"_chunk_id": doc_id.replace("doc-", f"{doc_type}-"),
85+
**doc,
86+
}
87+
]

graphgen/operators/read/read.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
RDFReader,
1313
TXTReader,
1414
)
15-
from graphgen.utils import logger
15+
from graphgen.utils import compute_mm_hash, logger
1616

1717
from .parallel_file_scanner import ParallelFileScanner
1818

@@ -110,10 +110,16 @@ def read(
110110
return ray.data.from_items([])
111111

112112
if len(read_tasks) == 1:
113-
logger.info("[READ] Successfully read files from %s", input_path)
114-
return read_tasks[0]
115-
# len(read_tasks) > 1
116-
combined_ds = read_tasks[0].union(*read_tasks[1:])
113+
combined_ds = read_tasks[0]
114+
else:
115+
combined_ds = read_tasks[0].union(*read_tasks[1:])
116+
117+
combined_ds = combined_ds.map(
118+
lambda record: {
119+
**record,
120+
"_doc_id": compute_mm_hash(record),
121+
}
122+
)
117123

118124
logger.info("[READ] Successfully read files from %s", input_path)
119125
return combined_ds

graphgen/operators/split/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

graphgen/operators/split/split_chunks.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

0 commit comments

Comments
 (0)