Skip to content

Commit 42693df

Browse files
wip: refactor evaluator structure
1 parent 8ef5f47 commit 42693df

File tree

10 files changed

+220
-40
lines changed

10 files changed

+220
-40
lines changed

examples/evaluate/evaluate.sh

Lines changed: 0 additions & 3 deletions
This file was deleted.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/evaluate/evaluate_kg/evaluate_kg_config.yaml
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
global_params:
2+
working_dir: cache
3+
graph_backend: kuzu # graph database backend, support: kuzu, networkx
4+
kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv
5+
6+
nodes:
7+
- id: read
8+
op_name: read
9+
type: source
10+
dependencies: []
11+
params:
12+
input_path:
13+
- examples/input_examples/extract_demo.txt
14+
15+
- id: chunk
16+
op_name: chunk
17+
type: map_batch
18+
dependencies:
19+
- read
20+
execution_params:
21+
replicas: 4
22+
params:
23+
chunk_size: 20480 # larger chunk size for better context
24+
chunk_overlap: 2000
25+
26+
- id: build_kg
27+
op_name: build_kg
28+
type: map_batch
29+
dependencies:
30+
- chunk_documents
31+
execution_params:
32+
replicas: 1
33+
batch_size: 128
34+
35+
- id: evaluate
36+
op_name: evaluate
37+
type: aggregate
38+
dependencies:
39+
- build_kg
40+
params:
41+
metrics:
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/evaluate/evaluate_qa/evaluate_qa_config.yaml
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
global_params:
2+
working_dir: cache
3+
graph_backend: kuzu # graph database backend, support: kuzu, networkx
4+
kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv
5+
6+
nodes:
7+
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
8+
op_name: read
9+
type: source
10+
dependencies: []
11+
params:
12+
input_path:
13+
- examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples
14+
15+
- id: chunk_documents
16+
op_name: chunk
17+
type: map_batch
18+
dependencies:
19+
- read_files
20+
execution_params:
21+
replicas: 4
22+
params:
23+
chunk_size: 1024 # chunk size for text splitting
24+
chunk_overlap: 100 # chunk overlap for text splitting
25+
26+
- id: build_kg
27+
op_name: build_kg
28+
type: map_batch
29+
dependencies:
30+
- chunk_documents
31+
execution_params:
32+
replicas: 1
33+
batch_size: 128
34+
35+
- id: quiz
36+
op_name: quiz
37+
type: aggregate
38+
dependencies:
39+
- build_kg
40+
execution_params:
41+
replicas: 1
42+
batch_size: 128
43+
params:
44+
quiz_samples: 2 # number of quiz samples to generate
45+
concurrency_limit: 200
46+
47+
- id: judge
48+
op_name: judge
49+
type: map_batch
50+
dependencies:
51+
- quiz
52+
execution_params:
53+
replicas: 1
54+
batch_size: 128
55+
56+
- id: partition
57+
op_name: partition
58+
type: aggregate
59+
dependencies:
60+
- judge
61+
params:
62+
method: ece # ece is a custom partition method based on comprehension loss
63+
method_params:
64+
max_units_per_community: 20 # max nodes and edges per community
65+
min_units_per_community: 5 # min nodes and edges per community
66+
max_tokens_per_community: 10240 # max tokens per community
67+
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
68+
69+
- id: generate
70+
op_name: generate
71+
type: map_batch
72+
dependencies:
73+
- partition
74+
execution_params:
75+
replicas: 1
76+
batch_size: 128
77+
params:
78+
method: aggregated # atomic, aggregated, multi_hop, cot, vqa
79+
data_format: ChatML # Alpaca, Sharegpt, ChatML
80+
81+
- id: evaluate
82+
op_name: evaluate
83+
type: map_batch
84+
dependencies:
85+
- generate
86+
execution_params:
87+
replicas: 1
88+
batch_size: 128
89+
params:
90+
metrics:

examples/evaluate_kg/evaluate_kg.sh

Lines changed: 0 additions & 5 deletions
This file was deleted.

graphgen/operators/evaluate_kg/evaluate_kg.py renamed to graphgen/operators/evaluate/evaluate_kg.py

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import json
33
from pathlib import Path
4+
45
from dotenv import load_dotenv
56

67
from graphgen.models import KGQualityEvaluator
@@ -37,14 +38,22 @@ def _print_accuracy_summary(acc):
3738
precision = e.get("precision", {})
3839

3940
print(" Entity Extraction Quality:")
40-
print(f" Overall Score: {overall.get('mean', 0):.3f} (mean), "
41-
f"{overall.get('median', 0):.3f} (median)")
42-
print(f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), "
43-
f"{accuracy.get('median', 0):.3f} (median)")
44-
print(f" Completeness: {completeness.get('mean', 0):.3f} (mean), "
45-
f"{completeness.get('median', 0):.3f} (median)")
46-
print(f" Precision: {precision.get('mean', 0):.3f} (mean), "
47-
f"{precision.get('median', 0):.3f} (median)")
41+
print(
42+
f" Overall Score: {overall.get('mean', 0):.3f} (mean), "
43+
f"{overall.get('median', 0):.3f} (median)"
44+
)
45+
print(
46+
f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), "
47+
f"{accuracy.get('median', 0):.3f} (median)"
48+
)
49+
print(
50+
f" Completeness: {completeness.get('mean', 0):.3f} (mean), "
51+
f"{completeness.get('median', 0):.3f} (median)"
52+
)
53+
print(
54+
f" Precision: {precision.get('mean', 0):.3f} (mean), "
55+
f"{precision.get('median', 0):.3f} (median)"
56+
)
4857
print(f" Total Chunks Evaluated: {e.get('total_chunks', 0)}")
4958

5059
if "relation_accuracy" in acc:
@@ -55,14 +64,22 @@ def _print_accuracy_summary(acc):
5564
precision = r.get("precision", {})
5665

5766
print(" Relation Extraction Quality:")
58-
print(f" Overall Score: {overall.get('mean', 0):.3f} (mean), "
59-
f"{overall.get('median', 0):.3f} (median)")
60-
print(f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), "
61-
f"{accuracy.get('median', 0):.3f} (median)")
62-
print(f" Completeness: {completeness.get('mean', 0):.3f} (mean), "
63-
f"{completeness.get('median', 0):.3f} (median)")
64-
print(f" Precision: {precision.get('mean', 0):.3f} (mean), "
65-
f"{precision.get('median', 0):.3f} (median)")
67+
print(
68+
f" Overall Score: {overall.get('mean', 0):.3f} (mean), "
69+
f"{overall.get('median', 0):.3f} (median)"
70+
)
71+
print(
72+
f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), "
73+
f"{accuracy.get('median', 0):.3f} (median)"
74+
)
75+
print(
76+
f" Completeness: {completeness.get('mean', 0):.3f} (mean), "
77+
f"{completeness.get('median', 0):.3f} (median)"
78+
)
79+
print(
80+
f" Precision: {precision.get('mean', 0):.3f} (mean), "
81+
f"{precision.get('median', 0):.3f} (median)"
82+
)
6683
print(f" Total Chunks Evaluated: {r.get('total_chunks', 0)}")
6784
else:
6885
print(f"\n[Accuracy] Error: {acc['error']}")
@@ -73,19 +90,25 @@ def _print_consistency_summary(cons):
7390
if "error" not in cons:
7491
print("\n[Consistency]")
7592
print(f" Conflict Rate: {cons.get('conflict_rate', 0):.3f}")
76-
print(f" Conflict Entities: {cons.get('conflict_entities_count', 0)} / "
77-
f"{cons.get('total_entities', 0)}")
78-
entities_checked = cons.get('entities_checked', 0)
93+
print(
94+
f" Conflict Entities: {cons.get('conflict_entities_count', 0)} / "
95+
f"{cons.get('total_entities', 0)}"
96+
)
97+
entities_checked = cons.get("entities_checked", 0)
7998
if entities_checked > 0:
80-
print(f" Entities Checked: {entities_checked} (entities with multiple sources)")
81-
conflicts = cons.get('conflicts', [])
99+
print(
100+
f" Entities Checked: {entities_checked} (entities with multiple sources)"
101+
)
102+
conflicts = cons.get("conflicts", [])
82103
if conflicts:
83104
print(f" Total Conflicts Found: {len(conflicts)}")
84105
# Show sample conflicts
85106
sample_conflicts = conflicts[:3]
86107
for conflict in sample_conflicts:
87-
print(f" - {conflict.get('entity_id', 'N/A')}: {conflict.get('conflict_type', 'N/A')} "
88-
f"(severity: {conflict.get('conflict_severity', 0):.2f})")
108+
print(
109+
f" - {conflict.get('entity_id', 'N/A')}: {conflict.get('conflict_type', 'N/A')} "
110+
f"(severity: {conflict.get('conflict_severity', 0):.2f})"
111+
)
89112
else:
90113
print(f"\n[Consistency] Error: {cons['error']}")
91114

@@ -103,15 +126,19 @@ def _print_structure_summary(struct):
103126
noise_check = thresholds.get("noise_ratio", {})
104127
noise_threshold = noise_check.get("threshold", "N/A")
105128
noise_pass = noise_check.get("pass", False)
106-
print(f" Noise Ratio: {struct.get('noise_ratio', 0):.3f} "
107-
f"({'✓' if noise_pass else '✗'} < {noise_threshold})")
129+
print(
130+
f" Noise Ratio: {struct.get('noise_ratio', 0):.3f} "
131+
f"({'✓' if noise_pass else '✗'} < {noise_threshold})"
132+
)
108133

109134
# Largest CC Ratio
110135
lcc_check = thresholds.get("largest_cc_ratio", {})
111136
lcc_threshold = lcc_check.get("threshold", "N/A")
112137
lcc_pass = lcc_check.get("pass", False)
113-
print(f" Largest CC Ratio: {struct.get('largest_cc_ratio', 0):.3f} "
114-
f"({'✓' if lcc_pass else '✗'} > {lcc_threshold})")
138+
print(
139+
f" Largest CC Ratio: {struct.get('largest_cc_ratio', 0):.3f} "
140+
f"({'✓' if lcc_pass else '✗'} > {lcc_threshold})"
141+
)
115142

116143
# Avg Degree
117144
avg_degree_check = thresholds.get("avg_degree", {})
@@ -122,16 +149,20 @@ def _print_structure_summary(struct):
122149
threshold_str = f"{avg_degree_threshold[0]}-{avg_degree_threshold[1]}"
123150
else:
124151
threshold_str = str(avg_degree_threshold)
125-
print(f" Avg Degree: {struct.get('avg_degree', 0):.2f} "
126-
f"({'✓' if avg_degree_pass else '✗'} {threshold_str})")
152+
print(
153+
f" Avg Degree: {struct.get('avg_degree', 0):.2f} "
154+
f"({'✓' if avg_degree_pass else '✗'} {threshold_str})"
155+
)
127156

128157
# Power Law R²
129-
if struct.get('powerlaw_r2') is not None:
158+
if struct.get("powerlaw_r2") is not None:
130159
powerlaw_check = thresholds.get("powerlaw_r2", {})
131160
powerlaw_threshold = powerlaw_check.get("threshold", "N/A")
132161
powerlaw_pass = powerlaw_check.get("pass", False)
133-
print(f" Power Law R²: {struct.get('powerlaw_r2', 0):.3f} "
134-
f"({'✓' if powerlaw_pass else '✗'} > {powerlaw_threshold})")
162+
print(
163+
f" Power Law R²: {struct.get('powerlaw_r2', 0):.3f} "
164+
f"({'✓' if powerlaw_pass else '✗'} > {powerlaw_threshold})"
165+
)
135166
else:
136167
print(f"\n[Structural Robustness] Error: {struct['error']}")
137168

File renamed without changes.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pandas as pd
2+
3+
from graphgen.bases import BaseLLMWrapper, BaseOperator
4+
from graphgen.common import init_llm
5+
6+
7+
class EvaluateService(BaseOperator):
8+
"""
9+
1. KG Quality Evaluation
10+
2. QA Quality Evaluation
11+
"""
12+
13+
def __init__(self, working_dir: str = "cache"):
14+
super().__init__(working_dir=working_dir, op_name="evaluate_service")
15+
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
16+
17+
def process(self, batch: pd.DataFrame) -> pd.DataFrame:
18+
items = batch.to_dict(orient="records")
19+
return pd.DataFrame(self.evaluate(items))
20+
21+
def evaluate(self, items: list[dict]) -> list[dict]:
22+
pass

graphgen/operators/evaluate_kg/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)