Skip to content

Commit da90335

Browse files
Merge pull request #51 from open-sciencelab/splitter
Add Splitter classes
2 parents 6aaad9d + fdaef0e commit da90335

30 files changed

+750
-178
lines changed

graphgen/bases/base_splitter.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import copy
2+
import re
3+
from abc import ABC, abstractmethod
4+
from dataclasses import dataclass
5+
from typing import Callable, Iterable, List, Literal, Optional, Union
6+
7+
from graphgen.bases.datatypes import Chunk
8+
from graphgen.utils import logger
9+
10+
11+
@dataclass
12+
class BaseSplitter(ABC):
13+
"""
14+
Abstract base class for splitting text into smaller chunks.
15+
"""
16+
17+
chunk_size: int = 1024
18+
chunk_overlap: int = 100
19+
length_function: Callable[[str], int] = len
20+
keep_separator: bool = False
21+
add_start_index: bool = False
22+
strip_whitespace: bool = True
23+
24+
@abstractmethod
25+
def split_text(self, text: str) -> List[str]:
26+
"""
27+
Split the input text into smaller chunks.
28+
29+
:param text: The input text to be split.
30+
:return: A list of text chunks.
31+
"""
32+
33+
def create_chunks(
34+
self, texts: List[str], metadatas: Optional[List[dict]] = None
35+
) -> List[Chunk]:
36+
"""Create chunks from a list of texts."""
37+
_metadatas = metadatas or [{}] * len(texts)
38+
chunks = []
39+
for i, text in enumerate(texts):
40+
index = 0
41+
previous_chunk_len = 0
42+
for chunk in self.split_text(text):
43+
metadata = copy.deepcopy(_metadatas[i])
44+
if self.add_start_index:
45+
offset = index + previous_chunk_len - self.chunk_overlap
46+
index = text.find(chunk, max(0, offset))
47+
metadata["start_index"] = index
48+
previous_chunk_len = len(chunk)
49+
new_chunk = Chunk(content=chunk, metadata=metadata)
50+
chunks.append(new_chunk)
51+
return chunks
52+
53+
def _join_chunks(self, chunks: List[str], separator: str) -> Optional[str]:
54+
text = separator.join(chunks)
55+
if self.strip_whitespace:
56+
text = text.strip()
57+
if text == "":
58+
return None
59+
return text
60+
61+
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
62+
# We now want to combine these smaller pieces into medium size chunks to send to the LLM.
63+
separator_len = self.length_function(separator)
64+
65+
chunks = []
66+
current_chunk: List[str] = []
67+
total = 0
68+
for d in splits:
69+
_len = self.length_function(d)
70+
if (
71+
total + _len + (separator_len if len(current_chunk) > 0 else 0)
72+
> self.chunk_size
73+
):
74+
if total > self.chunk_size:
75+
logger.warning(
76+
"Created a chunk of size %s, which is longer than the specified %s",
77+
total,
78+
self.chunk_size,
79+
)
80+
if len(current_chunk) > 0:
81+
chunk = self._join_chunks(current_chunk, separator)
82+
if chunk is not None:
83+
chunks.append(chunk)
84+
# Keep on popping if:
85+
# - we have a larger chunk than in the chunk overlap
86+
# - or if we still have any chunks and the length is long
87+
while total > self.chunk_overlap or (
88+
total + _len + (separator_len if len(current_chunk) > 0 else 0)
89+
> self.chunk_size
90+
and total > 0
91+
):
92+
total -= self.length_function(current_chunk[0]) + (
93+
separator_len if len(current_chunk) > 1 else 0
94+
)
95+
current_chunk = current_chunk[1:]
96+
current_chunk.append(d)
97+
total += _len + (separator_len if len(current_chunk) > 1 else 0)
98+
chunk = self._join_chunks(current_chunk, separator)
99+
if chunk is not None:
100+
chunks.append(chunk)
101+
return chunks
102+
103+
@staticmethod
104+
def _split_text_with_regex(
105+
text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
106+
) -> List[str]:
107+
# Now that we have the separator, split the text
108+
if separator:
109+
if keep_separator:
110+
# The parentheses in the pattern keep the delimiters in the result.
111+
_splits = re.split(f"({separator})", text)
112+
splits = (
113+
(
114+
[
115+
_splits[i] + _splits[i + 1]
116+
for i in range(0, len(_splits) - 1, 2)
117+
]
118+
)
119+
if keep_separator == "end"
120+
else (
121+
[_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
122+
)
123+
)
124+
if len(_splits) % 2 == 0:
125+
splits += _splits[-1:]
126+
splits = (
127+
(splits + [_splits[-1]])
128+
if keep_separator == "end"
129+
else ([_splits[0]] + splits)
130+
)
131+
else:
132+
splits = re.split(separator, text)
133+
else:
134+
splits = list(text)
135+
return [s for s in splits if s != ""]

graphgen/bases/datatypes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from dataclasses import dataclass, field
2+
3+
4+
@dataclass
5+
class Chunk:
6+
id: str
7+
content: str
8+
metadata: dict = field(default_factory=dict)
9+
10+
11+
@dataclass
12+
class QAPair:
13+
"""
14+
A pair of question and answer.
15+
"""
16+
17+
question: str
18+
answer: str

graphgen/configs/aggregated_config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
1+
read:
2+
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3+
split:
4+
chunk_size: 1024 # chunk size for text splitting
5+
chunk_overlap: 100 # chunk overlap for text splitting
26
output_data_type: aggregated # atomic, aggregated, multi_hop, cot
37
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
48
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path

graphgen/configs/atomic_config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv. See resources/input_examples for examples
1+
read:
2+
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv. See resources/input_examples for examples
3+
split:
4+
chunk_size: 1024 # chunk size for text splitting
5+
chunk_overlap: 100 # chunk overlap for text splitting
26
output_data_type: atomic # atomic, aggregated, multi_hop, cot
37
output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
48
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path

graphgen/configs/cot_config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt. See resources/input_examples for examples
1+
read:
2+
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt. See resources/input_examples for examples
3+
split:
4+
chunk_size: 1024 # chunk size for text splitting
5+
chunk_overlap: 100 # chunk overlap for text splitting
26
output_data_type: cot # atomic, aggregated, multi_hop, cot
37
output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
48
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path

graphgen/configs/multi_hop_config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt. See resources/input_examples for examples
1+
read:
2+
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt. See resources/input_examples for examples
3+
split:
4+
chunk_size: 1024 # chunk size for text splitting
5+
chunk_overlap: 100 # chunk overlap for text splitting
26
output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
37
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
48
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path

0 commit comments

Comments
 (0)