diff --git a/examples/evaluate/evaluate.sh b/examples/evaluate/evaluate.sh deleted file mode 100644 index 2b352669..00000000 --- a/examples/evaluate/evaluate.sh +++ /dev/null @@ -1,3 +0,0 @@ -python3 -m graphgen.evaluate --folder cache/data \ - --reward "OpenAssistant/reward-model-deberta-v3-large-v2,BAAI/IndustryCorpus2_DataRater" \ - --uni MingZhong/unieval-sum \ diff --git a/examples/evaluate/evaluate_kg/evaluate_kg.sh b/examples/evaluate/evaluate_kg/evaluate_kg.sh new file mode 100644 index 00000000..2bf2f37e --- /dev/null +++ b/examples/evaluate/evaluate_kg/evaluate_kg.sh @@ -0,0 +1,2 @@ +python3 -m graphgen.run \ +--config_file examples/evaluate/evaluate_kg/kg_evaluation_config.yaml \ No newline at end of file diff --git a/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml new file mode 100644 index 00000000..d86d01b1 --- /dev/null +++ b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml @@ -0,0 +1,45 @@ +global_params: + working_dir: cache + graph_backend: kuzu # graph database backend, support: kuzu, networkx + kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv + +nodes: + - id: read + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/extract_demo.txt + + - id: chunk + op_name: chunk + type: map_batch + dependencies: + - read + execution_params: + replicas: 4 + params: + chunk_size: 20480 # larger chunk size for better context + chunk_overlap: 2000 + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk + execution_params: + replicas: 1 + batch_size: 128 + + - id: evaluate + op_name: evaluate + type: aggregate + save_output: true + dependencies: + - build_kg + params: + metrics: + - kg_structure + - kg_accuracy + - kg_consistency diff --git a/examples/evaluate/evaluate_qa/evaluate_qa.sh b/examples/evaluate/evaluate_qa/evaluate_qa.sh new file mode 100644 index 00000000..5bfe392c --- /dev/null +++ b/examples/evaluate/evaluate_qa/evaluate_qa.sh @@ -0,0 +1,2 @@ +python3 -m graphgen.run \ +--config_file examples/evaluate/evaluate_qa/qa_evaluation_config.yaml \ No newline at end of file diff --git a/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml b/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml new file mode 100644 index 00000000..459f9fad --- /dev/null +++ b/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml @@ -0,0 +1,98 @@ +global_params: + working_dir: cache + graph_backend: kuzu # graph database backend, support: kuzu, networkx + kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv + +nodes: + - id: read_files # id is unique in the pipeline, and can be referenced by other steps + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples + + - id: chunk_documents + op_name: chunk + type: map_batch + dependencies: + - read_files + execution_params: + replicas: 4 + params: + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk_documents + execution_params: + replicas: 1 + batch_size: 128 + + - id: quiz + op_name: quiz + type: aggregate + dependencies: + - build_kg + execution_params: + replicas: 1 + batch_size: 128 + params: + quiz_samples: 2 # number of quiz samples to generate + concurrency_limit: 200 + + - id: judge + op_name: judge + type: map_batch + dependencies: + - quiz + execution_params: + replicas: 1 + batch_size: 128 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - judge + params: + method: ece # ece is a custom partition method based on comprehension loss + method_params: + max_units_per_community: 20 # max nodes and edges per community + min_units_per_community: 5 # min nodes and edges per community + max_tokens_per_community: 10240 # max tokens per community + unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + save_output: true + params: + method: aggregated # atomic, aggregated, multi_hop, cot, vqa + data_format: ChatML # Alpaca, Sharegpt, ChatML + + - id: evaluate + op_name: evaluate + type: map_batch + dependencies: + - generate + execution_params: + replicas: 1 + batch_size: 128 + save_output: true + params: + metrics: + - qa_length + - qa_mtld + - qa_reward_score + - qa_uni_score + mtld_params: + threshold: 0.7 diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py index 41136974..0727b3fa 100644 --- a/graphgen/bases/__init__.py +++ b/graphgen/bases/__init__.py @@ -9,4 +9,5 @@ from .base_splitter import BaseSplitter from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace from .base_tokenizer import BaseTokenizer +from .base_evaluator import BaseEvaluator from .datatypes import Chunk, Config, Node, QAPair, Token diff --git a/graphgen/bases/base_evaluator.py b/graphgen/bases/base_evaluator.py new file mode 100644 index 00000000..3cc5df18 --- /dev/null +++ b/graphgen/bases/base_evaluator.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod +from .datatypes import QAPair + + +class BaseEvaluator(ABC): + @abstractmethod + def evaluate(self, pair: QAPair) -> float: + """ + Evaluate the text and return a score. + """ diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index ff7d2d1a..e72c5869 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -1,5 +1,6 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Generic, TypeVar, Union +from typing import Dict, Generic, List, Set, TypeVar, Union T = TypeVar("T") @@ -45,52 +46,90 @@ def reload(self): raise NotImplementedError -class BaseGraphStorage(StorageNameSpace): +class BaseGraphStorage(StorageNameSpace, ABC): + @abstractmethod + def is_directed(self) -> bool: + pass + + @abstractmethod def has_node(self, node_id: str) -> bool: raise NotImplementedError + @abstractmethod def has_edge(self, source_node_id: str, target_node_id: str) -> bool: raise NotImplementedError + @abstractmethod def node_degree(self, node_id: str) -> int: raise NotImplementedError - def edge_degree(self, src_id: str, tgt_id: str) -> int: - raise NotImplementedError + @abstractmethod + def get_all_node_degrees(self) -> Dict[str, int]: + pass + def get_isolated_nodes(self) -> List[str]: + return [ + node_id + for node_id, degree in self.get_all_node_degrees().items() + if degree == 0 + ] + + @abstractmethod def get_node(self, node_id: str) -> Union[dict, None]: raise NotImplementedError + @abstractmethod def update_node(self, node_id: str, node_data: dict[str, str]): raise NotImplementedError + @abstractmethod def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]: raise NotImplementedError + @abstractmethod + def get_node_count(self) -> int: + pass + + @abstractmethod def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]: raise NotImplementedError + @abstractmethod def update_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): raise NotImplementedError + @abstractmethod def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]: raise NotImplementedError + @abstractmethod + def get_edge_count(self) -> int: + pass + + @abstractmethod def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]: raise NotImplementedError + @abstractmethod def upsert_node(self, node_id: str, node_data: dict[str, str]): raise NotImplementedError + @abstractmethod def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): raise NotImplementedError + @abstractmethod def delete_node(self, node_id: str): raise NotImplementedError + @abstractmethod def reload(self): raise NotImplementedError + + @abstractmethod + def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: + raise NotImplementedError diff --git a/graphgen/common/init_storage.py b/graphgen/common/init_storage.py index 56528e7a..aaffb630 100644 --- a/graphgen/common/init_storage.py +++ b/graphgen/common/init_storage.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Union +from typing import Any, Dict, List, Set, Union import ray @@ -68,6 +68,21 @@ def __init__(self, backend: str, working_dir: str, namespace: str): def index_done_callback(self): return self.graph.index_done_callback() + def is_directed(self) -> bool: + return self.graph.is_directed() + + def get_all_node_degrees(self) -> Dict[str, int]: + return self.graph.get_all_node_degrees() + + def get_node_count(self) -> int: + return self.graph.get_node_count() + + def get_edge_count(self) -> int: + return self.graph.get_edge_count() + + def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: + return self.graph.get_connected_components(undirected) + def has_node(self, node_id: str) -> bool: return self.graph.has_node(node_id) @@ -165,6 +180,21 @@ def __init__(self, actor_handle: ray.actor.ActorHandle): def index_done_callback(self): return ray.get(self.actor.index_done_callback.remote()) + def is_directed(self) -> bool: + return ray.get(self.actor.is_directed.remote()) + + def get_all_node_degrees(self) -> Dict[str, int]: + return ray.get(self.actor.get_all_node_degrees.remote()) + + def get_node_count(self) -> int: + return ray.get(self.actor.get_node_count.remote()) + + def get_edge_count(self) -> int: + return ray.get(self.actor.get_edge_count.remote()) + + def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: + return ray.get(self.actor.get_connected_components.remote(undirected)) + def has_node(self, node_id: str) -> bool: return ray.get(self.actor.has_node.remote(node_id)) @@ -239,10 +269,14 @@ def create_storage(backend: str, working_dir: str, namespace: str): try: actor_handle = ray.get_actor(actor_name) except ValueError: - actor_handle = ray.remote(actor_class).options( - name=actor_name, - get_if_exists=True, - ).remote(backend, working_dir, namespace) + actor_handle = ( + ray.remote(actor_class) + .options( + name=actor_name, + get_if_exists=True, + ) + .remote(backend, working_dir, namespace) + ) ray.get(actor_handle.ready.remote()) return proxy_class(actor_handle) diff --git a/graphgen/engine.py b/graphgen/engine.py index 26bcff58..7b871a61 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -271,6 +271,8 @@ def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]: for node in sorted_nodes: self._execute_node(node, initial_ds) + if getattr(node, "save_output", False): + self.datasets[node.id] = self.datasets[node.id].materialize() output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)] return {node.id: self.datasets[node.id] for node in output_nodes} diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 21344d74..43d38bed 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,4 +1,12 @@ -from .evaluator import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator +from .evaluator import ( + AccuracyEvaluator, + ConsistencyEvaluator, + LengthEvaluator, + MTLDEvaluator, + RewardEvaluator, + StructureEvaluator, + UniEvaluator, +) from .generator import ( AggregatedGenerator, AtomicGenerator, diff --git a/graphgen/models/evaluator/__init__.py b/graphgen/models/evaluator/__init__.py index a9b445b4..6091aeb5 100644 --- a/graphgen/models/evaluator/__init__.py +++ b/graphgen/models/evaluator/__init__.py @@ -1,4 +1,2 @@ -from .length_evaluator import LengthEvaluator -from .mtld_evaluator import MTLDEvaluator -from .reward_evaluator import RewardEvaluator -from .uni_evaluator import UniEvaluator +from .kg import AccuracyEvaluator, ConsistencyEvaluator, StructureEvaluator +from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator diff --git a/graphgen/models/evaluator/base_evaluator.py b/graphgen/models/evaluator/base_evaluator.py deleted file mode 100644 index e24cfa43..00000000 --- a/graphgen/models/evaluator/base_evaluator.py +++ /dev/null @@ -1,52 +0,0 @@ -import asyncio - -from tqdm.asyncio import tqdm as tqdm_async - -from graphgen.bases.datatypes import QAPair -from graphgen.utils import create_event_loop - - -class BaseEvaluator: - def __init__(self, max_concurrent: int = 100): - self.max_concurrent = max_concurrent - self.results: list[float] = None - - def evaluate(self, pairs: list[QAPair]) -> list[float]: - """ - Evaluate the text and return a score. - """ - return create_event_loop().run_until_complete(self.async_evaluate(pairs)) - - async def async_evaluate(self, pairs: list[QAPair]) -> list[float]: - semaphore = asyncio.Semaphore(self.max_concurrent) - - async def evaluate_with_semaphore(pair): - async with semaphore: # 获取Semaphore - return await self.evaluate_single(pair) - - results = [] - for result in tqdm_async( - asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]), - total=len(pairs), - ): - results.append(await result) - return results - - async def evaluate_single(self, pair: QAPair) -> float: - raise NotImplementedError() - - def get_average_score(self, pairs: list[QAPair]) -> float: - """ - Get the average score of a batch of texts. - """ - results = self.evaluate(pairs) - self.results = results - return sum(self.results) / len(pairs) - - def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]: - """ - Get the min and max score of a batch of texts. - """ - if self.results is None: - self.get_average_score(pairs) - return min(self.results), max(self.results) diff --git a/graphgen/models/evaluator/kg/README.md b/graphgen/models/evaluator/kg/README.md new file mode 100644 index 00000000..10e26f6b --- /dev/null +++ b/graphgen/models/evaluator/kg/README.md @@ -0,0 +1,237 @@ +# KG Quality Evaluation Module + +This module provides comprehensive quality evaluation for knowledge graphs built by GraphGen. + +## Module Structure + +The evaluation functionality is organized into modular components: + +- **`accuracy_evaluator.py`**: Entity/relation extraction quality evaluation using LLM-as-a-Judge +- **`consistency_evaluator.py`**: Attribute value conflict detection +- **`structure_evaluator.py`**: Graph structural robustness metrics + +The evaluation components are integrated in `graphgen/operators/evaluate/evaluate_kg.py`, which provides functions to create and use these evaluators. + +## Features + +### 1. Accuracy Assessment +- **Entity Extraction Quality**: Uses LLM-as-a-Judge to evaluate the quality of entity extraction from chunks + - Evaluates accuracy (correctness of extracted entities) + - Evaluates completeness (whether important entities are missed) + - Evaluates precision (naming accuracy and specificity) +- **Relation Extraction Quality**: Uses LLM-as-a-Judge to evaluate the quality of relation extraction from chunks + - Evaluates accuracy (correctness of extracted relations) + - Evaluates completeness (whether important relations are missed) + - Evaluates precision (relation description accuracy) +- Provides multi-dimensional quality scores (0-1 scale) with detailed reasoning for each chunk + +### 2. Consistency Assessment +- **Semantic Conflict Detection**: Uses LLM-as-a-Judge to detect semantic conflicts in entity attributes + - **Entity Type Conflicts**: Detects when the same entity is extracted with different types across chunks + - **Entity Description Conflicts**: Detects when entity descriptions from different chunks are semantically inconsistent + - **Relation Conflicts**: Detects when the same entity pair has conflicting relation descriptions +- Only evaluates entities with multiple source chunks (entities appearing in multiple chunks) +- Uses LLM to extract entity attributes from each chunk and compare them semantically +- Calculates conflict rate: `conflict_entities_count / total_entities` +- Returns detailed conflict information including conflict severity and reasoning + +### 3. Structural Robustness Assessment +- **Noise Ratio**: Isolated nodes / total nodes (threshold: < 15%) +- **Largest Connected Component Ratio**: Largest CC nodes / total nodes (threshold: > 90%) +- **Average Node Degree**: Average degree across all nodes (threshold: 2-5) +- **Power Law Distribution R²**: Degree distribution fit (threshold: > 0.75) + +## Usage + +### Command Line Usage + +```bash +# Run all evaluations +python -m graphgen.operators.evaluate_kg.evaluate_kg --working_dir cache + +# Run specific evaluation +python -m graphgen.operators.evaluate_kg.evaluate_kg --working_dir cache --accuracy_only + +# Specify backends +python -m graphgen.operators.evaluate_kg.evaluate_kg \ + --working_dir cache \ + --graph_backend networkx \ + --kv_backend json_kv +``` + +### Shell Script Usage + +```bash +# Basic usage +bash examples/evaluate_kg/evaluate_kg.sh + +# With custom options +bash examples/evaluate_kg/evaluate_kg.sh \ + --working_dir cache \ + --accuracy_only +``` + +## Configuration + +All evaluation thresholds use default values defined in the evaluator classes: + +- **Structure thresholds**: Defined in `StructureEvaluator` with defaults: + - `noise_ratio_threshold`: 0.15 + - `largest_cc_ratio_threshold`: 0.90 + - `avg_degree_min`: 2.0 + - `avg_degree_max`: 5.0 + - `powerlaw_r2_threshold`: 0.75 + +**Note**: Accuracy evaluation automatically loads chunks from the chunk storage and evaluates the quality of entity/relation extraction using LLM-as-a-Judge. No configuration file is needed. + +## Requirements + +- **NetworkX**: Required for structural evaluation +- **scipy**: Required for power law distribution fitting +- **numpy**: Required for numerical calculations +- **LLM Client**: Required for accuracy evaluation (configure via `TRAINEE_*` env vars) + +## Output Format + +The evaluation returns a dictionary with the following structure: + +```python +{ + "accuracy": { + "entity_accuracy": { + "overall_score": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "accuracy": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "completeness": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "precision": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "total_chunks": int, + "detailed_results": [ + { + "chunk_id": str, + "chunk_content": str, + "extracted_entities_count": int, + "accuracy": float, + "completeness": float, + "precision": float, + "overall_score": float, + "accuracy_reasoning": str, + "completeness_reasoning": str, + "precision_reasoning": str, + "issues": [str] + }, + ... + ] + }, + "relation_accuracy": { + "overall_score": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "accuracy": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "completeness": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "precision": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "total_chunks": int, + "detailed_results": [ + { + "chunk_id": str, + "chunk_content": str, + "extracted_relations_count": int, + "accuracy": float, + "completeness": float, + "precision": float, + "overall_score": float, + "accuracy_reasoning": str, + "completeness_reasoning": str, + "precision_reasoning": str, + "issues": [str] + }, + ... + ] + } + }, + "consistency": { + "conflict_rate": float, + "conflict_entities_count": int, + "total_entities": int, + "entities_checked": int, + "conflicts": [ + { + "entity_id": str, + "conflict_type": str, # "entity_type" or "description" + "conflict_severity": float, # 0-1, severity of the conflict + "conflict_reasoning": str, + "conflicting_values": [str], + "recommended_value": str, # for entity_type conflicts + "conflict_details": str # for description conflicts + }, + ... + ] + }, + "structure": { + "total_nodes": int, + "total_edges": int, + "noise_ratio": float, + "largest_cc_ratio": float, + "avg_degree": float, + "powerlaw_r2": float | None, + "thresholds": { + "noise_ratio": { "value": float, "threshold": float, "pass": bool }, + ... + } + } +} +``` + +## Notes + +- Accuracy evaluation uses LLM-as-a-Judge to evaluate extraction quality from chunks +- Accuracy evaluation automatically loads chunks from chunk storage (no need for source_text_paths) +- The evaluator associates extracted entities/relations with their source chunks using the `source_id` field +- Structural evaluation automatically converts Kuzu storage to NetworkX for analysis +- All evaluations include error handling and will return error messages if something fails +- The evaluator automatically loads graph and chunk storage from the working directory +- LLM evaluation may take time for large numbers of chunks (controlled by `max_concurrent` parameter) diff --git a/graphgen/models/evaluator/kg/__init__.py b/graphgen/models/evaluator/kg/__init__.py new file mode 100644 index 00000000..375cbc50 --- /dev/null +++ b/graphgen/models/evaluator/kg/__init__.py @@ -0,0 +1,18 @@ +""" +Knowledge Graph Quality Evaluator + +This module provides comprehensive quality evaluation for knowledge graphs, +1. accuracy assessment (entity/relation/triple validation), +2. consistency assessment (attribute conflict detection), and structural +3. robustness assessment (noise ratio, connectivity, degree distribution). +""" + +from .accuracy_evaluator import AccuracyEvaluator +from .consistency_evaluator import ConsistencyEvaluator +from .structure_evaluator import StructureEvaluator + +__all__ = [ + "AccuracyEvaluator", + "ConsistencyEvaluator", + "StructureEvaluator", +] diff --git a/graphgen/models/evaluator/kg/accuracy_evaluator.py b/graphgen/models/evaluator/kg/accuracy_evaluator.py new file mode 100644 index 00000000..9663b6f8 --- /dev/null +++ b/graphgen/models/evaluator/kg/accuracy_evaluator.py @@ -0,0 +1,350 @@ +import asyncio +import json +import re +from typing import Any, Dict, List + +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper +from graphgen.bases.datatypes import Chunk +from graphgen.templates import ACCURACY_EVALUATION_PROMPT +from graphgen.utils import detect_main_language, logger + + +class AccuracyEvaluator: + """Evaluates accuracy of entity recognition and relation extraction using LLM-as-a-Judge. + + For each chunk, uses LLM to evaluate the quality of extracted entities and relations + by comparing them with the original chunk content. Provides multi-dimensional quality + scores (accuracy, completeness, precision). + """ + + def __init__( + self, + graph_storage: BaseGraphStorage, + chunk_storage: BaseKVStorage, + llm_client: BaseLLMWrapper, + ): + self.graph_storage = graph_storage + self.chunk_storage = chunk_storage + self.llm_client = llm_client + + def evaluate(self) -> Dict[str, Any]: + """Evaluate entity and relation extraction quality using LLM-as-a-Judge. + + Returns: + Dictionary containing entity_accuracy and relation_accuracy metrics. + """ + # 1. Load all chunks from storage + chunks = self._load_chunks_from_storage() + + if not chunks: + logger.warning("No chunks found in storage") + return {"error": "No chunks found in storage"} + + logger.info(f"Found {len(chunks)} chunks to evaluate") + + # 2. Evaluate each chunk + entity_evaluations, relation_evaluations = self._evaluate_all_chunks(chunks) + + # 3. Aggregate results + return self._aggregate_evaluation_results( + entity_evaluations, relation_evaluations + ) + + def _load_chunks_from_storage(self) -> List[Chunk]: + """Load all chunks from chunk storage.""" + chunks = [] + all_chunk_data = self.chunk_storage.get_all() + + for chunk_id, chunk_data in all_chunk_data.items(): + try: + chunk = Chunk.from_dict(chunk_id, chunk_data) + chunks.append(chunk) + except Exception as e: + logger.warning(f"Failed to load chunk {chunk_id}: {e}") + continue + + return chunks + + def _get_extracted_entities_for_chunk(self, chunk_id: str) -> List[Dict]: + """Get all entities extracted from the specified chunk.""" + entities = [] + all_nodes = self.graph_storage.get_all_nodes() or [] + + for node_id, node_data in all_nodes: + if not isinstance(node_data, dict): + continue + source_ids = node_data.get("source_id", "").split("") + # Check if this chunk_id is in the source_ids + if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]: + entities.append( + { + "entity_name": node_data.get("entity_name", node_id), + "entity_type": node_data.get("entity_type", ""), + "description": node_data.get("description", ""), + } + ) + + return entities + + def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]: + """Get all relations extracted from the specified chunk.""" + relations = [] + all_edges = self.graph_storage.get_all_edges() or [] + + for src_id, dst_id, edge_data in all_edges: + if not isinstance(edge_data, dict): + continue + source_ids = edge_data.get("source_id", "").split("") + # Check if this chunk_id is in the source_ids + if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]: + src_node = self.graph_storage.get_node(src_id) or {} + dst_node = self.graph_storage.get_node(dst_id) or {} + relations.append( + { + "source_entity": src_node.get("entity_name", src_id), + "target_entity": dst_node.get("entity_name", dst_id), + "relationship_summary": edge_data.get("description", ""), + } + ) + + return relations + + def _evaluate_all_chunks( + self, chunks: List[Chunk] + ) -> tuple[List[Dict], List[Dict]]: + """Evaluate all chunks sequentially.""" + entity_evaluations = [] + relation_evaluations = [] + + for chunk in chunks: + try: + entities = self._get_extracted_entities_for_chunk(chunk.id) + relations = self._get_extracted_relations_for_chunk(chunk.id) + + entity_eval = self._evaluate_entity_extraction(chunk, entities) + relation_eval = self._evaluate_relation_extraction(chunk, relations) + + entity_evaluations.append(entity_eval) + relation_evaluations.append(relation_eval) + except Exception as e: + logger.error(f"Failed to evaluate chunk {chunk.id}: {e}") + continue + + return entity_evaluations, relation_evaluations + + def _evaluate_entity_extraction( + self, chunk: Chunk, extracted_entities: List[Dict] + ) -> Dict[str, Any]: + """Use LLM to evaluate entity extraction quality.""" + try: + lang = detect_main_language(chunk.content) + + prompt = ACCURACY_EVALUATION_PROMPT[lang]["ENTITY"].format( + chunk_content=chunk.content, + extracted_entities=json.dumps( + extracted_entities, ensure_ascii=False, indent=2 + ), + ) + + response = asyncio.run(self.llm_client.generate_answer(prompt)) + + # Try to parse JSON response + try: + evaluation_result = json.loads(response) + except json.JSONDecodeError: + # Try to extract JSON from markdown code blocks or other formats + json_match = re.search(r"\{.*\}", response, re.DOTALL) + if json_match: + evaluation_result = json.loads(json_match.group(0)) + else: + logger.warning( + f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}" + ) + # Return default evaluation + evaluation_result = { + "accuracy": 0.0, + "completeness": 0.0, + "precision": 0.0, + "overall_score": 0.0, + "accuracy_reasoning": "Failed to parse LLM response", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": ["LLM response parsing failed"], + } + + # Validate and calculate overall_score if not provided + if "overall_score" not in evaluation_result: + accuracy = float(evaluation_result.get("accuracy", 0.0)) + completeness = float(evaluation_result.get("completeness", 0.0)) + precision = float(evaluation_result.get("precision", 0.0)) + evaluation_result["overall_score"] = ( + 0.4 * accuracy + 0.4 * completeness + 0.2 * precision + ) + + return { + "chunk_id": chunk.id, + "chunk_content": chunk.content[:200] + if chunk.content + else "", # First 200 chars for debugging + "extracted_entities_count": len(extracted_entities), + **evaluation_result, + } + except Exception as e: + logger.error( + f"Error evaluating entity extraction for chunk {chunk.id}: {e}" + ) + return { + "chunk_id": chunk.id, + "chunk_content": chunk.content[:200] if chunk.content else "", + "extracted_entities_count": len(extracted_entities), + "accuracy": 0.0, + "completeness": 0.0, + "precision": 0.0, + "overall_score": 0.0, + "accuracy_reasoning": f"Evaluation failed: {str(e)}", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": [f"Evaluation error: {str(e)}"], + } + + def _evaluate_relation_extraction( + self, chunk: Chunk, extracted_relations: List[Dict] + ) -> Dict[str, Any]: + """Use LLM to evaluate relation extraction quality.""" + try: + lang = detect_main_language(chunk.content) + prompt = ACCURACY_EVALUATION_PROMPT[lang]["RELATION"].format( + chunk_content=chunk.content, + extracted_relations=json.dumps( + extracted_relations, ensure_ascii=False, indent=2 + ), + ) + + response = asyncio.run(self.llm_client.generate_answer(prompt)) + + # Try to parse JSON response + try: + evaluation_result = json.loads(response) + except json.JSONDecodeError: + # Try to extract JSON from markdown code blocks or other formats + json_match = re.search(r"\{.*\}", response, re.DOTALL) + if json_match: + evaluation_result = json.loads(json_match.group(0)) + else: + logger.warning( + f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}" + ) + # Return default evaluation + evaluation_result = { + "accuracy": 0.0, + "completeness": 0.0, + "precision": 0.0, + "overall_score": 0.0, + "accuracy_reasoning": "Failed to parse LLM response", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": ["LLM response parsing failed"], + } + + # Validate and calculate overall_score if not provided + if "overall_score" not in evaluation_result: + accuracy = float(evaluation_result.get("accuracy", 0.0)) + completeness = float(evaluation_result.get("completeness", 0.0)) + precision = float(evaluation_result.get("precision", 0.0)) + evaluation_result["overall_score"] = ( + 0.4 * accuracy + 0.4 * completeness + 0.2 * precision + ) + + return { + "chunk_id": chunk.id, + "chunk_content": chunk.content[:200] if chunk.content else "", + "extracted_relations_count": len(extracted_relations), + **evaluation_result, + } + except Exception as e: + logger.error( + f"Error evaluating relation extraction for chunk {chunk.id}: {e}" + ) + return { + "chunk_id": chunk.id, + "chunk_content": chunk.content[:200] if chunk.content else "", + "extracted_relations_count": len(extracted_relations), + "accuracy": 0.0, + "completeness": 0.0, + "precision": 0.0, + "overall_score": 0.0, + "accuracy_reasoning": f"Evaluation failed: {str(e)}", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": [f"Evaluation error: {str(e)}"], + } + + @staticmethod + def _aggregate_evaluation_results( + entity_evaluations: List[Dict], relation_evaluations: List[Dict] + ) -> Dict[str, Any]: + """Aggregate evaluation results from all chunks.""" + + def calculate_stats(scores: List[float]) -> Dict[str, float]: + if not scores: + return {"mean": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "std": 0.0} + sorted_scores = sorted(scores) + n = len(scores) + mean = sum(scores) / n + median = ( + sorted_scores[n // 2] + if n % 2 == 1 + else (sorted_scores[n // 2 - 1] + sorted_scores[n // 2]) / 2 + ) + variance = sum((x - mean) ** 2 for x in scores) / n + std = variance**0.5 + + return { + "mean": mean, + "median": median, + "min": min(scores), + "max": max(scores), + "std": std, + } + + # Extract scores + entity_overall_scores = [ + e.get("overall_score", 0.0) for e in entity_evaluations + ] + entity_accuracy_scores = [e.get("accuracy", 0.0) for e in entity_evaluations] + entity_completeness_scores = [ + e.get("completeness", 0.0) for e in entity_evaluations + ] + entity_precision_scores = [e.get("precision", 0.0) for e in entity_evaluations] + + relation_overall_scores = [ + r.get("overall_score", 0.0) for r in relation_evaluations + ] + relation_accuracy_scores = [ + r.get("accuracy", 0.0) for r in relation_evaluations + ] + relation_completeness_scores = [ + r.get("completeness", 0.0) for r in relation_evaluations + ] + relation_precision_scores = [ + r.get("precision", 0.0) for r in relation_evaluations + ] + + return { + "entity_accuracy": { + "overall_score": calculate_stats(entity_overall_scores), + "accuracy": calculate_stats(entity_accuracy_scores), + "completeness": calculate_stats(entity_completeness_scores), + "precision": calculate_stats(entity_precision_scores), + "total_chunks": len(entity_evaluations), + "detailed_results": entity_evaluations, + }, + "relation_accuracy": { + "overall_score": calculate_stats(relation_overall_scores), + "accuracy": calculate_stats(relation_accuracy_scores), + "completeness": calculate_stats(relation_completeness_scores), + "precision": calculate_stats(relation_precision_scores), + "total_chunks": len(relation_evaluations), + "detailed_results": relation_evaluations, + }, + } diff --git a/graphgen/models/evaluator/kg/consistency_evaluator.py b/graphgen/models/evaluator/kg/consistency_evaluator.py new file mode 100644 index 00000000..069e7591 --- /dev/null +++ b/graphgen/models/evaluator/kg/consistency_evaluator.py @@ -0,0 +1,380 @@ +import asyncio +import json +import re +from typing import Any, Dict, List + +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper +from graphgen.bases.datatypes import Chunk +from graphgen.templates.evaluation.kg.consistency_evaluation import ( + ENTITY_DESCRIPTION_CONFLICT_PROMPT, + ENTITY_EXTRACTION_PROMPT, + ENTITY_TYPE_CONFLICT_PROMPT, + RELATION_CONFLICT_PROMPT, +) +from graphgen.utils import logger + + +class ConsistencyEvaluator: + """Evaluates consistency by detecting semantic conflicts using LLM-as-a-Judge. + + For entities with multiple source chunks, compares entity_type and description + extracted from different chunks to detect semantic conflicts. + """ + + def __init__( + self, + graph_storage: BaseGraphStorage, + chunk_storage: BaseKVStorage, + llm_client: BaseLLMWrapper, + ): + self.graph_storage = graph_storage + self.chunk_storage = chunk_storage + self.llm_client = llm_client + + def evaluate(self) -> Dict[str, Any]: + """Evaluate consistency by detecting semantic conflicts.""" + all_nodes = self.graph_storage.get_all_nodes() or [] + if not all_nodes: + return {"error": "Empty graph"} + + return self._evaluate_consistency(all_nodes) + + def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]: + """Evaluate consistency by detecting semantic conflicts.""" + # Filter entities with multiple source chunks + entities_with_multiple_sources = [] + for node_id, node_data in all_nodes: + if not isinstance(node_data, dict): + continue + source_ids = node_data.get("source_id", "").split("") + source_ids = [sid.strip() for sid in source_ids if sid.strip()] + if len(source_ids) > 1: # Only check entities from multiple chunks + entities_with_multiple_sources.append((node_id, node_data, source_ids)) + + if not entities_with_multiple_sources: + logger.info( + "No entities with multiple sources found, skipping consistency check" + ) + return { + "conflict_rate": 0.0, + "conflict_entities_count": 0, + "total_entities": len(all_nodes), + "conflicts": [], + } + + logger.info( + f"Checking consistency for {len(entities_with_multiple_sources)} entities with multiple sources" + ) + + # Evaluate entities sequentially + conflicts = [] + conflict_entities = set() + + for entity_info in entities_with_multiple_sources: + try: + entity_id, entity_conflicts = self._evaluate_entity_consistency(entity_info) + if entity_conflicts: + conflicts.extend(entity_conflicts) + conflict_entities.add(entity_id) + except Exception as e: + logger.error( + f"Failed to evaluate entity {entity_info[0]}: {e}" + ) + continue + + total_entities = len(all_nodes) + conflict_rate = ( + len(conflict_entities) / total_entities if total_entities > 0 else 0 + ) + + return { + "conflict_rate": conflict_rate, + "conflict_entities_count": len(conflict_entities), + "total_entities": total_entities, + "entities_checked": len(entities_with_multiple_sources), + "conflicts": conflicts[:100], # Limit to first 100 conflicts + } + + def _clean_entity_id(self, entity_id: str) -> str: + """Clean entity ID by removing surrounding quotes.""" + clean_id = entity_id.strip() + if (clean_id.startswith('"') and clean_id.endswith('"')) or ( + clean_id.startswith("'") and clean_id.endswith("'") + ): + clean_id = clean_id[1:-1].strip() + return clean_id + + def _evaluate_entity_consistency( + self, entity_info: tuple + ) -> tuple[str, List[Dict]]: + """Evaluate consistency for a single entity.""" + entity_id, _node_data, source_ids = entity_info + # Clean entity_id for display + clean_entity_id = self._clean_entity_id(entity_id) + conflicts = [] + + # Get chunks for this entity + chunks = self._get_entity_chunks(source_ids) + if len(chunks) < 2: + return entity_id, [] + + # Extract entity attributes from each chunk + entity_extractions = {} + for chunk in chunks: + extraction = self._extract_entity_from_chunk(entity_id, chunk) + if extraction: + entity_extractions[chunk.id] = extraction + + if len(entity_extractions) < 2: + return entity_id, [] + + # Check entity type consistency + type_extractions = { + chunk_id: ext.get("entity_type", "") + for chunk_id, ext in entity_extractions.items() + } + type_conflict = self._check_entity_type_consistency( + entity_id, type_extractions + ) + if type_conflict and type_conflict.get("has_conflict", False): + conflicts.append( + { + "entity_id": clean_entity_id, + "conflict_type": "entity_type", + "conflict_severity": type_conflict.get("conflict_severity", 0.0), + "conflict_reasoning": type_conflict.get("conflict_reasoning", ""), + "conflicting_values": type_conflict.get("conflicting_types", []), + "recommended_value": type_conflict.get("recommended_type", ""), + } + ) + + # Check entity description consistency + descriptions = { + chunk_id: ext.get("description", "") + for chunk_id, ext in entity_extractions.items() + } + desc_conflict = self._check_entity_description_consistency( + entity_id, descriptions + ) + if desc_conflict and desc_conflict.get("has_conflict", False): + conflicts.append( + { + "entity_id": clean_entity_id, + "conflict_type": "description", + "conflict_severity": desc_conflict.get("conflict_severity", 0.0), + "conflict_reasoning": desc_conflict.get("conflict_reasoning", ""), + "conflicting_values": desc_conflict.get( + "conflicting_descriptions", [] + ), + "conflict_details": desc_conflict.get("conflict_details", ""), + } + ) + + return entity_id, conflicts + + def _get_entity_chunks(self, source_ids: List[str]) -> List[Chunk]: + """Get all chunks related to an entity.""" + chunks = [] + for chunk_id in source_ids: + chunk_data = self.chunk_storage.get_by_id(chunk_id) + if chunk_data: + try: + chunk = Chunk.from_dict(chunk_id, chunk_data) + chunks.append(chunk) + except Exception as e: + logger.warning(f"Failed to load chunk {chunk_id}: {e}") + continue + return chunks + + def _extract_entity_from_chunk( + self, entity_id: str, chunk: Chunk + ) -> Dict[str, str]: + """Extract entity attributes from a chunk using LLM.""" + try: + # Clean entity_id: remove surrounding quotes if present + clean_entity_id = self._clean_entity_id(entity_id) + + prompt = ENTITY_EXTRACTION_PROMPT.format( + entity_name=clean_entity_id, + chunk_content=chunk.content[:2000] + if chunk.content + else "", # Limit content length + ) + + response = asyncio.run(self.llm_client.generate_answer(prompt)) + + # Try to parse JSON response + try: + extraction = json.loads(response) + except json.JSONDecodeError: + # Try to extract JSON from markdown code blocks + json_match = re.search(r"\{.*\}", response, re.DOTALL) + if json_match: + extraction = json.loads(json_match.group(0)) + else: + logger.warning( + f"Failed to parse extraction response for {entity_id} in chunk {chunk.id}" + ) + return {} + + # Normalize entity_type to lowercase and validate + entity_type = extraction.get("entity_type", "").lower().strip() + # Valid preset types + valid_types = { + "concept", + "date", + "location", + "keyword", + "organization", + "person", + "event", + "work", + "nature", + "artificial", + "science", + "technology", + "mission", + "gene", + } + # If entity_type is not in valid types, default to "concept" + if entity_type not in valid_types: + if entity_type: # If LLM provided a type but it's invalid + logger.warning( + f"Invalid entity_type '{entity_type}' for entity {clean_entity_id} in chunk {chunk.id}, " + f"defaulting to 'concept'" + ) + entity_type = "concept" + + return { + "entity_type": entity_type, + "description": extraction.get("description", ""), + } + except Exception as e: + logger.error( + f"Error extracting entity {entity_id} from chunk {chunk.id}: {e}" + ) + return {} + + def _check_entity_type_consistency( + self, entity_id: str, type_extractions: Dict[str, str] + ) -> Dict[str, Any]: + """Check entity type consistency using LLM.""" + if len(set(type_extractions.values())) <= 1: + # All types are the same, no conflict + return {"has_conflict": False} + + try: + type_list = [ + f"Chunk {chunk_id}: {entity_type}" + for chunk_id, entity_type in type_extractions.items() + if entity_type + ] + + prompt = ENTITY_TYPE_CONFLICT_PROMPT.format( + entity_name=entity_id, type_extractions="\n".join(type_list) + ) + + response = asyncio.run(self.llm_client.generate_answer(prompt)) + + # Parse JSON response + try: + result = json.loads(response) + except json.JSONDecodeError: + json_match = re.search(r"\{.*\}", response, re.DOTALL) + if json_match: + result = json.loads(json_match.group(0)) + else: + logger.warning( + f"Failed to parse conflict detection response for {entity_id}" + ) + return {"has_conflict": False} + + return result + except Exception as e: + logger.error(f"Error checking type consistency for {entity_id}: {e}") + return {"has_conflict": False} + + def _check_entity_description_consistency( + self, entity_id: str, descriptions: Dict[str, str] + ) -> Dict[str, Any]: + """Check entity description consistency using LLM.""" + # Filter out empty descriptions + valid_descriptions = {k: v for k, v in descriptions.items() if v} + if len(valid_descriptions) < 2: + return {"has_conflict": False} + + if len(set(valid_descriptions.values())) <= 1: + # All descriptions are the same, no conflict + return {"has_conflict": False} + + try: + desc_list = [ + f"Chunk {chunk_id}: {description}" + for chunk_id, description in valid_descriptions.items() + ] + + prompt = ENTITY_DESCRIPTION_CONFLICT_PROMPT.format( + entity_name=entity_id, descriptions="\n".join(desc_list) + ) + + response = asyncio.run(self.llm_client.generate_answer(prompt)) + + # Parse JSON response + try: + result = json.loads(response) + except json.JSONDecodeError: + json_match = re.search(r"\{.*\}", response, re.DOTALL) + if json_match: + result = json.loads(json_match.group(0)) + else: + logger.warning( + f"Failed to parse conflict detection response for {entity_id}" + ) + return {"has_conflict": False} + + return result + except Exception as e: + logger.error(f"Error checking description consistency for {entity_id}: {e}") + return {"has_conflict": False} + + def _check_relation_consistency( + self, src_id: str, dst_id: str, relation_extractions: Dict[str, str] + ) -> Dict[str, Any]: + """Check relation consistency using LLM.""" + if len(set(relation_extractions.values())) <= 1: + return {"has_conflict": False} + + try: + rel_list = [ + f"Chunk {chunk_id}: {relation}" + for chunk_id, relation in relation_extractions.items() + if relation + ] + + prompt = RELATION_CONFLICT_PROMPT.format( + source_entity=src_id, + target_entity=dst_id, + relation_descriptions="\n".join(rel_list), + ) + + response = asyncio.run(self.llm_client.generate_answer(prompt)) + + # Parse JSON response + try: + result = json.loads(response) + except json.JSONDecodeError: + json_match = re.search(r"\{.*\}", response, re.DOTALL) + if json_match: + result = json.loads(json_match.group(0)) + else: + logger.warning( + f"Failed to parse relation conflict response for {src_id}->{dst_id}" + ) + return {"has_conflict": False} + + return result + except Exception as e: + logger.error( + f"Error checking relation consistency for {src_id}->{dst_id}: {e}" + ) + return {"has_conflict": False} diff --git a/graphgen/models/evaluator/kg/structure_evaluator.py b/graphgen/models/evaluator/kg/structure_evaluator.py new file mode 100644 index 00000000..d9fa45a9 --- /dev/null +++ b/graphgen/models/evaluator/kg/structure_evaluator.py @@ -0,0 +1,97 @@ +from typing import Any, Dict, Optional + +import numpy as np +from scipy import stats + +from graphgen.bases import BaseGraphStorage +from graphgen.utils import logger + + +class StructureEvaluator: + """Evaluates structural robustness of the graph.""" + + def __init__( + self, + graph_storage: BaseGraphStorage, + noise_ratio_threshold: float = 0.15, + largest_cc_ratio_threshold: float = 0.90, + avg_degree_min: float = 2.0, + avg_degree_max: float = 5.0, + powerlaw_r2_threshold: float = 0.75, + ): + self.graph_storage = graph_storage + self.noise_ratio_threshold = noise_ratio_threshold + self.largest_cc_ratio_threshold = largest_cc_ratio_threshold + self.avg_degree_min = avg_degree_min + self.avg_degree_max = avg_degree_max + self.powerlaw_r2_threshold = powerlaw_r2_threshold + + def evaluate(self) -> Dict[str, Any]: + """ + Evaluate the structural robustness of the graph. + :return: + """ + storage = self.graph_storage + + total_nodes = storage.get_node_count() + if total_nodes == 0: + return {"error": "Empty graph"} + + total_edges = storage.get_edge_count() + degree_map = storage.get_all_node_degrees() + + # Noise ratio: isolated nodes / total nodes + isolated_nodes = [nid for nid, deg in degree_map.items() if deg == 0] + noise_ratio = len(isolated_nodes) / total_nodes + + # Largest connected component + components = storage.get_connected_components(undirected=True) + largest_cc_ratio = ( + len(max(components, key=len)) / total_nodes if components else 0 + ) + + avg_degree = sum(degree_map.values()) / total_nodes + powerlaw_r2 = self._calculate_powerlaw_r2(degree_map) + + results = { + "total_nodes": total_nodes, + "total_edges": total_edges, + "noise_ratio": noise_ratio, + "largest_cc_ratio": largest_cc_ratio, + "avg_degree": avg_degree, + "powerlaw_r2": powerlaw_r2, + "is_robust": ( + noise_ratio < self.noise_ratio_threshold + and largest_cc_ratio > self.largest_cc_ratio_threshold + and self.avg_degree_min <= avg_degree <= self.avg_degree_max + and ( + powerlaw_r2 is not None and powerlaw_r2 > self.powerlaw_r2_threshold + ) + ), + } + + return results + + @staticmethod + def _calculate_powerlaw_r2(degree_map: Dict[str, int]) -> Optional[float]: + degrees = [deg for deg in degree_map.values() if deg > 0] + + if len(degrees) < 10: + logger.warning("Insufficient nodes for power law fitting") + return None + + try: + # Fit power law: log(y) = a * log(x) + b + log_degrees = np.log(degrees) + sorted_log_degrees = np.sort(log_degrees) + x = np.arange(1, len(sorted_log_degrees) + 1) + log_x = np.log(x) + + # Linear regression on log-log scale + r_value, *_ = stats.linregress(log_x, sorted_log_degrees) + r2 = r_value**2 + + return float(r2) + except Exception as e: + logger.error(f"Power law R² calculation failed: {e}") + return None diff --git a/graphgen/models/evaluator/length_evaluator.py b/graphgen/models/evaluator/length_evaluator.py deleted file mode 100644 index d5c33211..00000000 --- a/graphgen/models/evaluator/length_evaluator.py +++ /dev/null @@ -1,19 +0,0 @@ -from graphgen.bases.datatypes import QAPair -from graphgen.models.evaluator.base_evaluator import BaseEvaluator -from graphgen.models.tokenizer import Tokenizer -from graphgen.utils import create_event_loop - - -class LengthEvaluator(BaseEvaluator): - def __init__(self, tokenizer_name: str = "cl100k_base", max_concurrent: int = 100): - super().__init__(max_concurrent) - self.tokenizer_name = tokenizer_name - self.tokenizer = Tokenizer(model_name=self.tokenizer_name) - - async def evaluate_single(self, pair: QAPair) -> float: - loop = create_event_loop() - return await loop.run_in_executor(None, self._calculate_length, pair.answer) - - def _calculate_length(self, text: str) -> float: - tokens = self.tokenizer.encode(text) - return len(tokens) diff --git a/graphgen/models/evaluator/qa/__init__.py b/graphgen/models/evaluator/qa/__init__.py new file mode 100644 index 00000000..a9b445b4 --- /dev/null +++ b/graphgen/models/evaluator/qa/__init__.py @@ -0,0 +1,4 @@ +from .length_evaluator import LengthEvaluator +from .mtld_evaluator import MTLDEvaluator +from .reward_evaluator import RewardEvaluator +from .uni_evaluator import UniEvaluator diff --git a/graphgen/models/evaluator/qa/length_evaluator.py b/graphgen/models/evaluator/qa/length_evaluator.py new file mode 100644 index 00000000..266edfb6 --- /dev/null +++ b/graphgen/models/evaluator/qa/length_evaluator.py @@ -0,0 +1,18 @@ + +import os +from graphgen.bases import BaseEvaluator, QAPair +from graphgen.models.tokenizer import Tokenizer + + +class LengthEvaluator(BaseEvaluator): + def __init__(self, tokenizer_name: str = None): + tokenizer_model = tokenizer_name or os.environ.get("TOKENIZER_MODEL", "cl100k_base") + self.tokenizer: Tokenizer = Tokenizer(tokenizer_model) + + def evaluate(self, pair: QAPair) -> float: + """ + Evaluate the length of the qa pair. + """ + content = pair.question + pair.answer + tokens = self.tokenizer.encode(content) + return len(tokens) diff --git a/graphgen/models/evaluator/mtld_evaluator.py b/graphgen/models/evaluator/qa/mtld_evaluator.py similarity index 59% rename from graphgen/models/evaluator/mtld_evaluator.py rename to graphgen/models/evaluator/qa/mtld_evaluator.py index c106d86c..e4e18d32 100644 --- a/graphgen/models/evaluator/mtld_evaluator.py +++ b/graphgen/models/evaluator/qa/mtld_evaluator.py @@ -1,38 +1,33 @@ from typing import Set -from graphgen.bases.datatypes import QAPair -from graphgen.models.evaluator.base_evaluator import BaseEvaluator -from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language - -nltk_helper = NLTKHelper() +from graphgen.bases import BaseEvaluator, QAPair +from graphgen.utils import NLTKHelper, detect_main_language class MTLDEvaluator(BaseEvaluator): """ - 衡量文本词汇多样性的指标 + Metrics for measuring the lexical diversity of text. """ - def __init__(self, max_concurrent: int = 100): - super().__init__(max_concurrent) - self.stopwords_en: Set[str] = set(nltk_helper.get_stopwords("english")) - self.stopwords_zh: Set[str] = set(nltk_helper.get_stopwords("chinese")) - - async def evaluate_single(self, pair: QAPair) -> float: - loop = create_event_loop() - return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer) + def __init__(self, threshold: float = 0.72): + self.nltk_helper = NLTKHelper() + self.stopwords_en: Set[str] = set(self.nltk_helper.get_stopwords("en")) + self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("zh")) + self.threshold = threshold - def _calculate_mtld_score(self, text: str, threshold=0.72) -> float: + def evaluate(self, pair: QAPair) -> float: """ - 计算MTLD (向前和向后的平均值) + Calculate the MTLD (Mean Token Length Diversity) score for a given text. min is 1.0 higher is better """ + text = pair.answer if not text or not text.strip(): return 0.0 lang = detect_main_language(text) - tokens = nltk_helper.word_tokenize(text, lang) + tokens = self.nltk_helper.word_tokenize(text, lang) stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en filtered_tokens = [word for word in tokens if word not in stopwords] @@ -41,13 +36,13 @@ def _calculate_mtld_score(self, text: str, threshold=0.72) -> float: if not filtered_tokens: return 0 - # 计算向前的MTLD - forward_factors = self._compute_factors(filtered_tokens, threshold) + # Compute forward factors + forward_factors = self._compute_factors(filtered_tokens, self.threshold) - # 计算向后的MTLD - backward_factors = self._compute_factors(filtered_tokens[::-1], threshold) + # Compute backward factors + backward_factors = self._compute_factors(filtered_tokens[::-1], self.threshold) - # 取平均值 + # Compute average factors return (forward_factors + backward_factors) / 2 @staticmethod @@ -66,7 +61,7 @@ def _compute_factors(tokens: list, threshold: float) -> float: current_segment = [] unique_words = set() - # 处理最后一个不完整片段 + # handle last segment if current_segment: ttr = len(unique_words) / len(current_segment) if ttr <= threshold: diff --git a/graphgen/models/evaluator/qa/reward_evaluator.py b/graphgen/models/evaluator/qa/reward_evaluator.py new file mode 100644 index 00000000..a7fcbc22 --- /dev/null +++ b/graphgen/models/evaluator/qa/reward_evaluator.py @@ -0,0 +1,66 @@ +from typing import Optional +from graphgen.bases import BaseEvaluator, QAPair + + +class RewardEvaluator(BaseEvaluator): + """ + Reward Model Evaluator for single QAPair evaluation. + """ + + def __init__( + self, + reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2", + max_length: int = 2560, + device: Optional[str] = None, + ): + """ + Initialize the reward evaluator. + + Args: + reward_name: Model name or path on HuggingFace Hub + max_length: Maximum token length for the model + device: Device to run the model on. If None, auto-detect CUDA/CPU. + """ + self.reward_name = reward_name + self.max_length = max_length + + import torch + from transformers import AutoModelForSequenceClassification, AutoTokenizer + self.torch = torch + + # Set device (auto-detect if not specified) + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + try: + self.tokenizer = AutoTokenizer.from_pretrained(reward_name) + self.model = AutoModelForSequenceClassification.from_pretrained(reward_name) + self.model.to(self.device) + self.model.eval() + except Exception as e: + raise RuntimeError(f"Failed to load reward model '{reward_name}': {e}") from e + + def evaluate(self, pair: QAPair) -> float: + """ + Evaluate a single question-answer pair using the reward model. + + Args: + pair: QAPair containing question and answer strings + + Returns: + Score as a float + """ + # Tokenize + inputs = self.tokenizer( + pair.question, + pair.answer, + return_tensors="pt", + max_length=self.max_length, + truncation=True, + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get score + with self.torch.no_grad(): + score = self.model(**inputs).logits[0].item() + + return score diff --git a/graphgen/models/evaluator/qa/uni_evaluator.py b/graphgen/models/evaluator/qa/uni_evaluator.py new file mode 100644 index 00000000..38406512 --- /dev/null +++ b/graphgen/models/evaluator/qa/uni_evaluator.py @@ -0,0 +1,105 @@ +# https://github.com/maszhongming/UniEval/tree/main +from typing import Optional, List +from graphgen.bases import BaseEvaluator, QAPair + + +class UniEvaluator(BaseEvaluator): + """ + UniEvaluator for single QAPair evaluation across quality dimensions. + + Dimensions: naturalness, coherence, understandability + + Usage: + evaluator = UniEvaluator() + pair = QAPair(question="...", answer="...") + scores = evaluator.evaluate(pair) + # {"naturalness": 0.85, "coherence": 0.92, "understandability": 0.88} + """ + + DEFAULT_MODEL: str = "MingZhong/unieval-sum" + DEFAULT_DIMS: List[str] = ["naturalness", "coherence", "understandability"] + DEFAULT_MAX_LENGTH: int = 2560 + + def __init__( + self, + model_name: Optional[str] = None, + max_length: Optional[int] = None, + device: Optional[str] = None, + ): + """ + Args: + model_name: HuggingFace model name/path + max_length: Tokenizer max sequence length + device: 'cuda', 'cpu', or None for auto-detect + """ + import torch + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + self.torch = torch + + self.model_name = model_name or self.DEFAULT_MODEL + self.max_length = max_length or self.DEFAULT_MAX_LENGTH + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + # Load model & tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) + self.model.to(self.device) + self.model.eval() + + # Pre-compute Yes/No token IDs + self._yes_id = self.tokenizer("Yes")["input_ids"][0] + self._no_id = self.tokenizer("No")["input_ids"][0] + + @staticmethod + def _build_input_text(dimension: str, question: str, answer: str) -> str: + """Construct input text for specified dimension.""" + if dimension == "naturalness": + return f"question: Is this a natural response? response: {answer}" + if dimension == "coherence": + return f"question: Is this a coherent response? response: {answer} history: {question}" + if dimension == "understandability": + return f"question: Is this an understandable response? response: {answer}" + raise NotImplementedError(f"Unsupported dimension '{dimension}'") + + def evaluate( + self, + pair: QAPair, + dimensions: Optional[List[str]] = None, + ) -> dict[str, float]: + """Evaluate a single QAPair across specified dimensions.""" + dimensions = dimensions or self.DEFAULT_DIMS + + # Validate dimensions + invalid = set(dimensions) - set(self.DEFAULT_DIMS) + if invalid: + raise ValueError(f"Invalid dimensions: {invalid}. Available: {self.DEFAULT_DIMS}") + + results = {} + no_token = self.torch.tensor([[self._no_id]], device=self.device) + + for dim in dimensions: + # Tokenize input + src = self.tokenizer( + self._build_input_text(dim, pair.question, pair.answer), + max_length=self.max_length, + truncation=True, + return_tensors="pt", + ) + src_tokens = src["input_ids"].to(self.device) + src_mask = src["attention_mask"].to(self.device) + + # Score + with self.torch.no_grad(): + logits = self.model( + input_ids=src_tokens, + attention_mask=src_mask, + labels=no_token, + use_cache=False, + ).logits[:, 0, :] # [1, vocab_size] + + probs = self.torch.softmax(logits, dim=-1)[0] + score = probs[self._yes_id] / (probs[self._yes_id] + probs[self._no_id]) + + results[dim] = score.item() + + return results diff --git a/graphgen/models/evaluator/reward_evaluator.py b/graphgen/models/evaluator/reward_evaluator.py deleted file mode 100644 index 4d2c2fb9..00000000 --- a/graphgen/models/evaluator/reward_evaluator.py +++ /dev/null @@ -1,107 +0,0 @@ -from dataclasses import dataclass - -from tqdm import tqdm - -from graphgen.bases.datatypes import QAPair - - -@dataclass -class RewardEvaluator: - """ - Reward Model Evaluator. - OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好 - """ - - reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2" - max_length: int = 2560 - results: list[float] = None - - def __post_init__(self): - import torch - - self.num_gpus = torch.cuda.device_count() - - @staticmethod - def process_chunk(rank, pairs, reward_name, max_length, return_dict): - import torch - from transformers import AutoModelForSequenceClassification, AutoTokenizer - - device = f"cuda:{rank}" - torch.cuda.set_device(rank) - - rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name) - tokenizer = AutoTokenizer.from_pretrained(reward_name) - rank_model.to(device) - rank_model.eval() - - results = [] - with torch.no_grad(): - for pair in tqdm(pairs): - inputs = tokenizer( - pair.question, - pair.answer, - return_tensors="pt", - max_length=max_length, - truncation=True, - ) - inputs = {k: v.to(device) for k, v in inputs.items()} - score = rank_model(**inputs).logits[0].item() - results.append(score) - - return_dict[rank] = results - - def evaluate(self, pairs: list[QAPair]) -> list[float]: - import torch.multiprocessing as mp - - chunk_size = len(pairs) // self.num_gpus - chunks = [] - for i in range(self.num_gpus): - start = i * chunk_size - end = start + chunk_size - if i == self.num_gpus - 1: - end = len(pairs) - chunks.append(pairs[start:end]) - - # multi-process - manager = mp.Manager() - return_dict = manager.dict() - processes = [] - - for rank, chunk in enumerate(chunks): - p = mp.Process( - target=self.process_chunk, - args=(rank, chunk, self.reward_name, self.max_length, return_dict), - ) - p.start() - processes.append(p) - - for p in processes: - p.join() - - # 合并结果 - results = [] - for rank in range(len(chunks)): - results.extend(return_dict[rank]) - - for p in processes: - if p.is_alive(): - p.terminate() - p.join() - - return results - - def get_average_score(self, pairs: list[QAPair]) -> float: - """ - Get the average score of a batch of texts. - """ - results = self.evaluate(pairs) - self.results = results - return sum(self.results) / len(pairs) - - def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]: - """ - Get the min and max score of a batch of texts. - """ - if self.results is None: - self.get_average_score(pairs) - return min(self.results), max(self.results) diff --git a/graphgen/models/evaluator/uni_evaluator.py b/graphgen/models/evaluator/uni_evaluator.py deleted file mode 100644 index 20fa3517..00000000 --- a/graphgen/models/evaluator/uni_evaluator.py +++ /dev/null @@ -1,183 +0,0 @@ -# https://github.com/maszhongming/UniEval/tree/main - -from dataclasses import dataclass, field - -from tqdm import tqdm - -from graphgen.bases.datatypes import QAPair - - -def _add_questions(dimension: str, question: str, answer: str): - if dimension == "naturalness": - cur_input = ( - "question: Is this a natural response in the dialogue? response: " - + answer - ) - elif dimension == "coherence": - cur_input = ( - "question: Is this a coherent response given the dialogue history? response: " - + answer - + " dialogue history: " - + question - ) - elif dimension == "understandability": - cur_input = ( - "question: Is this an understandable response in the dialogue? response: " - + answer - ) - else: - raise NotImplementedError( - "The input format for this dimension is still undefined. Please customize it first." - ) - return cur_input - - -@dataclass -class UniEvaluator: - model_name: str = "MingZhong/unieval-sum" - dimensions: list = field( - default_factory=lambda: ["naturalness", "coherence", "understandability"] - ) - max_length: int = 2560 - results: dict = None - - def __post_init__(self): - import torch - - self.num_gpus = torch.cuda.device_count() - self.results = {} - - @staticmethod - def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict): - import torch - from transformers import AutoModelForSeq2SeqLM, AutoTokenizer - - device = f"cuda:{rank}" - torch.cuda.set_device(rank) - - rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) - tokenizer = AutoTokenizer.from_pretrained(model_name) - rank_model.to(device) - rank_model.eval() - - softmax = torch.nn.Softmax(dim=1) - - pos_id = tokenizer("Yes")["input_ids"][0] - neg_id = tokenizer("No")["input_ids"][0] - - results = [] - with torch.no_grad(): - for pair in tqdm(pairs): - text = _add_questions(dimension, pair.question, pair.answer) - - tgt = "No" - - encoded_src = tokenizer( - text, - max_length=max_length, - truncation=True, - padding=True, - return_tensors="pt", - ) - encoded_tgt = tokenizer( - tgt, - max_length=max_length, - truncation=True, - padding=True, - return_tensors="pt", - ) - - src_tokens = encoded_src["input_ids"].to(device) - src_mask = encoded_src["attention_mask"].to(device) - - tgt_tokens = encoded_tgt["input_ids"].to(device)[:, 0].unsqueeze(-1) - - output = rank_model( - input_ids=src_tokens, - attention_mask=src_mask, - labels=tgt_tokens, - use_cache=False, - ) - - logits = output.logits.view(-1, rank_model.config.vocab_size) - - pos_score = softmax(logits)[:, pos_id] # Yes - neg_score = softmax(logits)[:, neg_id] - score = pos_score / (pos_score + neg_score) - - results.append(score.item()) - - return_dict[rank] = results - - def evaluate(self, pairs: list[QAPair]) -> list[dict]: - import torch.multiprocessing as mp - - final_results = [] - for dimension in self.dimensions: - chunk_size = len(pairs) // self.num_gpus - chunks = [] - for i in range(self.num_gpus): - start = i * chunk_size - end = start + chunk_size - if i == self.num_gpus - 1: - end = len(pairs) - chunks.append(pairs[start:end]) - - # multi-process - manager = mp.Manager() - return_dict = manager.dict() - processes = [] - - for rank, chunk in enumerate(chunks): - p = mp.Process( - target=self.process_chunk, - args=( - rank, - chunk, - self.model_name, - self.max_length, - dimension, - return_dict, - ), - ) - p.start() - processes.append(p) - - for p in processes: - p.join() - - # 合并结果 - results = [] - for rank in range(len(chunks)): - results.extend(return_dict[rank]) - - for p in processes: - if p.is_alive(): - p.terminate() - p.join() - - final_results.append({dimension: results}) - return final_results - - def get_average_score(self, pairs: list[QAPair]) -> dict: - """ - Get the average score of a batch of texts. - """ - results = self.evaluate(pairs) - final_results = {} - for result in results: - for key, value in result.items(): - final_results[key] = sum(value) / len(value) - self.results[key] = value - return final_results - - def get_min_max_score(self, pairs: list[QAPair]) -> dict: - """ - Get the min and max score of a batch of texts. - """ - if self.results is None: - self.get_average_score(pairs) - final_results = {} - for key, value in self.results.items(): - final_results[key] = min(value), max(value) - return final_results diff --git a/graphgen/models/storage/graph/kuzu_storage.py b/graphgen/models/storage/graph/kuzu_storage.py index db3e97ea..52b41519 100644 --- a/graphgen/models/storage/graph/kuzu_storage.py +++ b/graphgen/models/storage/graph/kuzu_storage.py @@ -1,7 +1,8 @@ import json import os +from collections import defaultdict from dataclasses import dataclass -from typing import Any +from typing import Any, Dict, List, Set try: import kuzu @@ -78,6 +79,94 @@ def _safe_json_loads(data_str: str) -> dict: print(f"Error decoding JSON: {e}") return {} + def is_directed(self) -> bool: + return True + + def get_all_node_degrees(self) -> Dict[str, int]: + query = """ + MATCH (n:Entity) + OPTIONAL MATCH (n)-[r]-() + RETURN n.id, count(r) as degree + """ + + result = self._conn.execute(query) + degree_map = {} + while result.has_next(): + row = result.get_next() + if row and len(row) >= 2: + node_id, degree = row[0], row[1] + degree_map[node_id] = int(degree) + + return degree_map + + def get_isolated_nodes(self) -> List[str]: + query = """ + MATCH (n:Entity) + WHERE NOT (n)--() + RETURN n.id + """ + + result = self._conn.execute(query) + return [row[0] for row in result if row] + + def get_node_count(self) -> int: + result = self._conn.execute("MATCH (n:Entity) RETURN count(n)") + return result.get_next()[0] + + def get_edge_count(self) -> int: + result = self._conn.execute("MATCH ()-[e:Relation]->() RETURN count(e)") + return result.get_next()[0] + + def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: + parent = {} + rank = {} + + def find(x: str) -> str: + if parent[x] != x: + parent[x] = find(parent[x]) + return parent[x] + + def union(x: str, y: str): + root_x, root_y = find(x), find(y) + if root_x == root_y: + return + if rank[root_x] < rank[root_y]: + parent[root_x] = root_y + elif rank[root_x] > rank[root_y]: + parent[root_y] = root_x + else: + parent[root_y] = root_x + rank[root_x] += 1 + + all_nodes = self.get_all_node_degrees().keys() + for node_id in all_nodes: + parent[node_id] = node_id + rank[node_id] = 0 + + query = ( + """ + MATCH (a:Entity)-[e:Relation]-(b:Entity) + RETURN DISTINCT a.id, b.id + """ + if undirected + else """ + MATCH (a:Entity)-[e:Relation]->(b:Entity) + RETURN DISTINCT a.id, b.id + """ + ) + + result = self._conn.execute(query) + for row in result: + if row and len(row) >= 2: + union(row[0], row[1]) + + components_dict = defaultdict(set) + for node_id in all_nodes: + root = find(node_id) + components_dict[root].add(node_id) + + return list(components_dict.values()) + def has_node(self, node_id: str) -> bool: result = self._conn.execute( "MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id} diff --git a/graphgen/models/storage/graph/networkx_storage.py b/graphgen/models/storage/graph/networkx_storage.py index 7fb73b79..b043e9d2 100644 --- a/graphgen/models/storage/graph/networkx_storage.py +++ b/graphgen/models/storage/graph/networkx_storage.py @@ -1,7 +1,7 @@ import html import os from dataclasses import dataclass -from typing import Any, Optional, Union, cast +from typing import Any, Dict, List, Optional, Set, Union, cast import networkx as nx @@ -10,6 +10,31 @@ @dataclass class NetworkXStorage(BaseGraphStorage): + def is_directed(self) -> bool: + return self._graph.is_directed() + + def get_all_node_degrees(self) -> Dict[str, int]: + return { + str(node_id): int(self._graph.degree[node_id]) + for node_id in self._graph.nodes() + } + + def get_node_count(self) -> int: + return self._graph.number_of_nodes() + + def get_edge_count(self) -> int: + return self._graph.number_of_edges() + + def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: + graph = self._graph + + if undirected and graph.is_directed(): + graph = graph.to_undirected() + + return [ + set(str(node) for node in comp) for comp in nx.connected_components(graph) + ] + @staticmethod def load_nx_graph(file_name) -> Optional[nx.Graph]: if os.path.exists(file_name): diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 5bb1261a..ab840cc5 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,5 +1,6 @@ from .build_kg import BuildKGService from .chunk import ChunkService +from .evaluate import EvaluateService from .extract import ExtractService from .generate import GenerateService from .judge import JudgeService @@ -8,6 +9,7 @@ from .read import read from .search import SearchService + operators = { "read": read, "chunk": ChunkService, @@ -18,4 +20,5 @@ "search": SearchService, "partition": PartitionService, "generate": GenerateService, + "evaluate": EvaluateService, } diff --git a/graphgen/operators/evaluate/__init__.py b/graphgen/operators/evaluate/__init__.py index e69de29b..060c68d6 100644 --- a/graphgen/operators/evaluate/__init__.py +++ b/graphgen/operators/evaluate/__init__.py @@ -0,0 +1,3 @@ +from .evaluate_service import EvaluateService + +__all__ = ["EvaluateService"] diff --git a/graphgen/operators/evaluate/evaluate.py b/graphgen/operators/evaluate/evaluate.py deleted file mode 100644 index fdbfbf82..00000000 --- a/graphgen/operators/evaluate/evaluate.py +++ /dev/null @@ -1,177 +0,0 @@ -# TODO: this module needs refactoring to merge into GraphGen framework -"""Evaluate the quality of the generated text using various metrics""" - -import argparse -import json -import os - -import pandas as pd -from dotenv import load_dotenv - -from graphgen.bases.datatypes import QAPair -from graphgen.models import ( - LengthEvaluator, - MTLDEvaluator, - RewardEvaluator, - UniEvaluator, -) -from graphgen.utils import logger, set_logger - -sys_path = os.path.abspath(os.path.dirname(__file__)) -set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log")) - -load_dotenv() - - -def evaluate_length(corpus, tokenizer_name): - length_evaluator = LengthEvaluator(tokenizer_name=tokenizer_name) - logger.info("Length evaluator loaded") - scores = length_evaluator.get_average_score(corpus) - logger.info("Length scores: %s", scores) - return scores - - -def evaluate_mtld(corpus): - mtld_evaluator = MTLDEvaluator() - logger.info("MTLD evaluator loaded") - scores = mtld_evaluator.get_average_score(corpus) - logger.info("MTLD scores: %s", scores) - min_max_scores = mtld_evaluator.get_min_max_score(corpus) - logger.info("MTLD min max scores: %s", min_max_scores) - return scores, min_max_scores - - -def evaluate_reward(corpus, reward_model_names): - scores = [] - for reward_name in reward_model_names: - reward_evaluator = RewardEvaluator(reward_name=reward_name) - logger.info("Loaded reward model: %s", reward_name) - average_score = reward_evaluator.get_average_score(corpus) - logger.info("%s scores: %s", reward_name, average_score) - min_max_scores = reward_evaluator.get_min_max_score(corpus) - logger.info("%s min max scores: %s", reward_name, min_max_scores) - scores.append( - { - "reward_name": reward_name.split("/")[-1], - "score": average_score, - "min_max_scores": min_max_scores, - } - ) - del reward_evaluator - clean_gpu_cache() - return scores - - -def evaluate_uni(corpus, uni_model_name): - uni_evaluator = UniEvaluator(model_name=uni_model_name) - logger.info("Uni evaluator loaded with model %s", uni_model_name) - uni_scores = uni_evaluator.get_average_score(corpus) - for key, value in uni_scores.items(): - logger.info("Uni %s scores: %s", key, value) - min_max_scores = uni_evaluator.get_min_max_score(corpus) - for key, value in min_max_scores.items(): - logger.info("Uni %s min max scores: %s", key, value) - del uni_evaluator - clean_gpu_cache() - return ( - uni_scores["naturalness"], - uni_scores["coherence"], - uni_scores["understandability"], - min_max_scores["naturalness"], - min_max_scores["coherence"], - min_max_scores["understandability"], - ) - - -def clean_gpu_cache(): - import torch - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -if __name__ == "__main__": - import torch.multiprocessing as mp - - parser = argparse.ArgumentParser() - - parser.add_argument( - "--folder", type=str, default="cache/data", help="folder to load data" - ) - parser.add_argument( - "--output", type=str, default="cache/output", help="path to save output" - ) - - parser.add_argument( - "--tokenizer", type=str, default="cl100k_base", help="tokenizer name" - ) - parser.add_argument( - "--reward", - type=str, - default="OpenAssistant/reward-model-deberta-v3-large-v2", - help="Comma-separated list of reward models", - ) - parser.add_argument( - "--uni", type=str, default="MingZhong/unieval-sum", help="uni model name" - ) - - args = parser.parse_args() - - if not os.path.exists(args.folder): - raise ValueError(f"Folder {args.folder} does not exist") - - if not os.path.exists(args.output): - os.makedirs(args.output) - - reward_models = args.reward.split(",") - - results = [] - - logger.info("Data loaded from %s", args.folder) - mp.set_start_method("spawn") - - for file in os.listdir(args.folder): - if file.endswith(".json"): - logger.info("Processing %s", file) - with open(os.path.join(args.folder, file), "r", encoding="utf-8") as f: - data = json.load(f) - data = [ - QAPair(question=data[key]["question"], answer=data[key]["answer"]) - for key in data - ] - - length_scores = evaluate_length(data, args.tokenizer) - mtld_scores, min_max_mtld_scores = evaluate_mtld(data) - reward_scores = evaluate_reward(data, reward_models) - ( - uni_naturalness_scores, - uni_coherence_scores, - uni_understandability_scores, - min_max_uni_naturalness_scores, - min_max_uni_coherence_scores, - min_max_uni_understandability_scores, - ) = evaluate_uni(data, args.uni) - - result = { - "file": file, - "number": len(data), - "length": length_scores, - "mtld": mtld_scores, - "mtld_min_max": min_max_mtld_scores, - "uni_naturalness": uni_naturalness_scores, - "uni_coherence": uni_coherence_scores, - "uni_understandability": uni_understandability_scores, - "uni_naturalness_min_max": min_max_uni_naturalness_scores, - "uni_coherence_min_max": min_max_uni_coherence_scores, - "uni_understandability_min_max": min_max_uni_understandability_scores, - } - for reward_score in reward_scores: - result[reward_score["reward_name"]] = reward_score["score"] - result[f"{reward_score['reward_name']}_min_max"] = reward_score[ - "min_max_scores" - ] - - results.append(result) - - results = pd.DataFrame(results) - results.to_csv(os.path.join(args.output, "evaluation.csv"), index=False) diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py new file mode 100644 index 00000000..b0875d7f --- /dev/null +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -0,0 +1,181 @@ +from typing import Any, Dict + +import pandas as pd + +from graphgen.bases import BaseLLMWrapper, BaseOperator, QAPair +from graphgen.common import init_llm, init_storage +from graphgen.utils import logger, run_concurrent + + +class EvaluateService(BaseOperator): + """ + 1. KG Quality Evaluation + 2. QA Quality Evaluation + """ + + def __init__( + self, + working_dir: str = "cache", + metrics: list[str] = None, + graph_backend: str = "kuzu", + kv_backend: str = "rocksdb", + **kwargs, + ): + super().__init__(working_dir=working_dir, op_name="evaluate_service") + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + self.metrics = metrics or [] + self.kwargs = kwargs + self.graph_storage = init_storage( + backend=graph_backend, working_dir=working_dir, namespace="graph" + ) + self.chunk_storage = init_storage( + backend=kv_backend, working_dir=working_dir, namespace="chunk" + ) + + # Initialize evaluators + self.qa_evaluators = {} + self.kg_evaluators = {} + self._init_evaluators() + + def _init_evaluators(self): + """Initialize QA and KG evaluators based on metrics.""" + for metric in self.metrics: + if metric == "qa_length": + from graphgen.models import LengthEvaluator + + self.qa_evaluators[metric] = LengthEvaluator() + elif metric == "qa_mtld": + from graphgen.models import MTLDEvaluator + + self.qa_evaluators[metric] = MTLDEvaluator( + **self.kwargs.get("mtld_params", {}) + ) + elif metric == "qa_reward_score": + from graphgen.models import RewardEvaluator + + self.qa_evaluators[metric] = RewardEvaluator( + **self.kwargs.get("reward_params", {}) + ) + elif metric == "qa_uni_score": + from graphgen.models import UniEvaluator + + self.qa_evaluators[metric] = UniEvaluator( + **self.kwargs.get("uni_params", {}) + ) + elif metric == "kg_accuracy": + from graphgen.models import AccuracyEvaluator + + self.kg_evaluators[metric] = AccuracyEvaluator( + graph_storage=self.graph_storage, + chunk_storage=self.chunk_storage, + llm_client=self.llm_client, + ) + elif metric == "kg_consistency": + from graphgen.models import ConsistencyEvaluator + + self.kg_evaluators[metric] = ConsistencyEvaluator( + graph_storage=self.graph_storage, + chunk_storage=self.chunk_storage, + llm_client=self.llm_client, + ) + elif metric == "kg_structure": + from graphgen.models import StructureEvaluator + + self.kg_evaluators[metric] = StructureEvaluator( + graph_storage=self.graph_storage, + **self.kwargs.get("structure_params", {}), + ) + else: + raise ValueError(f"Unknown QA metric: {metric}") + + async def _process_single_qa(self, item: dict[str, Any]) -> dict[str, Any]: + try: + qa_pair = QAPair( + question=str(item.get("question", "")), + answer=str(item.get("answer", "")), + ) + if not qa_pair.question or not qa_pair.answer: + self.logger.error("Empty question or answer, skipping.") + return {} + except Exception as e: + self.logger.error("Error in QAPair creation: %s", str(e)) + return {} + + for metric, evaluator in self.qa_evaluators.items(): + try: + score = evaluator.evaluate(qa_pair) + if isinstance(score, dict): + for sub_metric, sub_score in score.items(): + item[f"{metric}_{sub_metric}"] = float(sub_score) + else: + item[metric] = float(score) + except Exception as e: + self.logger.error("Error in %s evaluation: %s", metric, str(e)) + item[metric] = None + return item + + def _evaluate_qa(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]: + def transform_messages_format(items: list[dict]) -> list[dict]: + """ + Transform from [{'messages': [...]}, ...] to [{'question': '...', 'answer': '...'}, ...] + """ + transformed = [] + for item in items: + messages = item.get("messages", []) + question = next( + (m["content"] for m in messages if m.get("role") == "user"), "" + ) + answer = next( + (m["content"] for m in messages if m.get("role") == "assistant"), "" + ) + + transformed.append({"question": question, "answer": answer}) + return transformed + + if not items: + return [] + + if not self.qa_evaluators: + self.logger.warning("No QA evaluators initialized, skipping QA evaluation") + return [] + + items = transform_messages_format(items) + results = run_concurrent( + self._process_single_qa, + items, + desc="Evaluating QA items", + unit="item", + ) + + results = [item for item in results if item] + return results + + def _evaluate_kg(self) -> Dict[str, Any]: + results = {} + + for metric, evaluator in self.kg_evaluators.items(): + try: + self.logger.info("Running %s evaluation...", metric) + score = evaluator.evaluate() + results[metric] = score + except Exception as e: + self.logger.error("Error in %s evaluation: %s", metric, str(e)) + results[metric] = {"error": str(e)} + return results + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + # QA evaluation + if len(self.qa_evaluators) > 0: + items = batch.to_dict(orient="records") + results = self._evaluate_qa(items) + return pd.DataFrame(results) + + # KG evaluation + if len(self.kg_evaluators) > 0: + results = self._evaluate_kg() + # Convert dict to DataFrame (single row) + return pd.DataFrame([results]) + + # No metrics specified + logger.warning("No metrics specified, returning empty DataFrame") + return pd.DataFrame() diff --git a/graphgen/run.py b/graphgen/run.py index a1b65364..d3d47cd3 100644 --- a/graphgen/run.py +++ b/graphgen/run.py @@ -91,10 +91,11 @@ def main(): results = engine.execute(ds) for node_id, dataset in results.items(): - output_path = os.path.join(output_path, f"{node_id}") - os.makedirs(output_path, exist_ok=True) + logger.info("Saving results for node %s", node_id) + node_output_path = os.path.join(output_path, f"{node_id}") + os.makedirs(node_output_path, exist_ok=True) dataset.write_json( - output_path, + node_output_path, filename_provider=NodeFilenameProvider(node_id), pandas_json_args_fn=lambda: { "force_ascii": False, @@ -102,7 +103,7 @@ def main(): "lines": True, }, ) - logger.info("Node %s results saved to %s", node_id, output_path) + logger.info("Node %s results saved to %s", node_id, node_output_path) save_config(os.path.join(output_path, "config.yaml"), config) logger.info("GraphGen completed successfully. Data saved to %s", output_path) diff --git a/graphgen/templates/__init__.py b/graphgen/templates/__init__.py index 0940e910..cbfa4e17 100644 --- a/graphgen/templates/__init__.py +++ b/graphgen/templates/__init__.py @@ -1,5 +1,6 @@ from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT +from .evaluation import ACCURACY_EVALUATION_PROMPT, CONSISTENCY_EVALUATION_PROMPT from .extraction import SCHEMA_GUIDED_EXTRACTION_PROMPT from .generation import ( AGGREGATED_GENERATION_PROMPT, diff --git a/graphgen/templates/evaluation/__init__.py b/graphgen/templates/evaluation/__init__.py new file mode 100644 index 00000000..7c2676a5 --- /dev/null +++ b/graphgen/templates/evaluation/__init__.py @@ -0,0 +1 @@ +from .kg import ACCURACY_EVALUATION_PROMPT, CONSISTENCY_EVALUATION_PROMPT diff --git a/graphgen/templates/evaluation/kg/__init__.py b/graphgen/templates/evaluation/kg/__init__.py new file mode 100644 index 00000000..db8edce6 --- /dev/null +++ b/graphgen/templates/evaluation/kg/__init__.py @@ -0,0 +1,2 @@ +from .accuracy_evaluation import ACCURACY_EVALUATION_PROMPT +from .consistency_evaluation import CONSISTENCY_EVALUATION_PROMPT diff --git a/graphgen/templates/evaluation/kg/accuracy_evaluation.py b/graphgen/templates/evaluation/kg/accuracy_evaluation.py new file mode 100644 index 00000000..f98b8b0f --- /dev/null +++ b/graphgen/templates/evaluation/kg/accuracy_evaluation.py @@ -0,0 +1,156 @@ +ENTITY_EVALUATION_PROMPT_ZH = """你是一个知识图谱质量评估专家。你的任务是从给定的文本块和提取的实体列表,评估实体提取的质量。 + +评估维度: +1. ACCURACY (准确性, 权重: 40%): 提取的实体是否正确,是否有误提取或错误识别 +2. COMPLETENESS (完整性, 权重: 40%): 是否遗漏了文本中的重要实体 +3. PRECISION (精确性, 权重: 20%): 提取的实体是否精确,命名是否准确 + +评分标准(每个维度 0-1 分): +- EXCELLENT (0.8-1.0): 高质量提取 +- GOOD (0.6-0.79): 良好质量,有少量问题 +- ACCEPTABLE (0.4-0.59): 可接受,有明显问题 +- POOR (0.0-0.39): 质量差,需要改进 + +综合评分 = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision + +请评估以下内容: + +原始文本块: +{chunk_content} + +提取的实体列表: +{extracted_entities} + +请以 JSON 格式返回评估结果: +{{ + "accuracy": <0-1之间的浮点数>, + "completeness": <0-1之间的浮点数>, + "precision": <0-1之间的浮点数>, + "overall_score": <综合评分>, + "accuracy_reasoning": "<准确性评估理由>", + "completeness_reasoning": "<完整性评估理由,包括遗漏的重要实体>", + "precision_reasoning": "<精确性评估理由>", + "issues": ["<发现的问题列表>"] +}} +""" + +ENTITY_EVALUATION_PROMPT_EN = """You are a Knowledge Graph Quality Assessment Expert. \ +Your task is to evaluate the quality of entity extraction from a given text block and extracted entity list. + +Evaluation Dimensions: +1. ACCURACY (Weight: 40%): Whether the extracted entities are correct, and if there are any false extractions or misidentifications +2. COMPLETENESS (Weight: 40%): Whether important entities from the text are missing +3. PRECISION (Weight: 20%): Whether the extracted entities are precise and accurately named + +Scoring Criteria (0-1 scale for each dimension): +- EXCELLENT (0.8-1.0): High-quality extraction +- GOOD (0.6-0.79): Good quality with minor issues +- ACCEPTABLE (0.4-0.59): Acceptable with noticeable issues +- POOR (0.0-0.39): Poor quality, needs improvement + +Overall Score = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision + +Please evaluate the following: + +Original Text Block: +{chunk_content} + +Extracted Entity List: +{extracted_entities} + +Please return the evaluation result in JSON format: +{{ + "accuracy": , + "completeness": , + "precision": , + "overall_score": , + "accuracy_reasoning": "", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": [""] +}} +""" + +RELATION_EVALUATION_PROMPT_ZH = """你是一个知识图谱质量评估专家。你的任务是从给定的文本块和提取的关系列表,评估关系抽取的质量。 + +评估维度: +1. ACCURACY (准确性, 权重: 40%): 提取的关系是否正确,关系描述是否准确 +2. COMPLETENESS (完整性, 权重: 40%): 是否遗漏了文本中的重要关系 +3. PRECISION (精确性, 权重: 20%): 关系描述是否精确,是否过于宽泛 + +评分标准(每个维度 0-1 分): +- EXCELLENT (0.8-1.0): 高质量提取 +- GOOD (0.6-0.79): 良好质量,有少量问题 +- ACCEPTABLE (0.4-0.59): 可接受,有明显问题 +- POOR (0.0-0.39): 质量差,需要改进 + +综合评分 = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision + +请评估以下内容: + +原始文本块: +{chunk_content} + +提取的关系列表: +{extracted_relations} + +请以 JSON 格式返回评估结果: +{{ + "accuracy": <0-1之间的浮点数>, + "completeness": <0-1之间的浮点数>, + "precision": <0-1之间的浮点数>, + "overall_score": <综合评分>, + "accuracy_reasoning": "<准确性评估理由>", + "completeness_reasoning": "<完整性评估理由,包括遗漏的重要关系>", + "precision_reasoning": "<精确性评估理由>", + "issues": ["<发现的问题列表>"] +}} +""" + +RELATION_EVALUATION_PROMPT_EN = """You are a Knowledge Graph Quality Assessment Expert. \ +Your task is to evaluate the quality of relation extraction from a given text block and extracted relation list. + +Evaluation Dimensions: +1. ACCURACY (Weight: 40%): Whether the extracted relations are correct and the relation descriptions are accurate +2. COMPLETENESS (Weight: 40%): Whether important relations from the text are missing +3. PRECISION (Weight: 20%): Whether the relation descriptions are precise and not overly broad + +Scoring Criteria (0-1 scale for each dimension): +- EXCELLENT (0.8-1.0): High-quality extraction +- GOOD (0.6-0.79): Good quality with minor issues +- ACCEPTABLE (0.4-0.59): Acceptable with noticeable issues +- POOR (0.0-0.39): Poor quality, needs improvement + +Overall Score = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision + +Please evaluate the following: + +Original Text Block: +{chunk_content} + +Extracted Relation List: +{extracted_relations} + +Please return the evaluation result in JSON format: +{{ + "accuracy": , + "completeness": , + "precision": , + "overall_score": , + "accuracy_reasoning": "", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": [""] +}} +""" + +ACCURACY_EVALUATION_PROMPT = { + "zh": { + "ENTITY": ENTITY_EVALUATION_PROMPT_ZH, + "RELATION": RELATION_EVALUATION_PROMPT_ZH, + }, + "en": { + "ENTITY": ENTITY_EVALUATION_PROMPT_EN, + "RELATION": RELATION_EVALUATION_PROMPT_EN, + }, +} diff --git a/graphgen/templates/evaluation/kg/consistency_evaluation.py b/graphgen/templates/evaluation/kg/consistency_evaluation.py new file mode 100644 index 00000000..1600ef94 --- /dev/null +++ b/graphgen/templates/evaluation/kg/consistency_evaluation.py @@ -0,0 +1,102 @@ +ENTITY_TYPE_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中被提取为不同的类型,是否存在语义冲突。 + +实体名称:{entity_name} + +在不同文本块中的类型提取结果: +{type_extractions} + +预设的实体类型列表(供参考): +concept, date, location, keyword, organization, person, event, work, nature, artificial, science, technology, mission, gene + +请判断这些类型是否存在语义冲突(即它们是否描述的是同一类事物,还是存在矛盾)。 +注意:如果类型只是同一概念的不同表述(如 concept 和 keyword),可能不算严重冲突。 + +请以 JSON 格式返回: +{{ + "has_conflict": , + "conflict_severity": <0-1之间的浮点数,0表示无冲突,1表示严重冲突>, + "conflict_reasoning": "<冲突判断的理由>", + "conflicting_types": ["<存在冲突的类型对>"], + "recommended_type": "<如果存在冲突,推荐的正确类型(必须是预设类型之一)>" +}} +""" + +ENTITY_DESCRIPTION_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中的描述是否存在语义冲突。 + +实体名称:{entity_name} + +在不同文本块中的描述: +{descriptions} + +请判断这些描述是否存在语义冲突(即它们是否描述的是同一个实体,还是存在矛盾的信息)。 + +请以 JSON 格式返回: +{{ + "has_conflict": , + "conflict_severity": <0-1之间的浮点数>, + "conflict_reasoning": "<冲突判断的理由>", + "conflicting_descriptions": ["<存在冲突的描述对>"], + "conflict_details": "<具体的冲突内容>" +}} +""" + +RELATION_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一对实体在不同文本块中的关系描述是否存在语义冲突。 + +实体对:{source_entity} -> {target_entity} + +在不同文本块中的关系描述: +{relation_descriptions} + +请判断这些关系描述是否存在语义冲突。 + +请以 JSON 格式返回: +{{ + "has_conflict": , + "conflict_severity": <0-1之间的浮点数>, + "conflict_reasoning": "<冲突判断的理由>", + "conflicting_relations": ["<存在冲突的关系描述对>"] +}} +""" + +ENTITY_EXTRACTION_PROMPT = """从以下文本块中提取指定实体的类型和描述。 + +**重要**:你只需要提取指定的实体,不要提取其他实体。 + +实体名称:{entity_name} + +文本块: +{chunk_content} + +请从文本块中找到并提取**仅此实体**(实体名称:{entity_name})的以下信息: + +1. entity_type: 实体类型,必须是以下预设类型之一(小写): + - concept: 概念 + - date: 日期 + - location: 地点 + - keyword: 关键词 + - organization: 组织 + - person: 人物 + - event: 事件 + - work: 作品/工作 + - nature: 自然 + - artificial: 人工 + - science: 科学 + - technology: 技术 + - mission: 任务 + - gene: 基因 + + 如果无法确定类型,请使用 "concept" 作为默认值。 + +2. description: 实体描述(简要描述该实体在文本中的作用和特征) + +请以 JSON 格式返回: +{{ + "entity_type": "<实体类型(必须是上述预设类型之一)>", + "description": "<实体描述>" +}} +""" + +CONSISTENCY_EVALUATION_PROMPT = { + "en": "", + "zh": "" +} diff --git a/graphgen/utils/help_nltk.py b/graphgen/utils/help_nltk.py index 2d2610ba..c7d5e301 100644 --- a/graphgen/utils/help_nltk.py +++ b/graphgen/utils/help_nltk.py @@ -1,39 +1,61 @@ +from functools import lru_cache import os -from typing import Dict, List, Optional +from typing import Dict, List, Final, Optional +import warnings import nltk import jieba -resource_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "resources") - +warnings.filterwarnings( + "ignore", + category=UserWarning, + module=r"jieba\._compat" +) class NLTKHelper: - _stopwords: Dict[str, Optional[List[str]]] = { - "english": None, - "chinese": None, + """ + NLTK helper class + """ + + SUPPORTED_LANGUAGES: Final[Dict[str, str]] = { + "en": "english", + "zh": "chinese" + } + _NLTK_PACKAGES: Final[Dict[str, str]] = { + "stopwords": "corpora", + "punkt_tab": "tokenizers" } - def __init__(self): + def __init__(self, nltk_data_path: Optional[str] = None): + self._nltk_path = nltk_data_path or os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "resources", + "nltk_data" + ) + nltk.data.path.append(self._nltk_path) jieba.initialize() + self._ensure_nltk_data("stopwords") + self._ensure_nltk_data("punkt_tab") + + def _ensure_nltk_data(self, package_name: str) -> None: + """ + ensure nltk data is downloaded + """ + try: + nltk.data.find(f"{self._NLTK_PACKAGES[package_name]}/{package_name}") + except LookupError: + nltk.download(package_name, download_dir=self._nltk_path, quiet=True) + + @lru_cache(maxsize=2) def get_stopwords(self, lang: str) -> List[str]: - nltk.data.path.append(os.path.join(resource_path, "nltk_data")) - if self._stopwords[lang] is None: - try: - nltk.data.find("corpora/stopwords") - except LookupError: - nltk.download("stopwords", download_dir=os.path.join(resource_path, "nltk_data")) - - self._stopwords[lang] = nltk.corpus.stopwords.words(lang) - return self._stopwords[lang] - - @staticmethod - def word_tokenize(text: str, lang: str) -> List[str]: + if lang not in self.SUPPORTED_LANGUAGES: + raise ValueError(f"Language {lang} is not supported.") + return nltk.corpus.stopwords.words(self.SUPPORTED_LANGUAGES[lang]) + + def word_tokenize(self, text: str, lang: str) -> List[str]: + if lang not in self.SUPPORTED_LANGUAGES: + raise ValueError(f"Language {lang} is not supported.") if lang == "zh": return jieba.lcut(text) - nltk.data.path.append(os.path.join(resource_path, "nltk_data")) - try: - nltk.data.find("tokenizers/punkt_tab") - except LookupError: - nltk.download("punkt_tab", download_dir=os.path.join(resource_path, "nltk_data")) return nltk.word_tokenize(text) diff --git a/requirements.txt b/requirements.txt index 119c95e0..b0eb3966 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,6 @@ aiohttp socksio pydantic ray==2.52.1 -kuzu pyarrow leidenalg @@ -32,9 +31,11 @@ python-louvain # storage rocksdict +kuzu # KG rdflib +scipy # Bioinformatics biopython diff --git a/webui/utils/count_tokens.py b/webui/utils/count_tokens.py index 82b5522c..3016ac5c 100644 --- a/webui/utils/count_tokens.py +++ b/webui/utils/count_tokens.py @@ -7,10 +7,12 @@ # pylint: disable=wrong-import-position root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(root_dir) -from graphgen.models import Tokenizer def count_tokens(file, tokenizer_name, data_frame): + # Lazy import to avoid circular dependency + from graphgen.models import Tokenizer # pylint: disable=import-outside-toplevel + if not file or not os.path.exists(file): return data_frame