@@ -52,153 +52,20 @@ def judge(self, items: list[dict]) -> None:
5252 desc = "Judging descriptions" ,
5353 unit = "description" ,
5454 )
55-
5655 # Update the graph storage with the computed losses
5756 for item in results :
58- print (item )
59- node_id = item .get ("node_id" )
60- edge_source = item .get ("edge_source" )
61- edge_target = item .get ("edge_target" )
57+ index = item ["index" ]
6258 loss = item ["loss" ]
63- if node_id is not None :
59+ if isinstance (index , str ):
60+ node_id = index
6461 node_data = self .graph_storage .get_node (node_id )
65- if node_data is not None :
62+ if node_data :
6663 node_data ["loss" ] = loss
6764 self .graph_storage .update_node (node_id , node_data )
68- elif edge_source is not None and edge_target is not None :
65+ elif isinstance (index , tuple ):
66+ edge_source , edge_target = index
6967 edge_data = self .graph_storage .get_edge (edge_source , edge_target )
70- if edge_data is not None :
68+ if edge_data :
7169 edge_data ["loss" ] = loss
7270 self .graph_storage .update_edge (edge_source , edge_target , edge_data )
7371 self .graph_storage .index_done_callback ()
74-
75-
76- # async def judge_statement( # pylint: disable=too-many-statements
77- # trainee_llm_client: BaseLLMWrapper,
78- # graph_storage: NetworkXStorage,
79- # rephrase_storage: JsonKVStorage,
80- # re_judge: bool = False,
81- # progress_bar: gr.Progress = None,
82- # ) -> NetworkXStorage:
83- # """
84- # Get all edges and nodes and judge them
85- #
86- # :param trainee_llm_client: judge the statements to get comprehension loss
87- # :param graph_storage: graph storage instance
88- # :param rephrase_storage: rephrase storage instance
89- # :param re_judge: re-judge the relations
90- # :param progress_bar
91- # :return:
92- # """
93- #
94- # async def _judge_single_relation(
95- # edge: tuple,
96- # ):
97- # source_id = edge[0]
98- # target_id = edge[1]
99- # edge_data = edge[2]
100- #
101- # if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
102- # logger.debug(
103- # "Edge %s -> %s already judged, loss: %s, skip",
104- # source_id,
105- # target_id,
106- # edge_data["loss"],
107- # )
108- # return source_id, target_id, edge_data
109- #
110- # description = edge_data["description"]
111- #
112- # try:
113- # descriptions = rephrase_storage.get_by_id(description)
114- # assert descriptions is not None
115- #
116- # judgements = []
117- # gts = [gt for _, gt in descriptions]
118- # for description, gt in descriptions:
119- # judgement = await trainee_llm_client.generate_topk_per_token(
120- # STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
121- # )
122- # judgements.append(judgement[0].top_candidates)
123- #
124- # loss = yes_no_loss_entropy(judgements, gts)
125- #
126- # logger.debug(
127- # "Edge %s -> %s description: %s loss: %s",
128- # source_id,
129- # target_id,
130- # description,
131- # loss,
132- # )
133- #
134- # edge_data["loss"] = loss
135- # except Exception as e: # pylint: disable=broad-except
136- # logger.error(
137- # "Error in judging relation %s -> %s: %s", source_id, target_id, e
138- # )
139- # logger.info("Use default loss 0.1")
140- # edge_data["loss"] = -math.log(0.1)
141- #
142- # graph_storage.update_edge(source_id, target_id, edge_data)
143- # return source_id, target_id, edge_data
144- #
145- # edges = graph_storage.get_all_edges()
146- #
147- # await run_concurrent(
148- # _judge_single_relation,
149- # edges,
150- # desc="Judging relations",
151- # unit="relation",
152- # progress_bar=progress_bar,
153- # )
154- #
155- # async def _judge_single_entity(
156- # node: tuple,
157- # ):
158- # node_id = node[0]
159- # node_data = node[1]
160- #
161- # if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
162- # logger.debug(
163- # "Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
164- # )
165- # return node_id, node_data
166- #
167- # description = node_data["description"]
168- #
169- # try:
170- # descriptions = rephrase_storage.get_by_id(description)
171- # assert descriptions is not None
172- #
173- # judgements = []
174- # gts = [gt for _, gt in descriptions]
175- # for description, gt in descriptions:
176- # judgement = await trainee_llm_client.generate_topk_per_token(
177- # STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
178- # )
179- # judgements.append(judgement[0].top_candidates)
180- #
181- # loss = yes_no_loss_entropy(judgements, gts)
182- #
183- # logger.debug("Node %s description: %s loss: %s", node_id, description, loss)
184- #
185- # node_data["loss"] = loss
186- # except Exception as e: # pylint: disable=broad-except
187- # logger.error("Error in judging entity %s: %s", node_id, e)
188- # logger.error("Use default loss 0.1")
189- # node_data["loss"] = -math.log(0.1)
190- #
191- # graph_storage.update_node(node_id, node_data)
192- # return node_id, node_data
193- #
194- # nodes = graph_storage.get_all_nodes()
195- #
196- # await run_concurrent(
197- # _judge_single_entity,
198- # nodes,
199- # desc="Judging entities",
200- # unit="entity",
201- # progress_bar=progress_bar,
202- # )
203- #
204- # return graph_storage
0 commit comments