diff --git a/.env.example b/.env.example
index c1102c1c..94064796 100644
--- a/.env.example
+++ b/.env.example
@@ -1,7 +1,29 @@
+# Tokenizer
TOKENIZER_MODEL=
-SYNTHESIZER_MODEL=
+
+# LLM
+# Support different backends: http_api, openai_api, ollama_api, ollama, huggingface, tgi, sglang, tensorrt
+
+# http_api / openai_api
+SYNTHESIZER_BACKEND=openai_api
+SYNTHESIZER_MODEL=gpt-4o-mini
SYNTHESIZER_BASE_URL=
SYNTHESIZER_API_KEY=
-TRAINEE_MODEL=
+TRAINEE_BACKEND=openai_api
+TRAINEE_MODEL=gpt-4o-mini
TRAINEE_BASE_URL=
TRAINEE_API_KEY=
+
+# # ollama_api
+# SYNTHESIZER_BACKEND=ollama_api
+# SYNTHESIZER_MODEL=gemma3
+# SYNTHESIZER_BASE_URL=http://localhost:11434
+#
+# Note: TRAINEE with ollama_api backend is not supported yet as ollama_api does not support logprobs.
+
+# # huggingface
+# SYNTHESIZER_BACKEND=huggingface
+# SYNTHESIZER_MODEL=Qwen/Qwen2.5-0.5B-Instruct
+#
+# TRAINEE_BACKEND=huggingface
+# TRAINEE_MODEL=Qwen/Qwen2.5-0.5B-Instruct
diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py
index ace331d5..ed452628 100644
--- a/graphgen/bases/__init__.py
+++ b/graphgen/bases/__init__.py
@@ -1,6 +1,6 @@
from .base_generator import BaseGenerator
from .base_kg_builder import BaseKGBuilder
-from .base_llm_client import BaseLLMClient
+from .base_llm_wrapper import BaseLLMWrapper
from .base_partitioner import BasePartitioner
from .base_reader import BaseReader
from .base_splitter import BaseSplitter
diff --git a/graphgen/bases/base_generator.py b/graphgen/bases/base_generator.py
index d6148cee..85de5877 100644
--- a/graphgen/bases/base_generator.py
+++ b/graphgen/bases/base_generator.py
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any
-from graphgen.bases.base_llm_client import BaseLLMClient
+from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
class BaseGenerator(ABC):
@@ -9,7 +9,7 @@ class BaseGenerator(ABC):
Generate QAs based on given prompts.
"""
- def __init__(self, llm_client: BaseLLMClient):
+ def __init__(self, llm_client: BaseLLMWrapper):
self.llm_client = llm_client
@staticmethod
diff --git a/graphgen/bases/base_kg_builder.py b/graphgen/bases/base_kg_builder.py
index e234d8da..d8a5d66a 100644
--- a/graphgen/bases/base_kg_builder.py
+++ b/graphgen/bases/base_kg_builder.py
@@ -2,13 +2,13 @@
from collections import defaultdict
from typing import Dict, List, Tuple
-from graphgen.bases.base_llm_client import BaseLLMClient
+from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.base_storage import BaseGraphStorage
from graphgen.bases.datatypes import Chunk
class BaseKGBuilder(ABC):
- def __init__(self, llm_client: BaseLLMClient):
+ def __init__(self, llm_client: BaseLLMWrapper):
self.llm_client = llm_client
self._nodes: Dict[str, List[dict]] = defaultdict(list)
self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list)
diff --git a/graphgen/bases/base_llm_client.py b/graphgen/bases/base_llm_wrapper.py
similarity index 91%
rename from graphgen/bases/base_llm_client.py
rename to graphgen/bases/base_llm_wrapper.py
index 1abe5143..5a28fbb4 100644
--- a/graphgen/bases/base_llm_client.py
+++ b/graphgen/bases/base_llm_wrapper.py
@@ -8,7 +8,7 @@
from graphgen.bases.datatypes import Token
-class BaseLLMClient(abc.ABC):
+class BaseLLMWrapper(abc.ABC):
"""
LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...).
"""
@@ -66,3 +66,9 @@ def filter_think_tags(text: str, think_tag: str = "think") -> str:
think_pattern = re.compile(rf"<{think_tag}>.*?{think_tag}>", re.DOTALL)
filtered_text = think_pattern.sub("", text).strip()
return filtered_text if filtered_text else text.strip()
+
+ def shutdown(self) -> None:
+ """Shutdown the LLM engine if applicable."""
+
+ def restart(self) -> None:
+ """Reinitialize the LLM engine if applicable."""
diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py
index 8b0559d6..e8258829 100644
--- a/graphgen/graphgen.py
+++ b/graphgen/graphgen.py
@@ -1,11 +1,11 @@
import asyncio
import os
import time
-from dataclasses import dataclass
from typing import Dict, cast
import gradio as gr
+from graphgen.bases import BaseLLMWrapper
from graphgen.bases.base_storage import StorageNameSpace
from graphgen.bases.datatypes import Chunk
from graphgen.models import (
@@ -20,6 +20,7 @@
build_text_kg,
chunk_documents,
generate_qas,
+ init_llm,
judge_statement,
partition_kg,
quiz,
@@ -31,40 +32,28 @@
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
-@dataclass
class GraphGen:
- unique_id: int = int(time.time())
- working_dir: str = os.path.join(sys_path, "cache")
-
- # llm
- tokenizer_instance: Tokenizer = None
- synthesizer_llm_client: OpenAIClient = None
- trainee_llm_client: OpenAIClient = None
-
- # webui
- progress_bar: gr.Progress = None
-
- def __post_init__(self):
- self.tokenizer_instance: Tokenizer = self.tokenizer_instance or Tokenizer(
+ def __init__(
+ self,
+ unique_id: int = int(time.time()),
+ working_dir: str = os.path.join(sys_path, "cache"),
+ tokenizer_instance: Tokenizer = None,
+ synthesizer_llm_client: OpenAIClient = None,
+ trainee_llm_client: OpenAIClient = None,
+ progress_bar: gr.Progress = None,
+ ):
+ self.unique_id: int = unique_id
+ self.working_dir: str = working_dir
+
+ # llm
+ self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer(
model_name=os.getenv("TOKENIZER_MODEL")
)
- self.synthesizer_llm_client: OpenAIClient = (
- self.synthesizer_llm_client
- or OpenAIClient(
- model_name=os.getenv("SYNTHESIZER_MODEL"),
- api_key=os.getenv("SYNTHESIZER_API_KEY"),
- base_url=os.getenv("SYNTHESIZER_BASE_URL"),
- tokenizer=self.tokenizer_instance,
- )
- )
-
- self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient(
- model_name=os.getenv("TRAINEE_MODEL"),
- api_key=os.getenv("TRAINEE_API_KEY"),
- base_url=os.getenv("TRAINEE_BASE_URL"),
- tokenizer=self.tokenizer_instance,
+ self.synthesizer_llm_client: BaseLLMWrapper = (
+ synthesizer_llm_client or init_llm("synthesizer")
)
+ self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="full_docs"
@@ -86,6 +75,9 @@ def __post_init__(self):
namespace="qa",
)
+ # webui
+ self.progress_bar: gr.Progress = progress_bar
+
@async_to_sync_method
async def insert(self, read_config: Dict, split_config: Dict):
"""
@@ -272,6 +264,12 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
)
# TODO: assert trainee_llm_client is valid before judge
+ if not self.trainee_llm_client:
+ # TODO: shutdown existing synthesizer_llm_client properly
+ logger.info("No trainee LLM client provided, initializing a new one.")
+ self.synthesizer_llm_client.shutdown()
+ self.trainee_llm_client = init_llm("trainee")
+
re_judge = quiz_and_judge_config["re_judge"]
_update_relations = await judge_statement(
self.trainee_llm_client,
@@ -279,9 +277,16 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
self.rephrase_storage,
re_judge,
)
+
await self.rephrase_storage.index_done_callback()
await _update_relations.index_done_callback()
+ logger.info("Shutting down trainee LLM client.")
+ self.trainee_llm_client.shutdown()
+ self.trainee_llm_client = None
+ logger.info("Restarting synthesizer LLM client.")
+ self.synthesizer_llm_client.restart()
+
@async_to_sync_method
async def generate(self, partition_config: Dict, generate_config: Dict):
# Step 1: partition the graph
diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py
index 37476034..08694166 100644
--- a/graphgen/models/__init__.py
+++ b/graphgen/models/__init__.py
@@ -7,8 +7,7 @@
VQAGenerator,
)
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
-from .llm.openai_client import OpenAIClient
-from .llm.topk_token_model import TopkTokenModel
+from .llm import HTTPClient, OllamaClient, OpenAIClient
from .partitioner import (
AnchorBFSPartitioner,
BFSPartitioner,
diff --git a/graphgen/models/kg_builder/light_rag_kg_builder.py b/graphgen/models/kg_builder/light_rag_kg_builder.py
index cde42d27..2d7bff01 100644
--- a/graphgen/models/kg_builder/light_rag_kg_builder.py
+++ b/graphgen/models/kg_builder/light_rag_kg_builder.py
@@ -2,7 +2,7 @@
from collections import Counter, defaultdict
from typing import Dict, List, Tuple
-from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
+from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMWrapper, Chunk
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
from graphgen.utils import (
detect_main_language,
@@ -15,7 +15,7 @@
class LightRAGKGBuilder(BaseKGBuilder):
- def __init__(self, llm_client: BaseLLMClient, max_loop: int = 3):
+ def __init__(self, llm_client: BaseLLMWrapper, max_loop: int = 3):
super().__init__(llm_client)
self.max_loop = max_loop
diff --git a/graphgen/models/llm/__init__.py b/graphgen/models/llm/__init__.py
index e69de29b..c70395d5 100644
--- a/graphgen/models/llm/__init__.py
+++ b/graphgen/models/llm/__init__.py
@@ -0,0 +1,4 @@
+from .api.http_client import HTTPClient
+from .api.ollama_client import OllamaClient
+from .api.openai_client import OpenAIClient
+from .local.hf_wrapper import HuggingFaceWrapper
diff --git a/graphgen/models/llm/api/__init__.py b/graphgen/models/llm/api/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/graphgen/models/llm/api/http_client.py b/graphgen/models/llm/api/http_client.py
new file mode 100644
index 00000000..2c3b0acd
--- /dev/null
+++ b/graphgen/models/llm/api/http_client.py
@@ -0,0 +1,197 @@
+import asyncio
+import math
+from typing import Any, Dict, List, Optional
+
+import aiohttp
+from tenacity import (
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+)
+
+from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
+from graphgen.bases.datatypes import Token
+from graphgen.models.llm.limitter import RPM, TPM
+
+
+class HTTPClient(BaseLLMWrapper):
+ """
+ A generic async HTTP client for LLMs compatible with OpenAI's chat/completions format.
+ It uses aiohttp for making requests and includes retry logic and token usage tracking.
+ Usage example:
+ client = HTTPClient(
+ model_name="gpt-4o-mini",
+ base_url="http://localhost:8080",
+ api_key="your_api_key",
+ json_mode=True,
+ seed=42,
+ topk_per_token=5,
+ request_limit=True,
+ )
+
+ answer = await client.generate_answer("Hello, world!")
+ tokens = await client.generate_topk_per_token("Hello, world!")
+ """
+
+ _instance: Optional["HTTPClient"] = None
+ _lock = asyncio.Lock()
+
+ def __new__(cls, **kwargs):
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ return cls._instance
+
+ def __init__(
+ self,
+ *,
+ model: str,
+ base_url: str,
+ api_key: Optional[str] = None,
+ json_mode: bool = False,
+ seed: Optional[int] = None,
+ topk_per_token: int = 5,
+ request_limit: bool = False,
+ rpm: Optional[RPM] = None,
+ tpm: Optional[TPM] = None,
+ **kwargs: Any,
+ ):
+ # Initialize only once in the singleton pattern
+ if getattr(self, "_initialized", False):
+ return
+ self._initialized: bool = True
+ super().__init__(**kwargs)
+ self.model_name = model
+ self.base_url = base_url.rstrip("/")
+ self.api_key = api_key
+ self.json_mode = json_mode
+ self.seed = seed
+ self.topk_per_token = topk_per_token
+ self.request_limit = request_limit
+ self.rpm = rpm or RPM()
+ self.tpm = tpm or TPM()
+
+ self.token_usage: List[Dict[str, int]] = []
+ self._session: Optional[aiohttp.ClientSession] = None
+
+ @property
+ def session(self) -> aiohttp.ClientSession:
+ if self._session is None or self._session.closed:
+ headers = (
+ {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
+ )
+ self._session = aiohttp.ClientSession(headers=headers)
+ return self._session
+
+ async def close(self):
+ if self._session and not self._session.closed:
+ await self._session.close()
+
+ def _build_body(self, text: str, history: List[str]) -> Dict[str, Any]:
+ messages = []
+ if self.system_prompt:
+ messages.append({"role": "system", "content": self.system_prompt})
+
+ # chatml format: alternating user and assistant messages
+ if history and isinstance(history[0], dict):
+ messages.extend(history)
+
+ messages.append({"role": "user", "content": text})
+
+ body = {
+ "model": self.model_name,
+ "messages": messages,
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ "max_tokens": self.max_tokens,
+ }
+ if self.seed:
+ body["seed"] = self.seed
+ if self.json_mode:
+ body["response_format"] = {"type": "json_object"}
+ return body
+
+ @retry(
+ stop=stop_after_attempt(5),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)),
+ )
+ async def generate_answer(
+ self,
+ text: str,
+ history: Optional[List[str]] = None,
+ **extra: Any,
+ ) -> str:
+ body = self._build_body(text, history or [])
+ prompt_tokens = sum(
+ len(self.tokenizer.encode(m["content"])) for m in body["messages"]
+ )
+ est = prompt_tokens + body["max_tokens"]
+
+ if self.request_limit:
+ await self.rpm.wait(silent=True)
+ await self.tpm.wait(est, silent=True)
+
+ async with self.session.post(
+ f"{self.base_url}/chat/completions",
+ json=body,
+ timeout=aiohttp.ClientTimeout(total=60),
+ ) as resp:
+ resp.raise_for_status()
+ data = await resp.json()
+
+ msg = data["choices"][0]["message"]["content"]
+ if "usage" in data:
+ self.token_usage.append(
+ {
+ "prompt_tokens": data["usage"]["prompt_tokens"],
+ "completion_tokens": data["usage"]["completion_tokens"],
+ "total_tokens": data["usage"]["total_tokens"],
+ }
+ )
+ return self.filter_think_tags(msg)
+
+ @retry(
+ stop=stop_after_attempt(5),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)),
+ )
+ async def generate_topk_per_token(
+ self,
+ text: str,
+ history: Optional[List[str]] = None,
+ **extra: Any,
+ ) -> List[Token]:
+ body = self._build_body(text, history or [])
+ body["max_tokens"] = 1
+ if self.topk_per_token > 0:
+ body["logprobs"] = True
+ body["top_logprobs"] = self.topk_per_token
+
+ async with self.session.post(
+ f"{self.base_url}/chat/completions",
+ json=body,
+ timeout=aiohttp.ClientTimeout(total=60),
+ ) as resp:
+ resp.raise_for_status()
+ data = await resp.json()
+
+ token_logprobs = data["choices"][0]["logprobs"]["content"]
+ tokens = []
+ for item in token_logprobs:
+ candidates = [
+ Token(t["token"], math.exp(t["logprob"])) for t in item["top_logprobs"]
+ ]
+ tokens.append(
+ Token(
+ item["token"], math.exp(item["logprob"]), top_candidates=candidates
+ )
+ )
+ return tokens
+
+ async def generate_inputs_prob(
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
+ ) -> List[Token]:
+ raise NotImplementedError(
+ "generate_inputs_prob is not implemented in HTTPClient"
+ )
diff --git a/graphgen/models/llm/api/ollama_client.py b/graphgen/models/llm/api/ollama_client.py
new file mode 100644
index 00000000..9a4946a6
--- /dev/null
+++ b/graphgen/models/llm/api/ollama_client.py
@@ -0,0 +1,105 @@
+from typing import Any, Dict, List, Optional
+
+from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
+from graphgen.bases.datatypes import Token
+from graphgen.models.llm.limitter import RPM, TPM
+
+
+class OllamaClient(BaseLLMWrapper):
+ """
+ Requires a local or remote Ollama server to be running (default port 11434).
+ The top_logprobs field is not yet implemented by the official API.
+ """
+
+ def __init__(
+ self,
+ *,
+ model: str = "gemma3",
+ base_url: str = "http://localhost:11434",
+ json_mode: bool = False,
+ seed: Optional[int] = None,
+ topk_per_token: int = 5,
+ request_limit: bool = False,
+ rpm: Optional[RPM] = None,
+ tpm: Optional[TPM] = None,
+ **kwargs: Any,
+ ):
+ try:
+ import ollama
+ except ImportError as e:
+ raise ImportError(
+ "Ollama SDK is not installed."
+ "It is required to use OllamaClient."
+ "Please install it with `pip install ollama`."
+ ) from e
+ super().__init__(**kwargs)
+ self.model_name = model
+ self.base_url = base_url
+ self.json_mode = json_mode
+ self.seed = seed
+ self.topk_per_token = topk_per_token
+ self.request_limit = request_limit
+ self.rpm = rpm or RPM()
+ self.tpm = tpm or TPM()
+ self.token_usage: List[Dict[str, int]] = []
+
+ self.client = ollama.AsyncClient(host=self.base_url)
+
+ async def generate_answer(
+ self,
+ text: str,
+ history: Optional[List[Dict[str, str]]] = None,
+ **extra: Any,
+ ) -> str:
+ messages = []
+ if self.system_prompt:
+ messages.append({"role": "system", "content": self.system_prompt})
+ if history:
+ messages.extend(history)
+ messages.append({"role": "user", "content": text})
+
+ options = {
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ "num_predict": self.max_tokens,
+ }
+ if self.seed is not None:
+ options["seed"] = self.seed
+
+ prompt_tokens = sum(len(self.tokenizer.encode(m["content"])) for m in messages)
+ est = prompt_tokens + self.max_tokens
+ if self.request_limit:
+ await self.rpm.wait(silent=True)
+ await self.tpm.wait(est, silent=True)
+
+ response = await self.client.chat(
+ model=self.model_name,
+ messages=messages,
+ format="json" if self.json_mode else "",
+ options=options,
+ stream=False,
+ )
+
+ usage = response.get("prompt_eval_count", 0), response.get("eval_count", 0)
+ self.token_usage.append(
+ {
+ "prompt_tokens": usage[0],
+ "completion_tokens": usage[1],
+ "total_tokens": sum(usage),
+ }
+ )
+ content = response["message"]["content"]
+ return self.filter_think_tags(content)
+
+ async def generate_topk_per_token(
+ self,
+ text: str,
+ history: Optional[List[Dict[str, str]]] = None,
+ **extra: Any,
+ ) -> List[Token]:
+ raise NotImplementedError("Ollama API does not support per-token top-k yet.")
+
+ async def generate_inputs_prob(
+ self, text: str, history: Optional[List[Dict[str, str]]] = None, **extra: Any
+ ) -> List[Token]:
+ raise NotImplementedError("Ollama API does not support per-token logprobs yet.")
diff --git a/graphgen/models/llm/openai_client.py b/graphgen/models/llm/api/openai_client.py
similarity index 96%
rename from graphgen/models/llm/openai_client.py
rename to graphgen/models/llm/api/openai_client.py
index 34316937..5f9c131a 100644
--- a/graphgen/models/llm/openai_client.py
+++ b/graphgen/models/llm/api/openai_client.py
@@ -10,7 +10,7 @@
wait_exponential,
)
-from graphgen.bases.base_llm_client import BaseLLMClient
+from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token
from graphgen.models.llm.limitter import RPM, TPM
@@ -28,7 +28,7 @@ def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]:
return tokens
-class OpenAIClient(BaseLLMClient):
+class OpenAIClient(BaseLLMWrapper):
def __init__(
self,
*,
@@ -105,8 +105,8 @@ async def generate_topk_per_token(
kwargs["logprobs"] = True
kwargs["top_logprobs"] = self.topk_per_token
- # Limit max_tokens to 5 to avoid long completions
- kwargs["max_tokens"] = 5
+ # Limit max_tokens to 1 to avoid long completions
+ kwargs["max_tokens"] = 1
completion = await self.client.chat.completions.create( # pylint: disable=E1125
model=self.model_name, **kwargs
diff --git a/graphgen/models/llm/local/__init__.py b/graphgen/models/llm/local/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/graphgen/models/llm/local/hf_wrapper.py b/graphgen/models/llm/local/hf_wrapper.py
new file mode 100644
index 00000000..b0538aad
--- /dev/null
+++ b/graphgen/models/llm/local/hf_wrapper.py
@@ -0,0 +1,147 @@
+from typing import Any, List, Optional
+
+from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
+from graphgen.bases.datatypes import Token
+
+
+class HuggingFaceWrapper(BaseLLMWrapper):
+ """
+ Async inference backend based on HuggingFace Transformers
+ """
+
+ def __init__(
+ self,
+ model: str,
+ torch_dtype="auto",
+ device_map="auto",
+ trust_remote_code=True,
+ temperature=0.0,
+ top_p=1.0,
+ topk=5,
+ **kwargs: Any,
+ ):
+ super().__init__(temperature=temperature, top_p=top_p, **kwargs)
+
+ try:
+ import torch
+ from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ GenerationConfig,
+ )
+ except ImportError as exc:
+ raise ImportError(
+ "HuggingFaceWrapper requires torch, transformers and accelerate. "
+ "Install them with: pip install torch transformers accelerate"
+ ) from exc
+
+ self.torch = torch
+ self.AutoTokenizer = AutoTokenizer
+ self.AutoModelForCausalLM = AutoModelForCausalLM
+ self.GenerationConfig = GenerationConfig
+
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model, trust_remote_code=trust_remote_code
+ )
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model,
+ torch_dtype=torch_dtype,
+ device_map=device_map,
+ trust_remote_code=trust_remote_code,
+ )
+ self.model.eval()
+ self.temperature = temperature
+ self.top_p = top_p
+ self.topk = topk
+
+ @staticmethod
+ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
+ msgs = history or []
+ lines = []
+ for m in msgs:
+ if isinstance(m, dict):
+ role = m.get("role", "")
+ content = m.get("content", "")
+ lines.append(f"{role}: {content}")
+ else:
+ lines.append(str(m))
+ lines.append(prompt)
+ return "\n".join(lines)
+
+ async def generate_answer(
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
+ ) -> str:
+ full = self._build_inputs(text, history)
+ inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device)
+
+ gen_kwargs = {
+ "max_new_tokens": extra.get("max_new_tokens", 512),
+ "do_sample": self.temperature > 0,
+ "temperature": self.temperature if self.temperature > 0 else 1.0,
+ "pad_token_id": self.tokenizer.eos_token_id,
+ }
+
+ # Add top_p and top_k only if temperature > 0
+ if self.temperature > 0:
+ gen_kwargs.update(top_p=self.top_p, top_k=self.topk)
+
+ gen_config = self.GenerationConfig(**gen_kwargs)
+
+ with self.torch.no_grad():
+ out = self.model.generate(**inputs, generation_config=gen_config)
+
+ gen = out[0, inputs.input_ids.shape[-1] :]
+ return self.tokenizer.decode(gen, skip_special_tokens=True)
+
+ async def generate_topk_per_token(
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
+ ) -> List[Token]:
+ full = self._build_inputs(text, history)
+ inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device)
+
+ with self.torch.no_grad():
+ out = self.model.generate(
+ **inputs,
+ max_new_tokens=1,
+ do_sample=False,
+ temperature=1.0,
+ return_dict_in_generate=True,
+ output_scores=True,
+ pad_token_id=self.tokenizer.eos_token_id,
+ )
+
+ scores = out.scores[0][0] # (vocab,)
+ probs = self.torch.softmax(scores, dim=-1)
+ top_probs, top_idx = self.torch.topk(probs, k=self.topk)
+
+ tokens = []
+ for p, idx in zip(top_probs.cpu().numpy(), top_idx.cpu().numpy()):
+ tokens.append(Token(self.tokenizer.decode([idx]), float(p)))
+ return tokens
+
+ async def generate_inputs_prob(
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
+ ) -> List[Token]:
+ full = self._build_inputs(text, history)
+ ids = self.tokenizer.encode(full)
+ logprobs = []
+
+ for i in range(1, len(ids) + 1):
+ trunc = ids[: i - 1] + ids[i:] if i < len(ids) else ids[:-1]
+ inputs = self.torch.tensor([trunc]).to(self.model.device)
+
+ with self.torch.no_grad():
+ logits = self.model(inputs).logits[0, -1, :]
+ probs = self.torch.softmax(logits, dim=-1)
+
+ true_id = ids[i - 1]
+ logprobs.append(
+ Token(
+ self.tokenizer.decode([true_id]),
+ float(probs[true_id].cpu()),
+ )
+ )
+ return logprobs
diff --git a/graphgen/models/llm/local/sglang_wrapper.py b/graphgen/models/llm/local/sglang_wrapper.py
new file mode 100644
index 00000000..3781af2c
--- /dev/null
+++ b/graphgen/models/llm/local/sglang_wrapper.py
@@ -0,0 +1,148 @@
+import math
+from typing import Any, Dict, List, Optional
+
+from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
+from graphgen.bases.datatypes import Token
+
+
+class SGLangWrapper(BaseLLMWrapper):
+ """
+ Async inference backend based on SGLang offline engine.
+ """
+
+ def __init__(
+ self,
+ model: str,
+ temperature: float = 0.0,
+ top_p: float = 1.0,
+ topk: int = 5,
+ **kwargs: Any,
+ ):
+ super().__init__(temperature=temperature, top_p=top_p, **kwargs)
+ try:
+ import sglang as sgl
+ from sglang.utils import async_stream_and_merge, stream_and_merge
+ except ImportError as exc:
+ raise ImportError(
+ "SGLangWrapper requires sglang. Install it with: "
+ "uv pip install sglang --prerelease=allow"
+ ) from exc
+
+ self.model_path: str = model
+ self.temperature = temperature
+ self.top_p = top_p
+ self.topk = topk
+
+ # Initialise the offline engine
+ self.engine = sgl.Engine(model_path=self.model_path)
+
+ # Keep helpers for streaming
+ self.async_stream_and_merge = async_stream_and_merge
+ self.stream_and_merge = stream_and_merge
+
+ @staticmethod
+ def _build_sampling_params(
+ temperature: float,
+ top_p: float,
+ max_tokens: int,
+ topk: int,
+ logprobs: bool = False,
+ ) -> Dict[str, Any]:
+ """Build SGLang-compatible sampling-params dict."""
+ params = {
+ "temperature": temperature,
+ "top_p": top_p,
+ "max_new_tokens": max_tokens,
+ }
+ if logprobs and topk > 0:
+ params["logprobs"] = topk
+ return params
+
+ def _prep_prompt(self, text: str, history: Optional[List[dict]] = None) -> str:
+ """Convert raw text (+ optional history) into a single prompt string."""
+ parts = []
+ if self.system_prompt:
+ parts.append(self.system_prompt)
+ if history:
+ assert len(history) % 2 == 0, "History must have even length (u/a turns)."
+ parts.extend([item["content"] for item in history])
+ parts.append(text)
+ return "\n".join(parts)
+
+ def _tokens_from_output(self, output: Dict[str, Any]) -> List[Token]:
+ tokens: List[Token] = []
+
+ meta = output.get("meta_info", {})
+ logprobs = meta.get("output_token_logprobs", [])
+ topks = meta.get("output_top_logprobs", [])
+
+ tokenizer = self.engine.tokenizer_manager.tokenizer
+
+ for idx, (lp, tid, _) in enumerate(logprobs):
+ prob = math.exp(lp)
+ tok_str = tokenizer.decode([tid])
+
+ top_candidates = []
+ if self.topk > 0 and idx < len(topks):
+ for t_lp, t_tid, _ in topks[idx][: self.topk]:
+ top_candidates.append(
+ Token(text=tokenizer.decode([t_tid]), prob=math.exp(t_lp))
+ )
+
+ tokens.append(Token(text=tok_str, prob=prob, top_candidates=top_candidates))
+
+ return tokens
+
+ async def generate_answer(
+ self,
+ text: str,
+ history: Optional[List[str]] = None,
+ **extra: Any,
+ ) -> str:
+ prompt = self._prep_prompt(text, history)
+ sampling_params = self._build_sampling_params(
+ temperature=self.temperature,
+ top_p=self.top_p,
+ max_tokens=self.max_tokens,
+ topk=0, # no logprobs needed for simple generation
+ )
+
+ outputs = await self.engine.async_generate([prompt], sampling_params)
+ return self.filter_think_tags(outputs[0]["text"])
+
+ async def generate_topk_per_token(
+ self,
+ text: str,
+ history: Optional[List[str]] = None,
+ **extra: Any,
+ ) -> List[Token]:
+ prompt = self._prep_prompt(text, history)
+ sampling_params = self._build_sampling_params(
+ temperature=self.temperature,
+ top_p=self.top_p,
+ max_tokens=1, # keep short for token-level analysis
+ topk=self.topk,
+ )
+
+ outputs = await self.engine.async_generate(
+ [prompt], sampling_params, return_logprob=True, top_logprobs_num=5
+ )
+ print(outputs)
+ return self._tokens_from_output(outputs[0])
+
+ async def generate_inputs_prob(
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
+ ) -> List[Token]:
+ raise NotImplementedError(
+ "SGLangWrapper does not support per-token logprobs yet."
+ )
+
+ def shutdown(self) -> None:
+ """Gracefully shutdown the SGLang engine."""
+ if hasattr(self, "engine"):
+ self.engine.shutdown()
+
+ def restart(self) -> None:
+ """Restart the SGLang engine."""
+ self.shutdown()
+ self.engine = self.engine.__class__(model_path=self.model_path)
diff --git a/graphgen/models/llm/local/tgi_wrapper.py b/graphgen/models/llm/local/tgi_wrapper.py
new file mode 100644
index 00000000..a722f6ea
--- /dev/null
+++ b/graphgen/models/llm/local/tgi_wrapper.py
@@ -0,0 +1,36 @@
+from typing import Any, List, Optional
+
+from graphgen.bases import BaseLLMWrapper
+from graphgen.bases.datatypes import Token
+
+
+# TODO: implement TGIWrapper methods
+class TGIWrapper(BaseLLMWrapper):
+ """
+ Async inference backend based on TGI (Text-Generation-Inference)
+ """
+
+ def __init__(
+ self,
+ model_url: str, # e.g. "http://localhost:8080"
+ temperature: float = 0.0,
+ top_p: float = 1.0,
+ topk: int = 5,
+ **kwargs: Any
+ ):
+ super().__init__(temperature=temperature, top_p=top_p, **kwargs)
+
+ async def generate_answer(
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
+ ) -> str:
+ pass
+
+ async def generate_topk_per_token(
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
+ ) -> List[Token]:
+ pass
+
+ async def generate_inputs_prob(
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
+ ) -> List[Token]:
+ pass
diff --git a/graphgen/models/llm/ollama_client.py b/graphgen/models/llm/local/trt_wrapper.py
similarity index 66%
rename from graphgen/models/llm/ollama_client.py
rename to graphgen/models/llm/local/trt_wrapper.py
index 5d6e5d20..078f5ba9 100644
--- a/graphgen/models/llm/ollama_client.py
+++ b/graphgen/models/llm/local/trt_wrapper.py
@@ -1,10 +1,15 @@
-# TODO: implement ollama client
from typing import Any, List, Optional
-from graphgen.bases import BaseLLMClient, Token
+from graphgen.bases import BaseLLMWrapper
+from graphgen.bases.datatypes import Token
-class OllamaClient(BaseLLMClient):
+# TODO: implement TensorRTWrapper methods
+class TensorRTWrapper(BaseLLMWrapper):
+ """
+ Async inference backend based on TensorRT-LLM
+ """
+
async def generate_answer(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> str:
diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py
new file mode 100644
index 00000000..d3f6cfcc
--- /dev/null
+++ b/graphgen/models/llm/local/vllm_wrapper.py
@@ -0,0 +1,137 @@
+from typing import Any, List, Optional
+
+from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
+from graphgen.bases.datatypes import Token
+
+
+class VLLMWrapper(BaseLLMWrapper):
+ """
+ Async inference backend based on vLLM (https://github.com/vllm-project/vllm)
+ """
+
+ def __init__(
+ self,
+ model: str,
+ tensor_parallel_size: int = 1,
+ gpu_memory_utilization: float = 0.9,
+ temperature: float = 0.0,
+ top_p: float = 1.0,
+ topk: int = 5,
+ **kwargs: Any,
+ ):
+ super().__init__(temperature=temperature, top_p=top_p, **kwargs)
+
+ try:
+ from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
+ except ImportError as exc:
+ raise ImportError(
+ "VLLMWrapper requires vllm. Install it with: uv pip install vllm --torch-backend=auto"
+ ) from exc
+
+ self.SamplingParams = SamplingParams
+
+ engine_args = AsyncEngineArgs(
+ model=model,
+ tensor_parallel_size=tensor_parallel_size,
+ gpu_memory_utilization=gpu_memory_utilization,
+ trust_remote_code=kwargs.get("trust_remote_code", True),
+ )
+ self.engine = AsyncLLMEngine.from_engine_args(engine_args)
+
+ self.temperature = temperature
+ self.top_p = top_p
+ self.topk = topk
+
+ @staticmethod
+ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
+ msgs = history or []
+ lines = []
+ for m in msgs:
+ if isinstance(m, dict):
+ role = m.get("role", "")
+ content = m.get("content", "")
+ lines.append(f"{role}: {content}")
+ else:
+ lines.append(str(m))
+ lines.append(prompt)
+ return "\n".join(lines)
+
+ async def generate_answer(
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
+ ) -> str:
+ full_prompt = self._build_inputs(text, history)
+
+ sp = self.SamplingParams(
+ temperature=self.temperature if self.temperature > 0 else 1.0,
+ top_p=self.top_p if self.temperature > 0 else 1.0,
+ max_tokens=extra.get("max_new_tokens", 512),
+ )
+
+ results = []
+ async for req_output in self.engine.generate(
+ full_prompt, sp, request_id="graphgen_req"
+ ):
+ results = req_output.outputs
+ return results[-1].text
+
+ async def generate_topk_per_token(
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
+ ) -> List[Token]:
+ full_prompt = self._build_inputs(text, history)
+
+ sp = self.SamplingParams(
+ temperature=0,
+ max_tokens=1,
+ logprobs=self.topk,
+ )
+
+ results = []
+ async for req_output in self.engine.generate(
+ full_prompt, sp, request_id="graphgen_topk"
+ ):
+ results = req_output.outputs
+ top_logprobs = results[-1].logprobs[0]
+
+ tokens = []
+ for _, logprob_obj in top_logprobs.items():
+ tok_str = logprob_obj.decoded_token
+ prob = float(logprob_obj.logprob.exp())
+ tokens.append(Token(tok_str, prob))
+ tokens.sort(key=lambda x: -x.prob)
+ return tokens
+
+ async def generate_inputs_prob(
+ self, text: str, history: Optional[List[str]] = None, **extra: Any
+ ) -> List[Token]:
+ full_prompt = self._build_inputs(text, history)
+
+ # vLLM 没有现成的“mask 一个 token 再算 prob”接口,
+ # 我们采用最直观的方式:把 prompt 一次性送进去,打开
+ # prompt_logprobs=True,让 vLLM 返回 *输入部分* 每个位置的
+ # logprob,然后挑出对应 token 的概率即可。
+ sp = self.SamplingParams(
+ temperature=0,
+ max_tokens=0, # 不生成新 token
+ prompt_logprobs=1, # 只要 top-1 就够了
+ )
+
+ results = []
+ async for req_output in self.engine.generate(
+ full_prompt, sp, request_id="graphgen_prob"
+ ):
+ results = req_output.outputs
+
+ # prompt_logprobs 是一个 list,长度 = prompt token 数,
+ # 每个元素是 dict{token_id: logprob_obj} 或 None(首个位置为 None)
+ prompt_logprobs = results[-1].prompt_logprobs
+
+ tokens = []
+ for _, logprob_dict in enumerate(prompt_logprobs):
+ if logprob_dict is None:
+ continue
+ # 这里每个 dict 只有 1 个 kv,因为 top-1
+ _, logprob_obj = next(iter(logprob_dict.items()))
+ tok_str = logprob_obj.decoded_token
+ prob = float(logprob_obj.logprob.exp())
+ tokens.append(Token(tok_str, prob))
+ return tokens
diff --git a/graphgen/models/llm/topk_token_model.py b/graphgen/models/llm/topk_token_model.py
deleted file mode 100644
index e93ca01a..00000000
--- a/graphgen/models/llm/topk_token_model.py
+++ /dev/null
@@ -1,53 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import List, Optional
-
-from graphgen.bases import Token
-
-
-class TopkTokenModel(ABC):
- def __init__(
- self,
- do_sample: bool = False,
- temperature: float = 0,
- max_tokens: int = 4096,
- repetition_penalty: float = 1.05,
- num_beams: int = 1,
- topk: int = 50,
- topp: float = 0.95,
- topk_per_token: int = 5,
- ):
- self.do_sample = do_sample
- self.temperature = temperature
- self.max_tokens = max_tokens
- self.repetition_penalty = repetition_penalty
- self.num_beams = num_beams
- self.topk = topk
- self.topp = topp
- self.topk_per_token = topk_per_token
-
- @abstractmethod
- async def generate_topk_per_token(self, text: str) -> List[Token]:
- """
- Generate prob, text and candidates for each token of the model's output.
- This function is used to visualize the inference process.
- """
- raise NotImplementedError
-
- @abstractmethod
- async def generate_inputs_prob(
- self, text: str, history: Optional[List[str]] = None
- ) -> List[Token]:
- """
- Generate prob and text for each token of the input text.
- This function is used to visualize the ppl.
- """
- raise NotImplementedError
-
- @abstractmethod
- async def generate_answer(
- self, text: str, history: Optional[List[str]] = None
- ) -> str:
- """
- Generate answer from the model.
- """
- raise NotImplementedError
diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py
index 2ad37e63..3e8e7ba9 100644
--- a/graphgen/operators/__init__.py
+++ b/graphgen/operators/__init__.py
@@ -1,5 +1,6 @@
from .build_kg import build_mm_kg, build_text_kg
from .generate import generate_qas
+from .init import init_llm
from .judge import judge_statement
from .partition import partition_kg
from .quiz import quiz
diff --git a/graphgen/operators/build_kg/build_mm_kg.py b/graphgen/operators/build_kg/build_mm_kg.py
index 9301c2b9..624b10ad 100644
--- a/graphgen/operators/build_kg/build_mm_kg.py
+++ b/graphgen/operators/build_kg/build_mm_kg.py
@@ -3,14 +3,15 @@
import gradio as gr
+from graphgen.bases import BaseLLMWrapper
from graphgen.bases.base_storage import BaseGraphStorage
from graphgen.bases.datatypes import Chunk
-from graphgen.models import MMKGBuilder, OpenAIClient
+from graphgen.models import MMKGBuilder
from graphgen.utils import run_concurrent
async def build_mm_kg(
- llm_client: OpenAIClient,
+ llm_client: BaseLLMWrapper,
kg_instance: BaseGraphStorage,
chunks: List[Chunk],
progress_bar: gr.Progress = None,
diff --git a/graphgen/operators/build_kg/build_text_kg.py b/graphgen/operators/build_kg/build_text_kg.py
index 3babe2e5..3c75f022 100644
--- a/graphgen/operators/build_kg/build_text_kg.py
+++ b/graphgen/operators/build_kg/build_text_kg.py
@@ -3,14 +3,15 @@
import gradio as gr
+from graphgen.bases import BaseLLMWrapper
from graphgen.bases.base_storage import BaseGraphStorage
from graphgen.bases.datatypes import Chunk
-from graphgen.models import LightRAGKGBuilder, OpenAIClient
+from graphgen.models import LightRAGKGBuilder
from graphgen.utils import run_concurrent
async def build_text_kg(
- llm_client: OpenAIClient,
+ llm_client: BaseLLMWrapper,
kg_instance: BaseGraphStorage,
chunks: List[Chunk],
progress_bar: gr.Progress = None,
diff --git a/graphgen/operators/generate/generate_qas.py b/graphgen/operators/generate/generate_qas.py
index feadafc1..875e3bab 100644
--- a/graphgen/operators/generate/generate_qas.py
+++ b/graphgen/operators/generate/generate_qas.py
@@ -1,6 +1,6 @@
from typing import Any
-from graphgen.bases import BaseLLMClient
+from graphgen.bases import BaseLLMWrapper
from graphgen.models import (
AggregatedGenerator,
AtomicGenerator,
@@ -12,7 +12,7 @@
async def generate_qas(
- llm_client: BaseLLMClient,
+ llm_client: BaseLLMWrapper,
batches: list[
tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
diff --git a/graphgen/operators/init/__init__.py b/graphgen/operators/init/__init__.py
new file mode 100644
index 00000000..ec604441
--- /dev/null
+++ b/graphgen/operators/init/__init__.py
@@ -0,0 +1 @@
+from .init_llm import init_llm
diff --git a/graphgen/operators/init/init_llm.py b/graphgen/operators/init/init_llm.py
new file mode 100644
index 00000000..f7e33356
--- /dev/null
+++ b/graphgen/operators/init/init_llm.py
@@ -0,0 +1,84 @@
+import os
+from typing import Any, Dict, Optional
+
+from graphgen.bases import BaseLLMWrapper
+from graphgen.models import Tokenizer
+
+
+class LLMFactory:
+ """
+ A factory class to create LLM wrapper instances based on the specified backend.
+ Supported backends include:
+ - http_api: HTTPClient
+ - openai_api: OpenAIClient
+ - ollama_api: OllamaClient
+ - ollama: OllamaWrapper
+ - deepspeed: DeepSpeedWrapper
+ - huggingface: HuggingFaceWrapper
+ - tgi: TGIWrapper
+ - sglang: SGLangWrapper
+ - tensorrt: TensorRTWrapper
+ """
+
+ @staticmethod
+ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
+ # add tokenizer
+ tokenizer: Tokenizer = Tokenizer(
+ os.environ.get("TOKENIZER_MODEL", "cl100k_base"),
+ )
+ config["tokenizer"] = tokenizer
+ if backend == "http_api":
+ from graphgen.models.llm.api.http_client import HTTPClient
+
+ return HTTPClient(**config)
+ if backend == "openai_api":
+ from graphgen.models.llm.api.openai_client import OpenAIClient
+
+ return OpenAIClient(**config)
+ if backend == "ollama_api":
+ from graphgen.models.llm.api.ollama_client import OllamaClient
+
+ return OllamaClient(**config)
+ if backend == "huggingface":
+ from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper
+
+ return HuggingFaceWrapper(**config)
+ # if backend == "sglang":
+ # from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper
+ #
+ # return SGLangWrapper(**config)
+
+ if backend == "vllm":
+ from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper
+
+ return VLLMWrapper(**config)
+
+ raise NotImplementedError(f"Backend {backend} is not implemented yet.")
+
+
+def _load_env_group(prefix: str) -> Dict[str, Any]:
+ """
+ Collect environment variables with the given prefix into a dictionary,
+ stripping the prefix from the keys.
+ """
+ return {
+ k[len(prefix) :].lower(): v
+ for k, v in os.environ.items()
+ if k.startswith(prefix)
+ }
+
+
+def init_llm(model_type: str) -> Optional[BaseLLMWrapper]:
+ if model_type == "synthesizer":
+ prefix = "SYNTHESIZER_"
+ elif model_type == "trainee":
+ prefix = "TRAINEE_"
+ else:
+ raise NotImplementedError(f"Model type {model_type} is not implemented yet.")
+ config = _load_env_group(prefix)
+ # if config is empty, return None
+ if not config:
+ return None
+ backend = config.pop("backend")
+ llm_wrapper = LLMFactory.create_llm_wrapper(backend, config)
+ return llm_wrapper
diff --git a/graphgen/operators/judge.py b/graphgen/operators/judge.py
index f7b0b963..d291d29a 100644
--- a/graphgen/operators/judge.py
+++ b/graphgen/operators/judge.py
@@ -3,13 +3,14 @@
from tqdm.asyncio import tqdm as tqdm_async
-from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient
+from graphgen.bases import BaseLLMWrapper
+from graphgen.models import JsonKVStorage, NetworkXStorage
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
from graphgen.utils import logger, yes_no_loss_entropy
async def judge_statement( # pylint: disable=too-many-statements
- trainee_llm_client: OpenAIClient,
+ trainee_llm_client: BaseLLMWrapper,
graph_storage: NetworkXStorage,
rephrase_storage: JsonKVStorage,
re_judge: bool = False,
diff --git a/graphgen/operators/quiz.py b/graphgen/operators/quiz.py
index a8623bfb..cd86ef2d 100644
--- a/graphgen/operators/quiz.py
+++ b/graphgen/operators/quiz.py
@@ -3,13 +3,14 @@
from tqdm.asyncio import tqdm as tqdm_async
-from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient
+from graphgen.bases import BaseLLMWrapper
+from graphgen.models import JsonKVStorage, NetworkXStorage
from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
from graphgen.utils import detect_main_language, logger
async def quiz(
- synth_llm_client: OpenAIClient,
+ synth_llm_client: BaseLLMWrapper,
graph_storage: NetworkXStorage,
rephrase_storage: JsonKVStorage,
max_samples: int = 1,
diff --git a/requirements.txt b/requirements.txt
index 44b3687a..82740f03 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -25,4 +25,4 @@ igraph
python-louvain
# For visualization
-matplotlib
\ No newline at end of file
+matplotlib
diff --git a/tests/integration_tests/models/llm/api/test_http_client.py b/tests/integration_tests/models/llm/api/test_http_client.py
new file mode 100644
index 00000000..d2996d1c
--- /dev/null
+++ b/tests/integration_tests/models/llm/api/test_http_client.py
@@ -0,0 +1,143 @@
+# pylint: disable=protected-access
+import math
+
+import pytest
+
+from graphgen.models.llm.api.http_client import HTTPClient
+
+
+class DummyTokenizer:
+ def encode(self, text: str):
+ # simple tokenization: split on spaces
+ return text.split()
+
+
+class _MockResponse:
+ def __init__(self, data):
+ self._data = data
+
+ def raise_for_status(self):
+ return None
+
+ async def json(self):
+ return self._data
+
+
+class _PostCtx:
+ def __init__(self, data):
+ self._resp = _MockResponse(data)
+
+ async def __aenter__(self):
+ return self._resp
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return False
+
+
+class MockSession:
+ def __init__(self, data):
+ self._data = data
+ self.closed = False
+
+ def post(self, *args, **kwargs):
+ return _PostCtx(self._data)
+
+ async def close(self):
+ self.closed = True
+
+
+class DummyLimiter:
+ def __init__(self):
+ self.calls = []
+
+ async def wait(self, *args, **kwargs):
+ self.calls.append((args, kwargs))
+
+
+@pytest.mark.asyncio
+async def test_generate_answer_records_usage_and_uses_limiters():
+ # arrange
+ data = {
+ "choices": [{"message": {"content": "Hello world!"}}],
+ "usage": {"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5},
+ }
+ client = HTTPClient(model="m", base_url="http://test")
+ client._session = MockSession(data)
+ client.tokenizer = DummyTokenizer()
+ client.system_prompt = "sys"
+ client.temperature = 0.0
+ client.top_p = 1.0
+ client.max_tokens = 10
+ client.filter_think_tags = lambda s: s.replace("", "").replace(
+ "", ""
+ )
+ rpm = DummyLimiter()
+ tpm = DummyLimiter()
+ client.rpm = rpm
+ client.tpm = tpm
+ client.request_limit = True
+
+ # act
+ out = await client.generate_answer("hi", history=["u1", "a1"])
+
+ # assert
+ assert out == "Hello world!"
+ assert client.token_usage[-1] == {
+ "prompt_tokens": 3,
+ "completion_tokens": 2,
+ "total_tokens": 5,
+ }
+ assert len(rpm.calls) == 1
+ assert len(tpm.calls) == 1
+
+
+@pytest.mark.asyncio
+async def test_generate_topk_per_token_parses_logprobs():
+ # arrange
+ # create two token items with top_logprobs
+ data = {
+ "choices": [
+ {
+ "logprobs": {
+ "content": [
+ {
+ "token": "A",
+ "logprob": math.log(0.6),
+ "top_logprobs": [
+ {"token": "A", "logprob": math.log(0.6)},
+ {"token": "B", "logprob": math.log(0.4)},
+ ],
+ },
+ {
+ "token": "B",
+ "logprob": math.log(0.2),
+ "top_logprobs": [
+ {"token": "B", "logprob": math.log(0.2)},
+ {"token": "C", "logprob": math.log(0.8)},
+ ],
+ },
+ ]
+ }
+ }
+ ]
+ }
+ client = HTTPClient(model="m", base_url="http://test")
+ client._session = MockSession(data)
+ client.tokenizer = DummyTokenizer()
+ client.system_prompt = None
+ client.temperature = 0.0
+ client.top_p = 1.0
+ client.max_tokens = 10
+ client.topk_per_token = 2
+
+ # act
+ tokens = await client.generate_topk_per_token("hi", history=[])
+
+ # assert
+ assert len(tokens) == 2
+ # check probabilities and top_candidates
+ assert abs(tokens[0].prob - 0.6) < 1e-9
+ assert abs(tokens[1].prob - 0.2) < 1e-9
+ assert len(tokens[0].top_candidates) == 2
+ assert tokens[0].top_candidates[0].text == "A"
+ assert tokens[0].top_candidates[1].text == "B"
diff --git a/tests/integration_tests/models/llm/api/test_ollama_client.py b/tests/integration_tests/models/llm/api/test_ollama_client.py
new file mode 100644
index 00000000..b20bc44c
--- /dev/null
+++ b/tests/integration_tests/models/llm/api/test_ollama_client.py
@@ -0,0 +1,91 @@
+# pylint: disable=redefined-outer-name
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+from graphgen.models import OllamaClient
+
+
+# ----------------- fixture -----------------
+@pytest.fixture
+def mock_ollama_pkg():
+ """
+ mock ollama
+ """
+ ollama_mock = MagicMock()
+ ollama_mock.AsyncClient = AsyncMock
+ with patch.dict("sys.modules", {"ollama": ollama_mock}):
+ yield ollama_mock
+
+
+@pytest.fixture
+def ollama_client(mock_ollama_pkg) -> OllamaClient:
+ """
+ Returns a default-configured OllamaClient with client.chat mocked
+ """
+ cli = OllamaClient(model="gemma3", base_url="http://test:11434")
+ cli.tokenizer = MagicMock()
+ cli.tokenizer.encode = MagicMock(side_effect=lambda x: x.split())
+ cli.client.chat = AsyncMock(
+ return_value={
+ "message": {"content": "hi from ollama"},
+ "prompt_eval_count": 10,
+ "eval_count": 5,
+ }
+ )
+ return cli
+
+
+@pytest.mark.asyncio
+async def test_generate_answer_basic(ollama_client: OllamaClient):
+ ans = await ollama_client.generate_answer("hello")
+ assert ans == "hi from ollama"
+ ollama_client.client.chat.assert_awaited_once()
+ call = ollama_client.client.chat.call_args
+ assert call.kwargs["model"] == "gemma3"
+ assert call.kwargs["messages"][-1]["content"] == "hello"
+ assert call.kwargs["stream"] is False
+
+
+@pytest.mark.asyncio
+async def test_generate_answer_with_history(ollama_client: OllamaClient):
+ hist = [{"role": "user", "content": "prev"}]
+ await ollama_client.generate_answer("now", history=hist)
+ msgs = ollama_client.client.chat.call_args.kwargs["messages"]
+ assert msgs[-2]["content"] == "prev"
+ assert msgs[-1]["content"] == "now"
+
+
+@pytest.mark.asyncio
+async def test_token_usage_recorded(ollama_client: OllamaClient):
+ await ollama_client.generate_answer("test")
+ assert len(ollama_client.token_usage) == 1
+ assert ollama_client.token_usage[0]["prompt_tokens"] == 10
+ assert ollama_client.token_usage[0]["completion_tokens"] == 5
+ assert ollama_client.token_usage[0]["total_tokens"] == 15
+
+
+@pytest.mark.asyncio
+async def test_rpm_tpm_limiter_called(ollama_client: OllamaClient):
+ ollama_client.request_limit = True
+ with patch.object(ollama_client.rpm, "wait", AsyncMock()) as rpm_mock, patch.object(
+ ollama_client.tpm, "wait", AsyncMock()
+ ) as tpm_mock:
+
+ await ollama_client.generate_answer("limited")
+ rpm_mock.assert_awaited_once_with(silent=True)
+ tpm_mock.assert_awaited_once_with(
+ ollama_client.max_tokens + len("limited".split()), silent=True
+ )
+
+
+def test_import_error_when_ollama_missing():
+ with patch.dict("sys.modules", {"ollama": None}):
+ with pytest.raises(ImportError, match="Ollama SDK is not installed"):
+ OllamaClient()
+
+
+@pytest.mark.asyncio
+async def test_generate_inputs_prob_not_implemented(ollama_client: OllamaClient):
+ with pytest.raises(NotImplementedError):
+ await ollama_client.generate_inputs_prob("any")
diff --git a/tests/integration_tests/models/llm/local/test_hf_wrapper.py b/tests/integration_tests/models/llm/local/test_hf_wrapper.py
new file mode 100644
index 00000000..ae23ce11
--- /dev/null
+++ b/tests/integration_tests/models/llm/local/test_hf_wrapper.py
@@ -0,0 +1,43 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper
+
+
+@pytest.fixture(autouse=True)
+def mock_hf(monkeypatch):
+ mock_tokenizer = MagicMock()
+ mock_tokenizer.pad_token = None
+ mock_tokenizer.eos_token = ""
+ mock_tokenizer.eos_token_id = 0
+ mock_tokenizer.decode.return_value = "hello"
+ mock_tokenizer.encode.return_value = [1, 2, 3]
+ monkeypatch.setattr(
+ "graphgen.models.llm.local.hf_wrapper.AutoTokenizer.from_pretrained",
+ lambda *a, **kw: mock_tokenizer,
+ )
+
+ mock_model = MagicMock()
+ mock_model.device = "cpu"
+ mock_model.generate.return_value = MagicMock(
+ __getitem__=lambda s, k: [0, 1, 2, 3], shape=(1, 4)
+ )
+ mock_model.eval.return_value = None
+ monkeypatch.setattr(
+ "graphgen.models.llm.local.hf_wrapper.AutoModelForCausalLM.from_pretrained",
+ lambda *a, **kw: mock_model,
+ )
+
+ monkeypatch.setattr(
+ "graphgen.models.llm.local.hf_wrapper.torch.no_grad", MagicMock()
+ )
+
+ return mock_tokenizer, mock_model
+
+
+@pytest.mark.asyncio
+async def test_generate_answer():
+ wrapper = HuggingFaceWrapper("fake-model")
+ result = await wrapper.generate_answer("hi")
+ assert isinstance(result, str)