99from graphgen .bases .datatypes import Community
1010from graphgen .models .partitioner .bfs_partitioner import BFSPartitioner
1111
12+ NODE_UNIT : str = "n"
13+ EDGE_UNIT : str = "e"
14+
1215
1316@dataclass
1417class ECEPartitioner (BFSPartitioner ):
@@ -66,9 +69,9 @@ async def partition(
6669 node_dict = dict (nodes )
6770 edge_dict = {frozenset ((u , v )): d for u , v , d in edges }
6871
69- all_units : List [Tuple [str , Any , dict ]] = [( "n" , nid , d ) for nid , d in nodes ] + [
70- ("e" , frozenset (( u , v )), d ) for u , v , d in edges
71- ]
72+ all_units : List [Tuple [str , Any , dict ]] = [
73+ (NODE_UNIT , nid , d ) for nid , d in nodes
74+ ] + [( EDGE_UNIT , frozenset (( u , v )), d ) for u , v , d in edges ]
7275
7376 used_n : Set [str ] = set ()
7477 used_e : Set [frozenset [str ]] = set ()
@@ -89,7 +92,7 @@ async def _grow_community(
8992 async def _add_unit (u ):
9093 nonlocal token_sum
9194 t , i , d = u
92- if t == "n" :
95+ if t == NODE_UNIT : # node
9396 if i in used_n or i in community_nodes :
9497 return False
9598 community_nodes [i ] = d
@@ -117,15 +120,15 @@ async def _add_unit(u):
117120 cur_type , cur_id , _ = await queue .get ()
118121
119122 neighbors : List [Tuple [str , Any , dict ]] = []
120- if cur_type == "n" :
123+ if cur_type == NODE_UNIT :
121124 for nb_id in adj .get (cur_id , []):
122125 e_key = frozenset ((cur_id , nb_id ))
123126 if e_key not in used_e and e_key not in community_edges :
124- neighbors .append (("e" , e_key , edge_dict [e_key ]))
127+ neighbors .append ((EDGE_UNIT , e_key , edge_dict [e_key ]))
125128 else :
126129 for n_id in cur_id :
127130 if n_id not in used_n and n_id not in community_nodes :
128- neighbors .append (("n" , n_id , node_dict [n_id ]))
131+ neighbors .append ((NODE_UNIT , n_id , node_dict [n_id ]))
129132
130133 neighbors = self ._sort_units (neighbors , unit_sampling )
131134 for nb in neighbors :
@@ -149,7 +152,9 @@ async def _add_unit(u):
149152
150153 async for unit in tqdm_async (all_units , desc = "ECE partition" ):
151154 utype , uid , _ = unit
152- if (utype == "n" and uid in used_n ) or (utype == "e" and uid in used_e ):
155+ if (utype == NODE_UNIT and uid in used_n ) or (
156+ utype == EDGE_UNIT and uid in used_e
157+ ):
153158 continue
154159 comm = await _grow_community (unit )
155160 if comm is not None :
0 commit comments