Skip to content

Commit f072c2e

Browse files
tests: add tests for ECEPartitioner
1 parent 9ed56a6 commit f072c2e

File tree

4 files changed

+209
-12
lines changed

4 files changed

+209
-12
lines changed

graphgen/configs/aggregated_config.yaml

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,9 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1313
partition: # graph partition configuration
1414
method: ece # ece is a custom partition method based on comprehension loss
1515
method_params:
16-
bidirectional: true # whether to traverse the graph in both directions
17-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18-
expand_method: max_width # expand method, support: max_width, max_depth
19-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20-
max_depth: 5 # maximum depth for graph traversal
21-
max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
22-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
16+
max_units_per_community: 10 # max nodes and edges per community
17+
max_tokens_per_community: 10240 # max tokens per community
18+
unit_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
2419
generate:
2520
mode: aggregated # atomic, aggregated, multi_hop, cot
2621
data_format: ChatML # Alpaca, Sharegpt, ChatML

graphgen/models/partitioner/ece_partitioner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ async def partition(
5555
g: BaseGraphStorage,
5656
max_units_per_community: int = 10,
5757
max_tokens_per_community: int = 10240,
58-
edge_sampling: str = "random",
58+
unit_sampling: str = "random",
5959
**kwargs: Any,
6060
) -> List[Community]:
6161
nodes: List[Tuple[str, dict]] = await g.get_all_nodes()
@@ -73,7 +73,7 @@ async def partition(
7373
used_e: Set[frozenset[str]] = set()
7474
communities: List = []
7575

76-
all_units = self._sort_units(all_units, edge_sampling)
76+
all_units = self._sort_units(all_units, unit_sampling)
7777

7878
async def _grow_community(seed_unit: Tuple[str, Any, dict]) -> Community:
7979
nonlocal used_n, used_e
@@ -124,7 +124,7 @@ async def _add_unit(u):
124124
if n_id not in used_n and n_id not in community_nodes:
125125
neighbors.append(("n", n_id, node_dict[n_id]))
126126

127-
neighbors = self._sort_units(neighbors, edge_sampling)
127+
neighbors = self._sort_units(neighbors, unit_sampling)
128128
for nb in neighbors:
129129
if (
130130
len(community_nodes) + len(community_edges)

graphgen/operators/partition/partition_kg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ async def partition_kg(
3030
elif method == "ece":
3131
logger.info("Partitioning knowledge graph using ECE method.")
3232
# TODO: before ECE partitioning, we need to:
33-
# 1. 'quiz and judge' to get the comprehension loss
33+
# 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random
3434
# 2. pre-tokenize nodes and edges to get the token length
3535
edges = await kg_instance.get_all_edges()
3636
nodes = await kg_instance.get_all_nodes()
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import tempfile
2+
3+
import pytest
4+
5+
from graphgen.bases.datatypes import Community
6+
from graphgen.models import ECEPartitioner, NetworkXStorage
7+
8+
9+
@pytest.mark.asyncio
10+
async def test_ece_empty_graph():
11+
"""ECE partitioning on an empty graph should return an empty community list."""
12+
with tempfile.TemporaryDirectory() as tmpdir:
13+
storage = NetworkXStorage(working_dir=tmpdir, namespace="empty")
14+
partitioner = ECEPartitioner()
15+
communities = await partitioner.partition(
16+
storage, max_units_per_community=5, unit_sampling="random"
17+
)
18+
assert communities == []
19+
20+
21+
@pytest.mark.asyncio
22+
async def test_ece_single_node():
23+
"""A single node must be placed in exactly one community under any edge-sampling strategy."""
24+
nodes = [("A", {"desc": "alone", "length": 10, "loss": 0.1})]
25+
26+
for strategy in ("random", "min_loss", "max_loss"):
27+
with tempfile.TemporaryDirectory() as tmpdir:
28+
storage = NetworkXStorage(
29+
working_dir=tmpdir, namespace=f"single_{strategy}"
30+
)
31+
for nid, ndata in nodes:
32+
await storage.upsert_node(nid, ndata)
33+
34+
partitioner = ECEPartitioner()
35+
communities: list[Community] = await partitioner.partition(
36+
storage, max_units_per_community=5, unit_sampling=strategy
37+
)
38+
assert len(communities) == 1
39+
assert communities[0].nodes == ["A"]
40+
assert communities[0].edges == []
41+
42+
43+
@pytest.mark.asyncio
44+
async def test_ece_small_graph_random():
45+
"""
46+
2x3 grid graph:
47+
0 — 1 — 2
48+
| | |
49+
3 — 4 — 5
50+
6 nodes & 7 edges, max_units=4 => at least 3 communities expected with random sampling.
51+
"""
52+
nodes = [(str(i), {"desc": f"node{i}", "length": 10}) for i in range(6)]
53+
edges = [
54+
("0", "1", {"desc": "e01", "loss": 0.1, "length": 5}),
55+
("1", "2", {"desc": "e12", "loss": 0.2, "length": 5}),
56+
("0", "3", {"desc": "e03", "loss": 0.3, "length": 5}),
57+
("1", "4", {"desc": "e14", "loss": 0.4, "length": 5}),
58+
("2", "5", {"desc": "e25", "loss": 0.5, "length": 5}),
59+
("3", "4", {"desc": "e34", "loss": 0.6, "length": 5}),
60+
("4", "5", {"desc": "e45", "loss": 0.7, "length": 5}),
61+
]
62+
63+
with tempfile.TemporaryDirectory() as tmpdir:
64+
storage = NetworkXStorage(working_dir=tmpdir, namespace="small_random")
65+
for nid, ndata in nodes:
66+
await storage.upsert_node(nid, ndata)
67+
for src, tgt, edata in edges:
68+
await storage.upsert_edge(src, tgt, edata)
69+
70+
partitioner = ECEPartitioner()
71+
communities: list[Community] = await partitioner.partition(
72+
storage, max_units_per_community=4, unit_sampling="random"
73+
)
74+
75+
# Basic integrity checks
76+
all_nodes = set()
77+
all_edges = set()
78+
for c in communities:
79+
assert len(c.nodes) + len(c.edges) <= 4
80+
all_nodes.update(c.nodes)
81+
all_edges.update((u, v) if u < v else (v, u) for u, v in c.edges)
82+
assert all_nodes == {str(i) for i in range(6)}
83+
assert len(all_edges) == 7
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_ece_small_graph_min_loss():
88+
"""
89+
Same grid graph, but using min_loss sampling.
90+
Edges with lower loss should be preferred during community expansion.
91+
"""
92+
nodes = [
93+
(str(i), {"desc": f"node{i}", "length": 10, "loss": i * 0.1}) for i in range(6)
94+
]
95+
edges = [
96+
("0", "1", {"desc": "e01", "loss": 0.05, "length": 5}),
97+
("1", "2", {"desc": "e12", "loss": 0.10, "length": 5}),
98+
("0", "3", {"desc": "e03", "loss": 0.15, "length": 5}),
99+
("1", "4", {"desc": "e14", "loss": 0.20, "length": 5}),
100+
("2", "5", {"desc": "e25", "loss": 0.25, "length": 5}),
101+
("3", "4", {"desc": "e34", "loss": 0.30, "length": 5}),
102+
("4", "5", {"desc": "e45", "loss": 0.35, "length": 5}),
103+
]
104+
105+
with tempfile.TemporaryDirectory() as tmpdir:
106+
storage = NetworkXStorage(working_dir=tmpdir, namespace="small_min")
107+
for nid, ndata in nodes:
108+
await storage.upsert_node(nid, ndata)
109+
for src, tgt, edata in edges:
110+
await storage.upsert_edge(src, tgt, edata)
111+
112+
partitioner = ECEPartitioner()
113+
communities: list[Community] = await partitioner.partition(
114+
storage, max_units_per_community=4, unit_sampling="min_loss"
115+
)
116+
117+
all_nodes = set()
118+
all_edges = set()
119+
for c in communities:
120+
assert len(c.nodes) + len(c.edges) <= 4
121+
all_nodes.update(c.nodes)
122+
all_edges.update((u, v) if u < v else (v, u) for u, v in c.edges)
123+
assert all_nodes == {str(i) for i in range(6)}
124+
assert len(all_edges) == 7
125+
126+
127+
@pytest.mark.asyncio
128+
async def test_ece_small_graph_max_loss():
129+
"""
130+
Same grid graph, but using max_loss sampling.
131+
Edges with higher loss should be preferred during community expansion.
132+
"""
133+
nodes = [
134+
(str(i), {"desc": f"node{i}", "length": 10, "loss": (5 - i) * 0.1})
135+
for i in range(6)
136+
]
137+
edges = [
138+
("0", "1", {"desc": "e01", "loss": 0.35, "length": 5}),
139+
("1", "2", {"desc": "e12", "loss": 0.30, "length": 5}),
140+
("0", "3", {"desc": "e03", "loss": 0.25, "length": 5}),
141+
("1", "4", {"desc": "e14", "loss": 0.20, "length": 5}),
142+
("2", "5", {"desc": "e25", "loss": 0.15, "length": 5}),
143+
("3", "4", {"desc": "e34", "loss": 0.10, "length": 5}),
144+
("4", "5", {"desc": "e45", "loss": 0.05, "length": 5}),
145+
]
146+
147+
with tempfile.TemporaryDirectory() as tmpdir:
148+
storage = NetworkXStorage(working_dir=tmpdir, namespace="small_max")
149+
for nid, ndata in nodes:
150+
await storage.upsert_node(nid, ndata)
151+
for src, tgt, edata in edges:
152+
await storage.upsert_edge(src, tgt, edata)
153+
154+
partitioner = ECEPartitioner()
155+
communities: list[Community] = await partitioner.partition(
156+
storage, max_units_per_community=4, unit_sampling="max_loss"
157+
)
158+
159+
all_nodes = set()
160+
all_edges = set()
161+
for c in communities:
162+
assert len(c.nodes) + len(c.edges) <= 4
163+
all_nodes.update(c.nodes)
164+
all_edges.update((u, v) if u < v else (v, u) for u, v in c.edges)
165+
assert all_nodes == {str(i) for i in range(6)}
166+
assert len(all_edges) == 7
167+
168+
169+
@pytest.mark.asyncio
170+
async def test_ece_max_tokens_limit():
171+
"""Ensure max_tokens_per_community is respected."""
172+
# node id -> data
173+
node_data = {"A": {"length": 3000}, "B": {"length": 3000}, "C": {"length": 3000}}
174+
# edge list
175+
edges = [("A", "B", {"loss": 0.1, "length": 2000})]
176+
177+
with tempfile.TemporaryDirectory() as tmpdir:
178+
storage = NetworkXStorage(working_dir=tmpdir, namespace="token_limit")
179+
for nid, ndata in node_data.items():
180+
await storage.upsert_node(nid, ndata)
181+
for src, tgt, edata in edges:
182+
await storage.upsert_edge(src, tgt, edata)
183+
184+
partitioner = ECEPartitioner()
185+
communities: list[Community] = await partitioner.partition(
186+
storage,
187+
max_units_per_community=10,
188+
max_tokens_per_community=5000, # 1 node (3000) + 1 edge (2000) = 5000
189+
edge_sampling="random",
190+
)
191+
192+
# With a 5000-token budget we need at least two communities
193+
assert len(communities) >= 2
194+
195+
# helper: quick edge lookup
196+
edge_lens = {(u, v): d["length"] for u, v, d in edges}
197+
edge_lens.update({(v, u): d["length"] for u, v, d in edges}) # undirected
198+
199+
for c in communities:
200+
node_tokens = sum(node_data[n]["length"] for n in c.nodes)
201+
edge_tokens = sum(edge_lens[e] for e in c.edges)
202+
assert node_tokens + edge_tokens <= 5000

0 commit comments

Comments
 (0)