Skip to content

Commit fc8d6cb

Browse files
fix: fix parameter passing error & add logger
1 parent 2e6fba9 commit fc8d6cb

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

webui/app.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from graphgen.graphgen import GraphGen
2020
from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
2121
from graphgen.models.llm.limitter import RPM, TPM
22+
from graphgen.utils import set_logger
23+
2224

2325
css = """
2426
.center-row {
@@ -30,8 +32,9 @@
3032

3133
def init_graph_gen(config: dict, env: dict) -> GraphGen:
3234
# Set up working directory
33-
working_dir = setup_workspace(os.path.join(root_dir, "cache"))
35+
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
3436

37+
set_logger(log_file, if_stream=False)
3538
graph_gen = GraphGen(
3639
working_dir=working_dir
3740
)
@@ -86,7 +89,7 @@ def sum_tokens(client):
8689
"tokenizer": arguments[2],
8790
"qa_form": arguments[3],
8891
"web_search": False,
89-
"quiz_samples": 2,
92+
"quiz_samples": arguments[19],
9093
"traverse_strategy": {
9194
"bidirectional": arguments[4],
9295
"expand_method": arguments[5],
@@ -159,7 +162,7 @@ def sum_tokens(client):
159162

160163
if config['if_trainee_model']:
161164
# Generate quiz
162-
graph_gen.quiz(max_samples=quiz_samples)
165+
graph_gen.quiz(max_samples=config['quiz_samples'])
163166

164167
# Judge statements
165168
graph_gen.judge()
@@ -472,7 +475,7 @@ def sum_tokens(client):
472475
bidirectional, expand_method, max_extra_edges, max_tokens,
473476
max_depth, edge_sampling, isolated_node_strategy,
474477
loss_strategy, base_url, synthesizer_model, trainee_model,
475-
api_key, chunk_size, rpm, tpm, token_counter
478+
api_key, chunk_size, rpm, tpm, quiz_samples, token_counter
476479
],
477480
outputs=[output, token_counter],
478481
)

webui/cache_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ def setup_workspace(folder):
99
working_dir = os.path.join(folder, request_id)
1010
os.makedirs(working_dir, exist_ok=True)
1111

12-
return working_dir
12+
log_dir = os.path.join(folder, "logs")
13+
os.makedirs(log_dir, exist_ok=True)
14+
log_file = os.path.join(log_dir, f"{request_id}.log")
15+
16+
return log_file, working_dir
1317

1418

1519
def cleanup_workspace(folder):

0 commit comments

Comments
 (0)