@@ -26,9 +26,10 @@ async def _get_node_info(
2626
2727def _get_level_n_edges_by_max_width (
2828 edge_adj_list : dict ,
29+ node_dict : dict ,
2930 edges : list ,
30- src_id : str ,
31- tgt_id : str ,
31+ nodes ,
32+ src_edge : tuple ,
3233 max_depth : int ,
3334 bidirectional : bool ,
3435 max_extra_edges : int ,
@@ -39,15 +40,18 @@ def _get_level_n_edges_by_max_width(
3940 n is decided by max_depth in traverse_strategy
4041
4142 :param edge_adj_list
43+ :param node_dict
4244 :param edges
43- :param src_id
44- :param tgt_id
45+ :param nodes
46+ :param src_edge
4547 :param max_depth
4648 :param bidirectional
4749 :param max_extra_edges
4850 :param edge_sampling
4951 :return: level n edges
5052 """
53+ src_id , tgt_id , _ = src_edge
54+
5155 level_n_edges = []
5256
5357 start_nodes = {tgt_id } if not bidirectional else {src_id , tgt_id }
@@ -66,7 +70,8 @@ def _get_level_n_edges_by_max_width(
6670 break
6771
6872 if len (candidate_edges ) >= max_extra_edges :
69- candidate_edges = _sort_edges (candidate_edges , edge_sampling )[:max_extra_edges ]
73+ er_tuples = [([nodes [node_dict [edge [0 ]]], nodes [node_dict [edge [1 ]]]], edge ) for edge in candidate_edges ]
74+ candidate_edges = _sort_edges (er_tuples , edge_sampling )[:max_extra_edges ]
7075 for edge in candidate_edges :
7176 level_n_edges .append (edge )
7277 edge [2 ]["visited" ] = True
@@ -138,7 +143,8 @@ def _get_level_n_edges_by_max_tokens(
138143 if not candidate_edges :
139144 break
140145
141- candidate_edges = _sort_edges (candidate_edges , edge_sampling )
146+ er_tuples = [([nodes [node_dict [edge [0 ]]], nodes [node_dict [edge [1 ]]]], edge ) for edge in candidate_edges ]
147+ candidate_edges = _sort_edges (er_tuples , edge_sampling )
142148 for edge in candidate_edges :
143149 max_tokens -= edge [2 ]["length" ]
144150 if not edge [0 ] in temp_nodes :
@@ -166,22 +172,24 @@ def _get_level_n_edges_by_max_tokens(
166172 return level_n_edges
167173
168174
169- def _sort_edges (edges : list , edge_sampling : str ) -> list :
175+ def _sort_edges (er_tuples : list , edge_sampling : str ) -> list :
170176 """
171177 Sort edges with edge sampling strategy
172178
173- :param edges: total edges
179+ :param er_tuples: [(nodes:list, edge:tuple)]
174180 :param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
175181 :return: sorted edges
176182 """
177183 if edge_sampling == "random" :
178- random .shuffle ( edges )
184+ er_tuples = random .sample ( er_tuples , len ( er_tuples ) )
179185 elif edge_sampling == "min_loss" :
180- edges = sorted (edges , key = lambda x : x [2 ]["loss" ])
186+ er_tuples = sorted (er_tuples , key = lambda x : sum ( node [ 1 ][ "loss" ] for node in x [ 0 ]) + x [ 1 ] [2 ]["loss" ])
181187 elif edge_sampling == "max_loss" :
182- edges = sorted (edges , key = lambda x : x [2 ]["loss" ], reverse = True )
188+ er_tuples = sorted (er_tuples , key = lambda x : sum (node [1 ]["loss" ] for node in x [0 ]) + x [1 ][2 ]["loss" ],
189+ reverse = True )
183190 else :
184191 raise ValueError (f"Invalid edge sampling: { edge_sampling } " )
192+ edges = [edge for _ , edge in er_tuples ]
185193 return edges
186194
187195async def get_batches_with_strategy (
@@ -199,8 +207,6 @@ async def get_batches_with_strategy(
199207 max_depth = traverse_strategy .max_depth
200208 edge_sampling = traverse_strategy .edge_sampling
201209
202- edges = _sort_edges (edges , edge_sampling )
203-
204210 # 构建临接矩阵
205211 edge_adj_list = defaultdict (list )
206212 node_dict = {}
@@ -220,6 +226,9 @@ async def get_cached_node_info(node_id: str) -> dict:
220226 for i , (node_name , _ ) in enumerate (nodes ):
221227 node_dict [node_name ] = i
222228
229+ er_tuples = [([nodes [node_dict [edge [0 ]]], nodes [node_dict [edge [1 ]]]], edge ) for edge in edges ]
230+ edges = _sort_edges (er_tuples , edge_sampling )
231+
223232 for edge in tqdm_async (edges , desc = "Preparing batches" ):
224233 if "visited" in edge [2 ] and edge [2 ]["visited" ]:
225234 continue
@@ -238,7 +247,7 @@ async def get_cached_node_info(node_id: str) -> dict:
238247
239248 if expand_method == "max_width" :
240249 level_n_edges = _get_level_n_edges_by_max_width (
241- edge_adj_list , edges , src_id , tgt_id , max_depth ,
250+ edge_adj_list , node_dict , edges , nodes , edge , max_depth ,
242251 traverse_strategy .bidirectional , traverse_strategy .max_extra_edges ,
243252 edge_sampling
244253 )
@@ -260,7 +269,6 @@ async def get_cached_node_info(node_id: str) -> dict:
260269
261270 processing_batches .append ((_process_nodes , _process_edges ))
262271
263- l
264272 # isolate nodes
265273 isolated_node_strategy = traverse_strategy .isolated_node_strategy
266274 if isolated_node_strategy == "add" :
0 commit comments