Skip to content

Commit af23e20

Browse files
fix(models): add difficulty order
1 parent bbb66ad commit af23e20

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

graphgen/operators/traverse_graph.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,20 +171,21 @@ async def _process_single_batch(
171171
losses.append(loss)
172172
q1, q2 = get_loss_tercile(losses)
173173

174+
difficulty_order = traverse_strategy.difficulty_order
174175
for i, batch in enumerate(processing_batches):
175176
if len(batch[1]) == 0:
176-
processing_batches[i] = (batch[0], batch[1], "easy")
177+
processing_batches[i] = (batch[0], batch[1], difficulty_order[0])
177178
continue
178179
loss = sum(edge[2]['loss'] for edge in batch[1]) / len(batch[1])
179180
if loss < q1:
180181
# easy
181-
processing_batches[i] = (batch[0], batch[1], "easy")
182+
processing_batches[i] = (batch[0], batch[1], difficulty_order[0])
182183
elif loss < q2:
183184
# medium
184-
processing_batches[i] = (batch[0], batch[1], "medium")
185+
processing_batches[i] = (batch[0], batch[1], difficulty_order[1])
185186
else:
186187
# hard
187-
processing_batches[i] = (batch[0], batch[1], "hard")
188+
processing_batches[i] = (batch[0], batch[1], difficulty_order[2])
188189

189190
for result in tqdm_async(asyncio.as_completed(
190191
[_process_single_batch(batch) for batch in processing_batches]

models/strategy/travserse_strategy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
22

33
from models.strategy.base_strategy import BaseStrategy
44

@@ -19,6 +19,8 @@ class TraverseStrategy(BaseStrategy):
1919
edge_sampling: str = "max_loss" # "max_loss" or "min_loss" or "random"
2020
# 孤立节点的处理策略
2121
isolated_node_strategy: str = "add" # "add" or "ignore"
22+
# 难度顺序 ["easy", "medium", "hard"], ["hard", "medium", "easy"], ["medium", "medium", "medium"]
23+
difficulty_order: list = field(default_factory=lambda: ["easy", "medium", "hard"])
2224

2325
def to_yaml(self):
2426
return {
@@ -29,6 +31,7 @@ def to_yaml(self):
2931
"max_tokens": self.max_tokens,
3032
"max_depth": self.max_depth,
3133
"edge_sampling": self.edge_sampling,
32-
"isolated_node_strategy": self.isolated_node_strategy
34+
"isolated_node_strategy": self.isolated_node_strategy,
35+
"difficulty_order": self.difficulty_order
3336
}
3437
}

0 commit comments

Comments
 (0)