Skip to content

Commit b79ee12

Browse files
refactor(webui): restructure GraphGen parameter handling
1 parent e1a80c7 commit b79ee12

File tree

2 files changed

+85
-28
lines changed

2 files changed

+85
-28
lines changed

webui/app.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from gradio_i18n import Translate, gettext as _
1010

11+
from base import GraphGenParams
1112
from test_api import test_api_connection
1213
from cache_utils import setup_workspace, cleanup_workspace
1314
from count_tokens import count_tokens
@@ -30,6 +31,7 @@
3031
}
3132
"""
3233

34+
3335
def init_graph_gen(config: dict, env: dict) -> GraphGen:
3436
# Set up working directory
3537
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
@@ -77,40 +79,39 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
7779
return graph_gen
7880

7981
# pylint: disable=too-many-statements
80-
def run_graphgen(*arguments: list, progress=gr.Progress()):
82+
def run_graphgen(params, progress=gr.Progress()):
8183
def sum_tokens(client):
8284
return sum(u["total_tokens"] for u in client.token_usage)
8385

84-
# Unpack arguments
8586
config = {
86-
"if_trainee_model": arguments[0],
87-
"input_file": arguments[1],
88-
"tokenizer": arguments[2],
89-
"qa_form": arguments[3],
87+
"if_trainee_model": params.if_trainee_model,
88+
"input_file": params.input_file,
89+
"tokenizer": params.tokenizer,
90+
"qa_form": params.qa_form,
9091
"web_search": False,
91-
"quiz_samples": arguments[19],
92+
"quiz_samples": params.quiz_samples,
9293
"traverse_strategy": {
93-
"bidirectional": arguments[4],
94-
"expand_method": arguments[5],
95-
"max_extra_edges": arguments[6],
96-
"max_tokens": arguments[7],
97-
"max_depth": arguments[8],
98-
"edge_sampling": arguments[9],
99-
"isolated_node_strategy": arguments[10],
100-
"loss_strategy": arguments[11]
94+
"bidirectional": params.bidirectional,
95+
"expand_method": params.expand_method,
96+
"max_extra_edges": params.max_extra_edges,
97+
"max_tokens": params.max_tokens,
98+
"max_depth": params.max_depth,
99+
"edge_sampling": params.edge_sampling,
100+
"isolated_node_strategy": params.isolated_node_strategy,
101+
"loss_strategy": params.loss_strategy
101102
},
102-
"chunk_size": arguments[16],
103+
"chunk_size": params.chunk_size,
103104
}
104105

105106
env = {
106-
"SYNTHESIZER_BASE_URL": arguments[12],
107-
"SYNTHESIZER_MODEL": arguments[13],
108-
"TRAINEE_BASE_URL": arguments[20],
109-
"TRAINEE_MODEL": arguments[14],
110-
"SYNTHESIZER_API_KEY": arguments[15],
111-
"TRAINEE_API_KEY": arguments[21],
112-
"RPM": arguments[17],
113-
"TPM": arguments[18],
107+
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
108+
"SYNTHESIZER_MODEL": params.synthesizer_model,
109+
"TRAINEE_BASE_URL": params.trainee_url,
110+
"TRAINEE_MODEL": params.trainee_model,
111+
"SYNTHESIZER_API_KEY": params.api_key,
112+
"TRAINEE_API_KEY": params.trainee_api_key,
113+
"RPM": params.rpm,
114+
"TPM": params.tpm,
114115
}
115116

116117
# Test API connection
@@ -189,7 +190,7 @@ def sum_tokens(client):
189190
trainee_tokens = sum_tokens(graph_gen.trainee_llm_client) if config['if_trainee_model'] else 0
190191
total_tokens = synthesizer_tokens + trainee_tokens
191192

192-
data_frame = arguments[-1]
193+
data_frame = params.token_counter
193194
try:
194195
_update_data = [
195196
[
@@ -460,7 +461,6 @@ def sum_tokens(client):
460461
inputs=if_trainee_model,
461462
outputs=[trainee_url, trainee_model, quiz_samples, edge_sampling, trainee_api_key])
462463

463-
# 计算上传文件的token数
464464
upload_file.change(
465465
lambda x: (gr.update(visible=True)),
466466
inputs=[upload_file],
@@ -476,8 +476,34 @@ def sum_tokens(client):
476476
lambda x: (gr.update(visible=False)),
477477
inputs=[token_counter],
478478
outputs=[token_counter],
479-
).then(
480-
run_graphgen,
479+
)
480+
481+
submit_btn.click(
482+
lambda *args: run_graphgen(GraphGenParams(
483+
if_trainee_model=args[0],
484+
input_file=args[1],
485+
tokenizer=args[2],
486+
qa_form=args[3],
487+
bidirectional=args[4],
488+
expand_method=args[5],
489+
max_extra_edges=args[6],
490+
max_tokens=args[7],
491+
max_depth=args[8],
492+
edge_sampling=args[9],
493+
isolated_node_strategy=args[10],
494+
loss_strategy=args[11],
495+
synthesizer_url=args[12],
496+
synthesizer_model=args[13],
497+
trainee_model=args[14],
498+
api_key=args[15],
499+
chunk_size=args[16],
500+
rpm=args[17],
501+
tpm=args[18],
502+
quiz_samples=args[19],
503+
trainee_url=args[20],
504+
trainee_api_key=args[21],
505+
token_counter=args[22],
506+
)),
481507
inputs=[
482508
if_trainee_model, upload_file, tokenizer, qa_form,
483509
bidirectional, expand_method, max_extra_edges, max_tokens,

webui/base.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from dataclasses import dataclass
2+
from typing import Any
3+
4+
@dataclass
5+
class GraphGenParams:
6+
"""
7+
GraphGen parameters
8+
"""
9+
if_trainee_model: bool
10+
input_file: str
11+
tokenizer: str
12+
qa_form: str
13+
bidirectional: bool
14+
expand_method: str
15+
max_extra_edges: int
16+
max_tokens: int
17+
max_depth: int
18+
edge_sampling: str
19+
isolated_node_strategy: str
20+
loss_strategy: str
21+
synthesizer_url: str
22+
synthesizer_model: str
23+
trainee_model: str
24+
api_key: str
25+
chunk_size: int
26+
rpm: int
27+
tpm: int
28+
quiz_samples: int
29+
trainee_url: str
30+
trainee_api_key: str
31+
token_counter: Any

0 commit comments

Comments
 (0)