Skip to content

Commit c5958f8

Browse files
fix: add param min_units_per_community
1 parent 55667d7 commit c5958f8

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ max-public-methods=20
308308
max-returns=6
309309

310310
# Maximum number of statements in function / method body.
311-
max-statements=50
311+
max-statements=60
312312

313313
# Minimum number of public methods for a class (see R0903).
314314
min-public-methods=2

graphgen/configs/aggregated_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ partition: # graph partition configuration
1414
method: ece # ece is a custom partition method based on comprehension loss
1515
method_params:
1616
max_units_per_community: 20 # max nodes and edges per community
17+
min_units_per_community: 5 # min nodes and edges per community
1718
max_tokens_per_community: 10240 # max tokens per community
1819
unit_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
1920
generate:

graphgen/models/partitioner/ece_partitioner.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import random
33
from dataclasses import dataclass
4-
from typing import Any, Dict, List, Set, Tuple
4+
from typing import Any, Dict, List, Optional, Set, Tuple
55

66
from tqdm.asyncio import tqdm as tqdm_async
77

@@ -54,6 +54,7 @@ async def partition(
5454
self,
5555
g: BaseGraphStorage,
5656
max_units_per_community: int = 10,
57+
min_units_per_community: int = 1,
5758
max_tokens_per_community: int = 10240,
5859
unit_sampling: str = "random",
5960
**kwargs: Any,
@@ -75,7 +76,9 @@ async def partition(
7576

7677
all_units = self._sort_units(all_units, unit_sampling)
7778

78-
async def _grow_community(seed_unit: Tuple[str, Any, dict]) -> Community:
79+
async def _grow_community(
80+
seed_unit: Tuple[str, Any, dict]
81+
) -> Optional[Community]:
7982
nonlocal used_n, used_e
8083

8184
community_nodes: Dict[str, dict] = {}
@@ -135,6 +138,9 @@ async def _add_unit(u):
135138
if await _add_unit(nb):
136139
await queue.put(nb)
137140

141+
if len(community_nodes) + len(community_edges) < min_units_per_community:
142+
return None
143+
138144
return Community(
139145
id=len(communities),
140146
nodes=list(community_nodes.keys()),
@@ -145,6 +151,8 @@ async def _add_unit(u):
145151
utype, uid, _ = unit
146152
if (utype == "n" and uid in used_n) or (utype == "e" and uid in used_e):
147153
continue
148-
communities.append(await _grow_community(unit))
154+
comm = await _grow_community(unit)
155+
if comm is not None:
156+
communities.append(comm)
149157

150158
return communities

graphgen/operators/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
from graphgen.operators.partition.traverse_graph import (
2-
traverse_graph_for_aggregated,
3-
traverse_graph_for_multi_hop,
4-
)
5-
61
from .build_kg import build_kg
72
from .generate import generate_qas
83
from .judge import judge_statement

0 commit comments

Comments
 (0)