|
19 | 19 | from graphgen.graphgen import GraphGen |
20 | 20 | from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy |
21 | 21 | from graphgen.models.llm.limitter import RPM, TPM |
| 22 | +from graphgen.utils import set_logger |
| 23 | + |
22 | 24 |
|
23 | 25 | css = """ |
24 | 26 | .center-row { |
|
30 | 32 |
|
31 | 33 | def init_graph_gen(config: dict, env: dict) -> GraphGen: |
32 | 34 | # Set up working directory |
33 | | - working_dir = setup_workspace(os.path.join(root_dir, "cache")) |
| 35 | + log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache")) |
34 | 36 |
|
| 37 | + set_logger(log_file, if_stream=False) |
35 | 38 | graph_gen = GraphGen( |
36 | 39 | working_dir=working_dir |
37 | 40 | ) |
@@ -86,7 +89,7 @@ def sum_tokens(client): |
86 | 89 | "tokenizer": arguments[2], |
87 | 90 | "qa_form": arguments[3], |
88 | 91 | "web_search": False, |
89 | | - "quiz_samples": 2, |
| 92 | + "quiz_samples": arguments[19], |
90 | 93 | "traverse_strategy": { |
91 | 94 | "bidirectional": arguments[4], |
92 | 95 | "expand_method": arguments[5], |
@@ -159,7 +162,7 @@ def sum_tokens(client): |
159 | 162 |
|
160 | 163 | if config['if_trainee_model']: |
161 | 164 | # Generate quiz |
162 | | - graph_gen.quiz(max_samples=quiz_samples) |
| 165 | + graph_gen.quiz(max_samples=config['quiz_samples']) |
163 | 166 |
|
164 | 167 | # Judge statements |
165 | 168 | graph_gen.judge() |
@@ -472,7 +475,7 @@ def sum_tokens(client): |
472 | 475 | bidirectional, expand_method, max_extra_edges, max_tokens, |
473 | 476 | max_depth, edge_sampling, isolated_node_strategy, |
474 | 477 | loss_strategy, base_url, synthesizer_model, trainee_model, |
475 | | - api_key, chunk_size, rpm, tpm, token_counter |
| 478 | + api_key, chunk_size, rpm, tpm, quiz_samples, token_counter |
476 | 479 | ], |
477 | 480 | outputs=[output, token_counter], |
478 | 481 | ) |
|
0 commit comments