88
99from gradio_i18n import Translate , gettext as _
1010
11+ from base import GraphGenParams
1112from test_api import test_api_connection
1213from cache_utils import setup_workspace , cleanup_workspace
1314from count_tokens import count_tokens
3031}
3132"""
3233
34+
3335def init_graph_gen (config : dict , env : dict ) -> GraphGen :
3436 # Set up working directory
3537 log_file , working_dir = setup_workspace (os .path .join (root_dir , "cache" ))
@@ -77,40 +79,39 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
7779 return graph_gen
7880
7981# pylint: disable=too-many-statements
80- def run_graphgen (* arguments : list , progress = gr .Progress ()):
82+ def run_graphgen (params , progress = gr .Progress ()):
8183 def sum_tokens (client ):
8284 return sum (u ["total_tokens" ] for u in client .token_usage )
8385
84- # Unpack arguments
8586 config = {
86- "if_trainee_model" : arguments [ 0 ] ,
87- "input_file" : arguments [ 1 ] ,
88- "tokenizer" : arguments [ 2 ] ,
89- "qa_form" : arguments [ 3 ] ,
87+ "if_trainee_model" : params . if_trainee_model ,
88+ "input_file" : params . input_file ,
89+ "tokenizer" : params . tokenizer ,
90+ "qa_form" : params . qa_form ,
9091 "web_search" : False ,
91- "quiz_samples" : arguments [ 19 ] ,
92+ "quiz_samples" : params . quiz_samples ,
9293 "traverse_strategy" : {
93- "bidirectional" : arguments [ 4 ] ,
94- "expand_method" : arguments [ 5 ] ,
95- "max_extra_edges" : arguments [ 6 ] ,
96- "max_tokens" : arguments [ 7 ] ,
97- "max_depth" : arguments [ 8 ] ,
98- "edge_sampling" : arguments [ 9 ] ,
99- "isolated_node_strategy" : arguments [ 10 ] ,
100- "loss_strategy" : arguments [ 11 ]
94+ "bidirectional" : params . bidirectional ,
95+ "expand_method" : params . expand_method ,
96+ "max_extra_edges" : params . max_extra_edges ,
97+ "max_tokens" : params . max_tokens ,
98+ "max_depth" : params . max_depth ,
99+ "edge_sampling" : params . edge_sampling ,
100+ "isolated_node_strategy" : params . isolated_node_strategy ,
101+ "loss_strategy" : params . loss_strategy
101102 },
102- "chunk_size" : arguments [ 16 ] ,
103+ "chunk_size" : params . chunk_size ,
103104 }
104105
105106 env = {
106- "SYNTHESIZER_BASE_URL" : arguments [ 12 ] ,
107- "SYNTHESIZER_MODEL" : arguments [ 13 ] ,
108- "TRAINEE_BASE_URL" : arguments [ 20 ] ,
109- "TRAINEE_MODEL" : arguments [ 14 ] ,
110- "SYNTHESIZER_API_KEY" : arguments [ 15 ] ,
111- "TRAINEE_API_KEY" : arguments [ 21 ] ,
112- "RPM" : arguments [ 17 ] ,
113- "TPM" : arguments [ 18 ] ,
107+ "SYNTHESIZER_BASE_URL" : params . synthesizer_url ,
108+ "SYNTHESIZER_MODEL" : params . synthesizer_model ,
109+ "TRAINEE_BASE_URL" : params . trainee_url ,
110+ "TRAINEE_MODEL" : params . trainee_model ,
111+ "SYNTHESIZER_API_KEY" : params . api_key ,
112+ "TRAINEE_API_KEY" : params . trainee_api_key ,
113+ "RPM" : params . rpm ,
114+ "TPM" : params . tpm ,
114115 }
115116
116117 # Test API connection
@@ -189,7 +190,7 @@ def sum_tokens(client):
189190 trainee_tokens = sum_tokens (graph_gen .trainee_llm_client ) if config ['if_trainee_model' ] else 0
190191 total_tokens = synthesizer_tokens + trainee_tokens
191192
192- data_frame = arguments [ - 1 ]
193+ data_frame = params . token_counter
193194 try :
194195 _update_data = [
195196 [
@@ -460,7 +461,6 @@ def sum_tokens(client):
460461 inputs = if_trainee_model ,
461462 outputs = [trainee_url , trainee_model , quiz_samples , edge_sampling , trainee_api_key ])
462463
463- # 计算上传文件的token数
464464 upload_file .change (
465465 lambda x : (gr .update (visible = True )),
466466 inputs = [upload_file ],
@@ -476,8 +476,34 @@ def sum_tokens(client):
476476 lambda x : (gr .update (visible = False )),
477477 inputs = [token_counter ],
478478 outputs = [token_counter ],
479- ).then (
480- run_graphgen ,
479+ )
480+
481+ submit_btn .click (
482+ lambda * args : run_graphgen (GraphGenParams (
483+ if_trainee_model = args [0 ],
484+ input_file = args [1 ],
485+ tokenizer = args [2 ],
486+ qa_form = args [3 ],
487+ bidirectional = args [4 ],
488+ expand_method = args [5 ],
489+ max_extra_edges = args [6 ],
490+ max_tokens = args [7 ],
491+ max_depth = args [8 ],
492+ edge_sampling = args [9 ],
493+ isolated_node_strategy = args [10 ],
494+ loss_strategy = args [11 ],
495+ synthesizer_url = args [12 ],
496+ synthesizer_model = args [13 ],
497+ trainee_model = args [14 ],
498+ api_key = args [15 ],
499+ chunk_size = args [16 ],
500+ rpm = args [17 ],
501+ tpm = args [18 ],
502+ quiz_samples = args [19 ],
503+ trainee_url = args [20 ],
504+ trainee_api_key = args [21 ],
505+ token_counter = args [22 ],
506+ )),
481507 inputs = [
482508 if_trainee_model , upload_file , tokenizer , qa_form ,
483509 bidirectional , expand_method , max_extra_edges , max_tokens ,
0 commit comments