|
8 | 8 | import pandas as pd |
9 | 9 | from dotenv import load_dotenv |
10 | 10 |
|
| 11 | +from graphgen.engine import Context, Engine, collect_ops |
11 | 12 | from graphgen.graphgen import GraphGen |
12 | 13 | from graphgen.models import OpenAIClient, Tokenizer |
13 | 14 | from graphgen.models.llm.limitter import RPM, TPM |
@@ -97,26 +98,61 @@ def sum_tokens(client): |
97 | 98 | "unit_sampling": params.ece_unit_sampling, |
98 | 99 | } |
99 | 100 |
|
| 101 | + pipeline = [ |
| 102 | + { |
| 103 | + "name": "read", |
| 104 | + "params": { |
| 105 | + "input_file": params.upload_file, |
| 106 | + "chunk_size": params.chunk_size, |
| 107 | + "chunk_overlap": params.chunk_overlap, |
| 108 | + }, |
| 109 | + }, |
| 110 | + { |
| 111 | + "name": "build_kg", |
| 112 | + }, |
| 113 | + ] |
| 114 | + |
| 115 | + if params.if_trainee_model: |
| 116 | + pipeline.append( |
| 117 | + { |
| 118 | + "name": "quiz_and_judge", |
| 119 | + "params": {"quiz_samples": params.quiz_samples, "re_judge": True}, |
| 120 | + } |
| 121 | + ) |
| 122 | + pipeline.append( |
| 123 | + { |
| 124 | + "name": "partition", |
| 125 | + "deps": ["quiz_and_judge"], |
| 126 | + "params": { |
| 127 | + "method": params.partition_method, |
| 128 | + "method_params": partition_params, |
| 129 | + }, |
| 130 | + } |
| 131 | + ) |
| 132 | + else: |
| 133 | + pipeline.append( |
| 134 | + { |
| 135 | + "name": "partition", |
| 136 | + "params": { |
| 137 | + "method": params.partition_method, |
| 138 | + "method_params": partition_params, |
| 139 | + }, |
| 140 | + } |
| 141 | + ) |
| 142 | + pipeline.append( |
| 143 | + { |
| 144 | + "name": "generate", |
| 145 | + "params": { |
| 146 | + "method": params.mode, |
| 147 | + "data_format": params.data_format, |
| 148 | + }, |
| 149 | + } |
| 150 | + ) |
| 151 | + |
100 | 152 | config = { |
101 | 153 | "if_trainee_model": params.if_trainee_model, |
102 | 154 | "read": {"input_file": params.upload_file}, |
103 | | - "split": { |
104 | | - "chunk_size": params.chunk_size, |
105 | | - "chunk_overlap": params.chunk_overlap, |
106 | | - }, |
107 | | - "search": {"enabled": False}, |
108 | | - "quiz_and_judge": { |
109 | | - "enabled": params.if_trainee_model, |
110 | | - "quiz_samples": params.quiz_samples, |
111 | | - }, |
112 | | - "partition": { |
113 | | - "method": params.partition_method, |
114 | | - "method_params": partition_params, |
115 | | - }, |
116 | | - "generate": { |
117 | | - "mode": params.mode, |
118 | | - "data_format": params.data_format, |
119 | | - }, |
| 155 | + "pipeline": pipeline, |
120 | 156 | } |
121 | 157 |
|
122 | 158 | env = { |
@@ -145,20 +181,12 @@ def sum_tokens(client): |
145 | 181 | # Initialize GraphGen |
146 | 182 | graph_gen = init_graph_gen(config, env) |
147 | 183 | graph_gen.clear() |
148 | | - |
149 | 184 | graph_gen.progress_bar = progress |
150 | 185 |
|
151 | 186 | try: |
152 | | - # Process the data |
153 | | - graph_gen.insert(read_config=config["read"], split_config=config["split"]) |
154 | | - |
155 | | - if config["if_trainee_model"]: |
156 | | - graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"]) |
157 | | - |
158 | | - graph_gen.generate( |
159 | | - partition_config=config["partition"], |
160 | | - generate_config=config["generate"], |
161 | | - ) |
| 187 | + ctx = Context(config=config, graph_gen=graph_gen) |
| 188 | + ops = collect_ops(config, graph_gen) |
| 189 | + Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx) |
162 | 190 |
|
163 | 191 | # Save output |
164 | 192 | output_data = graph_gen.qa_storage.data |
|
0 commit comments