Skip to content

Commit f2796f5

Browse files
Feat: add kg evaluators (#135)
* wip: refactor data evaluators & add kg evaluators * feat: add KG quality evaluation module * refactor: removed repeated calculations and remove hardcoded params * add: add kg_evaluate config file for params * fix: correct relation acc evaluation logic * refactor: enhance KG evaluator to use llm-as judge; remove evaluate_kg_config * fix: fix format and clean up imports * wip: refactor evaluator structure * wip: add annotations * refactor: refactor proj structure & configs * wip: split prompts * refactor: refactor base_evaluator * refator: refactor LengthEvaluator * refactor: refactor MTLDEvaluator * refactor: refactor NLTKHelper * refactor: refactor RewardEvaluator * refactor: refactor UniEvaluator * refactor: refactor evaluator structure * refactor: change evaluation methods in acc and consistency to sync * refactor: streamline evaluation functions for accuracy, consistency, and structure * wip: perf evaluate_service * perf: perf evaluate_service * fix: fix output node * merge * feat: add KGQualityEvaluator and integrate into EvaluateService for KG evaluations * refactor: remove KGQualityEvaluator and restructure KG evaluation integration * pylints * feat: add kg_structure evaluation * feat: add kg_accuracy & kg_consistency metrics --------- Co-authored-by: chenzihong-gavin <[email protected]> Co-authored-by: chenzihong <[email protected]>
1 parent a17bf7f commit f2796f5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2167
-610
lines changed

examples/evaluate/evaluate.sh

Lines changed: 0 additions & 3 deletions
This file was deleted.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/evaluate/evaluate_kg/kg_evaluation_config.yaml
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
global_params:
2+
working_dir: cache
3+
graph_backend: kuzu # graph database backend, support: kuzu, networkx
4+
kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv
5+
6+
nodes:
7+
- id: read
8+
op_name: read
9+
type: source
10+
dependencies: []
11+
params:
12+
input_path:
13+
- examples/input_examples/extract_demo.txt
14+
15+
- id: chunk
16+
op_name: chunk
17+
type: map_batch
18+
dependencies:
19+
- read
20+
execution_params:
21+
replicas: 4
22+
params:
23+
chunk_size: 20480 # larger chunk size for better context
24+
chunk_overlap: 2000
25+
26+
- id: build_kg
27+
op_name: build_kg
28+
type: map_batch
29+
dependencies:
30+
- chunk
31+
execution_params:
32+
replicas: 1
33+
batch_size: 128
34+
35+
- id: evaluate
36+
op_name: evaluate
37+
type: aggregate
38+
save_output: true
39+
dependencies:
40+
- build_kg
41+
params:
42+
metrics:
43+
- kg_structure
44+
- kg_accuracy
45+
- kg_consistency
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/evaluate/evaluate_qa/qa_evaluation_config.yaml
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
global_params:
2+
working_dir: cache
3+
graph_backend: kuzu # graph database backend, support: kuzu, networkx
4+
kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv
5+
6+
nodes:
7+
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
8+
op_name: read
9+
type: source
10+
dependencies: []
11+
params:
12+
input_path:
13+
- examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples
14+
15+
- id: chunk_documents
16+
op_name: chunk
17+
type: map_batch
18+
dependencies:
19+
- read_files
20+
execution_params:
21+
replicas: 4
22+
params:
23+
chunk_size: 1024 # chunk size for text splitting
24+
chunk_overlap: 100 # chunk overlap for text splitting
25+
26+
- id: build_kg
27+
op_name: build_kg
28+
type: map_batch
29+
dependencies:
30+
- chunk_documents
31+
execution_params:
32+
replicas: 1
33+
batch_size: 128
34+
35+
- id: quiz
36+
op_name: quiz
37+
type: aggregate
38+
dependencies:
39+
- build_kg
40+
execution_params:
41+
replicas: 1
42+
batch_size: 128
43+
params:
44+
quiz_samples: 2 # number of quiz samples to generate
45+
concurrency_limit: 200
46+
47+
- id: judge
48+
op_name: judge
49+
type: map_batch
50+
dependencies:
51+
- quiz
52+
execution_params:
53+
replicas: 1
54+
batch_size: 128
55+
56+
- id: partition
57+
op_name: partition
58+
type: aggregate
59+
dependencies:
60+
- judge
61+
params:
62+
method: ece # ece is a custom partition method based on comprehension loss
63+
method_params:
64+
max_units_per_community: 20 # max nodes and edges per community
65+
min_units_per_community: 5 # min nodes and edges per community
66+
max_tokens_per_community: 10240 # max tokens per community
67+
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
68+
69+
- id: generate
70+
op_name: generate
71+
type: map_batch
72+
dependencies:
73+
- partition
74+
execution_params:
75+
replicas: 1
76+
batch_size: 128
77+
save_output: true
78+
params:
79+
method: aggregated # atomic, aggregated, multi_hop, cot, vqa
80+
data_format: ChatML # Alpaca, Sharegpt, ChatML
81+
82+
- id: evaluate
83+
op_name: evaluate
84+
type: map_batch
85+
dependencies:
86+
- generate
87+
execution_params:
88+
replicas: 1
89+
batch_size: 128
90+
save_output: true
91+
params:
92+
metrics:
93+
- qa_length
94+
- qa_mtld
95+
- qa_reward_score
96+
- qa_uni_score
97+
mtld_params:
98+
threshold: 0.7

graphgen/bases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
from .base_splitter import BaseSplitter
1010
from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
1111
from .base_tokenizer import BaseTokenizer
12+
from .base_evaluator import BaseEvaluator
1213
from .datatypes import Chunk, Config, Node, QAPair, Token

graphgen/bases/base_evaluator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from abc import ABC, abstractmethod
2+
from .datatypes import QAPair
3+
4+
5+
class BaseEvaluator(ABC):
6+
@abstractmethod
7+
def evaluate(self, pair: QAPair) -> float:
8+
"""
9+
Evaluate the text and return a score.
10+
"""

graphgen/bases/base_storage.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from abc import ABC, abstractmethod
12
from dataclasses import dataclass
2-
from typing import Generic, TypeVar, Union
3+
from typing import Dict, Generic, List, Set, TypeVar, Union
34

45
T = TypeVar("T")
56

@@ -45,52 +46,90 @@ def reload(self):
4546
raise NotImplementedError
4647

4748

48-
class BaseGraphStorage(StorageNameSpace):
49+
class BaseGraphStorage(StorageNameSpace, ABC):
50+
@abstractmethod
51+
def is_directed(self) -> bool:
52+
pass
53+
54+
@abstractmethod
4955
def has_node(self, node_id: str) -> bool:
5056
raise NotImplementedError
5157

58+
@abstractmethod
5259
def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
5360
raise NotImplementedError
5461

62+
@abstractmethod
5563
def node_degree(self, node_id: str) -> int:
5664
raise NotImplementedError
5765

58-
def edge_degree(self, src_id: str, tgt_id: str) -> int:
59-
raise NotImplementedError
66+
@abstractmethod
67+
def get_all_node_degrees(self) -> Dict[str, int]:
68+
pass
6069

70+
def get_isolated_nodes(self) -> List[str]:
71+
return [
72+
node_id
73+
for node_id, degree in self.get_all_node_degrees().items()
74+
if degree == 0
75+
]
76+
77+
@abstractmethod
6178
def get_node(self, node_id: str) -> Union[dict, None]:
6279
raise NotImplementedError
6380

81+
@abstractmethod
6482
def update_node(self, node_id: str, node_data: dict[str, str]):
6583
raise NotImplementedError
6684

85+
@abstractmethod
6786
def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
6887
raise NotImplementedError
6988

89+
@abstractmethod
90+
def get_node_count(self) -> int:
91+
pass
92+
93+
@abstractmethod
7094
def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
7195
raise NotImplementedError
7296

97+
@abstractmethod
7398
def update_edge(
7499
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
75100
):
76101
raise NotImplementedError
77102

103+
@abstractmethod
78104
def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
79105
raise NotImplementedError
80106

107+
@abstractmethod
108+
def get_edge_count(self) -> int:
109+
pass
110+
111+
@abstractmethod
81112
def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
82113
raise NotImplementedError
83114

115+
@abstractmethod
84116
def upsert_node(self, node_id: str, node_data: dict[str, str]):
85117
raise NotImplementedError
86118

119+
@abstractmethod
87120
def upsert_edge(
88121
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
89122
):
90123
raise NotImplementedError
91124

125+
@abstractmethod
92126
def delete_node(self, node_id: str):
93127
raise NotImplementedError
94128

129+
@abstractmethod
95130
def reload(self):
96131
raise NotImplementedError
132+
133+
@abstractmethod
134+
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
135+
raise NotImplementedError

graphgen/common/init_storage.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Union
1+
from typing import Any, Dict, List, Set, Union
22

33
import ray
44

@@ -68,6 +68,21 @@ def __init__(self, backend: str, working_dir: str, namespace: str):
6868
def index_done_callback(self):
6969
return self.graph.index_done_callback()
7070

71+
def is_directed(self) -> bool:
72+
return self.graph.is_directed()
73+
74+
def get_all_node_degrees(self) -> Dict[str, int]:
75+
return self.graph.get_all_node_degrees()
76+
77+
def get_node_count(self) -> int:
78+
return self.graph.get_node_count()
79+
80+
def get_edge_count(self) -> int:
81+
return self.graph.get_edge_count()
82+
83+
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
84+
return self.graph.get_connected_components(undirected)
85+
7186
def has_node(self, node_id: str) -> bool:
7287
return self.graph.has_node(node_id)
7388

@@ -165,6 +180,21 @@ def __init__(self, actor_handle: ray.actor.ActorHandle):
165180
def index_done_callback(self):
166181
return ray.get(self.actor.index_done_callback.remote())
167182

183+
def is_directed(self) -> bool:
184+
return ray.get(self.actor.is_directed.remote())
185+
186+
def get_all_node_degrees(self) -> Dict[str, int]:
187+
return ray.get(self.actor.get_all_node_degrees.remote())
188+
189+
def get_node_count(self) -> int:
190+
return ray.get(self.actor.get_node_count.remote())
191+
192+
def get_edge_count(self) -> int:
193+
return ray.get(self.actor.get_edge_count.remote())
194+
195+
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
196+
return ray.get(self.actor.get_connected_components.remote(undirected))
197+
168198
def has_node(self, node_id: str) -> bool:
169199
return ray.get(self.actor.has_node.remote(node_id))
170200

@@ -239,10 +269,14 @@ def create_storage(backend: str, working_dir: str, namespace: str):
239269
try:
240270
actor_handle = ray.get_actor(actor_name)
241271
except ValueError:
242-
actor_handle = ray.remote(actor_class).options(
243-
name=actor_name,
244-
get_if_exists=True,
245-
).remote(backend, working_dir, namespace)
272+
actor_handle = (
273+
ray.remote(actor_class)
274+
.options(
275+
name=actor_name,
276+
get_if_exists=True,
277+
)
278+
.remote(backend, working_dir, namespace)
279+
)
246280
ray.get(actor_handle.ready.remote())
247281
return proxy_class(actor_handle)
248282

graphgen/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
271271

272272
for node in sorted_nodes:
273273
self._execute_node(node, initial_ds)
274+
if getattr(node, "save_output", False):
275+
self.datasets[node.id] = self.datasets[node.id].materialize()
274276

275277
output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
276278
return {node.id: self.datasets[node.id] for node in output_nodes}

0 commit comments

Comments
 (0)