Skip to content

Commit 3ec48c3

Browse files
Merge branch 'main' of https://github.com/open-sciencelab/GraphGen into feature/schema_guided_build
2 parents f89a320 + 519dfef commit 3ec48c3

File tree

1 file changed

+56
-28
lines changed

1 file changed

+56
-28
lines changed

webui/app.py

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pandas as pd
99
from dotenv import load_dotenv
1010

11+
from graphgen.engine import Context, Engine, collect_ops
1112
from graphgen.graphgen import GraphGen
1213
from graphgen.models import OpenAIClient, Tokenizer
1314
from graphgen.models.llm.limitter import RPM, TPM
@@ -97,26 +98,61 @@ def sum_tokens(client):
9798
"unit_sampling": params.ece_unit_sampling,
9899
}
99100

101+
pipeline = [
102+
{
103+
"name": "read",
104+
"params": {
105+
"input_file": params.upload_file,
106+
"chunk_size": params.chunk_size,
107+
"chunk_overlap": params.chunk_overlap,
108+
},
109+
},
110+
{
111+
"name": "build_kg",
112+
},
113+
]
114+
115+
if params.if_trainee_model:
116+
pipeline.append(
117+
{
118+
"name": "quiz_and_judge",
119+
"params": {"quiz_samples": params.quiz_samples, "re_judge": True},
120+
}
121+
)
122+
pipeline.append(
123+
{
124+
"name": "partition",
125+
"deps": ["quiz_and_judge"],
126+
"params": {
127+
"method": params.partition_method,
128+
"method_params": partition_params,
129+
},
130+
}
131+
)
132+
else:
133+
pipeline.append(
134+
{
135+
"name": "partition",
136+
"params": {
137+
"method": params.partition_method,
138+
"method_params": partition_params,
139+
},
140+
}
141+
)
142+
pipeline.append(
143+
{
144+
"name": "generate",
145+
"params": {
146+
"method": params.mode,
147+
"data_format": params.data_format,
148+
},
149+
}
150+
)
151+
100152
config = {
101153
"if_trainee_model": params.if_trainee_model,
102154
"read": {"input_file": params.upload_file},
103-
"split": {
104-
"chunk_size": params.chunk_size,
105-
"chunk_overlap": params.chunk_overlap,
106-
},
107-
"search": {"enabled": False},
108-
"quiz_and_judge": {
109-
"enabled": params.if_trainee_model,
110-
"quiz_samples": params.quiz_samples,
111-
},
112-
"partition": {
113-
"method": params.partition_method,
114-
"method_params": partition_params,
115-
},
116-
"generate": {
117-
"mode": params.mode,
118-
"data_format": params.data_format,
119-
},
155+
"pipeline": pipeline,
120156
}
121157

122158
env = {
@@ -145,20 +181,12 @@ def sum_tokens(client):
145181
# Initialize GraphGen
146182
graph_gen = init_graph_gen(config, env)
147183
graph_gen.clear()
148-
149184
graph_gen.progress_bar = progress
150185

151186
try:
152-
# Process the data
153-
graph_gen.insert(read_config=config["read"], split_config=config["split"])
154-
155-
if config["if_trainee_model"]:
156-
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
157-
158-
graph_gen.generate(
159-
partition_config=config["partition"],
160-
generate_config=config["generate"],
161-
)
187+
ctx = Context(config=config, graph_gen=graph_gen)
188+
ops = collect_ops(config, graph_gen)
189+
Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx)
162190

163191
# Save output
164192
output_data = graph_gen.qa_storage.data

0 commit comments

Comments
 (0)