Skip to content

Commit d7d6c2a

Browse files
fix: fix transferring quizzed data to JudgeService
1 parent c55fc09 commit d7d6c2a

File tree

2 files changed

+20
-153
lines changed

2 files changed

+20
-153
lines changed

graphgen/operators/judge/judge_service.py

Lines changed: 7 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -52,153 +52,20 @@ def judge(self, items: list[dict]) -> None:
5252
desc="Judging descriptions",
5353
unit="description",
5454
)
55-
5655
# Update the graph storage with the computed losses
5756
for item in results:
58-
print(item)
59-
node_id = item.get("node_id")
60-
edge_source = item.get("edge_source")
61-
edge_target = item.get("edge_target")
57+
index = item["index"]
6258
loss = item["loss"]
63-
if node_id is not None:
59+
if isinstance(index, str):
60+
node_id = index
6461
node_data = self.graph_storage.get_node(node_id)
65-
if node_data is not None:
62+
if node_data:
6663
node_data["loss"] = loss
6764
self.graph_storage.update_node(node_id, node_data)
68-
elif edge_source is not None and edge_target is not None:
65+
elif isinstance(index, tuple):
66+
edge_source, edge_target = index
6967
edge_data = self.graph_storage.get_edge(edge_source, edge_target)
70-
if edge_data is not None:
68+
if edge_data:
7169
edge_data["loss"] = loss
7270
self.graph_storage.update_edge(edge_source, edge_target, edge_data)
7371
self.graph_storage.index_done_callback()
74-
75-
76-
# async def judge_statement( # pylint: disable=too-many-statements
77-
# trainee_llm_client: BaseLLMWrapper,
78-
# graph_storage: NetworkXStorage,
79-
# rephrase_storage: JsonKVStorage,
80-
# re_judge: bool = False,
81-
# progress_bar: gr.Progress = None,
82-
# ) -> NetworkXStorage:
83-
# """
84-
# Get all edges and nodes and judge them
85-
#
86-
# :param trainee_llm_client: judge the statements to get comprehension loss
87-
# :param graph_storage: graph storage instance
88-
# :param rephrase_storage: rephrase storage instance
89-
# :param re_judge: re-judge the relations
90-
# :param progress_bar
91-
# :return:
92-
# """
93-
#
94-
# async def _judge_single_relation(
95-
# edge: tuple,
96-
# ):
97-
# source_id = edge[0]
98-
# target_id = edge[1]
99-
# edge_data = edge[2]
100-
#
101-
# if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
102-
# logger.debug(
103-
# "Edge %s -> %s already judged, loss: %s, skip",
104-
# source_id,
105-
# target_id,
106-
# edge_data["loss"],
107-
# )
108-
# return source_id, target_id, edge_data
109-
#
110-
# description = edge_data["description"]
111-
#
112-
# try:
113-
# descriptions = rephrase_storage.get_by_id(description)
114-
# assert descriptions is not None
115-
#
116-
# judgements = []
117-
# gts = [gt for _, gt in descriptions]
118-
# for description, gt in descriptions:
119-
# judgement = await trainee_llm_client.generate_topk_per_token(
120-
# STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
121-
# )
122-
# judgements.append(judgement[0].top_candidates)
123-
#
124-
# loss = yes_no_loss_entropy(judgements, gts)
125-
#
126-
# logger.debug(
127-
# "Edge %s -> %s description: %s loss: %s",
128-
# source_id,
129-
# target_id,
130-
# description,
131-
# loss,
132-
# )
133-
#
134-
# edge_data["loss"] = loss
135-
# except Exception as e: # pylint: disable=broad-except
136-
# logger.error(
137-
# "Error in judging relation %s -> %s: %s", source_id, target_id, e
138-
# )
139-
# logger.info("Use default loss 0.1")
140-
# edge_data["loss"] = -math.log(0.1)
141-
#
142-
# graph_storage.update_edge(source_id, target_id, edge_data)
143-
# return source_id, target_id, edge_data
144-
#
145-
# edges = graph_storage.get_all_edges()
146-
#
147-
# await run_concurrent(
148-
# _judge_single_relation,
149-
# edges,
150-
# desc="Judging relations",
151-
# unit="relation",
152-
# progress_bar=progress_bar,
153-
# )
154-
#
155-
# async def _judge_single_entity(
156-
# node: tuple,
157-
# ):
158-
# node_id = node[0]
159-
# node_data = node[1]
160-
#
161-
# if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
162-
# logger.debug(
163-
# "Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
164-
# )
165-
# return node_id, node_data
166-
#
167-
# description = node_data["description"]
168-
#
169-
# try:
170-
# descriptions = rephrase_storage.get_by_id(description)
171-
# assert descriptions is not None
172-
#
173-
# judgements = []
174-
# gts = [gt for _, gt in descriptions]
175-
# for description, gt in descriptions:
176-
# judgement = await trainee_llm_client.generate_topk_per_token(
177-
# STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
178-
# )
179-
# judgements.append(judgement[0].top_candidates)
180-
#
181-
# loss = yes_no_loss_entropy(judgements, gts)
182-
#
183-
# logger.debug("Node %s description: %s loss: %s", node_id, description, loss)
184-
#
185-
# node_data["loss"] = loss
186-
# except Exception as e: # pylint: disable=broad-except
187-
# logger.error("Error in judging entity %s: %s", node_id, e)
188-
# logger.error("Use default loss 0.1")
189-
# node_data["loss"] = -math.log(0.1)
190-
#
191-
# graph_storage.update_node(node_id, node_data)
192-
# return node_id, node_data
193-
#
194-
# nodes = graph_storage.get_all_nodes()
195-
#
196-
# await run_concurrent(
197-
# _judge_single_entity,
198-
# nodes,
199-
# desc="Judging entities",
200-
# unit="entity",
201-
# progress_bar=progress_bar,
202-
# )
203-
#
204-
# return graph_storage

graphgen/operators/quiz/quiz_service.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,31 +34,31 @@ def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
3434
self.graph_storage.reload()
3535
yield from self.quiz()
3636

37-
async def _process_single_quiz(self, item: str) -> dict | None:
37+
async def _process_single_quiz(self, item: tuple) -> dict | None:
3838
# if quiz in quiz_storage exists already, directly get it
39-
_description_id = compute_content_hash(item)
39+
index, desc = item
40+
_description_id = compute_content_hash(desc, prefix="quiz-")
4041
if self.quiz_storage.get_by_id(_description_id):
4142
return None
4243

4344
tasks = []
4445
for i in range(self.quiz_samples):
4546
if i > 0:
46-
tasks.append((item, "TEMPLATE", "yes"))
47-
tasks.append((item, "ANTI_TEMPLATE", "no"))
47+
tasks.append((desc, "TEMPLATE", "yes"))
48+
tasks.append((desc, "ANTI_TEMPLATE", "no"))
4849
try:
4950
quizzes = []
50-
for description, template_type, gt in tasks:
51-
prompt = self.generator.build_prompt_for_description(
52-
description, template_type
53-
)
51+
for d, template_type, gt in tasks:
52+
prompt = self.generator.build_prompt_for_description(d, template_type)
5453
new_description = await self.llm_client.generate_answer(
5554
prompt, temperature=1
5655
)
5756
rephrased_text = self.generator.parse_rephrased_text(new_description)
5857
quizzes.append((rephrased_text, gt))
5958
return {
6059
"_description_id": _description_id,
61-
"description": item,
60+
"description": desc,
61+
"index": index,
6262
"quizzes": quizzes,
6363
}
6464
except Exception as e:
@@ -76,13 +76,13 @@ def quiz(self) -> Iterable[pd.DataFrame]:
7676

7777
for edge in edges:
7878
edge_data = edge[2]
79-
description = edge_data["description"]
80-
items.append(description)
79+
desc = edge_data["description"]
80+
items.append(((edge[0], edge[1]), desc))
8181

8282
for node in nodes:
8383
node_data = node[1]
84-
description = node_data["description"]
85-
items.append(description)
84+
desc = node_data["description"]
85+
items.append((node[0], desc))
8686

8787
logger.info("Total descriptions to quiz: %d", len(items))
8888

0 commit comments

Comments
 (0)