33import json
44import tempfile
55
6+ import pandas as pd
67import gradio as gr
78
89from gradio_i18n import Translate , gettext as _
910from test_api import test_api_connection
1011from cache_utils import setup_workspace , cleanup_workspace
12+ from count_tokens import count_tokens
1113
1214# pylint: disable=wrong-import-position
1315root_dir = os .path .dirname (os .path .dirname (os .path .abspath (__file__ )))
2426}
2527"""
2628
27-
2829def init_graph_gen (config : dict , env : dict ) -> GraphGen :
2930 # Set up working directory
3031 working_dir = setup_workspace (os .path .join (root_dir , "cache" ))
@@ -65,6 +66,9 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
6566
6667# pylint: disable=too-many-statements
6768def run_graphgen (* arguments : list , progress = gr .Progress ()):
69+ def sum_tokens (client ):
70+ return sum (u ["total_tokens" ] for u in client .token_usage )
71+
6872 # Unpack arguments
6973 config = {
7074 "if_trainee_model" : arguments [0 ],
@@ -174,14 +178,44 @@ def run_graphgen(*arguments: list, progress=gr.Progress()):
174178 # Clean up workspace
175179 cleanup_workspace (graph_gen .working_dir )
176180
181+ synthesizer_tokens = sum_tokens (graph_gen .synthesizer_llm_client )
182+ trainee_tokens = sum_tokens (graph_gen .trainee_llm_client ) if config ['if_trainee_model' ] else 0
183+ total_tokens = synthesizer_tokens + trainee_tokens
184+
185+ data_frame = arguments [- 1 ]
186+ try :
187+ data_frame = arguments [- 1 ]
188+ _update_data = [
189+ [
190+ data_frame .iloc [0 , 0 ],
191+ data_frame .iloc [0 , 1 ],
192+ str (total_tokens )
193+ ]
194+ ]
195+ new_df = pd .DataFrame (
196+ _update_data ,
197+ columns = data_frame .columns
198+ )
199+ data_frame = new_df
200+
201+ except Exception as e :
202+ raise gr .Error (f"DataFrame operation error: { str (e )} " )
203+
177204 progress (1.0 , "Graph traversed" )
178- return output_file
205+ return output_file , gr .DataFrame (label = 'Token Stats' ,
206+ headers = ["Source Text Token Count" , "Predicted Token Count" , "Token Used" ],
207+ datatype = ["str" , "str" , "str" ],
208+ interactive = False ,
209+ value = data_frame ,
210+ visible = True ,
211+ wrap = True )
179212
180213 except Exception as e : # pylint: disable=broad-except
181214 raise gr .Error (f"Error occurred: { str (e )} " )
182215
183- with gr .Blocks (title = "GraphGen Demo" , theme = gr .themes .Glass (),
184- css = css ) as demo :
216+
217+ with (gr .Blocks (title = "GraphGen Demo" , theme = gr .themes .Glass (),
218+ css = css ) as demo ):
185219 # Header
186220 gr .Image (value = os .path .join (root_dir , 'resources' , 'images' , 'logo.png' ),
187221 label = "GraphGen Banner" ,
@@ -353,6 +387,14 @@ def run_graphgen(*arguments: list, progress=gr.Progress()):
353387 interactive = False ,
354388 )
355389
390+ with gr .Blocks ():
391+ token_counter = gr .DataFrame (label = 'Token Stats' ,
392+ headers = ["Source Text Token Count" , "Predicted Token Count" , "Token Used" ],
393+ datatype = ["str" , "str" , "str" ],
394+ interactive = False ,
395+ visible = False ,
396+ wrap = True )
397+
356398 submit_btn = gr .Button ("Run GraphGen" )
357399
358400 # Test Connection
@@ -377,17 +419,32 @@ def run_graphgen(*arguments: list, progress=gr.Progress()):
377419 inputs = if_trainee_model ,
378420 outputs = [trainee_model , quiz_samples , edge_sampling ])
379421
422+ # 计算上传文件的token数
423+ upload_file .change (
424+ lambda x : (gr .update (visible = True )),
425+ inputs = [upload_file ],
426+ outputs = [token_counter ],
427+ ).then (
428+ count_tokens ,
429+ inputs = [upload_file , tokenizer , token_counter ],
430+ outputs = [token_counter ],
431+ )
432+
380433 # run GraphGen
381434 submit_btn .click (
435+ lambda x : (gr .update (visible = False )),
436+ inputs = [token_counter ],
437+ outputs = [token_counter ],
438+ ).then (
382439 run_graphgen ,
383440 inputs = [
384441 if_trainee_model , upload_file , tokenizer , qa_form ,
385442 bidirectional , expand_method , max_extra_edges , max_tokens ,
386443 max_depth , edge_sampling , isolated_node_strategy ,
387444 loss_strategy , base_url , synthesizer_model , trainee_model ,
388- api_key , chunk_size
445+ api_key , chunk_size , token_counter
389446 ],
390- outputs = [output ],
447+ outputs = [output , token_counter ],
391448 )
392449
393450if __name__ == "__main__" :
0 commit comments