diff --git a/.env.example b/.env.example index c1102c1c..94064796 100644 --- a/.env.example +++ b/.env.example @@ -1,7 +1,29 @@ +# Tokenizer TOKENIZER_MODEL= -SYNTHESIZER_MODEL= + +# LLM +# Support different backends: http_api, openai_api, ollama_api, ollama, huggingface, tgi, sglang, tensorrt + +# http_api / openai_api +SYNTHESIZER_BACKEND=openai_api +SYNTHESIZER_MODEL=gpt-4o-mini SYNTHESIZER_BASE_URL= SYNTHESIZER_API_KEY= -TRAINEE_MODEL= +TRAINEE_BACKEND=openai_api +TRAINEE_MODEL=gpt-4o-mini TRAINEE_BASE_URL= TRAINEE_API_KEY= + +# # ollama_api +# SYNTHESIZER_BACKEND=ollama_api +# SYNTHESIZER_MODEL=gemma3 +# SYNTHESIZER_BASE_URL=http://localhost:11434 +# +# Note: TRAINEE with ollama_api backend is not supported yet as ollama_api does not support logprobs. + +# # huggingface +# SYNTHESIZER_BACKEND=huggingface +# SYNTHESIZER_MODEL=Qwen/Qwen2.5-0.5B-Instruct +# +# TRAINEE_BACKEND=huggingface +# TRAINEE_MODEL=Qwen/Qwen2.5-0.5B-Instruct diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py index ace331d5..ed452628 100644 --- a/graphgen/bases/__init__.py +++ b/graphgen/bases/__init__.py @@ -1,6 +1,6 @@ from .base_generator import BaseGenerator from .base_kg_builder import BaseKGBuilder -from .base_llm_client import BaseLLMClient +from .base_llm_wrapper import BaseLLMWrapper from .base_partitioner import BasePartitioner from .base_reader import BaseReader from .base_splitter import BaseSplitter diff --git a/graphgen/bases/base_generator.py b/graphgen/bases/base_generator.py index d6148cee..85de5877 100644 --- a/graphgen/bases/base_generator.py +++ b/graphgen/bases/base_generator.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any -from graphgen.bases.base_llm_client import BaseLLMClient +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper class BaseGenerator(ABC): @@ -9,7 +9,7 @@ class BaseGenerator(ABC): Generate QAs based on given prompts. """ - def __init__(self, llm_client: BaseLLMClient): + def __init__(self, llm_client: BaseLLMWrapper): self.llm_client = llm_client @staticmethod diff --git a/graphgen/bases/base_kg_builder.py b/graphgen/bases/base_kg_builder.py index e234d8da..d8a5d66a 100644 --- a/graphgen/bases/base_kg_builder.py +++ b/graphgen/bases/base_kg_builder.py @@ -2,13 +2,13 @@ from collections import defaultdict from typing import Dict, List, Tuple -from graphgen.bases.base_llm_client import BaseLLMClient +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Chunk class BaseKGBuilder(ABC): - def __init__(self, llm_client: BaseLLMClient): + def __init__(self, llm_client: BaseLLMWrapper): self.llm_client = llm_client self._nodes: Dict[str, List[dict]] = defaultdict(list) self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list) diff --git a/graphgen/bases/base_llm_client.py b/graphgen/bases/base_llm_wrapper.py similarity index 91% rename from graphgen/bases/base_llm_client.py rename to graphgen/bases/base_llm_wrapper.py index 1abe5143..5a28fbb4 100644 --- a/graphgen/bases/base_llm_client.py +++ b/graphgen/bases/base_llm_wrapper.py @@ -8,7 +8,7 @@ from graphgen.bases.datatypes import Token -class BaseLLMClient(abc.ABC): +class BaseLLMWrapper(abc.ABC): """ LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...). """ @@ -66,3 +66,9 @@ def filter_think_tags(text: str, think_tag: str = "think") -> str: think_pattern = re.compile(rf"<{think_tag}>.*?", re.DOTALL) filtered_text = think_pattern.sub("", text).strip() return filtered_text if filtered_text else text.strip() + + def shutdown(self) -> None: + """Shutdown the LLM engine if applicable.""" + + def restart(self) -> None: + """Reinitialize the LLM engine if applicable.""" diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 8b0559d6..e8258829 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -1,11 +1,11 @@ import asyncio import os import time -from dataclasses import dataclass from typing import Dict, cast import gradio as gr +from graphgen.bases import BaseLLMWrapper from graphgen.bases.base_storage import StorageNameSpace from graphgen.bases.datatypes import Chunk from graphgen.models import ( @@ -20,6 +20,7 @@ build_text_kg, chunk_documents, generate_qas, + init_llm, judge_statement, partition_kg, quiz, @@ -31,40 +32,28 @@ sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -@dataclass class GraphGen: - unique_id: int = int(time.time()) - working_dir: str = os.path.join(sys_path, "cache") - - # llm - tokenizer_instance: Tokenizer = None - synthesizer_llm_client: OpenAIClient = None - trainee_llm_client: OpenAIClient = None - - # webui - progress_bar: gr.Progress = None - - def __post_init__(self): - self.tokenizer_instance: Tokenizer = self.tokenizer_instance or Tokenizer( + def __init__( + self, + unique_id: int = int(time.time()), + working_dir: str = os.path.join(sys_path, "cache"), + tokenizer_instance: Tokenizer = None, + synthesizer_llm_client: OpenAIClient = None, + trainee_llm_client: OpenAIClient = None, + progress_bar: gr.Progress = None, + ): + self.unique_id: int = unique_id + self.working_dir: str = working_dir + + # llm + self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer( model_name=os.getenv("TOKENIZER_MODEL") ) - self.synthesizer_llm_client: OpenAIClient = ( - self.synthesizer_llm_client - or OpenAIClient( - model_name=os.getenv("SYNTHESIZER_MODEL"), - api_key=os.getenv("SYNTHESIZER_API_KEY"), - base_url=os.getenv("SYNTHESIZER_BASE_URL"), - tokenizer=self.tokenizer_instance, - ) - ) - - self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient( - model_name=os.getenv("TRAINEE_MODEL"), - api_key=os.getenv("TRAINEE_API_KEY"), - base_url=os.getenv("TRAINEE_BASE_URL"), - tokenizer=self.tokenizer_instance, + self.synthesizer_llm_client: BaseLLMWrapper = ( + synthesizer_llm_client or init_llm("synthesizer") ) + self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client self.full_docs_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="full_docs" @@ -86,6 +75,9 @@ def __post_init__(self): namespace="qa", ) + # webui + self.progress_bar: gr.Progress = progress_bar + @async_to_sync_method async def insert(self, read_config: Dict, split_config: Dict): """ @@ -272,6 +264,12 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): ) # TODO: assert trainee_llm_client is valid before judge + if not self.trainee_llm_client: + # TODO: shutdown existing synthesizer_llm_client properly + logger.info("No trainee LLM client provided, initializing a new one.") + self.synthesizer_llm_client.shutdown() + self.trainee_llm_client = init_llm("trainee") + re_judge = quiz_and_judge_config["re_judge"] _update_relations = await judge_statement( self.trainee_llm_client, @@ -279,9 +277,16 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): self.rephrase_storage, re_judge, ) + await self.rephrase_storage.index_done_callback() await _update_relations.index_done_callback() + logger.info("Shutting down trainee LLM client.") + self.trainee_llm_client.shutdown() + self.trainee_llm_client = None + logger.info("Restarting synthesizer LLM client.") + self.synthesizer_llm_client.restart() + @async_to_sync_method async def generate(self, partition_config: Dict, generate_config: Dict): # Step 1: partition the graph diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 37476034..08694166 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -7,8 +7,7 @@ VQAGenerator, ) from .kg_builder import LightRAGKGBuilder, MMKGBuilder -from .llm.openai_client import OpenAIClient -from .llm.topk_token_model import TopkTokenModel +from .llm import HTTPClient, OllamaClient, OpenAIClient from .partitioner import ( AnchorBFSPartitioner, BFSPartitioner, diff --git a/graphgen/models/kg_builder/light_rag_kg_builder.py b/graphgen/models/kg_builder/light_rag_kg_builder.py index cde42d27..2d7bff01 100644 --- a/graphgen/models/kg_builder/light_rag_kg_builder.py +++ b/graphgen/models/kg_builder/light_rag_kg_builder.py @@ -2,7 +2,7 @@ from collections import Counter, defaultdict from typing import Dict, List, Tuple -from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk +from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMWrapper, Chunk from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT from graphgen.utils import ( detect_main_language, @@ -15,7 +15,7 @@ class LightRAGKGBuilder(BaseKGBuilder): - def __init__(self, llm_client: BaseLLMClient, max_loop: int = 3): + def __init__(self, llm_client: BaseLLMWrapper, max_loop: int = 3): super().__init__(llm_client) self.max_loop = max_loop diff --git a/graphgen/models/llm/__init__.py b/graphgen/models/llm/__init__.py index e69de29b..c70395d5 100644 --- a/graphgen/models/llm/__init__.py +++ b/graphgen/models/llm/__init__.py @@ -0,0 +1,4 @@ +from .api.http_client import HTTPClient +from .api.ollama_client import OllamaClient +from .api.openai_client import OpenAIClient +from .local.hf_wrapper import HuggingFaceWrapper diff --git a/graphgen/models/llm/api/__init__.py b/graphgen/models/llm/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/models/llm/api/http_client.py b/graphgen/models/llm/api/http_client.py new file mode 100644 index 00000000..2c3b0acd --- /dev/null +++ b/graphgen/models/llm/api/http_client.py @@ -0,0 +1,197 @@ +import asyncio +import math +from typing import Any, Dict, List, Optional + +import aiohttp +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper +from graphgen.bases.datatypes import Token +from graphgen.models.llm.limitter import RPM, TPM + + +class HTTPClient(BaseLLMWrapper): + """ + A generic async HTTP client for LLMs compatible with OpenAI's chat/completions format. + It uses aiohttp for making requests and includes retry logic and token usage tracking. + Usage example: + client = HTTPClient( + model_name="gpt-4o-mini", + base_url="http://localhost:8080", + api_key="your_api_key", + json_mode=True, + seed=42, + topk_per_token=5, + request_limit=True, + ) + + answer = await client.generate_answer("Hello, world!") + tokens = await client.generate_topk_per_token("Hello, world!") + """ + + _instance: Optional["HTTPClient"] = None + _lock = asyncio.Lock() + + def __new__(cls, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__( + self, + *, + model: str, + base_url: str, + api_key: Optional[str] = None, + json_mode: bool = False, + seed: Optional[int] = None, + topk_per_token: int = 5, + request_limit: bool = False, + rpm: Optional[RPM] = None, + tpm: Optional[TPM] = None, + **kwargs: Any, + ): + # Initialize only once in the singleton pattern + if getattr(self, "_initialized", False): + return + self._initialized: bool = True + super().__init__(**kwargs) + self.model_name = model + self.base_url = base_url.rstrip("/") + self.api_key = api_key + self.json_mode = json_mode + self.seed = seed + self.topk_per_token = topk_per_token + self.request_limit = request_limit + self.rpm = rpm or RPM() + self.tpm = tpm or TPM() + + self.token_usage: List[Dict[str, int]] = [] + self._session: Optional[aiohttp.ClientSession] = None + + @property + def session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + headers = ( + {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} + ) + self._session = aiohttp.ClientSession(headers=headers) + return self._session + + async def close(self): + if self._session and not self._session.closed: + await self._session.close() + + def _build_body(self, text: str, history: List[str]) -> Dict[str, Any]: + messages = [] + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + + # chatml format: alternating user and assistant messages + if history and isinstance(history[0], dict): + messages.extend(history) + + messages.append({"role": "user", "content": text}) + + body = { + "model": self.model_name, + "messages": messages, + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens, + } + if self.seed: + body["seed"] = self.seed + if self.json_mode: + body["response_format"] = {"type": "json_object"} + return body + + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)), + ) + async def generate_answer( + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, + ) -> str: + body = self._build_body(text, history or []) + prompt_tokens = sum( + len(self.tokenizer.encode(m["content"])) for m in body["messages"] + ) + est = prompt_tokens + body["max_tokens"] + + if self.request_limit: + await self.rpm.wait(silent=True) + await self.tpm.wait(est, silent=True) + + async with self.session.post( + f"{self.base_url}/chat/completions", + json=body, + timeout=aiohttp.ClientTimeout(total=60), + ) as resp: + resp.raise_for_status() + data = await resp.json() + + msg = data["choices"][0]["message"]["content"] + if "usage" in data: + self.token_usage.append( + { + "prompt_tokens": data["usage"]["prompt_tokens"], + "completion_tokens": data["usage"]["completion_tokens"], + "total_tokens": data["usage"]["total_tokens"], + } + ) + return self.filter_think_tags(msg) + + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)), + ) + async def generate_topk_per_token( + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, + ) -> List[Token]: + body = self._build_body(text, history or []) + body["max_tokens"] = 1 + if self.topk_per_token > 0: + body["logprobs"] = True + body["top_logprobs"] = self.topk_per_token + + async with self.session.post( + f"{self.base_url}/chat/completions", + json=body, + timeout=aiohttp.ClientTimeout(total=60), + ) as resp: + resp.raise_for_status() + data = await resp.json() + + token_logprobs = data["choices"][0]["logprobs"]["content"] + tokens = [] + for item in token_logprobs: + candidates = [ + Token(t["token"], math.exp(t["logprob"])) for t in item["top_logprobs"] + ] + tokens.append( + Token( + item["token"], math.exp(item["logprob"]), top_candidates=candidates + ) + ) + return tokens + + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + raise NotImplementedError( + "generate_inputs_prob is not implemented in HTTPClient" + ) diff --git a/graphgen/models/llm/api/ollama_client.py b/graphgen/models/llm/api/ollama_client.py new file mode 100644 index 00000000..9a4946a6 --- /dev/null +++ b/graphgen/models/llm/api/ollama_client.py @@ -0,0 +1,105 @@ +from typing import Any, Dict, List, Optional + +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper +from graphgen.bases.datatypes import Token +from graphgen.models.llm.limitter import RPM, TPM + + +class OllamaClient(BaseLLMWrapper): + """ + Requires a local or remote Ollama server to be running (default port 11434). + The top_logprobs field is not yet implemented by the official API. + """ + + def __init__( + self, + *, + model: str = "gemma3", + base_url: str = "http://localhost:11434", + json_mode: bool = False, + seed: Optional[int] = None, + topk_per_token: int = 5, + request_limit: bool = False, + rpm: Optional[RPM] = None, + tpm: Optional[TPM] = None, + **kwargs: Any, + ): + try: + import ollama + except ImportError as e: + raise ImportError( + "Ollama SDK is not installed." + "It is required to use OllamaClient." + "Please install it with `pip install ollama`." + ) from e + super().__init__(**kwargs) + self.model_name = model + self.base_url = base_url + self.json_mode = json_mode + self.seed = seed + self.topk_per_token = topk_per_token + self.request_limit = request_limit + self.rpm = rpm or RPM() + self.tpm = tpm or TPM() + self.token_usage: List[Dict[str, int]] = [] + + self.client = ollama.AsyncClient(host=self.base_url) + + async def generate_answer( + self, + text: str, + history: Optional[List[Dict[str, str]]] = None, + **extra: Any, + ) -> str: + messages = [] + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + if history: + messages.extend(history) + messages.append({"role": "user", "content": text}) + + options = { + "temperature": self.temperature, + "top_p": self.top_p, + "num_predict": self.max_tokens, + } + if self.seed is not None: + options["seed"] = self.seed + + prompt_tokens = sum(len(self.tokenizer.encode(m["content"])) for m in messages) + est = prompt_tokens + self.max_tokens + if self.request_limit: + await self.rpm.wait(silent=True) + await self.tpm.wait(est, silent=True) + + response = await self.client.chat( + model=self.model_name, + messages=messages, + format="json" if self.json_mode else "", + options=options, + stream=False, + ) + + usage = response.get("prompt_eval_count", 0), response.get("eval_count", 0) + self.token_usage.append( + { + "prompt_tokens": usage[0], + "completion_tokens": usage[1], + "total_tokens": sum(usage), + } + ) + content = response["message"]["content"] + return self.filter_think_tags(content) + + async def generate_topk_per_token( + self, + text: str, + history: Optional[List[Dict[str, str]]] = None, + **extra: Any, + ) -> List[Token]: + raise NotImplementedError("Ollama API does not support per-token top-k yet.") + + async def generate_inputs_prob( + self, text: str, history: Optional[List[Dict[str, str]]] = None, **extra: Any + ) -> List[Token]: + raise NotImplementedError("Ollama API does not support per-token logprobs yet.") diff --git a/graphgen/models/llm/openai_client.py b/graphgen/models/llm/api/openai_client.py similarity index 96% rename from graphgen/models/llm/openai_client.py rename to graphgen/models/llm/api/openai_client.py index 34316937..5f9c131a 100644 --- a/graphgen/models/llm/openai_client.py +++ b/graphgen/models/llm/api/openai_client.py @@ -10,7 +10,7 @@ wait_exponential, ) -from graphgen.bases.base_llm_client import BaseLLMClient +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper from graphgen.bases.datatypes import Token from graphgen.models.llm.limitter import RPM, TPM @@ -28,7 +28,7 @@ def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]: return tokens -class OpenAIClient(BaseLLMClient): +class OpenAIClient(BaseLLMWrapper): def __init__( self, *, @@ -105,8 +105,8 @@ async def generate_topk_per_token( kwargs["logprobs"] = True kwargs["top_logprobs"] = self.topk_per_token - # Limit max_tokens to 5 to avoid long completions - kwargs["max_tokens"] = 5 + # Limit max_tokens to 1 to avoid long completions + kwargs["max_tokens"] = 1 completion = await self.client.chat.completions.create( # pylint: disable=E1125 model=self.model_name, **kwargs diff --git a/graphgen/models/llm/local/__init__.py b/graphgen/models/llm/local/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/models/llm/local/hf_wrapper.py b/graphgen/models/llm/local/hf_wrapper.py new file mode 100644 index 00000000..b0538aad --- /dev/null +++ b/graphgen/models/llm/local/hf_wrapper.py @@ -0,0 +1,147 @@ +from typing import Any, List, Optional + +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper +from graphgen.bases.datatypes import Token + + +class HuggingFaceWrapper(BaseLLMWrapper): + """ + Async inference backend based on HuggingFace Transformers + """ + + def __init__( + self, + model: str, + torch_dtype="auto", + device_map="auto", + trust_remote_code=True, + temperature=0.0, + top_p=1.0, + topk=5, + **kwargs: Any, + ): + super().__init__(temperature=temperature, top_p=top_p, **kwargs) + + try: + import torch + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, + ) + except ImportError as exc: + raise ImportError( + "HuggingFaceWrapper requires torch, transformers and accelerate. " + "Install them with: pip install torch transformers accelerate" + ) from exc + + self.torch = torch + self.AutoTokenizer = AutoTokenizer + self.AutoModelForCausalLM = AutoModelForCausalLM + self.GenerationConfig = GenerationConfig + + self.tokenizer = AutoTokenizer.from_pretrained( + model, trust_remote_code=trust_remote_code + ) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.model = AutoModelForCausalLM.from_pretrained( + model, + torch_dtype=torch_dtype, + device_map=device_map, + trust_remote_code=trust_remote_code, + ) + self.model.eval() + self.temperature = temperature + self.top_p = top_p + self.topk = topk + + @staticmethod + def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: + msgs = history or [] + lines = [] + for m in msgs: + if isinstance(m, dict): + role = m.get("role", "") + content = m.get("content", "") + lines.append(f"{role}: {content}") + else: + lines.append(str(m)) + lines.append(prompt) + return "\n".join(lines) + + async def generate_answer( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> str: + full = self._build_inputs(text, history) + inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device) + + gen_kwargs = { + "max_new_tokens": extra.get("max_new_tokens", 512), + "do_sample": self.temperature > 0, + "temperature": self.temperature if self.temperature > 0 else 1.0, + "pad_token_id": self.tokenizer.eos_token_id, + } + + # Add top_p and top_k only if temperature > 0 + if self.temperature > 0: + gen_kwargs.update(top_p=self.top_p, top_k=self.topk) + + gen_config = self.GenerationConfig(**gen_kwargs) + + with self.torch.no_grad(): + out = self.model.generate(**inputs, generation_config=gen_config) + + gen = out[0, inputs.input_ids.shape[-1] :] + return self.tokenizer.decode(gen, skip_special_tokens=True) + + async def generate_topk_per_token( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + full = self._build_inputs(text, history) + inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device) + + with self.torch.no_grad(): + out = self.model.generate( + **inputs, + max_new_tokens=1, + do_sample=False, + temperature=1.0, + return_dict_in_generate=True, + output_scores=True, + pad_token_id=self.tokenizer.eos_token_id, + ) + + scores = out.scores[0][0] # (vocab,) + probs = self.torch.softmax(scores, dim=-1) + top_probs, top_idx = self.torch.topk(probs, k=self.topk) + + tokens = [] + for p, idx in zip(top_probs.cpu().numpy(), top_idx.cpu().numpy()): + tokens.append(Token(self.tokenizer.decode([idx]), float(p))) + return tokens + + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + full = self._build_inputs(text, history) + ids = self.tokenizer.encode(full) + logprobs = [] + + for i in range(1, len(ids) + 1): + trunc = ids[: i - 1] + ids[i:] if i < len(ids) else ids[:-1] + inputs = self.torch.tensor([trunc]).to(self.model.device) + + with self.torch.no_grad(): + logits = self.model(inputs).logits[0, -1, :] + probs = self.torch.softmax(logits, dim=-1) + + true_id = ids[i - 1] + logprobs.append( + Token( + self.tokenizer.decode([true_id]), + float(probs[true_id].cpu()), + ) + ) + return logprobs diff --git a/graphgen/models/llm/local/sglang_wrapper.py b/graphgen/models/llm/local/sglang_wrapper.py new file mode 100644 index 00000000..3781af2c --- /dev/null +++ b/graphgen/models/llm/local/sglang_wrapper.py @@ -0,0 +1,148 @@ +import math +from typing import Any, Dict, List, Optional + +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper +from graphgen.bases.datatypes import Token + + +class SGLangWrapper(BaseLLMWrapper): + """ + Async inference backend based on SGLang offline engine. + """ + + def __init__( + self, + model: str, + temperature: float = 0.0, + top_p: float = 1.0, + topk: int = 5, + **kwargs: Any, + ): + super().__init__(temperature=temperature, top_p=top_p, **kwargs) + try: + import sglang as sgl + from sglang.utils import async_stream_and_merge, stream_and_merge + except ImportError as exc: + raise ImportError( + "SGLangWrapper requires sglang. Install it with: " + "uv pip install sglang --prerelease=allow" + ) from exc + + self.model_path: str = model + self.temperature = temperature + self.top_p = top_p + self.topk = topk + + # Initialise the offline engine + self.engine = sgl.Engine(model_path=self.model_path) + + # Keep helpers for streaming + self.async_stream_and_merge = async_stream_and_merge + self.stream_and_merge = stream_and_merge + + @staticmethod + def _build_sampling_params( + temperature: float, + top_p: float, + max_tokens: int, + topk: int, + logprobs: bool = False, + ) -> Dict[str, Any]: + """Build SGLang-compatible sampling-params dict.""" + params = { + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_tokens, + } + if logprobs and topk > 0: + params["logprobs"] = topk + return params + + def _prep_prompt(self, text: str, history: Optional[List[dict]] = None) -> str: + """Convert raw text (+ optional history) into a single prompt string.""" + parts = [] + if self.system_prompt: + parts.append(self.system_prompt) + if history: + assert len(history) % 2 == 0, "History must have even length (u/a turns)." + parts.extend([item["content"] for item in history]) + parts.append(text) + return "\n".join(parts) + + def _tokens_from_output(self, output: Dict[str, Any]) -> List[Token]: + tokens: List[Token] = [] + + meta = output.get("meta_info", {}) + logprobs = meta.get("output_token_logprobs", []) + topks = meta.get("output_top_logprobs", []) + + tokenizer = self.engine.tokenizer_manager.tokenizer + + for idx, (lp, tid, _) in enumerate(logprobs): + prob = math.exp(lp) + tok_str = tokenizer.decode([tid]) + + top_candidates = [] + if self.topk > 0 and idx < len(topks): + for t_lp, t_tid, _ in topks[idx][: self.topk]: + top_candidates.append( + Token(text=tokenizer.decode([t_tid]), prob=math.exp(t_lp)) + ) + + tokens.append(Token(text=tok_str, prob=prob, top_candidates=top_candidates)) + + return tokens + + async def generate_answer( + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, + ) -> str: + prompt = self._prep_prompt(text, history) + sampling_params = self._build_sampling_params( + temperature=self.temperature, + top_p=self.top_p, + max_tokens=self.max_tokens, + topk=0, # no logprobs needed for simple generation + ) + + outputs = await self.engine.async_generate([prompt], sampling_params) + return self.filter_think_tags(outputs[0]["text"]) + + async def generate_topk_per_token( + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, + ) -> List[Token]: + prompt = self._prep_prompt(text, history) + sampling_params = self._build_sampling_params( + temperature=self.temperature, + top_p=self.top_p, + max_tokens=1, # keep short for token-level analysis + topk=self.topk, + ) + + outputs = await self.engine.async_generate( + [prompt], sampling_params, return_logprob=True, top_logprobs_num=5 + ) + print(outputs) + return self._tokens_from_output(outputs[0]) + + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + raise NotImplementedError( + "SGLangWrapper does not support per-token logprobs yet." + ) + + def shutdown(self) -> None: + """Gracefully shutdown the SGLang engine.""" + if hasattr(self, "engine"): + self.engine.shutdown() + + def restart(self) -> None: + """Restart the SGLang engine.""" + self.shutdown() + self.engine = self.engine.__class__(model_path=self.model_path) diff --git a/graphgen/models/llm/local/tgi_wrapper.py b/graphgen/models/llm/local/tgi_wrapper.py new file mode 100644 index 00000000..a722f6ea --- /dev/null +++ b/graphgen/models/llm/local/tgi_wrapper.py @@ -0,0 +1,36 @@ +from typing import Any, List, Optional + +from graphgen.bases import BaseLLMWrapper +from graphgen.bases.datatypes import Token + + +# TODO: implement TGIWrapper methods +class TGIWrapper(BaseLLMWrapper): + """ + Async inference backend based on TGI (Text-Generation-Inference) + """ + + def __init__( + self, + model_url: str, # e.g. "http://localhost:8080" + temperature: float = 0.0, + top_p: float = 1.0, + topk: int = 5, + **kwargs: Any + ): + super().__init__(temperature=temperature, top_p=top_p, **kwargs) + + async def generate_answer( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> str: + pass + + async def generate_topk_per_token( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + pass + + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + pass diff --git a/graphgen/models/llm/ollama_client.py b/graphgen/models/llm/local/trt_wrapper.py similarity index 66% rename from graphgen/models/llm/ollama_client.py rename to graphgen/models/llm/local/trt_wrapper.py index 5d6e5d20..078f5ba9 100644 --- a/graphgen/models/llm/ollama_client.py +++ b/graphgen/models/llm/local/trt_wrapper.py @@ -1,10 +1,15 @@ -# TODO: implement ollama client from typing import Any, List, Optional -from graphgen.bases import BaseLLMClient, Token +from graphgen.bases import BaseLLMWrapper +from graphgen.bases.datatypes import Token -class OllamaClient(BaseLLMClient): +# TODO: implement TensorRTWrapper methods +class TensorRTWrapper(BaseLLMWrapper): + """ + Async inference backend based on TensorRT-LLM + """ + async def generate_answer( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> str: diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py new file mode 100644 index 00000000..d3f6cfcc --- /dev/null +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -0,0 +1,137 @@ +from typing import Any, List, Optional + +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper +from graphgen.bases.datatypes import Token + + +class VLLMWrapper(BaseLLMWrapper): + """ + Async inference backend based on vLLM (https://github.com/vllm-project/vllm) + """ + + def __init__( + self, + model: str, + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.9, + temperature: float = 0.0, + top_p: float = 1.0, + topk: int = 5, + **kwargs: Any, + ): + super().__init__(temperature=temperature, top_p=top_p, **kwargs) + + try: + from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams + except ImportError as exc: + raise ImportError( + "VLLMWrapper requires vllm. Install it with: uv pip install vllm --torch-backend=auto" + ) from exc + + self.SamplingParams = SamplingParams + + engine_args = AsyncEngineArgs( + model=model, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + trust_remote_code=kwargs.get("trust_remote_code", True), + ) + self.engine = AsyncLLMEngine.from_engine_args(engine_args) + + self.temperature = temperature + self.top_p = top_p + self.topk = topk + + @staticmethod + def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: + msgs = history or [] + lines = [] + for m in msgs: + if isinstance(m, dict): + role = m.get("role", "") + content = m.get("content", "") + lines.append(f"{role}: {content}") + else: + lines.append(str(m)) + lines.append(prompt) + return "\n".join(lines) + + async def generate_answer( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> str: + full_prompt = self._build_inputs(text, history) + + sp = self.SamplingParams( + temperature=self.temperature if self.temperature > 0 else 1.0, + top_p=self.top_p if self.temperature > 0 else 1.0, + max_tokens=extra.get("max_new_tokens", 512), + ) + + results = [] + async for req_output in self.engine.generate( + full_prompt, sp, request_id="graphgen_req" + ): + results = req_output.outputs + return results[-1].text + + async def generate_topk_per_token( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + full_prompt = self._build_inputs(text, history) + + sp = self.SamplingParams( + temperature=0, + max_tokens=1, + logprobs=self.topk, + ) + + results = [] + async for req_output in self.engine.generate( + full_prompt, sp, request_id="graphgen_topk" + ): + results = req_output.outputs + top_logprobs = results[-1].logprobs[0] + + tokens = [] + for _, logprob_obj in top_logprobs.items(): + tok_str = logprob_obj.decoded_token + prob = float(logprob_obj.logprob.exp()) + tokens.append(Token(tok_str, prob)) + tokens.sort(key=lambda x: -x.prob) + return tokens + + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + full_prompt = self._build_inputs(text, history) + + # vLLM 没有现成的“mask 一个 token 再算 prob”接口, + # 我们采用最直观的方式:把 prompt 一次性送进去,打开 + # prompt_logprobs=True,让 vLLM 返回 *输入部分* 每个位置的 + # logprob,然后挑出对应 token 的概率即可。 + sp = self.SamplingParams( + temperature=0, + max_tokens=0, # 不生成新 token + prompt_logprobs=1, # 只要 top-1 就够了 + ) + + results = [] + async for req_output in self.engine.generate( + full_prompt, sp, request_id="graphgen_prob" + ): + results = req_output.outputs + + # prompt_logprobs 是一个 list,长度 = prompt token 数, + # 每个元素是 dict{token_id: logprob_obj} 或 None(首个位置为 None) + prompt_logprobs = results[-1].prompt_logprobs + + tokens = [] + for _, logprob_dict in enumerate(prompt_logprobs): + if logprob_dict is None: + continue + # 这里每个 dict 只有 1 个 kv,因为 top-1 + _, logprob_obj = next(iter(logprob_dict.items())) + tok_str = logprob_obj.decoded_token + prob = float(logprob_obj.logprob.exp()) + tokens.append(Token(tok_str, prob)) + return tokens diff --git a/graphgen/models/llm/topk_token_model.py b/graphgen/models/llm/topk_token_model.py deleted file mode 100644 index e93ca01a..00000000 --- a/graphgen/models/llm/topk_token_model.py +++ /dev/null @@ -1,53 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List, Optional - -from graphgen.bases import Token - - -class TopkTokenModel(ABC): - def __init__( - self, - do_sample: bool = False, - temperature: float = 0, - max_tokens: int = 4096, - repetition_penalty: float = 1.05, - num_beams: int = 1, - topk: int = 50, - topp: float = 0.95, - topk_per_token: int = 5, - ): - self.do_sample = do_sample - self.temperature = temperature - self.max_tokens = max_tokens - self.repetition_penalty = repetition_penalty - self.num_beams = num_beams - self.topk = topk - self.topp = topp - self.topk_per_token = topk_per_token - - @abstractmethod - async def generate_topk_per_token(self, text: str) -> List[Token]: - """ - Generate prob, text and candidates for each token of the model's output. - This function is used to visualize the inference process. - """ - raise NotImplementedError - - @abstractmethod - async def generate_inputs_prob( - self, text: str, history: Optional[List[str]] = None - ) -> List[Token]: - """ - Generate prob and text for each token of the input text. - This function is used to visualize the ppl. - """ - raise NotImplementedError - - @abstractmethod - async def generate_answer( - self, text: str, history: Optional[List[str]] = None - ) -> str: - """ - Generate answer from the model. - """ - raise NotImplementedError diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 2ad37e63..3e8e7ba9 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,5 +1,6 @@ from .build_kg import build_mm_kg, build_text_kg from .generate import generate_qas +from .init import init_llm from .judge import judge_statement from .partition import partition_kg from .quiz import quiz diff --git a/graphgen/operators/build_kg/build_mm_kg.py b/graphgen/operators/build_kg/build_mm_kg.py index 9301c2b9..624b10ad 100644 --- a/graphgen/operators/build_kg/build_mm_kg.py +++ b/graphgen/operators/build_kg/build_mm_kg.py @@ -3,14 +3,15 @@ import gradio as gr +from graphgen.bases import BaseLLMWrapper from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Chunk -from graphgen.models import MMKGBuilder, OpenAIClient +from graphgen.models import MMKGBuilder from graphgen.utils import run_concurrent async def build_mm_kg( - llm_client: OpenAIClient, + llm_client: BaseLLMWrapper, kg_instance: BaseGraphStorage, chunks: List[Chunk], progress_bar: gr.Progress = None, diff --git a/graphgen/operators/build_kg/build_text_kg.py b/graphgen/operators/build_kg/build_text_kg.py index 3babe2e5..3c75f022 100644 --- a/graphgen/operators/build_kg/build_text_kg.py +++ b/graphgen/operators/build_kg/build_text_kg.py @@ -3,14 +3,15 @@ import gradio as gr +from graphgen.bases import BaseLLMWrapper from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Chunk -from graphgen.models import LightRAGKGBuilder, OpenAIClient +from graphgen.models import LightRAGKGBuilder from graphgen.utils import run_concurrent async def build_text_kg( - llm_client: OpenAIClient, + llm_client: BaseLLMWrapper, kg_instance: BaseGraphStorage, chunks: List[Chunk], progress_bar: gr.Progress = None, diff --git a/graphgen/operators/generate/generate_qas.py b/graphgen/operators/generate/generate_qas.py index feadafc1..875e3bab 100644 --- a/graphgen/operators/generate/generate_qas.py +++ b/graphgen/operators/generate/generate_qas.py @@ -1,6 +1,6 @@ from typing import Any -from graphgen.bases import BaseLLMClient +from graphgen.bases import BaseLLMWrapper from graphgen.models import ( AggregatedGenerator, AtomicGenerator, @@ -12,7 +12,7 @@ async def generate_qas( - llm_client: BaseLLMClient, + llm_client: BaseLLMWrapper, batches: list[ tuple[ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] diff --git a/graphgen/operators/init/__init__.py b/graphgen/operators/init/__init__.py new file mode 100644 index 00000000..ec604441 --- /dev/null +++ b/graphgen/operators/init/__init__.py @@ -0,0 +1 @@ +from .init_llm import init_llm diff --git a/graphgen/operators/init/init_llm.py b/graphgen/operators/init/init_llm.py new file mode 100644 index 00000000..f7e33356 --- /dev/null +++ b/graphgen/operators/init/init_llm.py @@ -0,0 +1,84 @@ +import os +from typing import Any, Dict, Optional + +from graphgen.bases import BaseLLMWrapper +from graphgen.models import Tokenizer + + +class LLMFactory: + """ + A factory class to create LLM wrapper instances based on the specified backend. + Supported backends include: + - http_api: HTTPClient + - openai_api: OpenAIClient + - ollama_api: OllamaClient + - ollama: OllamaWrapper + - deepspeed: DeepSpeedWrapper + - huggingface: HuggingFaceWrapper + - tgi: TGIWrapper + - sglang: SGLangWrapper + - tensorrt: TensorRTWrapper + """ + + @staticmethod + def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper: + # add tokenizer + tokenizer: Tokenizer = Tokenizer( + os.environ.get("TOKENIZER_MODEL", "cl100k_base"), + ) + config["tokenizer"] = tokenizer + if backend == "http_api": + from graphgen.models.llm.api.http_client import HTTPClient + + return HTTPClient(**config) + if backend == "openai_api": + from graphgen.models.llm.api.openai_client import OpenAIClient + + return OpenAIClient(**config) + if backend == "ollama_api": + from graphgen.models.llm.api.ollama_client import OllamaClient + + return OllamaClient(**config) + if backend == "huggingface": + from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper + + return HuggingFaceWrapper(**config) + # if backend == "sglang": + # from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper + # + # return SGLangWrapper(**config) + + if backend == "vllm": + from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper + + return VLLMWrapper(**config) + + raise NotImplementedError(f"Backend {backend} is not implemented yet.") + + +def _load_env_group(prefix: str) -> Dict[str, Any]: + """ + Collect environment variables with the given prefix into a dictionary, + stripping the prefix from the keys. + """ + return { + k[len(prefix) :].lower(): v + for k, v in os.environ.items() + if k.startswith(prefix) + } + + +def init_llm(model_type: str) -> Optional[BaseLLMWrapper]: + if model_type == "synthesizer": + prefix = "SYNTHESIZER_" + elif model_type == "trainee": + prefix = "TRAINEE_" + else: + raise NotImplementedError(f"Model type {model_type} is not implemented yet.") + config = _load_env_group(prefix) + # if config is empty, return None + if not config: + return None + backend = config.pop("backend") + llm_wrapper = LLMFactory.create_llm_wrapper(backend, config) + return llm_wrapper diff --git a/graphgen/operators/judge.py b/graphgen/operators/judge.py index f7b0b963..d291d29a 100644 --- a/graphgen/operators/judge.py +++ b/graphgen/operators/judge.py @@ -3,13 +3,14 @@ from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient +from graphgen.bases import BaseLLMWrapper +from graphgen.models import JsonKVStorage, NetworkXStorage from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT from graphgen.utils import logger, yes_no_loss_entropy async def judge_statement( # pylint: disable=too-many-statements - trainee_llm_client: OpenAIClient, + trainee_llm_client: BaseLLMWrapper, graph_storage: NetworkXStorage, rephrase_storage: JsonKVStorage, re_judge: bool = False, diff --git a/graphgen/operators/quiz.py b/graphgen/operators/quiz.py index a8623bfb..cd86ef2d 100644 --- a/graphgen/operators/quiz.py +++ b/graphgen/operators/quiz.py @@ -3,13 +3,14 @@ from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient +from graphgen.bases import BaseLLMWrapper +from graphgen.models import JsonKVStorage, NetworkXStorage from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT from graphgen.utils import detect_main_language, logger async def quiz( - synth_llm_client: OpenAIClient, + synth_llm_client: BaseLLMWrapper, graph_storage: NetworkXStorage, rephrase_storage: JsonKVStorage, max_samples: int = 1, diff --git a/requirements.txt b/requirements.txt index 44b3687a..82740f03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,4 +25,4 @@ igraph python-louvain # For visualization -matplotlib \ No newline at end of file +matplotlib diff --git a/tests/integration_tests/models/llm/api/test_http_client.py b/tests/integration_tests/models/llm/api/test_http_client.py new file mode 100644 index 00000000..d2996d1c --- /dev/null +++ b/tests/integration_tests/models/llm/api/test_http_client.py @@ -0,0 +1,143 @@ +# pylint: disable=protected-access +import math + +import pytest + +from graphgen.models.llm.api.http_client import HTTPClient + + +class DummyTokenizer: + def encode(self, text: str): + # simple tokenization: split on spaces + return text.split() + + +class _MockResponse: + def __init__(self, data): + self._data = data + + def raise_for_status(self): + return None + + async def json(self): + return self._data + + +class _PostCtx: + def __init__(self, data): + self._resp = _MockResponse(data) + + async def __aenter__(self): + return self._resp + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class MockSession: + def __init__(self, data): + self._data = data + self.closed = False + + def post(self, *args, **kwargs): + return _PostCtx(self._data) + + async def close(self): + self.closed = True + + +class DummyLimiter: + def __init__(self): + self.calls = [] + + async def wait(self, *args, **kwargs): + self.calls.append((args, kwargs)) + + +@pytest.mark.asyncio +async def test_generate_answer_records_usage_and_uses_limiters(): + # arrange + data = { + "choices": [{"message": {"content": "Hello world!"}}], + "usage": {"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5}, + } + client = HTTPClient(model="m", base_url="http://test") + client._session = MockSession(data) + client.tokenizer = DummyTokenizer() + client.system_prompt = "sys" + client.temperature = 0.0 + client.top_p = 1.0 + client.max_tokens = 10 + client.filter_think_tags = lambda s: s.replace("", "").replace( + "", "" + ) + rpm = DummyLimiter() + tpm = DummyLimiter() + client.rpm = rpm + client.tpm = tpm + client.request_limit = True + + # act + out = await client.generate_answer("hi", history=["u1", "a1"]) + + # assert + assert out == "Hello world!" + assert client.token_usage[-1] == { + "prompt_tokens": 3, + "completion_tokens": 2, + "total_tokens": 5, + } + assert len(rpm.calls) == 1 + assert len(tpm.calls) == 1 + + +@pytest.mark.asyncio +async def test_generate_topk_per_token_parses_logprobs(): + # arrange + # create two token items with top_logprobs + data = { + "choices": [ + { + "logprobs": { + "content": [ + { + "token": "A", + "logprob": math.log(0.6), + "top_logprobs": [ + {"token": "A", "logprob": math.log(0.6)}, + {"token": "B", "logprob": math.log(0.4)}, + ], + }, + { + "token": "B", + "logprob": math.log(0.2), + "top_logprobs": [ + {"token": "B", "logprob": math.log(0.2)}, + {"token": "C", "logprob": math.log(0.8)}, + ], + }, + ] + } + } + ] + } + client = HTTPClient(model="m", base_url="http://test") + client._session = MockSession(data) + client.tokenizer = DummyTokenizer() + client.system_prompt = None + client.temperature = 0.0 + client.top_p = 1.0 + client.max_tokens = 10 + client.topk_per_token = 2 + + # act + tokens = await client.generate_topk_per_token("hi", history=[]) + + # assert + assert len(tokens) == 2 + # check probabilities and top_candidates + assert abs(tokens[0].prob - 0.6) < 1e-9 + assert abs(tokens[1].prob - 0.2) < 1e-9 + assert len(tokens[0].top_candidates) == 2 + assert tokens[0].top_candidates[0].text == "A" + assert tokens[0].top_candidates[1].text == "B" diff --git a/tests/integration_tests/models/llm/api/test_ollama_client.py b/tests/integration_tests/models/llm/api/test_ollama_client.py new file mode 100644 index 00000000..b20bc44c --- /dev/null +++ b/tests/integration_tests/models/llm/api/test_ollama_client.py @@ -0,0 +1,91 @@ +# pylint: disable=redefined-outer-name +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from graphgen.models import OllamaClient + + +# ----------------- fixture ----------------- +@pytest.fixture +def mock_ollama_pkg(): + """ + mock ollama + """ + ollama_mock = MagicMock() + ollama_mock.AsyncClient = AsyncMock + with patch.dict("sys.modules", {"ollama": ollama_mock}): + yield ollama_mock + + +@pytest.fixture +def ollama_client(mock_ollama_pkg) -> OllamaClient: + """ + Returns a default-configured OllamaClient with client.chat mocked + """ + cli = OllamaClient(model="gemma3", base_url="http://test:11434") + cli.tokenizer = MagicMock() + cli.tokenizer.encode = MagicMock(side_effect=lambda x: x.split()) + cli.client.chat = AsyncMock( + return_value={ + "message": {"content": "hi from ollama"}, + "prompt_eval_count": 10, + "eval_count": 5, + } + ) + return cli + + +@pytest.mark.asyncio +async def test_generate_answer_basic(ollama_client: OllamaClient): + ans = await ollama_client.generate_answer("hello") + assert ans == "hi from ollama" + ollama_client.client.chat.assert_awaited_once() + call = ollama_client.client.chat.call_args + assert call.kwargs["model"] == "gemma3" + assert call.kwargs["messages"][-1]["content"] == "hello" + assert call.kwargs["stream"] is False + + +@pytest.mark.asyncio +async def test_generate_answer_with_history(ollama_client: OllamaClient): + hist = [{"role": "user", "content": "prev"}] + await ollama_client.generate_answer("now", history=hist) + msgs = ollama_client.client.chat.call_args.kwargs["messages"] + assert msgs[-2]["content"] == "prev" + assert msgs[-1]["content"] == "now" + + +@pytest.mark.asyncio +async def test_token_usage_recorded(ollama_client: OllamaClient): + await ollama_client.generate_answer("test") + assert len(ollama_client.token_usage) == 1 + assert ollama_client.token_usage[0]["prompt_tokens"] == 10 + assert ollama_client.token_usage[0]["completion_tokens"] == 5 + assert ollama_client.token_usage[0]["total_tokens"] == 15 + + +@pytest.mark.asyncio +async def test_rpm_tpm_limiter_called(ollama_client: OllamaClient): + ollama_client.request_limit = True + with patch.object(ollama_client.rpm, "wait", AsyncMock()) as rpm_mock, patch.object( + ollama_client.tpm, "wait", AsyncMock() + ) as tpm_mock: + + await ollama_client.generate_answer("limited") + rpm_mock.assert_awaited_once_with(silent=True) + tpm_mock.assert_awaited_once_with( + ollama_client.max_tokens + len("limited".split()), silent=True + ) + + +def test_import_error_when_ollama_missing(): + with patch.dict("sys.modules", {"ollama": None}): + with pytest.raises(ImportError, match="Ollama SDK is not installed"): + OllamaClient() + + +@pytest.mark.asyncio +async def test_generate_inputs_prob_not_implemented(ollama_client: OllamaClient): + with pytest.raises(NotImplementedError): + await ollama_client.generate_inputs_prob("any") diff --git a/tests/integration_tests/models/llm/local/test_hf_wrapper.py b/tests/integration_tests/models/llm/local/test_hf_wrapper.py new file mode 100644 index 00000000..ae23ce11 --- /dev/null +++ b/tests/integration_tests/models/llm/local/test_hf_wrapper.py @@ -0,0 +1,43 @@ +from unittest.mock import MagicMock + +import pytest + +from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper + + +@pytest.fixture(autouse=True) +def mock_hf(monkeypatch): + mock_tokenizer = MagicMock() + mock_tokenizer.pad_token = None + mock_tokenizer.eos_token = "" + mock_tokenizer.eos_token_id = 0 + mock_tokenizer.decode.return_value = "hello" + mock_tokenizer.encode.return_value = [1, 2, 3] + monkeypatch.setattr( + "graphgen.models.llm.local.hf_wrapper.AutoTokenizer.from_pretrained", + lambda *a, **kw: mock_tokenizer, + ) + + mock_model = MagicMock() + mock_model.device = "cpu" + mock_model.generate.return_value = MagicMock( + __getitem__=lambda s, k: [0, 1, 2, 3], shape=(1, 4) + ) + mock_model.eval.return_value = None + monkeypatch.setattr( + "graphgen.models.llm.local.hf_wrapper.AutoModelForCausalLM.from_pretrained", + lambda *a, **kw: mock_model, + ) + + monkeypatch.setattr( + "graphgen.models.llm.local.hf_wrapper.torch.no_grad", MagicMock() + ) + + return mock_tokenizer, mock_model + + +@pytest.mark.asyncio +async def test_generate_answer(): + wrapper = HuggingFaceWrapper("fake-model") + result = await wrapper.generate_answer("hi") + assert isinstance(result, str)