Skip to content

Commit dd0377f

Browse files
feat(webui): count token usage
1 parent 2343e7e commit dd0377f

File tree

3 files changed

+127
-9
lines changed

3 files changed

+127
-9
lines changed

graphgen/models/llm/openai_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
33
from typing import List, Dict, Optional
44
import openai
55
from openai import AsyncOpenAI, RateLimitError, APIConnectionError, APITimeoutError
@@ -31,10 +31,10 @@ class OpenAIModel(TopkTokenModel):
3131
model_name: str = "gpt-4o-mini"
3232
api_key: str = None
3333
base_url: str = None
34-
3534
system_prompt: str = ""
3635
json_mode: bool = False
3736
seed: int = None
37+
token_usage: list = field(default_factory=list)
3838

3939
def __post_init__(self):
4040
assert self.api_key is not None, "Please provide api key to access openai api."
@@ -99,7 +99,12 @@ async def generate_answer(self, text: str, history: Optional[List[str]] = None,
9999
model=self.model_name,
100100
**kwargs
101101
)
102-
102+
if hasattr(completion, "usage"):
103+
self.token_usage.append({
104+
"prompt_tokens": completion.usage.prompt_tokens,
105+
"completion_tokens": completion.usage.completion_tokens,
106+
"total_tokens": completion.usage.total_tokens,
107+
})
103108
return completion.choices[0].message.content
104109

105110
async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]:

webui/app.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import json
44
import tempfile
55

6+
import pandas as pd
67
import gradio as gr
78

89
from gradio_i18n import Translate, gettext as _
910
from test_api import test_api_connection
1011
from cache_utils import setup_workspace, cleanup_workspace
12+
from count_tokens import count_tokens
1113

1214
# pylint: disable=wrong-import-position
1315
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -24,7 +26,6 @@
2426
}
2527
"""
2628

27-
2829
def init_graph_gen(config: dict, env: dict) -> GraphGen:
2930
# Set up working directory
3031
working_dir = setup_workspace(os.path.join(root_dir, "cache"))
@@ -65,6 +66,9 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
6566

6667
# pylint: disable=too-many-statements
6768
def run_graphgen(*arguments: list, progress=gr.Progress()):
69+
def sum_tokens(client):
70+
return sum(u["total_tokens"] for u in client.token_usage)
71+
6872
# Unpack arguments
6973
config = {
7074
"if_trainee_model": arguments[0],
@@ -174,14 +178,44 @@ def run_graphgen(*arguments: list, progress=gr.Progress()):
174178
# Clean up workspace
175179
cleanup_workspace(graph_gen.working_dir)
176180

181+
synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
182+
trainee_tokens = sum_tokens(graph_gen.trainee_llm_client) if config['if_trainee_model'] else 0
183+
total_tokens = synthesizer_tokens + trainee_tokens
184+
185+
data_frame = arguments[-1]
186+
try:
187+
data_frame = arguments[-1]
188+
_update_data = [
189+
[
190+
data_frame.iloc[0, 0],
191+
data_frame.iloc[0, 1],
192+
str(total_tokens)
193+
]
194+
]
195+
new_df = pd.DataFrame(
196+
_update_data,
197+
columns=data_frame.columns
198+
)
199+
data_frame = new_df
200+
201+
except Exception as e:
202+
raise gr.Error(f"DataFrame operation error: {str(e)}")
203+
177204
progress(1.0, "Graph traversed")
178-
return output_file
205+
return output_file, gr.DataFrame(label='Token Stats',
206+
headers=["Source Text Token Count", "Predicted Token Count", "Token Used"],
207+
datatype=["str", "str", "str"],
208+
interactive=False,
209+
value=data_frame,
210+
visible=True,
211+
wrap=True)
179212

180213
except Exception as e: # pylint: disable=broad-except
181214
raise gr.Error(f"Error occurred: {str(e)}")
182215

183-
with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
184-
css=css) as demo:
216+
217+
with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(),
218+
css=css) as demo):
185219
# Header
186220
gr.Image(value=os.path.join(root_dir, 'resources', 'images', 'logo.png'),
187221
label="GraphGen Banner",
@@ -353,6 +387,14 @@ def run_graphgen(*arguments: list, progress=gr.Progress()):
353387
interactive=False,
354388
)
355389

390+
with gr.Blocks():
391+
token_counter = gr.DataFrame(label='Token Stats',
392+
headers=["Source Text Token Count", "Predicted Token Count", "Token Used"],
393+
datatype=["str", "str", "str"],
394+
interactive=False,
395+
visible=False,
396+
wrap=True)
397+
356398
submit_btn = gr.Button("Run GraphGen")
357399

358400
# Test Connection
@@ -377,17 +419,32 @@ def run_graphgen(*arguments: list, progress=gr.Progress()):
377419
inputs=if_trainee_model,
378420
outputs=[trainee_model, quiz_samples, edge_sampling])
379421

422+
# 计算上传文件的token数
423+
upload_file.change(
424+
lambda x: (gr.update(visible=True)),
425+
inputs=[upload_file],
426+
outputs=[token_counter],
427+
).then(
428+
count_tokens,
429+
inputs=[upload_file, tokenizer, token_counter],
430+
outputs=[token_counter],
431+
)
432+
380433
# run GraphGen
381434
submit_btn.click(
435+
lambda x: (gr.update(visible=False)),
436+
inputs=[token_counter],
437+
outputs=[token_counter],
438+
).then(
382439
run_graphgen,
383440
inputs=[
384441
if_trainee_model, upload_file, tokenizer, qa_form,
385442
bidirectional, expand_method, max_extra_edges, max_tokens,
386443
max_depth, edge_sampling, isolated_node_strategy,
387444
loss_strategy, base_url, synthesizer_model, trainee_model,
388-
api_key, chunk_size
445+
api_key, chunk_size, token_counter
389446
],
390-
outputs=[output],
447+
outputs=[output, token_counter],
391448
)
392449

393450
if __name__ == "__main__":

webui/count_tokens.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
import sys
3+
import json
4+
import pandas as pd
5+
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6+
sys.path.append(root_dir)
7+
8+
from graphgen.models import Tokenizer
9+
10+
def count_tokens(file, tokenizer_name, data_frame):
11+
if file.endswith(".jsonl"):
12+
with open(file, "r", encoding='utf-8') as f:
13+
data = [json.loads(line) for line in f]
14+
elif file.endswith(".json"):
15+
with open(file, "r", encoding='utf-8') as f:
16+
data = json.load(f)
17+
data = [item for sublist in data for item in sublist]
18+
elif file.endswith(".txt"):
19+
with open(file, "r", encoding='utf-8') as f:
20+
data = f.read()
21+
chunks = [
22+
data[i:i + 512] for i in range(0, len(data), 512)
23+
]
24+
data = [{"content": chunk} for chunk in chunks]
25+
else:
26+
raise ValueError(f"Unsupported file type: {file}")
27+
28+
tokenizer = Tokenizer(tokenizer_name)
29+
30+
# Count tokens
31+
token_count = 0
32+
33+
for item in data:
34+
if isinstance(item, dict):
35+
content = item.get("content", "")
36+
else:
37+
content = item
38+
token_count += len(tokenizer.encode_string(content))
39+
40+
_update_data = [[
41+
str(token_count),
42+
str(token_count * 50),
43+
"N/A"
44+
]]
45+
46+
try:
47+
new_df = pd.DataFrame(
48+
_update_data,
49+
columns=data_frame.columns
50+
)
51+
data_frame = new_df
52+
53+
except Exception as e:
54+
print("[ERROR] DataFrame操作异常:", str(e))
55+
56+
return data_frame

0 commit comments

Comments
 (0)