Skip to content

Commit 1d4ef67

Browse files
fix: use constants to distinguish nodes & edges
1 parent b4eed5b commit 1d4ef67

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

graphgen/models/partitioner/ece_partitioner.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from graphgen.bases.datatypes import Community
1010
from graphgen.models.partitioner.bfs_partitioner import BFSPartitioner
1111

12+
NODE_UNIT: str = "n"
13+
EDGE_UNIT: str = "e"
14+
1215

1316
@dataclass
1417
class ECEPartitioner(BFSPartitioner):
@@ -66,9 +69,9 @@ async def partition(
6669
node_dict = dict(nodes)
6770
edge_dict = {frozenset((u, v)): d for u, v, d in edges}
6871

69-
all_units: List[Tuple[str, Any, dict]] = [("n", nid, d) for nid, d in nodes] + [
70-
("e", frozenset((u, v)), d) for u, v, d in edges
71-
]
72+
all_units: List[Tuple[str, Any, dict]] = [
73+
(NODE_UNIT, nid, d) for nid, d in nodes
74+
] + [(EDGE_UNIT, frozenset((u, v)), d) for u, v, d in edges]
7275

7376
used_n: Set[str] = set()
7477
used_e: Set[frozenset[str]] = set()
@@ -89,7 +92,7 @@ async def _grow_community(
8992
async def _add_unit(u):
9093
nonlocal token_sum
9194
t, i, d = u
92-
if t == "n":
95+
if t == NODE_UNIT: # node
9396
if i in used_n or i in community_nodes:
9497
return False
9598
community_nodes[i] = d
@@ -117,15 +120,15 @@ async def _add_unit(u):
117120
cur_type, cur_id, _ = await queue.get()
118121

119122
neighbors: List[Tuple[str, Any, dict]] = []
120-
if cur_type == "n":
123+
if cur_type == NODE_UNIT:
121124
for nb_id in adj.get(cur_id, []):
122125
e_key = frozenset((cur_id, nb_id))
123126
if e_key not in used_e and e_key not in community_edges:
124-
neighbors.append(("e", e_key, edge_dict[e_key]))
127+
neighbors.append((EDGE_UNIT, e_key, edge_dict[e_key]))
125128
else:
126129
for n_id in cur_id:
127130
if n_id not in used_n and n_id not in community_nodes:
128-
neighbors.append(("n", n_id, node_dict[n_id]))
131+
neighbors.append((NODE_UNIT, n_id, node_dict[n_id]))
129132

130133
neighbors = self._sort_units(neighbors, unit_sampling)
131134
for nb in neighbors:
@@ -149,7 +152,9 @@ async def _add_unit(u):
149152

150153
async for unit in tqdm_async(all_units, desc="ECE partition"):
151154
utype, uid, _ = unit
152-
if (utype == "n" and uid in used_n) or (utype == "e" and uid in used_e):
155+
if (utype == NODE_UNIT and uid in used_n) or (
156+
utype == EDGE_UNIT and uid in used_e
157+
):
153158
continue
154159
comm = await _grow_community(unit)
155160
if comm is not None:

0 commit comments

Comments
 (0)