From ba48ee53a062ed4d0a776fad2d093c23e0b8dce8 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 27 Oct 2025 11:58:35 +0800 Subject: [PATCH 01/27] chore: delete duplicate model --- graphgen/models/__init__.py | 1 - graphgen/models/llm/topk_token_model.py | 53 ------------------------- 2 files changed, 54 deletions(-) delete mode 100644 graphgen/models/llm/topk_token_model.py diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 37476034..0d46dd98 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -8,7 +8,6 @@ ) from .kg_builder import LightRAGKGBuilder, MMKGBuilder from .llm.openai_client import OpenAIClient -from .llm.topk_token_model import TopkTokenModel from .partitioner import ( AnchorBFSPartitioner, BFSPartitioner, diff --git a/graphgen/models/llm/topk_token_model.py b/graphgen/models/llm/topk_token_model.py deleted file mode 100644 index e93ca01a..00000000 --- a/graphgen/models/llm/topk_token_model.py +++ /dev/null @@ -1,53 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List, Optional - -from graphgen.bases import Token - - -class TopkTokenModel(ABC): - def __init__( - self, - do_sample: bool = False, - temperature: float = 0, - max_tokens: int = 4096, - repetition_penalty: float = 1.05, - num_beams: int = 1, - topk: int = 50, - topp: float = 0.95, - topk_per_token: int = 5, - ): - self.do_sample = do_sample - self.temperature = temperature - self.max_tokens = max_tokens - self.repetition_penalty = repetition_penalty - self.num_beams = num_beams - self.topk = topk - self.topp = topp - self.topk_per_token = topk_per_token - - @abstractmethod - async def generate_topk_per_token(self, text: str) -> List[Token]: - """ - Generate prob, text and candidates for each token of the model's output. - This function is used to visualize the inference process. - """ - raise NotImplementedError - - @abstractmethod - async def generate_inputs_prob( - self, text: str, history: Optional[List[str]] = None - ) -> List[Token]: - """ - Generate prob and text for each token of the input text. - This function is used to visualize the ppl. - """ - raise NotImplementedError - - @abstractmethod - async def generate_answer( - self, text: str, history: Optional[List[str]] = None - ) -> str: - """ - Generate answer from the model. - """ - raise NotImplementedError From d1e0af501d66aff6311cbcbf6c1f4c7f613e28a5 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 27 Oct 2025 13:50:38 +0800 Subject: [PATCH 02/27] feat: add huggingface wrapper --- graphgen/bases/__init__.py | 2 +- graphgen/bases/base_generator.py | 4 +- graphgen/bases/base_kg_builder.py | 4 +- ...base_llm_client.py => base_llm_wrapper.py} | 2 +- graphgen/models/__init__.py | 3 +- .../models/kg_builder/light_rag_kg_builder.py | 4 +- graphgen/models/llm/api/__init__.py | 0 .../models/llm/{ => api}/ollama_client.py | 4 +- .../models/llm/{ => api}/openai_client.py | 4 +- graphgen/models/llm/local/__init__.py | 0 graphgen/models/llm/local/hf_backend.py | 105 ++++++++++++++++++ graphgen/operators/generate/generate_qas.py | 4 +- 12 files changed, 121 insertions(+), 15 deletions(-) rename graphgen/bases/{base_llm_client.py => base_llm_wrapper.py} (98%) create mode 100644 graphgen/models/llm/api/__init__.py rename graphgen/models/llm/{ => api}/ollama_client.py (85%) rename graphgen/models/llm/{ => api}/openai_client.py (98%) create mode 100644 graphgen/models/llm/local/__init__.py create mode 100644 graphgen/models/llm/local/hf_backend.py diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py index ace331d5..ed452628 100644 --- a/graphgen/bases/__init__.py +++ b/graphgen/bases/__init__.py @@ -1,6 +1,6 @@ from .base_generator import BaseGenerator from .base_kg_builder import BaseKGBuilder -from .base_llm_client import BaseLLMClient +from .base_llm_wrapper import BaseLLMWrapper from .base_partitioner import BasePartitioner from .base_reader import BaseReader from .base_splitter import BaseSplitter diff --git a/graphgen/bases/base_generator.py b/graphgen/bases/base_generator.py index d6148cee..85de5877 100644 --- a/graphgen/bases/base_generator.py +++ b/graphgen/bases/base_generator.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any -from graphgen.bases.base_llm_client import BaseLLMClient +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper class BaseGenerator(ABC): @@ -9,7 +9,7 @@ class BaseGenerator(ABC): Generate QAs based on given prompts. """ - def __init__(self, llm_client: BaseLLMClient): + def __init__(self, llm_client: BaseLLMWrapper): self.llm_client = llm_client @staticmethod diff --git a/graphgen/bases/base_kg_builder.py b/graphgen/bases/base_kg_builder.py index e234d8da..d8a5d66a 100644 --- a/graphgen/bases/base_kg_builder.py +++ b/graphgen/bases/base_kg_builder.py @@ -2,13 +2,13 @@ from collections import defaultdict from typing import Dict, List, Tuple -from graphgen.bases.base_llm_client import BaseLLMClient +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Chunk class BaseKGBuilder(ABC): - def __init__(self, llm_client: BaseLLMClient): + def __init__(self, llm_client: BaseLLMWrapper): self.llm_client = llm_client self._nodes: Dict[str, List[dict]] = defaultdict(list) self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list) diff --git a/graphgen/bases/base_llm_client.py b/graphgen/bases/base_llm_wrapper.py similarity index 98% rename from graphgen/bases/base_llm_client.py rename to graphgen/bases/base_llm_wrapper.py index 1abe5143..b1f9cb0d 100644 --- a/graphgen/bases/base_llm_client.py +++ b/graphgen/bases/base_llm_wrapper.py @@ -8,7 +8,7 @@ from graphgen.bases.datatypes import Token -class BaseLLMClient(abc.ABC): +class BaseLLMWrapper(abc.ABC): """ LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...). """ diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 0d46dd98..d074ea6a 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,3 +1,5 @@ +from graphgen.models.llm.api.openai_client import OpenAIClient + from .evaluator import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator from .generator import ( AggregatedGenerator, @@ -7,7 +9,6 @@ VQAGenerator, ) from .kg_builder import LightRAGKGBuilder, MMKGBuilder -from .llm.openai_client import OpenAIClient from .partitioner import ( AnchorBFSPartitioner, BFSPartitioner, diff --git a/graphgen/models/kg_builder/light_rag_kg_builder.py b/graphgen/models/kg_builder/light_rag_kg_builder.py index cde42d27..2d7bff01 100644 --- a/graphgen/models/kg_builder/light_rag_kg_builder.py +++ b/graphgen/models/kg_builder/light_rag_kg_builder.py @@ -2,7 +2,7 @@ from collections import Counter, defaultdict from typing import Dict, List, Tuple -from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk +from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMWrapper, Chunk from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT from graphgen.utils import ( detect_main_language, @@ -15,7 +15,7 @@ class LightRAGKGBuilder(BaseKGBuilder): - def __init__(self, llm_client: BaseLLMClient, max_loop: int = 3): + def __init__(self, llm_client: BaseLLMWrapper, max_loop: int = 3): super().__init__(llm_client) self.max_loop = max_loop diff --git a/graphgen/models/llm/api/__init__.py b/graphgen/models/llm/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/models/llm/ollama_client.py b/graphgen/models/llm/api/ollama_client.py similarity index 85% rename from graphgen/models/llm/ollama_client.py rename to graphgen/models/llm/api/ollama_client.py index 5d6e5d20..738eac49 100644 --- a/graphgen/models/llm/ollama_client.py +++ b/graphgen/models/llm/api/ollama_client.py @@ -1,10 +1,10 @@ # TODO: implement ollama client from typing import Any, List, Optional -from graphgen.bases import BaseLLMClient, Token +from graphgen.bases import BaseLLMWrapper, Token -class OllamaClient(BaseLLMClient): +class OllamaClient(BaseLLMWrapper): async def generate_answer( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> str: diff --git a/graphgen/models/llm/openai_client.py b/graphgen/models/llm/api/openai_client.py similarity index 98% rename from graphgen/models/llm/openai_client.py rename to graphgen/models/llm/api/openai_client.py index 30ec39c8..5f9c131a 100644 --- a/graphgen/models/llm/openai_client.py +++ b/graphgen/models/llm/api/openai_client.py @@ -10,7 +10,7 @@ wait_exponential, ) -from graphgen.bases.base_llm_client import BaseLLMClient +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper from graphgen.bases.datatypes import Token from graphgen.models.llm.limitter import RPM, TPM @@ -28,7 +28,7 @@ def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]: return tokens -class OpenAIClient(BaseLLMClient): +class OpenAIClient(BaseLLMWrapper): def __init__( self, *, diff --git a/graphgen/models/llm/local/__init__.py b/graphgen/models/llm/local/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/models/llm/local/hf_backend.py b/graphgen/models/llm/local/hf_backend.py new file mode 100644 index 00000000..b3bef4f1 --- /dev/null +++ b/graphgen/models/llm/local/hf_backend.py @@ -0,0 +1,105 @@ +from typing import Any, List, Optional + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper +from graphgen.bases.datatypes import Token + + +class HuggingFaceWrapper(BaseLLMWrapper): + def __init__( + self, + model_path: str, + torch_dtype="auto", + device_map="auto", + trust_remote_code=True, + temperature=0.0, + top_p=1.0, + topk=5, + **kwargs: Any + ): + super().__init__(temperature=temperature, top_p=top_p, **kwargs) + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=trust_remote_code + ) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch_dtype, + device_map=device_map, + trust_remote_code=trust_remote_code, + ) + self.model.eval() + self.temperature = temperature + self.top_p = top_p + self.topk = topk + + @staticmethod + def _build_inputs(prompt: str, history: Optional[List[str]] = None): + msgs = history or [] + msgs.append(prompt) + full = "\n".join(msgs) + return full + + async def generate_answer( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> str: + full = self._build_inputs(text, history) + inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device) + max_new = 512 + with torch.no_grad(): + out = self.model.generate( + **inputs, + max_new_tokens=max_new, + temperature=self.temperature if self.temperature > 0 else 0.0, + top_p=self.top_p if self.temperature > 0 else 1.0, + do_sample=self.temperature > 0, + pad_token_id=self.tokenizer.eos_token_id, + ) + gen = out[0, inputs.input_ids.shape[-1] :] + return self.tokenizer.decode(gen, skip_special_tokens=True) + + async def generate_topk_per_token( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + full = self._build_inputs(text, history) + inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device) + with torch.no_grad(): + out = self.model.generate( + **inputs, + max_new_tokens=1, + temperature=0, + return_dict_in_generate=True, + output_scores=True, + ) + scores = out.scores[0][0] # vocab + probs = torch.softmax(scores, dim=-1) + top_probs, top_idx = torch.topk(probs, k=self.topk) + tokens = [] + for p, idx in zip(top_probs.cpu().numpy(), top_idx.cpu().numpy()): + tokens.append(Token(self.tokenizer.decode([idx]), float(p))) + return tokens + + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + full = self._build_inputs(text, history) + ids = self.tokenizer.encode(full) + logprobs = [] + for i in range(1, len(ids) + 1): + trunc = ids[: i - 1] + ids[i:] if i < len(ids) else ids[:-1] + inputs = torch.tensor([trunc]).to(self.model.device) + with torch.no_grad(): + logits = self.model(inputs).logits[0, -1, :] + probs = torch.softmax(logits, dim=-1) + true_id = ids[i - 1] + logprobs.append( + Token( + self.tokenizer.decode([true_id]), + float(probs[true_id].cpu()), + ) + ) + return logprobs diff --git a/graphgen/operators/generate/generate_qas.py b/graphgen/operators/generate/generate_qas.py index feadafc1..875e3bab 100644 --- a/graphgen/operators/generate/generate_qas.py +++ b/graphgen/operators/generate/generate_qas.py @@ -1,6 +1,6 @@ from typing import Any -from graphgen.bases import BaseLLMClient +from graphgen.bases import BaseLLMWrapper from graphgen.models import ( AggregatedGenerator, AtomicGenerator, @@ -12,7 +12,7 @@ async def generate_qas( - llm_client: BaseLLMClient, + llm_client: BaseLLMWrapper, batches: list[ tuple[ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] From a3a0c60277b9f09ec6f69a70b8fdb9b8cb6cd386 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 27 Oct 2025 13:54:11 +0800 Subject: [PATCH 03/27] refactor: change file name --- graphgen/models/llm/local/{hf_backend.py => hf_wrapper.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename graphgen/models/llm/local/{hf_backend.py => hf_wrapper.py} (100%) diff --git a/graphgen/models/llm/local/hf_backend.py b/graphgen/models/llm/local/hf_wrapper.py similarity index 100% rename from graphgen/models/llm/local/hf_backend.py rename to graphgen/models/llm/local/hf_wrapper.py From 61766b50a8853afb728d8a67fbfea80adc93a147 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 27 Oct 2025 15:04:06 +0800 Subject: [PATCH 04/27] feat: add ds_wrapper, trt_wrapper, sglang_wrapper --- graphgen/models/llm/local/ds_wrapper.py | 19 +++ graphgen/models/llm/local/hf_wrapper.py | 4 + graphgen/models/llm/local/sglang_wrapper.py | 114 ++++++++++++++++++ graphgen/models/llm/local/trt_wrapper.py | 93 ++++++++++++++ .../models/llm/local/test_ds_wrapper.py | 33 +++++ .../models/llm/local/test_hf_wrapper.py | 43 +++++++ 6 files changed, 306 insertions(+) create mode 100644 graphgen/models/llm/local/ds_wrapper.py create mode 100644 graphgen/models/llm/local/sglang_wrapper.py create mode 100644 graphgen/models/llm/local/trt_wrapper.py create mode 100644 tests/integration_tests/models/llm/local/test_ds_wrapper.py create mode 100644 tests/integration_tests/models/llm/local/test_hf_wrapper.py diff --git a/graphgen/models/llm/local/ds_wrapper.py b/graphgen/models/llm/local/ds_wrapper.py new file mode 100644 index 00000000..23db41ba --- /dev/null +++ b/graphgen/models/llm/local/ds_wrapper.py @@ -0,0 +1,19 @@ +from .hf_wrapper import HuggingFaceWrapper + + +class DeepSpeedBackend(HuggingFaceWrapper): + """ + Inference backend based on DeepSpeed + """ + + def __init__(self, *args, ds_config=None, **kwargs): + super().__init__(*args, **kwargs) + try: + import deepspeed + except ImportError as exc: + raise ImportError( + "Please install deepspeed to use DeepSpeedBackend: pip install deepspeed" + ) from exc + ds_config = ds_config or {"train_batch_size": 1, "fp16": {"enabled": True}} + self.model, _, _, _ = deepspeed.initialize(model=self.model, config=ds_config) + self.model.module.eval() diff --git a/graphgen/models/llm/local/hf_wrapper.py b/graphgen/models/llm/local/hf_wrapper.py index b3bef4f1..e2e6f582 100644 --- a/graphgen/models/llm/local/hf_wrapper.py +++ b/graphgen/models/llm/local/hf_wrapper.py @@ -8,6 +8,10 @@ class HuggingFaceWrapper(BaseLLMWrapper): + """ + Async inference backend based on HuggingFace Transformers + """ + def __init__( self, model_path: str, diff --git a/graphgen/models/llm/local/sglang_wrapper.py b/graphgen/models/llm/local/sglang_wrapper.py new file mode 100644 index 00000000..a0d6b8d5 --- /dev/null +++ b/graphgen/models/llm/local/sglang_wrapper.py @@ -0,0 +1,114 @@ +import math +from typing import Any, List, Optional + +from graphgen.bases import BaseLLMWrapper +from graphgen.bases.datatypes import Token + + +class SGLangBackend(BaseLLMWrapper): + """ + Async inference backend based on SGLang + """ + + def __init__( + self, + model_path: str, + tp_size: int = 1, + max_context_len: int = 4096, + server_url: Optional[str] = None, + temperature: float = 0.0, + top_p: float = 1.0, + topk: int = 5, + **kwargs: Any + ): + super().__init__(temperature=temperature, top_p=top_p, **kwargs) + try: + import sglang as sgl + from sglang.backend.runtime_endpoint import RuntimeEndpoint + except ImportError as exc: + raise ImportError( + "Please install sglang to use SGLangBackend: pip install sglang[all]>=0.4.4" + ) from exc + self.model_path = model_path + self.temperature = temperature + self.top_p = top_p + self.topk = topk + + # if server_url is given, connect to remote server; else launch local runtime + if server_url: + self.runtime = RuntimeEndpoint(server_url) + else: + sgl.set_default_backend( + sgl.Runtime( + model_path, tp_size=tp_size, max_context_len=max_context_len + ) + ) + self.runtime = sgl.get_default_backend() + + self.tokenizer = self.runtime.get_tokenizer() + + @staticmethod + def _messages_to_str(prompt: str, history: Optional[List[str]] = None) -> str: + if not history: + return prompt + return "\n".join(history) + "\n" + prompt + + async def generate_answer( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> str: + text = self._messages_to_str(text, history) + + output = await self.runtime.generate( + text, + max_new_tokens=512, + temperature=self.temperature if self.temperature > 0 else 0, + top_p=self.top_p, + stop=None, + ) + return output + + async def generate_topk_per_token( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + text = self._messages_to_str(text, history) + + output_obj = await self.runtime.generate( + text, + max_new_tokens=1, + temperature=0, + return_logprob=True, + top_logprobs=self.topk, + logprob_start_len=0, + ) + + topk_list = output_obj["meta_info"]["top_logprobs"][ + 0 + ] # List[ (token_str, logprob), ... ] + return [Token(tok, math.exp(logprob)) for tok, logprob in topk_list] + + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + text = self._messages_to_str(text, history) + ids = self.tokenizer.encode(text) + if not ids: + return [] + + logprob_tokens: List[Token] = [] + + for i in range(1, len(ids) + 1): + trunc_ids = ids[: i - 1] + ids[i:] if i < len(ids) else ids[:-1] + trunc_text = self.tokenizer.decode(trunc_ids) + + output_obj = await self.runtime.generate( + trunc_text, + max_new_tokens=1, + temperature=0, + return_logprob=True, + top_logprobs=1, + logprob_start_len=len(trunc_ids) - 1, + ) + top1 = output_obj["meta_info"]["top_logprobs"][0][0] + logprob_tokens.append(Token(top1[0], math.exp(top1[1]))) + + return logprob_tokens diff --git a/graphgen/models/llm/local/trt_wrapper.py b/graphgen/models/llm/local/trt_wrapper.py new file mode 100644 index 00000000..be7223bd --- /dev/null +++ b/graphgen/models/llm/local/trt_wrapper.py @@ -0,0 +1,93 @@ +from typing import Any, List, Optional + +import numpy as np +from transformers import AutoTokenizer + +from graphgen.bases import BaseLLMWrapper +from graphgen.bases.datatypes import Token + + +class TensorRTBackend(BaseLLMWrapper): + """ + Async inference backend based on TensorRT-LLM + """ + + def __init__( + self, + engine_dir: str, + tokenizer_dir: str, + topk: int = 5, + temperature=0.0, + top_p=1.0, + **kwargs: Any + ): + super().__init__(temperature=temperature, top_p=top_p, **kwargs) + try: + from tensorrt_llm.runtime import ModelRunnerCpp + except ImportError as exc: + raise ImportError( + "Please install tensorrt-llm to use TensorRTBackend: pip install tensorrt-llm" + ) from exc + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) + self.runner = ModelRunnerCpp.from_dir(engine_dir) + self.topk = topk + self.temperature = temperature + self.top_p = top_p + + def _parse_generation(self, output_ids) -> str: + return self.tokenizer.decode(output_ids[0], skip_special_tokens=True) + + async def generate_answer( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> str: + full = "\n".join(history or []) + "\n" + text + ids = self.tokenizer.encode(full) + output_ids = self.runner.generate( + [ids], + max_new_tokens=512, + temperature=self.temperature, + top_p=self.top_p, + eos_token_id=self.tokenizer.eos_token_id, + ) + return self._parse_generation(output_ids) + + async def generate_topk_per_token( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + full = "\n".join(history or []) + "\n" + text + ids = self.tokenizer.encode(full) + *_, logits = self.runner.generate( + [ids], + max_new_tokens=1, + temperature=0, + output_logits=True, + ) + logits = logits[0, -1, :] + probs = np.softmax(logits) + top_idx = np.argpartition(probs, -self.topk)[-self.topk :] + top_idx = top_idx[np.argsort(probs[top_idx])[::-1]] + return [ + Token(self.tokenizer.decode([idx]), float(probs[idx])) for idx in top_idx + ] + + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + full = "\n".join(history or []) + "\n" + text + ids = self.tokenizer.encode(full) + logprob_tokens = [] + for i in range(1, len(ids) + 1): + trunc = ids[: i - 1] + ids[i:] if i < len(ids) else ids[:-1] + *_, logits = self.runner.generate( + [trunc], + max_new_tokens=1, + temperature=0, + output_logits=True, + ) + logits = logits[0, -1, :] + probs = np.softmax(logits) + true_id = ids[i - 1] + logprob_tokens.append( + Token(self.tokenizer.decode([true_id]), float(probs[true_id])) + ) + return logprob_tokens diff --git a/tests/integration_tests/models/llm/local/test_ds_wrapper.py b/tests/integration_tests/models/llm/local/test_ds_wrapper.py new file mode 100644 index 00000000..87df19c8 --- /dev/null +++ b/tests/integration_tests/models/llm/local/test_ds_wrapper.py @@ -0,0 +1,33 @@ +import sys + +import pytest + +from graphgen.models.llm.local.ds_wrapper import DeepSpeedBackend + + +def test_deepspeed_backend_init(monkeypatch): + class DummyModel: + def eval(self): + pass + + class DummyModule: + def __init__(self): + self.module = DummyModel() + + def dummy_initialize(model, config): + return DummyModule(), None, None, None + + monkeypatch.setitem( + sys.modules, + "deepspeed", + type("ds", (), {"initialize": staticmethod(dummy_initialize)})(), + ) + backend = DeepSpeedBackend(model=DummyModel()) + assert hasattr(backend.model, "module") + assert hasattr(backend.model.module, "eval") + + +def test_deepspeed_not_installed(monkeypatch): + monkeypatch.setitem(sys.modules, "deepspeed", None) + with pytest.raises(ImportError): + DeepSpeedBackend(model=object()) diff --git a/tests/integration_tests/models/llm/local/test_hf_wrapper.py b/tests/integration_tests/models/llm/local/test_hf_wrapper.py new file mode 100644 index 00000000..ae23ce11 --- /dev/null +++ b/tests/integration_tests/models/llm/local/test_hf_wrapper.py @@ -0,0 +1,43 @@ +from unittest.mock import MagicMock + +import pytest + +from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper + + +@pytest.fixture(autouse=True) +def mock_hf(monkeypatch): + mock_tokenizer = MagicMock() + mock_tokenizer.pad_token = None + mock_tokenizer.eos_token = "" + mock_tokenizer.eos_token_id = 0 + mock_tokenizer.decode.return_value = "hello" + mock_tokenizer.encode.return_value = [1, 2, 3] + monkeypatch.setattr( + "graphgen.models.llm.local.hf_wrapper.AutoTokenizer.from_pretrained", + lambda *a, **kw: mock_tokenizer, + ) + + mock_model = MagicMock() + mock_model.device = "cpu" + mock_model.generate.return_value = MagicMock( + __getitem__=lambda s, k: [0, 1, 2, 3], shape=(1, 4) + ) + mock_model.eval.return_value = None + monkeypatch.setattr( + "graphgen.models.llm.local.hf_wrapper.AutoModelForCausalLM.from_pretrained", + lambda *a, **kw: mock_model, + ) + + monkeypatch.setattr( + "graphgen.models.llm.local.hf_wrapper.torch.no_grad", MagicMock() + ) + + return mock_tokenizer, mock_model + + +@pytest.mark.asyncio +async def test_generate_answer(): + wrapper = HuggingFaceWrapper("fake-model") + result = await wrapper.generate_answer("hi") + assert isinstance(result, str) From e6f45028c2dbea900eb58d81d127ec573b022482 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 27 Oct 2025 15:27:36 +0800 Subject: [PATCH 05/27] refactor: refactor graphgen --- graphgen/graphgen.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 8b0559d6..dc1cd46d 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -1,7 +1,6 @@ import asyncio import os import time -from dataclasses import dataclass from typing import Dict, cast import gradio as gr @@ -31,26 +30,26 @@ sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -@dataclass class GraphGen: - unique_id: int = int(time.time()) - working_dir: str = os.path.join(sys_path, "cache") - - # llm - tokenizer_instance: Tokenizer = None - synthesizer_llm_client: OpenAIClient = None - trainee_llm_client: OpenAIClient = None - - # webui - progress_bar: gr.Progress = None - - def __post_init__(self): - self.tokenizer_instance: Tokenizer = self.tokenizer_instance or Tokenizer( + def __init__( + self, + unique_id: int = int(time.time()), + working_dir: str = os.path.join(sys_path, "cache"), + tokenizer_instance: Tokenizer = None, + synthesizer_llm_client: OpenAIClient = None, + trainee_llm_client: OpenAIClient = None, + progress_bar: gr.Progress = None, + ): + self.unique_id: int = unique_id + self.working_dir: str = working_dir + + # llm + self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer( model_name=os.getenv("TOKENIZER_MODEL") ) self.synthesizer_llm_client: OpenAIClient = ( - self.synthesizer_llm_client + synthesizer_llm_client or OpenAIClient( model_name=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), @@ -59,7 +58,7 @@ def __post_init__(self): ) ) - self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient( + self.trainee_llm_client: OpenAIClient = trainee_llm_client or OpenAIClient( model_name=os.getenv("TRAINEE_MODEL"), api_key=os.getenv("TRAINEE_API_KEY"), base_url=os.getenv("TRAINEE_BASE_URL"), @@ -86,6 +85,9 @@ def __post_init__(self): namespace="qa", ) + # webui + self.progress_bar: gr.Progress = progress_bar + @async_to_sync_method async def insert(self, read_config: Dict, split_config: Dict): """ From 5a1fc4d4c791dc4d79fa967af5804839b44501f9 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 27 Oct 2025 15:57:22 +0800 Subject: [PATCH 06/27] wip: add azure client --- graphgen/models/llm/api/azure_client.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 graphgen/models/llm/api/azure_client.py diff --git a/graphgen/models/llm/api/azure_client.py b/graphgen/models/llm/api/azure_client.py new file mode 100644 index 00000000..e69de29b From 276f88195f88e306f184c3b1fba446150fa0233a Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 27 Oct 2025 16:58:16 +0800 Subject: [PATCH 07/27] feat: add ollama_wrapper, tgi_wrapper --- graphgen/models/llm/local/ollama_wrapper.py | 65 ++++++++++++++++ graphgen/models/llm/local/tgi_wrapper.py | 86 +++++++++++++++++++++ 2 files changed, 151 insertions(+) create mode 100644 graphgen/models/llm/local/ollama_wrapper.py create mode 100644 graphgen/models/llm/local/tgi_wrapper.py diff --git a/graphgen/models/llm/local/ollama_wrapper.py b/graphgen/models/llm/local/ollama_wrapper.py new file mode 100644 index 00000000..cf3f36fc --- /dev/null +++ b/graphgen/models/llm/local/ollama_wrapper.py @@ -0,0 +1,65 @@ +from typing import Any, List, Optional + +from graphgen.bases import BaseLLMWrapper +from graphgen.bases.datatypes import Token + + +class OllamaBackend(BaseLLMWrapper): + """ + Async inference backend based on Ollama local server + """ + + def __init__( + self, + model: str, # e.g. "llama3.1:8b" + host: str = "http://localhost:11434", + temperature: float = 0.0, + top_p: float = 1.0, + topk: int = 5, + **kwargs: Any + ): + try: + import ollama + except ImportError as exc: + raise ImportError( + "Please install ollama to use OllamaBackend: pip install ollama>=0.1.5" + ) from exc + super().__init__(temperature=temperature, top_p=top_p, **kwargs) + self.client = ollama.AsyncClient(host=host) + self.model = model + self.topk = topk + + @staticmethod + def _messages_to_str(prompt: str, history: Optional[List[str]] = None) -> str: + if not history: + return prompt + return "\n".join(history) + "\n" + prompt + + async def generate_answer( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> str: + text = self._messages_to_str(text, history) + resp = await self.client.generate( + model=self.model, + prompt=text, + options={ + "temperature": self.temperature or 0, + "top_p": self.top_p if self.top_p < 1.0 else 1, + }, + stream=False, + ) + return resp["response"] + + async def generate_topk_per_token( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + raise NotImplementedError( + "Ollama backend does not support per-token top-k yet." + ) + + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + raise NotImplementedError( + "Ollama backend does not support per-token input probabilities yet." + ) diff --git a/graphgen/models/llm/local/tgi_wrapper.py b/graphgen/models/llm/local/tgi_wrapper.py new file mode 100644 index 00000000..cfb15608 --- /dev/null +++ b/graphgen/models/llm/local/tgi_wrapper.py @@ -0,0 +1,86 @@ +import math +from typing import Any, List, Optional + +from huggingface_hub import InferenceClient + +from graphgen.bases import BaseLLMWrapper +from graphgen.bases.datatypes import Token + + +class TGIWrapper(BaseLLMWrapper): + """ + Async inference backend based on TGI (Text-Generation-Inference) + """ + + def __init__( + self, + model_url: str, # e.g. "http://localhost:8080" + temperature: float = 0.0, + top_p: float = 1.0, + topk: int = 5, + **kwargs: Any + ): + super().__init__(temperature=temperature, top_p=top_p, **kwargs) + self.client = InferenceClient(model=model_url, token=False) + self.topk = topk + self.model_url = model_url + + @staticmethod + def _messages_to_str(prompt: str, history: Optional[List[str]] = None) -> str: + if not history: + return prompt + return "\n".join(history) + "\n" + prompt + + async def generate_answer( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> str: + text = self._messages_to_str(text, history) + out = await self.client.text_generation( + text, + max_new_tokens=extra.get("max_new_tokens", 512), + temperature=self.temperature or None, + top_p=self.top_p if self.top_p < 1.0 else None, + stop_sequences=extra.get("stop", None), + details=False, + ) + return out + + async def generate_topk_per_token( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + text = self._messages_to_str(text, history) + out = await self.client.text_generation( + text, + max_new_tokens=1, + temperature=0, + details=True, + decoder_input_details=True, + ) + # TGI 返回的 tokens[0].logprob.topk 字段 + topk = out.details.tokens[0].logprob.topk + return [Token(t.token, math.exp(t.logprob)) for t in topk[: self.topk]] + + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + """ + TGI does not provide a direct interface for "conditional probability of each input token", + here we approximate it with "input prefix + next token". + To implement it strictly, you can use /generate_stream and truncate it bit by bit. + """ + text = self._messages_to_str(text, history) + ids = self.client.tokenizer.encode(text) + tokens: List[Token] = [] + for i in range(1, len(ids) + 1): + prefix_ids = ids[:i] + prefix = self.client.tokenizer.decode(prefix_ids) + out = await self.client.text_generation( + prefix, + max_new_tokens=1, + temperature=0, + details=True, + decoder_input_details=True, + ) + t = out.details.tokens[0] + tokens.append(Token(t.token, math.exp(t.logprob))) + return tokens From 846b924c7096e4aea546e341a17de5ba18108b98 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 27 Oct 2025 17:21:01 +0800 Subject: [PATCH 08/27] feat: add azure_client, http_client, ollama_client --- graphgen/models/llm/api/azure_client.py | 165 +++++++++++++++++++++ graphgen/models/llm/api/http_client.py | 178 +++++++++++++++++++++++ graphgen/models/llm/api/ollama_client.py | 164 ++++++++++++++++++++- 3 files changed, 499 insertions(+), 8 deletions(-) create mode 100644 graphgen/models/llm/api/http_client.py diff --git a/graphgen/models/llm/api/azure_client.py b/graphgen/models/llm/api/azure_client.py index e69de29b..bbc7edc9 100644 --- a/graphgen/models/llm/api/azure_client.py +++ b/graphgen/models/llm/api/azure_client.py @@ -0,0 +1,165 @@ +import math +from typing import Any, Dict, List, Optional + +import openai +from openai import APIConnectionError, APITimeoutError, AsyncAzureOpenAI, RateLimitError +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper +from graphgen.bases.datatypes import Token +from graphgen.models.llm.limitter import RPM, TPM + + +def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]: + token_logprobs = response.choices[0].logprobs.content + tokens = [] + for token_prob in token_logprobs: + prob = math.exp(token_prob.logprob) + candidate_tokens = [ + Token(t.token, math.exp(t.logprob)) for t in token_prob.top_logprobs + ] + token = Token(token_prob.token, prob, top_candidates=candidate_tokens) + tokens.append(token) + return tokens + + +class AzureClient(BaseLLMWrapper): + """ + 直接复用 openai.AsyncAzureOpenAI,参数与 OpenAIClient 几乎一致, + 仅把 endpoint、api_version、deployment_name 换掉即可。 + """ + + def __init__( + self, + *, + model_name: str, # 对应 azure 的 deployment_name + azure_endpoint: str, + api_key: str, + api_version: str = "2024-02-15-preview", + json_mode: bool = False, + seed: Optional[int] = None, + topk_per_token: int = 5, + request_limit: bool = False, + rpm: Optional[RPM] = None, + tpm: Optional[TPM] = None, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.model_name = model_name + self.azure_endpoint = azure_endpoint + self.api_key = api_key + self.api_version = api_version + self.json_mode = json_mode + self.seed = seed + self.topk_per_token = topk_per_token + self.request_limit = request_limit + self.rpm = rpm or RPM() + self.tpm = tpm or TPM() + + self.token_usage: List[Dict[str, int]] = [] + self.__post_init__() + + def __post_init__(self): + self.client = AsyncAzureOpenAI( + azure_endpoint=self.azure_endpoint, + api_key=self.api_key, + api_version=self.api_version, + ) + + # _pre_generate 与 OpenAIClient 完全一致,直接抄 + def _pre_generate(self, text: str, history: List[str]) -> Dict[str, Any]: + kwargs = { + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens, + } + if self.seed: + kwargs["seed"] = self.seed + if self.json_mode: + kwargs["response_format"] = {"type": "json_object"} + + messages = [] + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + messages.append({"role": "user", "content": text}) + + if history: + assert len(history) % 2 == 0 + messages = history + messages + + kwargs["messages"] = messages + return kwargs + + # ---------------- generate_answer ---------------- + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + (RateLimitError, APIConnectionError, APITimeoutError) + ), + ) + async def generate_answer( + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, + ) -> str: + kwargs = self._pre_generate(text, history or []) + + prompt_tokens = sum( + len(self.tokenizer.encode(m["content"])) for m in kwargs["messages"] + ) + est = prompt_tokens + kwargs["max_tokens"] + + if self.request_limit: + await self.rpm.wait(silent=True) + await self.tpm.wait(est, silent=True) + + completion = await self.client.chat.completions.create( + model=self.model_name, **kwargs + ) + if hasattr(completion, "usage"): + self.token_usage.append( + { + "prompt_tokens": completion.usage.prompt_tokens, + "completion_tokens": completion.usage.completion_tokens, + "total_tokens": completion.usage.total_tokens, + } + ) + return self.filter_think_tags(completion.choices[0].message.content) + + # ---------------- generate_topk_per_token ---------------- + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + (RateLimitError, APIConnectionError, APITimeoutError) + ), + ) + async def generate_topk_per_token( + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, + ) -> List[Token]: + kwargs = self._pre_generate(text, history or []) + if self.topk_per_token > 0: + kwargs["logprobs"] = True + kwargs["top_logprobs"] = self.topk_per_token + kwargs["max_tokens"] = 5 + + completion = await self.client.chat.completions.create( + model=self.model_name, **kwargs + ) + return get_top_response_tokens(completion) + + # ---------------- generate_inputs_prob ---------------- + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + raise NotImplementedError diff --git a/graphgen/models/llm/api/http_client.py b/graphgen/models/llm/api/http_client.py new file mode 100644 index 00000000..469f20ba --- /dev/null +++ b/graphgen/models/llm/api/http_client.py @@ -0,0 +1,178 @@ +import asyncio +import math +from typing import Any, Dict, List, Optional + +import aiohttp +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper +from graphgen.bases.datatypes import Token +from graphgen.models.llm.limitter import RPM, TPM + + +class HTTPClient(BaseLLMWrapper): + """ + 最简“通用”实现:远端只要兼容 OpenAI 的 chat/completions 格式即可。 + 用 aiohttp 自行封装 retry、token 计数。 + """ + + def __init__( + self, + *, + model_name: str, + base_url: str, + api_key: Optional[str] = None, + json_mode: bool = False, + seed: Optional[int] = None, + topk_per_token: int = 5, + request_limit: bool = False, + rpm: Optional[RPM] = None, + tpm: Optional[TPM] = None, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.model_name = model_name + self.base_url = base_url.rstrip("/") + self.api_key = api_key + self.json_mode = json_mode + self.seed = seed + self.topk_per_token = topk_per_token + self.request_limit = request_limit + self.rpm = rpm or RPM() + self.tpm = tpm or TPM() + + self.token_usage: List[Dict[str, int]] = [] + self._session: Optional[aiohttp.ClientSession] = None + + def __post_init__(self): + pass + + @property + def session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + headers = {} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + self._session = aiohttp.ClientSession(headers=headers) + return self._session + + async def close(self): + if self._session and not self._session.closed: + await self._session.close() + + # ---------------- 内部 ---------------- + def _build_body(self, text: str, history: List[str]) -> Dict[str, Any]: + messages = [] + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + if history: + assert len(history) % 2 == 0 + for i in range(0, len(history), 2): + messages.append({"role": "user", "content": history[i]}) + messages.append({"role": "assistant", "content": history[i + 1]}) + messages.append({"role": "user", "content": text}) + + body = { + "model": self.model_name, + "messages": messages, + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens, + } + if self.seed: + body["seed"] = self.seed + if self.json_mode: + body["response_format"] = {"type": "json_object"} + return body + + # ---------------- generate_answer ---------------- + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)), + ) + async def generate_answer( + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, + ) -> str: + body = self._build_body(text, history or []) + prompt_tokens = sum( + len(self.tokenizer.encode(m["content"])) for m in body["messages"] + ) + est = prompt_tokens + body["max_tokens"] + + if self.request_limit: + await self.rpm.wait(silent=True) + await self.tpm.wait(est, silent=True) + + async with self.session.post( + f"{self.base_url}/v1/chat/completions", + json=body, + timeout=aiohttp.ClientTimeout(total=60), + ) as resp: + resp.raise_for_status() + data = await resp.json() + + msg = data["choices"][0]["message"]["content"] + if "usage" in data: + self.token_usage.append( + { + "prompt_tokens": data["usage"]["prompt_tokens"], + "completion_tokens": data["usage"]["completion_tokens"], + "total_tokens": data["usage"]["total_tokens"], + } + ) + return self.filter_think_tags(msg) + + # ---------------- generate_topk_per_token ---------------- + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)), + ) + async def generate_topk_per_token( + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, + ) -> List[Token]: + body = self._build_body(text, history or []) + body["max_tokens"] = 5 + if self.topk_per_token > 0: + body["logprobs"] = True + body["top_logprobs"] = self.topk_per_token + + async with self.session.post( + f"{self.base_url}/v1/chat/completions", + json=body, + timeout=aiohttp.ClientTimeout(total=60), + ) as resp: + resp.raise_for_status() + data = await resp.json() + + # 与 openai 格式一致 + token_logprobs = data["choices"][0]["logprobs"]["content"] + tokens = [] + for item in token_logprobs: + candidates = [ + Token(t["token"], math.exp(t["logprob"])) for t in item["top_logprobs"] + ] + tokens.append( + Token( + item["token"], math.exp(item["logprob"]), top_candidates=candidates + ) + ) + return tokens + + # ---------------- generate_inputs_prob ---------------- + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + raise NotImplementedError diff --git a/graphgen/models/llm/api/ollama_client.py b/graphgen/models/llm/api/ollama_client.py index 738eac49..c837191e 100644 --- a/graphgen/models/llm/api/ollama_client.py +++ b/graphgen/models/llm/api/ollama_client.py @@ -1,21 +1,169 @@ -# TODO: implement ollama client -from typing import Any, List, Optional +import asyncio +import math +from typing import Any, Dict, List, Optional -from graphgen.bases import BaseLLMWrapper, Token +import aiohttp +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper +from graphgen.bases.datatypes import Token +from graphgen.models.llm.limitter import RPM, TPM class OllamaClient(BaseLLMWrapper): + """ + 要求本地/远端启动 ollama server(默认 11434 端口)。 + ollama 的 /api/chat 在 0.1.24+ 支持 stream=False + raw=true 时返回 logprobs, + 但 top_logprobs 字段目前官方未实现,因此 generate_topk_per_token 只能降级到 + 取单个 token 的 logprob;若未来官方支持再补全。 + """ + + def __init__( + self, + *, + model_name: str = "llama3.1", + base_url: str = "http://localhost:11434", + json_mode: bool = False, + seed: Optional[int] = None, + topk_per_token: int = 5, + request_limit: bool = False, + rpm: Optional[RPM] = None, + tpm: Optional[TPM] = None, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.model_name = model_name + self.base_url = base_url.rstrip("/") + self.json_mode = json_mode + self.seed = seed + self.topk_per_token = topk_per_token + self.request_limit = request_limit + self.rpm = rpm or RPM() + self.tpm = tpm or TPM() + + self.token_usage: List[Dict[str, int]] = [] + self._session: Optional[aiohttp.ClientSession] = None + + def __post_init__(self): + # 基类若未调,可手动触发 + pass + + @property + def session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession() + return self._session + + async def close(self): + if self._session and not self._session.closed: + await self._session.close() + + # ---------------- 内部构造 ---------------- + def _build_payload(self, text: str, history: List[str]) -> Dict[str, Any]: + messages = [] + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + if history: + assert len(history) % 2 == 0 + for i in range(0, len(history), 2): + messages.append({"role": "user", "content": history[i]}) + messages.append({"role": "assistant", "content": history[i + 1]}) + messages.append({"role": "user", "content": text}) + + payload = { + "model": self.model_name, + "messages": messages, + "stream": False, + "options": { + "temperature": self.temperature, + "top_p": self.top_p, + "num_predict": self.max_tokens, + }, + } + if self.seed is not None: + payload["options"]["seed"] = self.seed + if self.json_mode: + payload["format"] = "json" + if self.topk_per_token > 0: + # ollama 0.1.24+ 支持 logprobs=true,但 top_logprobs 字段暂无 + payload["options"]["logprobs"] = True + return payload + + # ---------------- generate_answer ---------------- + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)), + ) async def generate_answer( - self, text: str, history: Optional[List[str]] = None, **extra: Any + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, ) -> str: - pass + payload = self._build_payload(text, history or []) + # 简易 token 估算 + prompt_tokens = sum( + len(self.tokenizer.encode(m["content"])) for m in payload["messages"] + ) + est = prompt_tokens + self.max_tokens + + if self.request_limit: + await self.rpm.wait(silent=True) + await self.tpm.wait(est, silent=True) + + async with self.session.post( + f"{self.base_url}/api/chat", + json=payload, + timeout=aiohttp.ClientTimeout(total=60), + ) as resp: + resp.raise_for_status() + data = await resp.json() + + # ollama 返回 {"message":{"content":"..."}, "prompt_eval_count":xx, "eval_count":yy} + content = data["message"]["content"] + self.token_usage.append( + { + "prompt_tokens": data.get("prompt_eval_count", 0), + "completion_tokens": data.get("eval_count", 0), + "total_tokens": data.get("prompt_eval_count", 0) + + data.get("eval_count", 0), + } + ) + return self.filter_think_tags(content) + # ---------------- generate_topk_per_token ---------------- async def generate_topk_per_token( - self, text: str, history: Optional[List[str]] = None, **extra: Any + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, ) -> List[Token]: - pass + # ollama 目前无 top_logprobs,只能拿到每个 token 的 logprob + payload = self._build_payload(text, history or []) + payload["options"]["num_predict"] = 5 # 限制长度 + async with self.session.post( + f"{self.base_url}/api/chat", + json=payload, + timeout=aiohttp.ClientTimeout(total=60), + ) as resp: + resp.raise_for_status() + data = await resp.json() + # ollama 返回 logprobs 在 ["message"]["logprobs"]["content"] 列表 + # 每项 {"token":str, "logprob":float} + tokens = [] + for item in data.get("message", {}).get("logprobs", {}).get("content", []): + tokens.append(Token(item["token"], math.exp(item["logprob"]))) + return tokens + + # ---------------- generate_inputs_prob ---------------- async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: - pass + raise NotImplementedError From f6bdaf6e53814c98e1ae42898e9dd0fcf049a8ed Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 27 Oct 2025 19:59:12 +0800 Subject: [PATCH 09/27] delete azure_client --- graphgen/models/llm/api/azure_client.py | 165 ------------------------ graphgen/models/llm/api/http_client.py | 21 ++- 2 files changed, 15 insertions(+), 171 deletions(-) delete mode 100644 graphgen/models/llm/api/azure_client.py diff --git a/graphgen/models/llm/api/azure_client.py b/graphgen/models/llm/api/azure_client.py deleted file mode 100644 index bbc7edc9..00000000 --- a/graphgen/models/llm/api/azure_client.py +++ /dev/null @@ -1,165 +0,0 @@ -import math -from typing import Any, Dict, List, Optional - -import openai -from openai import APIConnectionError, APITimeoutError, AsyncAzureOpenAI, RateLimitError -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - -from graphgen.bases.base_llm_wrapper import BaseLLMWrapper -from graphgen.bases.datatypes import Token -from graphgen.models.llm.limitter import RPM, TPM - - -def get_top_response_tokens(response: openai.ChatCompletion) -> List[Token]: - token_logprobs = response.choices[0].logprobs.content - tokens = [] - for token_prob in token_logprobs: - prob = math.exp(token_prob.logprob) - candidate_tokens = [ - Token(t.token, math.exp(t.logprob)) for t in token_prob.top_logprobs - ] - token = Token(token_prob.token, prob, top_candidates=candidate_tokens) - tokens.append(token) - return tokens - - -class AzureClient(BaseLLMWrapper): - """ - 直接复用 openai.AsyncAzureOpenAI,参数与 OpenAIClient 几乎一致, - 仅把 endpoint、api_version、deployment_name 换掉即可。 - """ - - def __init__( - self, - *, - model_name: str, # 对应 azure 的 deployment_name - azure_endpoint: str, - api_key: str, - api_version: str = "2024-02-15-preview", - json_mode: bool = False, - seed: Optional[int] = None, - topk_per_token: int = 5, - request_limit: bool = False, - rpm: Optional[RPM] = None, - tpm: Optional[TPM] = None, - **kwargs: Any, - ): - super().__init__(**kwargs) - self.model_name = model_name - self.azure_endpoint = azure_endpoint - self.api_key = api_key - self.api_version = api_version - self.json_mode = json_mode - self.seed = seed - self.topk_per_token = topk_per_token - self.request_limit = request_limit - self.rpm = rpm or RPM() - self.tpm = tpm or TPM() - - self.token_usage: List[Dict[str, int]] = [] - self.__post_init__() - - def __post_init__(self): - self.client = AsyncAzureOpenAI( - azure_endpoint=self.azure_endpoint, - api_key=self.api_key, - api_version=self.api_version, - ) - - # _pre_generate 与 OpenAIClient 完全一致,直接抄 - def _pre_generate(self, text: str, history: List[str]) -> Dict[str, Any]: - kwargs = { - "temperature": self.temperature, - "top_p": self.top_p, - "max_tokens": self.max_tokens, - } - if self.seed: - kwargs["seed"] = self.seed - if self.json_mode: - kwargs["response_format"] = {"type": "json_object"} - - messages = [] - if self.system_prompt: - messages.append({"role": "system", "content": self.system_prompt}) - messages.append({"role": "user", "content": text}) - - if history: - assert len(history) % 2 == 0 - messages = history + messages - - kwargs["messages"] = messages - return kwargs - - # ---------------- generate_answer ---------------- - @retry( - stop=stop_after_attempt(5), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - (RateLimitError, APIConnectionError, APITimeoutError) - ), - ) - async def generate_answer( - self, - text: str, - history: Optional[List[str]] = None, - **extra: Any, - ) -> str: - kwargs = self._pre_generate(text, history or []) - - prompt_tokens = sum( - len(self.tokenizer.encode(m["content"])) for m in kwargs["messages"] - ) - est = prompt_tokens + kwargs["max_tokens"] - - if self.request_limit: - await self.rpm.wait(silent=True) - await self.tpm.wait(est, silent=True) - - completion = await self.client.chat.completions.create( - model=self.model_name, **kwargs - ) - if hasattr(completion, "usage"): - self.token_usage.append( - { - "prompt_tokens": completion.usage.prompt_tokens, - "completion_tokens": completion.usage.completion_tokens, - "total_tokens": completion.usage.total_tokens, - } - ) - return self.filter_think_tags(completion.choices[0].message.content) - - # ---------------- generate_topk_per_token ---------------- - @retry( - stop=stop_after_attempt(5), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - (RateLimitError, APIConnectionError, APITimeoutError) - ), - ) - async def generate_topk_per_token( - self, - text: str, - history: Optional[List[str]] = None, - **extra: Any, - ) -> List[Token]: - kwargs = self._pre_generate(text, history or []) - if self.topk_per_token > 0: - kwargs["logprobs"] = True - kwargs["top_logprobs"] = self.topk_per_token - kwargs["max_tokens"] = 5 - - completion = await self.client.chat.completions.create( - model=self.model_name, **kwargs - ) - return get_top_response_tokens(completion) - - # ---------------- generate_inputs_prob ---------------- - async def generate_inputs_prob( - self, text: str, history: Optional[List[str]] = None, **extra: Any - ) -> List[Token]: - raise NotImplementedError diff --git a/graphgen/models/llm/api/http_client.py b/graphgen/models/llm/api/http_client.py index 469f20ba..606e5e45 100644 --- a/graphgen/models/llm/api/http_client.py +++ b/graphgen/models/llm/api/http_client.py @@ -17,8 +17,21 @@ class HTTPClient(BaseLLMWrapper): """ - 最简“通用”实现:远端只要兼容 OpenAI 的 chat/completions 格式即可。 - 用 aiohttp 自行封装 retry、token 计数。 + A generic async HTTP client for LLMs compatible with OpenAI's chat/completions format. + It uses aiohttp for making requests and includes retry logic and token usage tracking. + Usage example: + client = HTTPClient( + model_name="gpt-4o-mini", + base_url="http://localhost:8080", + api_key="your_api_key", + json_mode=True, + seed=42, + topk_per_token=5, + request_limit=True, + ) + + answer = await client.generate_answer("Hello, world!") + tokens = await client.generate_topk_per_token("Hello, world!") """ def __init__( @@ -65,7 +78,6 @@ async def close(self): if self._session and not self._session.closed: await self._session.close() - # ---------------- 内部 ---------------- def _build_body(self, text: str, history: List[str]) -> Dict[str, Any]: messages = [] if self.system_prompt: @@ -131,7 +143,6 @@ async def generate_answer( ) return self.filter_think_tags(msg) - # ---------------- generate_topk_per_token ---------------- @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -157,7 +168,6 @@ async def generate_topk_per_token( resp.raise_for_status() data = await resp.json() - # 与 openai 格式一致 token_logprobs = data["choices"][0]["logprobs"]["content"] tokens = [] for item in token_logprobs: @@ -171,7 +181,6 @@ async def generate_topk_per_token( ) return tokens - # ---------------- generate_inputs_prob ---------------- async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: From ee2d35eaccbb88e8cce0e377cff74768134a40f2 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 28 Oct 2025 11:37:40 +0800 Subject: [PATCH 10/27] tests: add http_client test --- graphgen/models/llm/api/http_client.py | 8 +- requirements.txt | 2 +- .../models/llm/api/test_http_client.py | 143 ++++++++++++++++++ 3 files changed, 147 insertions(+), 6 deletions(-) create mode 100644 tests/integration_tests/models/llm/api/test_http_client.py diff --git a/graphgen/models/llm/api/http_client.py b/graphgen/models/llm/api/http_client.py index 606e5e45..ed1991f4 100644 --- a/graphgen/models/llm/api/http_client.py +++ b/graphgen/models/llm/api/http_client.py @@ -62,9 +62,6 @@ def __init__( self.token_usage: List[Dict[str, int]] = [] self._session: Optional[aiohttp.ClientSession] = None - def __post_init__(self): - pass - @property def session(self) -> aiohttp.ClientSession: if self._session is None or self._session.closed: @@ -102,7 +99,6 @@ def _build_body(self, text: str, history: List[str]) -> Dict[str, Any]: body["response_format"] = {"type": "json_object"} return body - # ---------------- generate_answer ---------------- @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -184,4 +180,6 @@ async def generate_topk_per_token( async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: - raise NotImplementedError + raise NotImplementedError( + "generate_inputs_prob is not implemented in HTTPClient" + ) diff --git a/requirements.txt b/requirements.txt index 44b3687a..82740f03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,4 +25,4 @@ igraph python-louvain # For visualization -matplotlib \ No newline at end of file +matplotlib diff --git a/tests/integration_tests/models/llm/api/test_http_client.py b/tests/integration_tests/models/llm/api/test_http_client.py new file mode 100644 index 00000000..4949830d --- /dev/null +++ b/tests/integration_tests/models/llm/api/test_http_client.py @@ -0,0 +1,143 @@ +# pylint: disable=protected-access +import math + +import pytest + +from graphgen.models.llm.api.http_client import HTTPClient + + +class DummyTokenizer: + def encode(self, text: str): + # simple tokenization: split on spaces + return text.split() + + +class _MockResponse: + def __init__(self, data): + self._data = data + + def raise_for_status(self): + return None + + async def json(self): + return self._data + + +class _PostCtx: + def __init__(self, data): + self._resp = _MockResponse(data) + + async def __aenter__(self): + return self._resp + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class MockSession: + def __init__(self, data): + self._data = data + self.closed = False + + def post(self, *args, **kwargs): + return _PostCtx(self._data) + + async def close(self): + self.closed = True + + +class DummyLimiter: + def __init__(self): + self.calls = [] + + async def wait(self, *args, **kwargs): + self.calls.append((args, kwargs)) + + +@pytest.mark.asyncio +async def test_generate_answer_records_usage_and_uses_limiters(): + # arrange + data = { + "choices": [{"message": {"content": "Hello world!"}}], + "usage": {"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5}, + } + client = HTTPClient(model_name="m", base_url="http://test") + client._session = MockSession(data) + client.tokenizer = DummyTokenizer() + client.system_prompt = "sys" + client.temperature = 0.0 + client.top_p = 1.0 + client.max_tokens = 10 + client.filter_think_tags = lambda s: s.replace("", "").replace( + "", "" + ) + rpm = DummyLimiter() + tpm = DummyLimiter() + client.rpm = rpm + client.tpm = tpm + client.request_limit = True + + # act + out = await client.generate_answer("hi", history=["u1", "a1"]) + + # assert + assert out == "Hello world!" + assert client.token_usage[-1] == { + "prompt_tokens": 3, + "completion_tokens": 2, + "total_tokens": 5, + } + assert len(rpm.calls) == 1 + assert len(tpm.calls) == 1 + + +@pytest.mark.asyncio +async def test_generate_topk_per_token_parses_logprobs(): + # arrange + # create two token items with top_logprobs + data = { + "choices": [ + { + "logprobs": { + "content": [ + { + "token": "A", + "logprob": math.log(0.6), + "top_logprobs": [ + {"token": "A", "logprob": math.log(0.6)}, + {"token": "B", "logprob": math.log(0.4)}, + ], + }, + { + "token": "B", + "logprob": math.log(0.2), + "top_logprobs": [ + {"token": "B", "logprob": math.log(0.2)}, + {"token": "C", "logprob": math.log(0.8)}, + ], + }, + ] + } + } + ] + } + client = HTTPClient(model_name="m", base_url="http://test") + client._session = MockSession(data) + client.tokenizer = DummyTokenizer() + client.system_prompt = None + client.temperature = 0.0 + client.top_p = 1.0 + client.max_tokens = 10 + client.topk_per_token = 2 + + # act + tokens = await client.generate_topk_per_token("hi", history=[]) + + # assert + assert len(tokens) == 2 + # check probabilities and top_candidates + assert abs(tokens[0].prob - 0.6) < 1e-9 + assert abs(tokens[1].prob - 0.2) < 1e-9 + assert len(tokens[0].top_candidates) == 2 + assert tokens[0].top_candidates[0].text == "A" + assert tokens[0].top_candidates[1].text == "B" From d02e5a29dfa52eae17b472bf4c8917681a8bf0fc Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 28 Oct 2025 14:34:52 +0800 Subject: [PATCH 11/27] docs: update .env example --- .env.example | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.env.example b/.env.example index c1102c1c..fd3634a1 100644 --- a/.env.example +++ b/.env.example @@ -1,7 +1,15 @@ +# Tokenizer TOKENIZER_MODEL= + +# LLM +# Support different backends: http_api, openai_api, ollama_api, ollama, deepspeed, huggingface, tgi, sglang, tensorrt + +# http_api / openai_api +SYNTHESIZER_BACKEND=openai_api SYNTHESIZER_MODEL= SYNTHESIZER_BASE_URL= SYNTHESIZER_API_KEY= +TRAINEE_BACKEND=openai_api TRAINEE_MODEL= TRAINEE_BASE_URL= TRAINEE_API_KEY= From 614283fce5dbfe5df42394372c2a029c656197d3 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 28 Oct 2025 14:38:49 +0800 Subject: [PATCH 12/27] feat: switch llm backend(http_api) --- graphgen/graphgen.py | 18 ++---- graphgen/models/__init__.py | 3 +- graphgen/models/llm/__init__.py | 3 + graphgen/models/llm/api/http_client.py | 36 +++++++---- graphgen/models/llm/api/ollama_client.py | 6 -- graphgen/operators/__init__.py | 1 + graphgen/operators/init/__init__.py | 1 + graphgen/operators/init/init_llm.py | 59 +++++++++++++++++++ .../models/llm/api/test_http_client.py | 4 +- 9 files changed, 96 insertions(+), 35 deletions(-) create mode 100644 graphgen/operators/init/__init__.py create mode 100644 graphgen/operators/init/init_llm.py diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index dc1cd46d..635272e5 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -19,6 +19,7 @@ build_text_kg, chunk_documents, generate_qas, + init_llm, judge_statement, partition_kg, quiz, @@ -48,21 +49,12 @@ def __init__( model_name=os.getenv("TOKENIZER_MODEL") ) - self.synthesizer_llm_client: OpenAIClient = ( - synthesizer_llm_client - or OpenAIClient( - model_name=os.getenv("SYNTHESIZER_MODEL"), - api_key=os.getenv("SYNTHESIZER_API_KEY"), - base_url=os.getenv("SYNTHESIZER_BASE_URL"), - tokenizer=self.tokenizer_instance, - ) + self.synthesizer_llm_client: OpenAIClient = synthesizer_llm_client or init_llm( + "synthesizer" ) - self.trainee_llm_client: OpenAIClient = trainee_llm_client or OpenAIClient( - model_name=os.getenv("TRAINEE_MODEL"), - api_key=os.getenv("TRAINEE_API_KEY"), - base_url=os.getenv("TRAINEE_BASE_URL"), - tokenizer=self.tokenizer_instance, + self.trainee_llm_client: OpenAIClient = trainee_llm_client or init_llm( + "trainee" ) self.full_docs_storage: JsonKVStorage = JsonKVStorage( diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index d074ea6a..08694166 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,5 +1,3 @@ -from graphgen.models.llm.api.openai_client import OpenAIClient - from .evaluator import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator from .generator import ( AggregatedGenerator, @@ -9,6 +7,7 @@ VQAGenerator, ) from .kg_builder import LightRAGKGBuilder, MMKGBuilder +from .llm import HTTPClient, OllamaClient, OpenAIClient from .partitioner import ( AnchorBFSPartitioner, BFSPartitioner, diff --git a/graphgen/models/llm/__init__.py b/graphgen/models/llm/__init__.py index e69de29b..68769477 100644 --- a/graphgen/models/llm/__init__.py +++ b/graphgen/models/llm/__init__.py @@ -0,0 +1,3 @@ +from .api.http_client import HTTPClient +from .api.ollama_client import OllamaClient +from .api.openai_client import OpenAIClient diff --git a/graphgen/models/llm/api/http_client.py b/graphgen/models/llm/api/http_client.py index ed1991f4..c49018a5 100644 --- a/graphgen/models/llm/api/http_client.py +++ b/graphgen/models/llm/api/http_client.py @@ -34,10 +34,18 @@ class HTTPClient(BaseLLMWrapper): tokens = await client.generate_topk_per_token("Hello, world!") """ + _instance: Optional["HTTPClient"] = None + _lock = asyncio.Lock() + + def __new__(cls, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + def __init__( self, *, - model_name: str, + model: str, base_url: str, api_key: Optional[str] = None, json_mode: bool = False, @@ -48,8 +56,12 @@ def __init__( tpm: Optional[TPM] = None, **kwargs: Any, ): + # Initialize only once in the singleton pattern + if getattr(self, "_initialized", False): + return + self._initialized: bool = True super().__init__(**kwargs) - self.model_name = model_name + self.model_name = model self.base_url = base_url.rstrip("/") self.api_key = api_key self.json_mode = json_mode @@ -65,9 +77,9 @@ def __init__( @property def session(self) -> aiohttp.ClientSession: if self._session is None or self._session.closed: - headers = {} - if self.api_key: - headers["Authorization"] = f"Bearer {self.api_key}" + headers = ( + {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} + ) self._session = aiohttp.ClientSession(headers=headers) return self._session @@ -79,11 +91,11 @@ def _build_body(self, text: str, history: List[str]) -> Dict[str, Any]: messages = [] if self.system_prompt: messages.append({"role": "system", "content": self.system_prompt}) - if history: - assert len(history) % 2 == 0 - for i in range(0, len(history), 2): - messages.append({"role": "user", "content": history[i]}) - messages.append({"role": "assistant", "content": history[i + 1]}) + + # chatml format: alternating user and assistant messages + if history and isinstance(history[0], dict): + messages.extend(history) + messages.append({"role": "user", "content": text}) body = { @@ -121,7 +133,7 @@ async def generate_answer( await self.tpm.wait(est, silent=True) async with self.session.post( - f"{self.base_url}/v1/chat/completions", + f"{self.base_url}/chat/completions", json=body, timeout=aiohttp.ClientTimeout(total=60), ) as resp: @@ -157,7 +169,7 @@ async def generate_topk_per_token( body["top_logprobs"] = self.topk_per_token async with self.session.post( - f"{self.base_url}/v1/chat/completions", + f"{self.base_url}/chat/completions", json=body, timeout=aiohttp.ClientTimeout(total=60), ) as resp: diff --git a/graphgen/models/llm/api/ollama_client.py b/graphgen/models/llm/api/ollama_client.py index c837191e..5e60e3a3 100644 --- a/graphgen/models/llm/api/ollama_client.py +++ b/graphgen/models/llm/api/ollama_client.py @@ -49,10 +49,6 @@ def __init__( self.token_usage: List[Dict[str, int]] = [] self._session: Optional[aiohttp.ClientSession] = None - def __post_init__(self): - # 基类若未调,可手动触发 - pass - @property def session(self) -> aiohttp.ClientSession: if self._session is None or self._session.closed: @@ -63,7 +59,6 @@ async def close(self): if self._session and not self._session.closed: await self._session.close() - # ---------------- 内部构造 ---------------- def _build_payload(self, text: str, history: List[str]) -> Dict[str, Any]: messages = [] if self.system_prompt: @@ -94,7 +89,6 @@ def _build_payload(self, text: str, history: List[str]) -> Dict[str, Any]: payload["options"]["logprobs"] = True return payload - # ---------------- generate_answer ---------------- @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10), diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 2ad37e63..3e8e7ba9 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,5 +1,6 @@ from .build_kg import build_mm_kg, build_text_kg from .generate import generate_qas +from .init import init_llm from .judge import judge_statement from .partition import partition_kg from .quiz import quiz diff --git a/graphgen/operators/init/__init__.py b/graphgen/operators/init/__init__.py new file mode 100644 index 00000000..ec604441 --- /dev/null +++ b/graphgen/operators/init/__init__.py @@ -0,0 +1 @@ +from .init_llm import init_llm diff --git a/graphgen/operators/init/init_llm.py b/graphgen/operators/init/init_llm.py new file mode 100644 index 00000000..7072eb17 --- /dev/null +++ b/graphgen/operators/init/init_llm.py @@ -0,0 +1,59 @@ +import os +from typing import Any, Dict + +from graphgen.bases import BaseLLMWrapper +from graphgen.models import HTTPClient, OpenAIClient, Tokenizer + + +class LLMFactory: + """ + A factory class to create LLM wrapper instances based on the specified backend. + Supported backends include: + - http_api: HTTPClient + - openai_api: OpenAIClient + - ollama_api: OllamaClient + - ollama: OllamaWrapper + - deepspeed: DeepSpeedWrapper + - huggingface: HuggingFaceWrapper + - tgi: TGIWrapper + - sglang: SGLangWrapper + - tensorrt: TensorRTWrapper + """ + + @staticmethod + def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper: + # add tokenizer + tokenizer: Tokenizer = Tokenizer( + os.environ.get("TOKENIZER_MODEL", "cl100k_base"), + ) + config["tokenizer"] = tokenizer + if backend == "http_api": + return HTTPClient(**config) + if backend == "openai_api": + return OpenAIClient(**config) + raise NotImplementedError(f"Backend {backend} is not implemented yet.") + + +def _load_env_group(prefix: str) -> Dict[str, Any]: + """ + Collect environment variables with the given prefix into a dictionary, + stripping the prefix from the keys. + """ + return { + k[len(prefix) :].lower(): v + for k, v in os.environ.items() + if k.startswith(prefix) + } + + +def init_llm(model_type: str) -> BaseLLMWrapper: + if model_type == "synthesizer": + prefix = "SYNTHESIZER_" + elif model_type == "trainee": + prefix = "TRAINEE_" + else: + raise NotImplementedError(f"Model type {model_type} is not implemented yet.") + config = _load_env_group(prefix) + backend = config.pop("backend") + llm_wrapper = LLMFactory.create_llm_wrapper(backend, config) + return llm_wrapper diff --git a/tests/integration_tests/models/llm/api/test_http_client.py b/tests/integration_tests/models/llm/api/test_http_client.py index 4949830d..d2996d1c 100644 --- a/tests/integration_tests/models/llm/api/test_http_client.py +++ b/tests/integration_tests/models/llm/api/test_http_client.py @@ -61,7 +61,7 @@ async def test_generate_answer_records_usage_and_uses_limiters(): "choices": [{"message": {"content": "Hello world!"}}], "usage": {"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5}, } - client = HTTPClient(model_name="m", base_url="http://test") + client = HTTPClient(model="m", base_url="http://test") client._session = MockSession(data) client.tokenizer = DummyTokenizer() client.system_prompt = "sys" @@ -121,7 +121,7 @@ async def test_generate_topk_per_token_parses_logprobs(): } ] } - client = HTTPClient(model_name="m", base_url="http://test") + client = HTTPClient(model="m", base_url="http://test") client._session = MockSession(data) client.tokenizer = DummyTokenizer() client.system_prompt = None From abc8dc22f58100e38573173f7629090050334ca0 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 28 Oct 2025 15:37:59 +0800 Subject: [PATCH 13/27] fix: fix ollama_client --- graphgen/models/llm/api/ollama_client.py | 30 ++++++++---------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/graphgen/models/llm/api/ollama_client.py b/graphgen/models/llm/api/ollama_client.py index 5e60e3a3..faec503e 100644 --- a/graphgen/models/llm/api/ollama_client.py +++ b/graphgen/models/llm/api/ollama_client.py @@ -17,10 +17,9 @@ class OllamaClient(BaseLLMWrapper): """ - 要求本地/远端启动 ollama server(默认 11434 端口)。 - ollama 的 /api/chat 在 0.1.24+ 支持 stream=False + raw=true 时返回 logprobs, - 但 top_logprobs 字段目前官方未实现,因此 generate_topk_per_token 只能降级到 - 取单个 token 的 logprob;若未来官方支持再补全。 + Requires a local or remote Ollama server to be running (default port 11434). + The /api/chat endpoint in Ollama 0.1.24+ supports stream=False + and raw=true to return logprobs, but the top_logprobs field is not yet implemented by the official API. """ def __init__( @@ -63,12 +62,10 @@ def _build_payload(self, text: str, history: List[str]) -> Dict[str, Any]: messages = [] if self.system_prompt: messages.append({"role": "system", "content": self.system_prompt}) - if history: - assert len(history) % 2 == 0 - for i in range(0, len(history), 2): - messages.append({"role": "user", "content": history[i]}) - messages.append({"role": "assistant", "content": history[i + 1]}) - messages.append({"role": "user", "content": text}) + + # chatml format: alternating user and assistant messages + if history and isinstance(history[0], dict): + messages.extend(history) payload = { "model": self.model_name, @@ -85,7 +82,6 @@ def _build_payload(self, text: str, history: List[str]) -> Dict[str, Any]: if self.json_mode: payload["format"] = "json" if self.topk_per_token > 0: - # ollama 0.1.24+ 支持 logprobs=true,但 top_logprobs 字段暂无 payload["options"]["logprobs"] = True return payload @@ -101,7 +97,6 @@ async def generate_answer( **extra: Any, ) -> str: payload = self._build_payload(text, history or []) - # 简易 token 估算 prompt_tokens = sum( len(self.tokenizer.encode(m["content"])) for m in payload["messages"] ) @@ -119,7 +114,7 @@ async def generate_answer( resp.raise_for_status() data = await resp.json() - # ollama 返回 {"message":{"content":"..."}, "prompt_eval_count":xx, "eval_count":yy} + # {"message":{"content":"..."}, "prompt_eval_count":xx, "eval_count":yy} content = data["message"]["content"] self.token_usage.append( { @@ -131,16 +126,14 @@ async def generate_answer( ) return self.filter_think_tags(content) - # ---------------- generate_topk_per_token ---------------- async def generate_topk_per_token( self, text: str, history: Optional[List[str]] = None, **extra: Any, ) -> List[Token]: - # ollama 目前无 top_logprobs,只能拿到每个 token 的 logprob payload = self._build_payload(text, history or []) - payload["options"]["num_predict"] = 5 # 限制长度 + payload["options"]["num_predict"] = 5 async with self.session.post( f"{self.base_url}/api/chat", json=payload, @@ -149,15 +142,12 @@ async def generate_topk_per_token( resp.raise_for_status() data = await resp.json() - # ollama 返回 logprobs 在 ["message"]["logprobs"]["content"] 列表 - # 每项 {"token":str, "logprob":float} tokens = [] for item in data.get("message", {}).get("logprobs", {}).get("content", []): tokens.append(Token(item["token"], math.exp(item["logprob"]))) return tokens - # ---------------- generate_inputs_prob ---------------- async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: - raise NotImplementedError + raise NotImplementedError("Ollama API does not support per-token logprobs yet.") From ebf9d1cf15925f1f4d4d01c23a0f6d6cf110407c Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 28 Oct 2025 15:55:03 +0800 Subject: [PATCH 14/27] wip: fix ollama_client --- .env.example | 4 ++-- graphgen/models/llm/api/ollama_client.py | 2 ++ graphgen/operators/init/init_llm.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.env.example b/.env.example index fd3634a1..7e321cd8 100644 --- a/.env.example +++ b/.env.example @@ -6,10 +6,10 @@ TOKENIZER_MODEL= # http_api / openai_api SYNTHESIZER_BACKEND=openai_api -SYNTHESIZER_MODEL= +SYNTHESIZER_MODEL=gpt-4o-mini SYNTHESIZER_BASE_URL= SYNTHESIZER_API_KEY= TRAINEE_BACKEND=openai_api -TRAINEE_MODEL= +TRAINEE_MODEL=gpt-4o-mini TRAINEE_BASE_URL= TRAINEE_API_KEY= diff --git a/graphgen/models/llm/api/ollama_client.py b/graphgen/models/llm/api/ollama_client.py index faec503e..6feacae2 100644 --- a/graphgen/models/llm/api/ollama_client.py +++ b/graphgen/models/llm/api/ollama_client.py @@ -67,6 +67,8 @@ def _build_payload(self, text: str, history: List[str]) -> Dict[str, Any]: if history and isinstance(history[0], dict): messages.extend(history) + messages.append({"role": "user", "content": text}) + payload = { "model": self.model_name, "messages": messages, diff --git a/graphgen/operators/init/init_llm.py b/graphgen/operators/init/init_llm.py index 7072eb17..c355426a 100644 --- a/graphgen/operators/init/init_llm.py +++ b/graphgen/operators/init/init_llm.py @@ -2,7 +2,7 @@ from typing import Any, Dict from graphgen.bases import BaseLLMWrapper -from graphgen.models import HTTPClient, OpenAIClient, Tokenizer +from graphgen.models import HTTPClient, OllamaClient, OpenAIClient, Tokenizer class LLMFactory: @@ -31,6 +31,8 @@ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper: return HTTPClient(**config) if backend == "openai_api": return OpenAIClient(**config) + if backend == "ollama_api": + return OllamaClient(**config) raise NotImplementedError(f"Backend {backend} is not implemented yet.") From fac999756f05735d72a415f44b6406c962bf59a0 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 28 Oct 2025 16:52:14 +0800 Subject: [PATCH 15/27] tests: add ollama_client tests --- graphgen/models/llm/api/ollama_client.py | 142 +++++++----------- .../models/llm/api/test_ollama_client.py | 112 ++++++++++++++ 2 files changed, 170 insertions(+), 84 deletions(-) create mode 100644 tests/integration_tests/models/llm/api/test_ollama_client.py diff --git a/graphgen/models/llm/api/ollama_client.py b/graphgen/models/llm/api/ollama_client.py index 6feacae2..6ab256d0 100644 --- a/graphgen/models/llm/api/ollama_client.py +++ b/graphgen/models/llm/api/ollama_client.py @@ -1,15 +1,6 @@ -import asyncio import math from typing import Any, Dict, List, Optional -import aiohttp -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - from graphgen.bases.base_llm_wrapper import BaseLLMWrapper from graphgen.bases.datatypes import Token from graphgen.models.llm.limitter import RPM, TPM @@ -25,7 +16,7 @@ class OllamaClient(BaseLLMWrapper): def __init__( self, *, - model_name: str = "llama3.1", + model: str = "gemma3", base_url: str = "http://localhost:11434", json_mode: bool = False, seed: Optional[int] = None, @@ -35,121 +26,104 @@ def __init__( tpm: Optional[TPM] = None, **kwargs: Any, ): + try: + import ollama + except ImportError as e: + raise ImportError( + "Ollama SDK is not installed." + "It is required to use OllamaClient." + "Please install it with `pip install ollama`." + ) from e super().__init__(**kwargs) - self.model_name = model_name - self.base_url = base_url.rstrip("/") + self.model_name = model + self.base_url = base_url self.json_mode = json_mode self.seed = seed self.topk_per_token = topk_per_token self.request_limit = request_limit self.rpm = rpm or RPM() self.tpm = tpm or TPM() - self.token_usage: List[Dict[str, int]] = [] - self._session: Optional[aiohttp.ClientSession] = None - @property - def session(self) -> aiohttp.ClientSession: - if self._session is None or self._session.closed: - self._session = aiohttp.ClientSession() - return self._session + self.client = ollama.AsyncClient(host=self.base_url) - async def close(self): - if self._session and not self._session.closed: - await self._session.close() - - def _build_payload(self, text: str, history: List[str]) -> Dict[str, Any]: + async def generate_answer( + self, + text: str, + history: Optional[List[Dict[str, str]]] = None, + **extra: Any, + ) -> str: messages = [] if self.system_prompt: messages.append({"role": "system", "content": self.system_prompt}) - - # chatml format: alternating user and assistant messages - if history and isinstance(history[0], dict): + if history: messages.extend(history) - messages.append({"role": "user", "content": text}) - payload = { - "model": self.model_name, - "messages": messages, - "stream": False, - "options": { - "temperature": self.temperature, - "top_p": self.top_p, - "num_predict": self.max_tokens, - }, + options = { + "temperature": self.temperature, + "top_p": self.top_p, + "num_predict": self.max_tokens, } if self.seed is not None: - payload["options"]["seed"] = self.seed - if self.json_mode: - payload["format"] = "json" - if self.topk_per_token > 0: - payload["options"]["logprobs"] = True - return payload + options["seed"] = self.seed - @retry( - stop=stop_after_attempt(5), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)), - ) - async def generate_answer( - self, - text: str, - history: Optional[List[str]] = None, - **extra: Any, - ) -> str: - payload = self._build_payload(text, history or []) - prompt_tokens = sum( - len(self.tokenizer.encode(m["content"])) for m in payload["messages"] - ) + prompt_tokens = sum(len(self.tokenizer.encode(m["content"])) for m in messages) est = prompt_tokens + self.max_tokens - if self.request_limit: await self.rpm.wait(silent=True) await self.tpm.wait(est, silent=True) - async with self.session.post( - f"{self.base_url}/api/chat", - json=payload, - timeout=aiohttp.ClientTimeout(total=60), - ) as resp: - resp.raise_for_status() - data = await resp.json() + response = await self.client.chat( + model=self.model_name, + messages=messages, + format="json" if self.json_mode else "", + options=options, + stream=False, + ) - # {"message":{"content":"..."}, "prompt_eval_count":xx, "eval_count":yy} - content = data["message"]["content"] + usage = response.get("prompt_eval_count", 0), response.get("eval_count", 0) self.token_usage.append( { - "prompt_tokens": data.get("prompt_eval_count", 0), - "completion_tokens": data.get("eval_count", 0), - "total_tokens": data.get("prompt_eval_count", 0) - + data.get("eval_count", 0), + "prompt_tokens": usage[0], + "completion_tokens": usage[1], + "total_tokens": sum(usage), } ) + content = response["message"]["content"] return self.filter_think_tags(content) async def generate_topk_per_token( self, text: str, - history: Optional[List[str]] = None, + history: Optional[List[Dict[str, str]]] = None, **extra: Any, ) -> List[Token]: - payload = self._build_payload(text, history or []) - payload["options"]["num_predict"] = 5 - async with self.session.post( - f"{self.base_url}/api/chat", - json=payload, - timeout=aiohttp.ClientTimeout(total=60), - ) as resp: - resp.raise_for_status() - data = await resp.json() + messages = [] + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + if history: + messages.extend(history) + messages.append({"role": "user", "content": text}) + + response = await self.client.chat( + model=self.model_name, + messages=messages, + options={ + "temperature": self.temperature, + "top_p": self.top_p, + "num_predict": 5, + "logprobs": True, + }, + stream=False, + ) tokens = [] - for item in data.get("message", {}).get("logprobs", {}).get("content", []): + for item in response.get("message", {}).get("logprobs", {}).get("content", []): tokens.append(Token(item["token"], math.exp(item["logprob"]))) return tokens async def generate_inputs_prob( - self, text: str, history: Optional[List[str]] = None, **extra: Any + self, text: str, history: Optional[List[Dict[str, str]]] = None, **extra: Any ) -> List[Token]: raise NotImplementedError("Ollama API does not support per-token logprobs yet.") diff --git a/tests/integration_tests/models/llm/api/test_ollama_client.py b/tests/integration_tests/models/llm/api/test_ollama_client.py new file mode 100644 index 00000000..c01667eb --- /dev/null +++ b/tests/integration_tests/models/llm/api/test_ollama_client.py @@ -0,0 +1,112 @@ +# pylint: disable=redefined-outer-name + +import math +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from graphgen.models import OllamaClient + + +# ----------------- fixture ----------------- +@pytest.fixture +def mock_ollama_pkg(): + """ + mock ollama + """ + ollama_mock = MagicMock() + ollama_mock.AsyncClient = AsyncMock + with patch.dict("sys.modules", {"ollama": ollama_mock}): + yield ollama_mock + + +@pytest.fixture +def ollama_client(mock_ollama_pkg) -> OllamaClient: + """ + Returns a default-configured OllamaClient with client.chat mocked + """ + cli = OllamaClient(model="gemma3", base_url="http://test:11434") + cli.tokenizer = MagicMock() + cli.tokenizer.encode = MagicMock(side_effect=lambda x: x.split()) + cli.client.chat = AsyncMock( + return_value={ + "message": {"content": "hi from ollama"}, + "prompt_eval_count": 10, + "eval_count": 5, + } + ) + return cli + + +@pytest.mark.asyncio +async def test_generate_answer_basic(ollama_client: OllamaClient): + ans = await ollama_client.generate_answer("hello") + assert ans == "hi from ollama" + ollama_client.client.chat.assert_awaited_once() + call = ollama_client.client.chat.call_args + assert call.kwargs["model"] == "gemma3" + assert call.kwargs["messages"][-1]["content"] == "hello" + assert call.kwargs["stream"] is False + + +@pytest.mark.asyncio +async def test_generate_answer_with_history(ollama_client: OllamaClient): + hist = [{"role": "user", "content": "prev"}] + await ollama_client.generate_answer("now", history=hist) + msgs = ollama_client.client.chat.call_args.kwargs["messages"] + assert msgs[-2]["content"] == "prev" + assert msgs[-1]["content"] == "now" + + +@pytest.mark.asyncio +async def test_token_usage_recorded(ollama_client: OllamaClient): + await ollama_client.generate_answer("test") + assert len(ollama_client.token_usage) == 1 + assert ollama_client.token_usage[0]["prompt_tokens"] == 10 + assert ollama_client.token_usage[0]["completion_tokens"] == 5 + assert ollama_client.token_usage[0]["total_tokens"] == 15 + + +@pytest.mark.asyncio +async def test_rpm_tpm_limiter_called(ollama_client: OllamaClient): + ollama_client.request_limit = True + with patch.object(ollama_client.rpm, "wait", AsyncMock()) as rpm_mock, patch.object( + ollama_client.tpm, "wait", AsyncMock() + ) as tpm_mock: + + await ollama_client.generate_answer("limited") + rpm_mock.assert_awaited_once_with(silent=True) + tpm_mock.assert_awaited_once_with( + ollama_client.max_tokens + len("limited".split()), silent=True + ) + + +@pytest.mark.asyncio +async def test_generate_topk_per_token(ollama_client: OllamaClient): + ollama_client.client.chat.return_value = { + "message": { + "logprobs": { + "content": [ + {"token": "hello", "logprob": -0.1}, + {"token": "world", "logprob": -0.2}, + ] + } + } + } + tokens = await ollama_client.generate_topk_per_token("test") + assert len(tokens) == 2 + assert tokens[0].text == "hello" + assert math.isclose(tokens[0].prob, math.exp(-0.1)) + assert tokens[1].text == "world" + + +def test_import_error_when_ollama_missing(): + with patch.dict("sys.modules", {"ollama": None}): + with pytest.raises(ImportError, match="Ollama SDK is not installed"): + OllamaClient() + + +@pytest.mark.asyncio +async def test_generate_inputs_prob_not_implemented(ollama_client: OllamaClient): + with pytest.raises(NotImplementedError): + await ollama_client.generate_inputs_prob("any") From f5a45940b1a558f1e19426f0edec1db52a34fae9 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 28 Oct 2025 17:09:53 +0800 Subject: [PATCH 16/27] fix: fix generate_topk_per_token in ollmam_client --- graphgen/models/llm/api/ollama_client.py | 28 ++---------------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/graphgen/models/llm/api/ollama_client.py b/graphgen/models/llm/api/ollama_client.py index 6ab256d0..9a4946a6 100644 --- a/graphgen/models/llm/api/ollama_client.py +++ b/graphgen/models/llm/api/ollama_client.py @@ -1,4 +1,3 @@ -import math from typing import Any, Dict, List, Optional from graphgen.bases.base_llm_wrapper import BaseLLMWrapper @@ -9,8 +8,7 @@ class OllamaClient(BaseLLMWrapper): """ Requires a local or remote Ollama server to be running (default port 11434). - The /api/chat endpoint in Ollama 0.1.24+ supports stream=False - and raw=true to return logprobs, but the top_logprobs field is not yet implemented by the official API. + The top_logprobs field is not yet implemented by the official API. """ def __init__( @@ -99,29 +97,7 @@ async def generate_topk_per_token( history: Optional[List[Dict[str, str]]] = None, **extra: Any, ) -> List[Token]: - messages = [] - if self.system_prompt: - messages.append({"role": "system", "content": self.system_prompt}) - if history: - messages.extend(history) - messages.append({"role": "user", "content": text}) - - response = await self.client.chat( - model=self.model_name, - messages=messages, - options={ - "temperature": self.temperature, - "top_p": self.top_p, - "num_predict": 5, - "logprobs": True, - }, - stream=False, - ) - - tokens = [] - for item in response.get("message", {}).get("logprobs", {}).get("content", []): - tokens.append(Token(item["token"], math.exp(item["logprob"]))) - return tokens + raise NotImplementedError("Ollama API does not support per-token top-k yet.") async def generate_inputs_prob( self, text: str, history: Optional[List[Dict[str, str]]] = None, **extra: Any From d4beb5288dbe292a97643e2d7efea46486377ba5 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 28 Oct 2025 17:11:31 +0800 Subject: [PATCH 17/27] fix: delete useless tests --- .../models/llm/api/test_ollama_client.py | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/tests/integration_tests/models/llm/api/test_ollama_client.py b/tests/integration_tests/models/llm/api/test_ollama_client.py index c01667eb..b20bc44c 100644 --- a/tests/integration_tests/models/llm/api/test_ollama_client.py +++ b/tests/integration_tests/models/llm/api/test_ollama_client.py @@ -1,6 +1,4 @@ # pylint: disable=redefined-outer-name - -import math from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -81,25 +79,6 @@ async def test_rpm_tpm_limiter_called(ollama_client: OllamaClient): ) -@pytest.mark.asyncio -async def test_generate_topk_per_token(ollama_client: OllamaClient): - ollama_client.client.chat.return_value = { - "message": { - "logprobs": { - "content": [ - {"token": "hello", "logprob": -0.1}, - {"token": "world", "logprob": -0.2}, - ] - } - } - } - tokens = await ollama_client.generate_topk_per_token("test") - assert len(tokens) == 2 - assert tokens[0].text == "hello" - assert math.isclose(tokens[0].prob, math.exp(-0.1)) - assert tokens[1].text == "world" - - def test_import_error_when_ollama_missing(): with patch.dict("sys.modules", {"ollama": None}): with pytest.raises(ImportError, match="Ollama SDK is not installed"): From c8055f1770ba7822cbfad74b044ed6454d338cfc Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 28 Oct 2025 17:44:00 +0800 Subject: [PATCH 18/27] fix: fix transformers warning not using GenerationConfig --- graphgen/models/llm/__init__.py | 1 + graphgen/models/llm/local/hf_wrapper.py | 80 ++++++++++++++++--------- graphgen/operators/init/init_llm.py | 12 +++- 3 files changed, 64 insertions(+), 29 deletions(-) diff --git a/graphgen/models/llm/__init__.py b/graphgen/models/llm/__init__.py index 68769477..c70395d5 100644 --- a/graphgen/models/llm/__init__.py +++ b/graphgen/models/llm/__init__.py @@ -1,3 +1,4 @@ from .api.http_client import HTTPClient from .api.ollama_client import OllamaClient from .api.openai_client import OpenAIClient +from .local.hf_wrapper import HuggingFaceWrapper diff --git a/graphgen/models/llm/local/hf_wrapper.py b/graphgen/models/llm/local/hf_wrapper.py index e2e6f582..74c143b0 100644 --- a/graphgen/models/llm/local/hf_wrapper.py +++ b/graphgen/models/llm/local/hf_wrapper.py @@ -1,8 +1,5 @@ from typing import Any, List, Optional -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - from graphgen.bases.base_llm_wrapper import BaseLLMWrapper from graphgen.bases.datatypes import Token @@ -14,24 +11,43 @@ class HuggingFaceWrapper(BaseLLMWrapper): def __init__( self, - model_path: str, + model: str, torch_dtype="auto", device_map="auto", trust_remote_code=True, temperature=0.0, top_p=1.0, topk=5, - **kwargs: Any + **kwargs: Any, ): super().__init__(temperature=temperature, top_p=top_p, **kwargs) + + try: + import torch + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, + ) + except ImportError as exc: + raise ImportError( + "HuggingFaceWrapper requires torch and transformers. " + "Install them with: pip install torch transformers" + ) from exc + + self.torch = torch + self.AutoTokenizer = AutoTokenizer + self.AutoModelForCausalLM = AutoModelForCausalLM + self.GenerationConfig = GenerationConfig + self.tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=trust_remote_code + model, trust_remote_code=trust_remote_code ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( - model_path, + model, torch_dtype=torch_dtype, device_map=device_map, trust_remote_code=trust_remote_code, @@ -42,27 +58,28 @@ def __init__( self.topk = topk @staticmethod - def _build_inputs(prompt: str, history: Optional[List[str]] = None): + def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: msgs = history or [] msgs.append(prompt) - full = "\n".join(msgs) - return full + return "\n".join(msgs) async def generate_answer( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> str: full = self._build_inputs(text, history) inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device) - max_new = 512 - with torch.no_grad(): - out = self.model.generate( - **inputs, - max_new_tokens=max_new, - temperature=self.temperature if self.temperature > 0 else 0.0, - top_p=self.top_p if self.temperature > 0 else 1.0, - do_sample=self.temperature > 0, - pad_token_id=self.tokenizer.eos_token_id, - ) + + gen_config = self.GenerationConfig( + max_new_tokens=extra.get("max_new_tokens", 512), + temperature=self.temperature if self.temperature > 0 else 1.0, + top_p=self.top_p, + do_sample=self.temperature > 0, # temperature==0 => greedy + pad_token_id=self.tokenizer.eos_token_id, + ) + + with self.torch.no_grad(): + out = self.model.generate(**inputs, generation_config=gen_config) + gen = out[0, inputs.input_ids.shape[-1] :] return self.tokenizer.decode(gen, skip_special_tokens=True) @@ -71,17 +88,21 @@ async def generate_topk_per_token( ) -> List[Token]: full = self._build_inputs(text, history) inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device) - with torch.no_grad(): + + with self.torch.no_grad(): out = self.model.generate( **inputs, max_new_tokens=1, - temperature=0, + temperature=0.0, return_dict_in_generate=True, output_scores=True, + pad_token_id=self.tokenizer.eos_token_id, ) - scores = out.scores[0][0] # vocab - probs = torch.softmax(scores, dim=-1) - top_probs, top_idx = torch.topk(probs, k=self.topk) + + scores = out.scores[0][0] # (vocab,) + probs = self.torch.softmax(scores, dim=-1) + top_probs, top_idx = self.torch.topk(probs, k=self.topk) + tokens = [] for p, idx in zip(top_probs.cpu().numpy(), top_idx.cpu().numpy()): tokens.append(Token(self.tokenizer.decode([idx]), float(p))) @@ -93,12 +114,15 @@ async def generate_inputs_prob( full = self._build_inputs(text, history) ids = self.tokenizer.encode(full) logprobs = [] + for i in range(1, len(ids) + 1): trunc = ids[: i - 1] + ids[i:] if i < len(ids) else ids[:-1] - inputs = torch.tensor([trunc]).to(self.model.device) - with torch.no_grad(): + inputs = self.torch.tensor([trunc]).to(self.model.device) + + with self.torch.no_grad(): logits = self.model(inputs).logits[0, -1, :] - probs = torch.softmax(logits, dim=-1) + probs = self.torch.softmax(logits, dim=-1) + true_id = ids[i - 1] logprobs.append( Token( diff --git a/graphgen/operators/init/init_llm.py b/graphgen/operators/init/init_llm.py index c355426a..0576f2f5 100644 --- a/graphgen/operators/init/init_llm.py +++ b/graphgen/operators/init/init_llm.py @@ -2,7 +2,7 @@ from typing import Any, Dict from graphgen.bases import BaseLLMWrapper -from graphgen.models import HTTPClient, OllamaClient, OpenAIClient, Tokenizer +from graphgen.models import Tokenizer class LLMFactory: @@ -28,11 +28,21 @@ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper: ) config["tokenizer"] = tokenizer if backend == "http_api": + from graphgen.models.llm.api.http_client import HTTPClient + return HTTPClient(**config) if backend == "openai_api": + from graphgen.models.llm.api.openai_client import OpenAIClient + return OpenAIClient(**config) if backend == "ollama_api": + from graphgen.models.llm.api.ollama_client import OllamaClient + return OllamaClient(**config) + if backend == "huggingface": + from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper + + return HuggingFaceWrapper(**config) raise NotImplementedError(f"Backend {backend} is not implemented yet.") From ae9b28b041aa4fbe87fdb862bf2015dbd972a203 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 28 Oct 2025 19:11:48 +0800 Subject: [PATCH 19/27] fix: fix _build_inputs in hf_wrapper --- graphgen/models/llm/local/hf_wrapper.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/graphgen/models/llm/local/hf_wrapper.py b/graphgen/models/llm/local/hf_wrapper.py index 74c143b0..c33a5b03 100644 --- a/graphgen/models/llm/local/hf_wrapper.py +++ b/graphgen/models/llm/local/hf_wrapper.py @@ -60,8 +60,16 @@ def __init__( @staticmethod def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: msgs = history or [] - msgs.append(prompt) - return "\n".join(msgs) + lines = [] + for m in msgs: + if isinstance(m, dict): + role = m.get("role", "") + content = m.get("content", "") + lines.append(f"{role}: {content}") + else: + lines.append(str(m)) + lines.append(prompt) + return "\n".join(lines) async def generate_answer( self, text: str, history: Optional[List[str]] = None, **extra: Any From 292d9868d9f46e3e01527a186418cbf987303b8b Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 28 Oct 2025 19:43:32 +0800 Subject: [PATCH 20/27] fix: fix gen_kwargs --- graphgen/models/llm/local/hf_wrapper.py | 26 +++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/graphgen/models/llm/local/hf_wrapper.py b/graphgen/models/llm/local/hf_wrapper.py index c33a5b03..b0538aad 100644 --- a/graphgen/models/llm/local/hf_wrapper.py +++ b/graphgen/models/llm/local/hf_wrapper.py @@ -31,8 +31,8 @@ def __init__( ) except ImportError as exc: raise ImportError( - "HuggingFaceWrapper requires torch and transformers. " - "Install them with: pip install torch transformers" + "HuggingFaceWrapper requires torch, transformers and accelerate. " + "Install them with: pip install torch transformers accelerate" ) from exc self.torch = torch @@ -77,13 +77,18 @@ async def generate_answer( full = self._build_inputs(text, history) inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device) - gen_config = self.GenerationConfig( - max_new_tokens=extra.get("max_new_tokens", 512), - temperature=self.temperature if self.temperature > 0 else 1.0, - top_p=self.top_p, - do_sample=self.temperature > 0, # temperature==0 => greedy - pad_token_id=self.tokenizer.eos_token_id, - ) + gen_kwargs = { + "max_new_tokens": extra.get("max_new_tokens", 512), + "do_sample": self.temperature > 0, + "temperature": self.temperature if self.temperature > 0 else 1.0, + "pad_token_id": self.tokenizer.eos_token_id, + } + + # Add top_p and top_k only if temperature > 0 + if self.temperature > 0: + gen_kwargs.update(top_p=self.top_p, top_k=self.topk) + + gen_config = self.GenerationConfig(**gen_kwargs) with self.torch.no_grad(): out = self.model.generate(**inputs, generation_config=gen_config) @@ -101,7 +106,8 @@ async def generate_topk_per_token( out = self.model.generate( **inputs, max_new_tokens=1, - temperature=0.0, + do_sample=False, + temperature=1.0, return_dict_in_generate=True, output_scores=True, pad_token_id=self.tokenizer.eos_token_id, From f562ee29798e9bc7f434b7d4feabe099d918c013 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 28 Oct 2025 20:00:24 +0800 Subject: [PATCH 21/27] chore: delete ds_wrapper --- graphgen/models/llm/local/ds_wrapper.py | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 graphgen/models/llm/local/ds_wrapper.py diff --git a/graphgen/models/llm/local/ds_wrapper.py b/graphgen/models/llm/local/ds_wrapper.py deleted file mode 100644 index 23db41ba..00000000 --- a/graphgen/models/llm/local/ds_wrapper.py +++ /dev/null @@ -1,19 +0,0 @@ -from .hf_wrapper import HuggingFaceWrapper - - -class DeepSpeedBackend(HuggingFaceWrapper): - """ - Inference backend based on DeepSpeed - """ - - def __init__(self, *args, ds_config=None, **kwargs): - super().__init__(*args, **kwargs) - try: - import deepspeed - except ImportError as exc: - raise ImportError( - "Please install deepspeed to use DeepSpeedBackend: pip install deepspeed" - ) from exc - ds_config = ds_config or {"train_batch_size": 1, "fp16": {"enabled": True}} - self.model, _, _, _ = deepspeed.initialize(model=self.model, config=ds_config) - self.model.module.eval() From 399ef452f165724c7fc1819f607bf10feefba668 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 28 Oct 2025 20:23:34 +0800 Subject: [PATCH 22/27] feat: add vllm_wrapper --- graphgen/models/llm/local/sglang_wrapper.py | 2 +- graphgen/models/llm/local/tgi_wrapper.py | 58 +------ graphgen/models/llm/local/vllm_wrapper.py | 153 ++++++++++++++++++ .../models/llm/local/test_ds_wrapper.py | 33 ---- 4 files changed, 158 insertions(+), 88 deletions(-) create mode 100644 graphgen/models/llm/local/vllm_wrapper.py delete mode 100644 tests/integration_tests/models/llm/local/test_ds_wrapper.py diff --git a/graphgen/models/llm/local/sglang_wrapper.py b/graphgen/models/llm/local/sglang_wrapper.py index a0d6b8d5..aa48cbfc 100644 --- a/graphgen/models/llm/local/sglang_wrapper.py +++ b/graphgen/models/llm/local/sglang_wrapper.py @@ -5,7 +5,7 @@ from graphgen.bases.datatypes import Token -class SGLangBackend(BaseLLMWrapper): +class SGLangWrapper(BaseLLMWrapper): """ Async inference backend based on SGLang """ diff --git a/graphgen/models/llm/local/tgi_wrapper.py b/graphgen/models/llm/local/tgi_wrapper.py index cfb15608..8dac68c0 100644 --- a/graphgen/models/llm/local/tgi_wrapper.py +++ b/graphgen/models/llm/local/tgi_wrapper.py @@ -1,8 +1,5 @@ -import math from typing import Any, List, Optional -from huggingface_hub import InferenceClient - from graphgen.bases import BaseLLMWrapper from graphgen.bases.datatypes import Token @@ -21,66 +18,19 @@ def __init__( **kwargs: Any ): super().__init__(temperature=temperature, top_p=top_p, **kwargs) - self.client = InferenceClient(model=model_url, token=False) - self.topk = topk - self.model_url = model_url - - @staticmethod - def _messages_to_str(prompt: str, history: Optional[List[str]] = None) -> str: - if not history: - return prompt - return "\n".join(history) + "\n" + prompt + # TODO: implement tgi wrapper async def generate_answer( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> str: - text = self._messages_to_str(text, history) - out = await self.client.text_generation( - text, - max_new_tokens=extra.get("max_new_tokens", 512), - temperature=self.temperature or None, - top_p=self.top_p if self.top_p < 1.0 else None, - stop_sequences=extra.get("stop", None), - details=False, - ) - return out + pass async def generate_topk_per_token( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: - text = self._messages_to_str(text, history) - out = await self.client.text_generation( - text, - max_new_tokens=1, - temperature=0, - details=True, - decoder_input_details=True, - ) - # TGI 返回的 tokens[0].logprob.topk 字段 - topk = out.details.tokens[0].logprob.topk - return [Token(t.token, math.exp(t.logprob)) for t in topk[: self.topk]] + pass async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: - """ - TGI does not provide a direct interface for "conditional probability of each input token", - here we approximate it with "input prefix + next token". - To implement it strictly, you can use /generate_stream and truncate it bit by bit. - """ - text = self._messages_to_str(text, history) - ids = self.client.tokenizer.encode(text) - tokens: List[Token] = [] - for i in range(1, len(ids) + 1): - prefix_ids = ids[:i] - prefix = self.client.tokenizer.decode(prefix_ids) - out = await self.client.text_generation( - prefix, - max_new_tokens=1, - temperature=0, - details=True, - decoder_input_details=True, - ) - t = out.details.tokens[0] - tokens.append(Token(t.token, math.exp(t.logprob))) - return tokens + pass diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py new file mode 100644 index 00000000..b20c07ac --- /dev/null +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -0,0 +1,153 @@ +from typing import Any, List, Optional + +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper +from graphgen.bases.datatypes import Token + + +class VLLMWrapper(BaseLLMWrapper): + """ + Async inference backend based on vLLM (https://github.com/vllm-project/vllm) + """ + + def __init__( + self, + model: str, + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.9, + temperature: float = 0.0, + top_p: float = 1.0, + topk: int = 5, + **kwargs: Any, + ): + super().__init__(temperature=temperature, top_p=top_p, **kwargs) + + try: + from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams + except ImportError as exc: + raise ImportError( + "VLLMWrapper requires vllm. Install it with: pip install vllm" + ) from exc + + self.SamplingParams = SamplingParams + + engine_args = AsyncEngineArgs( + model=model, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + trust_remote_code=kwargs.get("trust_remote_code", True), + ) + self.engine = AsyncLLMEngine.from_engine_args(engine_args) + + self.temperature = temperature + self.top_p = top_p + self.topk = topk + + # ------------------------------------------------------------------ + # helper:把 history 拼成多轮格式(与 HFWrapper 保持一致) + # ------------------------------------------------------------------ + @staticmethod + def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: + msgs = history or [] + lines = [] + for m in msgs: + if isinstance(m, dict): + role = m.get("role", "") + content = m.get("content", "") + lines.append(f"{role}: {content}") + else: + lines.append(str(m)) + lines.append(prompt) + return "\n".join(lines) + + # ------------------------------------------------------------------ + # 1. 常规生成 + # ------------------------------------------------------------------ + async def generate_answer( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> str: + full_prompt = self._build_inputs(text, history) + + sp = self.SamplingParams( + temperature=self.temperature if self.temperature > 0 else 1.0, + top_p=self.top_p if self.temperature > 0 else 1.0, + max_tokens=extra.get("max_new_tokens", 512), + ) + + # vLLM 的异步接口 + results = [] + async for req_output in self.engine.generate( + full_prompt, sp, request_id="graphgen_req" + ): + results = req_output.outputs + # 取最后一次返回 + return results[-1].text + + # ------------------------------------------------------------------ + # 2. 只生成 1 个新 token,返回 top-k 概率 + # ------------------------------------------------------------------ + async def generate_topk_per_token( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + full_prompt = self._build_inputs(text, history) + + # 强制 greedy(temperature=0)并返回 logprobs + sp = self.SamplingParams( + temperature=0, + max_tokens=1, + logprobs=self.topk, # vLLM 会给出 top-k 的 logprob + ) + + results = [] + async for req_output in self.engine.generate( + full_prompt, sp, request_id="graphgen_topk" + ): + results = req_output.outputs + top_logprobs = results[-1].logprobs[0] # 第 1 个新生成 token 的 top-k + + tokens = [] + for _, logprob_obj in top_logprobs.items(): + tok_str = logprob_obj.decoded_token + prob = float(logprob_obj.logprob.exp()) + tokens.append(Token(tok_str, prob)) + # 按概率从高到低排序 + tokens.sort(key=lambda x: -x.prob) + return tokens + + # ------------------------------------------------------------------ + # 3. 逐 token 计算“被模型预测到”的概率(与 HFWrapper 语义对齐) + # ------------------------------------------------------------------ + async def generate_inputs_prob( + self, text: str, history: Optional[List[str]] = None, **extra: Any + ) -> List[Token]: + full_prompt = self._build_inputs(text, history) + + # vLLM 没有现成的“mask 一个 token 再算 prob”接口, + # 我们采用最直观的方式:把 prompt 一次性送进去,打开 + # prompt_logprobs=True,让 vLLM 返回 *输入部分* 每个位置的 + # logprob,然后挑出对应 token 的概率即可。 + sp = self.SamplingParams( + temperature=0, + max_tokens=0, # 不生成新 token + prompt_logprobs=1, # 只要 top-1 就够了 + ) + + results = [] + async for req_output in self.engine.generate( + full_prompt, sp, request_id="graphgen_prob" + ): + results = req_output.outputs + + # prompt_logprobs 是一个 list,长度 = prompt token 数, + # 每个元素是 dict{token_id: logprob_obj} 或 None(首个位置为 None) + prompt_logprobs = results[-1].prompt_logprobs + + tokens = [] + for _, logprob_dict in enumerate(prompt_logprobs): + if logprob_dict is None: + continue + # 这里每个 dict 只有 1 个 kv,因为 top-1 + _, logprob_obj = next(iter(logprob_dict.items())) + tok_str = logprob_obj.decoded_token + prob = float(logprob_obj.logprob.exp()) + tokens.append(Token(tok_str, prob)) + return tokens diff --git a/tests/integration_tests/models/llm/local/test_ds_wrapper.py b/tests/integration_tests/models/llm/local/test_ds_wrapper.py deleted file mode 100644 index 87df19c8..00000000 --- a/tests/integration_tests/models/llm/local/test_ds_wrapper.py +++ /dev/null @@ -1,33 +0,0 @@ -import sys - -import pytest - -from graphgen.models.llm.local.ds_wrapper import DeepSpeedBackend - - -def test_deepspeed_backend_init(monkeypatch): - class DummyModel: - def eval(self): - pass - - class DummyModule: - def __init__(self): - self.module = DummyModel() - - def dummy_initialize(model, config): - return DummyModule(), None, None, None - - monkeypatch.setitem( - sys.modules, - "deepspeed", - type("ds", (), {"initialize": staticmethod(dummy_initialize)})(), - ) - backend = DeepSpeedBackend(model=DummyModel()) - assert hasattr(backend.model, "module") - assert hasattr(backend.model.module, "eval") - - -def test_deepspeed_not_installed(monkeypatch): - monkeypatch.setitem(sys.modules, "deepspeed", None) - with pytest.raises(ImportError): - DeepSpeedBackend(model=object()) From 569db1dbcde157870fd336231b3f3d9113a4842c Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 29 Oct 2025 11:13:39 +0800 Subject: [PATCH 23/27] wip:sglang backend --- graphgen/models/llm/local/sglang_wrapper.py | 8 +++----- graphgen/models/llm/local/trt_wrapper.py | 2 +- graphgen/operators/init/init_llm.py | 4 ++++ 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/graphgen/models/llm/local/sglang_wrapper.py b/graphgen/models/llm/local/sglang_wrapper.py index aa48cbfc..99439aa2 100644 --- a/graphgen/models/llm/local/sglang_wrapper.py +++ b/graphgen/models/llm/local/sglang_wrapper.py @@ -12,7 +12,7 @@ class SGLangWrapper(BaseLLMWrapper): def __init__( self, - model_path: str, + model: str, tp_size: int = 1, max_context_len: int = 4096, server_url: Optional[str] = None, @@ -29,7 +29,7 @@ def __init__( raise ImportError( "Please install sglang to use SGLangBackend: pip install sglang[all]>=0.4.4" ) from exc - self.model_path = model_path + self.model_path = model self.temperature = temperature self.top_p = top_p self.topk = topk @@ -39,9 +39,7 @@ def __init__( self.runtime = RuntimeEndpoint(server_url) else: sgl.set_default_backend( - sgl.Runtime( - model_path, tp_size=tp_size, max_context_len=max_context_len - ) + sgl.Runtime(model, tp_size=tp_size, max_context_len=max_context_len) ) self.runtime = sgl.get_default_backend() diff --git a/graphgen/models/llm/local/trt_wrapper.py b/graphgen/models/llm/local/trt_wrapper.py index be7223bd..2a4f8f58 100644 --- a/graphgen/models/llm/local/trt_wrapper.py +++ b/graphgen/models/llm/local/trt_wrapper.py @@ -7,7 +7,7 @@ from graphgen.bases.datatypes import Token -class TensorRTBackend(BaseLLMWrapper): +class TensorRTWrapper(BaseLLMWrapper): """ Async inference backend based on TensorRT-LLM """ diff --git a/graphgen/operators/init/init_llm.py b/graphgen/operators/init/init_llm.py index 0576f2f5..63f3fd7a 100644 --- a/graphgen/operators/init/init_llm.py +++ b/graphgen/operators/init/init_llm.py @@ -43,6 +43,10 @@ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper: from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper return HuggingFaceWrapper(**config) + if backend == "sglang": + from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper + + return SGLangWrapper(**config) raise NotImplementedError(f"Backend {backend} is not implemented yet.") From e86cbbf60799d8d5941bad2ceecb6c9bc9cb457b Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 29 Oct 2025 15:51:15 +0800 Subject: [PATCH 24/27] fix: change llm_wrapper type --- graphgen/bases/base_llm_wrapper.py | 6 ++ graphgen/models/llm/local/ollama_wrapper.py | 65 -------------------- graphgen/operators/build_kg/build_mm_kg.py | 5 +- graphgen/operators/build_kg/build_text_kg.py | 5 +- graphgen/operators/init/init_llm.py | 7 ++- graphgen/operators/judge.py | 5 +- graphgen/operators/quiz.py | 5 +- 7 files changed, 23 insertions(+), 75 deletions(-) delete mode 100644 graphgen/models/llm/local/ollama_wrapper.py diff --git a/graphgen/bases/base_llm_wrapper.py b/graphgen/bases/base_llm_wrapper.py index b1f9cb0d..5a28fbb4 100644 --- a/graphgen/bases/base_llm_wrapper.py +++ b/graphgen/bases/base_llm_wrapper.py @@ -66,3 +66,9 @@ def filter_think_tags(text: str, think_tag: str = "think") -> str: think_pattern = re.compile(rf"<{think_tag}>.*?", re.DOTALL) filtered_text = think_pattern.sub("", text).strip() return filtered_text if filtered_text else text.strip() + + def shutdown(self) -> None: + """Shutdown the LLM engine if applicable.""" + + def restart(self) -> None: + """Reinitialize the LLM engine if applicable.""" diff --git a/graphgen/models/llm/local/ollama_wrapper.py b/graphgen/models/llm/local/ollama_wrapper.py deleted file mode 100644 index cf3f36fc..00000000 --- a/graphgen/models/llm/local/ollama_wrapper.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Any, List, Optional - -from graphgen.bases import BaseLLMWrapper -from graphgen.bases.datatypes import Token - - -class OllamaBackend(BaseLLMWrapper): - """ - Async inference backend based on Ollama local server - """ - - def __init__( - self, - model: str, # e.g. "llama3.1:8b" - host: str = "http://localhost:11434", - temperature: float = 0.0, - top_p: float = 1.0, - topk: int = 5, - **kwargs: Any - ): - try: - import ollama - except ImportError as exc: - raise ImportError( - "Please install ollama to use OllamaBackend: pip install ollama>=0.1.5" - ) from exc - super().__init__(temperature=temperature, top_p=top_p, **kwargs) - self.client = ollama.AsyncClient(host=host) - self.model = model - self.topk = topk - - @staticmethod - def _messages_to_str(prompt: str, history: Optional[List[str]] = None) -> str: - if not history: - return prompt - return "\n".join(history) + "\n" + prompt - - async def generate_answer( - self, text: str, history: Optional[List[str]] = None, **extra: Any - ) -> str: - text = self._messages_to_str(text, history) - resp = await self.client.generate( - model=self.model, - prompt=text, - options={ - "temperature": self.temperature or 0, - "top_p": self.top_p if self.top_p < 1.0 else 1, - }, - stream=False, - ) - return resp["response"] - - async def generate_topk_per_token( - self, text: str, history: Optional[List[str]] = None, **extra: Any - ) -> List[Token]: - raise NotImplementedError( - "Ollama backend does not support per-token top-k yet." - ) - - async def generate_inputs_prob( - self, text: str, history: Optional[List[str]] = None, **extra: Any - ) -> List[Token]: - raise NotImplementedError( - "Ollama backend does not support per-token input probabilities yet." - ) diff --git a/graphgen/operators/build_kg/build_mm_kg.py b/graphgen/operators/build_kg/build_mm_kg.py index 9301c2b9..624b10ad 100644 --- a/graphgen/operators/build_kg/build_mm_kg.py +++ b/graphgen/operators/build_kg/build_mm_kg.py @@ -3,14 +3,15 @@ import gradio as gr +from graphgen.bases import BaseLLMWrapper from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Chunk -from graphgen.models import MMKGBuilder, OpenAIClient +from graphgen.models import MMKGBuilder from graphgen.utils import run_concurrent async def build_mm_kg( - llm_client: OpenAIClient, + llm_client: BaseLLMWrapper, kg_instance: BaseGraphStorage, chunks: List[Chunk], progress_bar: gr.Progress = None, diff --git a/graphgen/operators/build_kg/build_text_kg.py b/graphgen/operators/build_kg/build_text_kg.py index 3babe2e5..3c75f022 100644 --- a/graphgen/operators/build_kg/build_text_kg.py +++ b/graphgen/operators/build_kg/build_text_kg.py @@ -3,14 +3,15 @@ import gradio as gr +from graphgen.bases import BaseLLMWrapper from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Chunk -from graphgen.models import LightRAGKGBuilder, OpenAIClient +from graphgen.models import LightRAGKGBuilder from graphgen.utils import run_concurrent async def build_text_kg( - llm_client: OpenAIClient, + llm_client: BaseLLMWrapper, kg_instance: BaseGraphStorage, chunks: List[Chunk], progress_bar: gr.Progress = None, diff --git a/graphgen/operators/init/init_llm.py b/graphgen/operators/init/init_llm.py index 63f3fd7a..51bb1de8 100644 --- a/graphgen/operators/init/init_llm.py +++ b/graphgen/operators/init/init_llm.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict +from typing import Any, Dict, Optional from graphgen.bases import BaseLLMWrapper from graphgen.models import Tokenizer @@ -62,7 +62,7 @@ def _load_env_group(prefix: str) -> Dict[str, Any]: } -def init_llm(model_type: str) -> BaseLLMWrapper: +def init_llm(model_type: str) -> Optional[BaseLLMWrapper]: if model_type == "synthesizer": prefix = "SYNTHESIZER_" elif model_type == "trainee": @@ -70,6 +70,9 @@ def init_llm(model_type: str) -> BaseLLMWrapper: else: raise NotImplementedError(f"Model type {model_type} is not implemented yet.") config = _load_env_group(prefix) + # if config is empty, return None + if not config: + return None backend = config.pop("backend") llm_wrapper = LLMFactory.create_llm_wrapper(backend, config) return llm_wrapper diff --git a/graphgen/operators/judge.py b/graphgen/operators/judge.py index f7b0b963..d291d29a 100644 --- a/graphgen/operators/judge.py +++ b/graphgen/operators/judge.py @@ -3,13 +3,14 @@ from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient +from graphgen.bases import BaseLLMWrapper +from graphgen.models import JsonKVStorage, NetworkXStorage from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT from graphgen.utils import logger, yes_no_loss_entropy async def judge_statement( # pylint: disable=too-many-statements - trainee_llm_client: OpenAIClient, + trainee_llm_client: BaseLLMWrapper, graph_storage: NetworkXStorage, rephrase_storage: JsonKVStorage, re_judge: bool = False, diff --git a/graphgen/operators/quiz.py b/graphgen/operators/quiz.py index a8623bfb..cd86ef2d 100644 --- a/graphgen/operators/quiz.py +++ b/graphgen/operators/quiz.py @@ -3,13 +3,14 @@ from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient +from graphgen.bases import BaseLLMWrapper +from graphgen.models import JsonKVStorage, NetworkXStorage from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT from graphgen.utils import detect_main_language, logger async def quiz( - synth_llm_client: OpenAIClient, + synth_llm_client: BaseLLMWrapper, graph_storage: NetworkXStorage, rephrase_storage: JsonKVStorage, max_samples: int = 1, From 095d18a3b9a57afab9c636080cab8ecb40510211 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 29 Oct 2025 16:06:37 +0800 Subject: [PATCH 25/27] wip: add sglang_wrapper --- .env.example | 17 +- graphgen/graphgen.py | 23 ++- graphgen/models/llm/local/sglang_wrapper.py | 197 +++++++++++++------- graphgen/models/llm/local/tgi_wrapper.py | 2 +- graphgen/models/llm/local/trt_wrapper.py | 75 +------- graphgen/models/llm/local/vllm_wrapper.py | 22 +-- graphgen/operators/init/init_llm.py | 12 +- 7 files changed, 180 insertions(+), 168 deletions(-) diff --git a/.env.example b/.env.example index 7e321cd8..835471f2 100644 --- a/.env.example +++ b/.env.example @@ -2,7 +2,7 @@ TOKENIZER_MODEL= # LLM -# Support different backends: http_api, openai_api, ollama_api, ollama, deepspeed, huggingface, tgi, sglang, tensorrt +# Support different backends: http_api, openai_api, ollama_api, ollama, huggingface, tgi, sglang, tensorrt # http_api / openai_api SYNTHESIZER_BACKEND=openai_api @@ -13,3 +13,18 @@ TRAINEE_BACKEND=openai_api TRAINEE_MODEL=gpt-4o-mini TRAINEE_BASE_URL= TRAINEE_API_KEY= + +# # ollama_api +# SYNTHESIZER_BACKEND=ollama_api +# SYNTHESIZER_MODEL=gemma3 +# SYNTHESIZER_BASE_URL=http://localhost:11434 +# +# Note: TRAINEE with ollama_api backend is not supported yet as ollama_api does not support logprobs. + +# # huggingface +# SYNTHESIZER_BACKEND=huggingface +# SYNTHESIZER_MODEL=Qwen/Qwen2.5-0.5B-Instruct +# +# TRAINEE_BACKEND=huggingface +# TRAINEE_MODEL=Qwen/Qwen2.5-0.5B-Instruct + diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 635272e5..e8258829 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -5,6 +5,7 @@ import gradio as gr +from graphgen.bases import BaseLLMWrapper from graphgen.bases.base_storage import StorageNameSpace from graphgen.bases.datatypes import Chunk from graphgen.models import ( @@ -49,13 +50,10 @@ def __init__( model_name=os.getenv("TOKENIZER_MODEL") ) - self.synthesizer_llm_client: OpenAIClient = synthesizer_llm_client or init_llm( - "synthesizer" - ) - - self.trainee_llm_client: OpenAIClient = trainee_llm_client or init_llm( - "trainee" + self.synthesizer_llm_client: BaseLLMWrapper = ( + synthesizer_llm_client or init_llm("synthesizer") ) + self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client self.full_docs_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="full_docs" @@ -266,6 +264,12 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): ) # TODO: assert trainee_llm_client is valid before judge + if not self.trainee_llm_client: + # TODO: shutdown existing synthesizer_llm_client properly + logger.info("No trainee LLM client provided, initializing a new one.") + self.synthesizer_llm_client.shutdown() + self.trainee_llm_client = init_llm("trainee") + re_judge = quiz_and_judge_config["re_judge"] _update_relations = await judge_statement( self.trainee_llm_client, @@ -273,9 +277,16 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): self.rephrase_storage, re_judge, ) + await self.rephrase_storage.index_done_callback() await _update_relations.index_done_callback() + logger.info("Shutting down trainee LLM client.") + self.trainee_llm_client.shutdown() + self.trainee_llm_client = None + logger.info("Restarting synthesizer LLM client.") + self.synthesizer_llm_client.restart() + @async_to_sync_method async def generate(self, partition_config: Dict, generate_config: Dict): # Step 1: partition the graph diff --git a/graphgen/models/llm/local/sglang_wrapper.py b/graphgen/models/llm/local/sglang_wrapper.py index 99439aa2..b01e4fba 100644 --- a/graphgen/models/llm/local/sglang_wrapper.py +++ b/graphgen/models/llm/local/sglang_wrapper.py @@ -1,112 +1,175 @@ import math -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional -from graphgen.bases import BaseLLMWrapper +from graphgen.bases.base_llm_wrapper import BaseLLMWrapper from graphgen.bases.datatypes import Token +# TODO: implement SGLangWrapper methods class SGLangWrapper(BaseLLMWrapper): """ - Async inference backend based on SGLang + Async inference backend based on SGLang offline engine. """ def __init__( self, model: str, - tp_size: int = 1, - max_context_len: int = 4096, - server_url: Optional[str] = None, temperature: float = 0.0, top_p: float = 1.0, topk: int = 5, - **kwargs: Any + **kwargs: Any, ): super().__init__(temperature=temperature, top_p=top_p, **kwargs) try: import sglang as sgl - from sglang.backend.runtime_endpoint import RuntimeEndpoint + from sglang.utils import async_stream_and_merge, stream_and_merge except ImportError as exc: raise ImportError( - "Please install sglang to use SGLangBackend: pip install sglang[all]>=0.4.4" + "SGLangWrapper requires sglang. Install it with: " + "uv pip install sglang --prerelease=allow" ) from exc - self.model_path = model + + self.model_path: str = model self.temperature = temperature self.top_p = top_p self.topk = topk - # if server_url is given, connect to remote server; else launch local runtime - if server_url: - self.runtime = RuntimeEndpoint(server_url) - else: - sgl.set_default_backend( - sgl.Runtime(model, tp_size=tp_size, max_context_len=max_context_len) - ) - self.runtime = sgl.get_default_backend() + # Initialise the offline engine + self.engine = sgl.Engine(model_path=self.model_path) - self.tokenizer = self.runtime.get_tokenizer() + # Keep helpers for streaming + self.async_stream_and_merge = async_stream_and_merge + self.stream_and_merge = stream_and_merge @staticmethod - def _messages_to_str(prompt: str, history: Optional[List[str]] = None) -> str: - if not history: - return prompt - return "\n".join(history) + "\n" + prompt + def _build_sampling_params( + temperature: float, + top_p: float, + max_tokens: int, + topk: int, + logprobs: bool = False, + ) -> Dict[str, Any]: + """Build SGLang-compatible sampling-params dict.""" + params = { + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_tokens, + } + if logprobs and topk > 0: + params["logprobs"] = topk + return params + + def _prep_prompt(self, text: str, history: Optional[List[str]] = None) -> str: + """Convert raw text (+ optional history) into a single prompt string.""" + parts = [] + if self.system_prompt: + parts.append(self.system_prompt) + if history: + assert len(history) % 2 == 0, "History must have even length (u/a turns)." + parts.extend(history) + parts.append(text) + return "\n".join(parts) + + def _tokens_from_output(self, output: Dict[str, Any]) -> List[Token]: + """ + Convert SGLang logprobs output into List[Token]. + SGLang returns: + output['logprobs'][t] -> { + "token": , + "logprob": , + "top_k_tokens": [...], + "top_k_logprobs": [...], + } + """ + tokens: List[Token] = [] + if "logprobs" not in output or not output["logprobs"]: + return tokens + + for entry in output["logprobs"]: + token_str = entry["token"] + logprob = entry["logprob"] + prob = math.exp(logprob) + + top_candidates = [] + if self.topk > 0 and "top_k_tokens" in entry: + for tok, lp in zip(entry["top_k_tokens"], entry["top_k_logprobs"]): + top_candidates.append(Token(tok, math.exp(lp))) + + tokens.append(Token(token_str, prob, top_candidates=top_candidates)) + return tokens async def generate_answer( - self, text: str, history: Optional[List[str]] = None, **extra: Any + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, ) -> str: - text = self._messages_to_str(text, history) - - output = await self.runtime.generate( - text, - max_new_tokens=512, - temperature=self.temperature if self.temperature > 0 else 0, + prompt = self._prep_prompt(text, history) + sampling_params = self._build_sampling_params( + temperature=self.temperature, top_p=self.top_p, - stop=None, + max_tokens=self.max_tokens, + topk=0, # no logprobs needed for simple generation ) - return output + + outputs = self.engine.generate([prompt], sampling_params) + return self.filter_think_tags(outputs[0]["text"]) async def generate_topk_per_token( - self, text: str, history: Optional[List[str]] = None, **extra: Any + self, + text: str, + history: Optional[List[str]] = None, + **extra: Any, ) -> List[Token]: - text = self._messages_to_str(text, history) - - output_obj = await self.runtime.generate( - text, - max_new_tokens=1, - temperature=0, - return_logprob=True, - top_logprobs=self.topk, - logprob_start_len=0, + prompt = self._prep_prompt(text, history) + sampling_params = self._build_sampling_params( + temperature=self.temperature, + top_p=self.top_p, + max_tokens=5, # keep short for token-level analysis + topk=self.topk, + logprobs=True, ) - topk_list = output_obj["meta_info"]["top_logprobs"][ - 0 - ] # List[ (token_str, logprob), ... ] - return [Token(tok, math.exp(logprob)) for tok, logprob in topk_list] + outputs = self.engine.generate([prompt], sampling_params) + return self._tokens_from_output(outputs[0]) async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: - text = self._messages_to_str(text, history) - ids = self.tokenizer.encode(text) - if not ids: - return [] - - logprob_tokens: List[Token] = [] - - for i in range(1, len(ids) + 1): - trunc_ids = ids[: i - 1] + ids[i:] if i < len(ids) else ids[:-1] - trunc_text = self.tokenizer.decode(trunc_ids) - - output_obj = await self.runtime.generate( - trunc_text, - max_new_tokens=1, - temperature=0, - return_logprob=True, - top_logprobs=1, - logprob_start_len=len(trunc_ids) - 1, + """ + Return per-token probabilities for the *input* sequence. + SGLang offline engine does not expose this directly; we emulate by + generating 0 new tokens with logprobs enabled (returns prompt logprobs). + """ + prompt = self._prep_prompt(text, history) + sampling_params = self._build_sampling_params( + temperature=0.0, # deterministic + top_p=1.0, + max_tokens=0, # generate nothing + topk=self.topk, + logprobs=True, + ) + + outputs = self.engine.generate([prompt], sampling_params) + # SGLang returns prompt logprobs under key 'prompt_logprobs' when max_new_tokens=0 + prompt_logprobs = outputs[0].get("prompt_logprobs", []) + tokens: List[Token] = [] + for entry in prompt_logprobs: + tokens.append( + Token( + text=entry["token"], + prob=math.exp(entry["logprob"]), + top_candidates=[], # SGLang does not give top-k for prompt tokens + ) ) - top1 = output_obj["meta_info"]["top_logprobs"][0][0] - logprob_tokens.append(Token(top1[0], math.exp(top1[1]))) + return tokens + + def shutdown(self) -> None: + """Gracefully shutdown the SGLang engine.""" + if hasattr(self, "engine"): + self.engine.shutdown() - return logprob_tokens + def restart(self) -> None: + """Restart the SGLang engine.""" + self.shutdown() + self.engine = self.engine.__class__(model_path=self.model_path) diff --git a/graphgen/models/llm/local/tgi_wrapper.py b/graphgen/models/llm/local/tgi_wrapper.py index 8dac68c0..a722f6ea 100644 --- a/graphgen/models/llm/local/tgi_wrapper.py +++ b/graphgen/models/llm/local/tgi_wrapper.py @@ -4,6 +4,7 @@ from graphgen.bases.datatypes import Token +# TODO: implement TGIWrapper methods class TGIWrapper(BaseLLMWrapper): """ Async inference backend based on TGI (Text-Generation-Inference) @@ -18,7 +19,6 @@ def __init__( **kwargs: Any ): super().__init__(temperature=temperature, top_p=top_p, **kwargs) - # TODO: implement tgi wrapper async def generate_answer( self, text: str, history: Optional[List[str]] = None, **extra: Any diff --git a/graphgen/models/llm/local/trt_wrapper.py b/graphgen/models/llm/local/trt_wrapper.py index 2a4f8f58..078f5ba9 100644 --- a/graphgen/models/llm/local/trt_wrapper.py +++ b/graphgen/models/llm/local/trt_wrapper.py @@ -1,93 +1,26 @@ from typing import Any, List, Optional -import numpy as np -from transformers import AutoTokenizer - from graphgen.bases import BaseLLMWrapper from graphgen.bases.datatypes import Token +# TODO: implement TensorRTWrapper methods class TensorRTWrapper(BaseLLMWrapper): """ Async inference backend based on TensorRT-LLM """ - def __init__( - self, - engine_dir: str, - tokenizer_dir: str, - topk: int = 5, - temperature=0.0, - top_p=1.0, - **kwargs: Any - ): - super().__init__(temperature=temperature, top_p=top_p, **kwargs) - try: - from tensorrt_llm.runtime import ModelRunnerCpp - except ImportError as exc: - raise ImportError( - "Please install tensorrt-llm to use TensorRTBackend: pip install tensorrt-llm" - ) from exc - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) - self.runner = ModelRunnerCpp.from_dir(engine_dir) - self.topk = topk - self.temperature = temperature - self.top_p = top_p - - def _parse_generation(self, output_ids) -> str: - return self.tokenizer.decode(output_ids[0], skip_special_tokens=True) - async def generate_answer( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> str: - full = "\n".join(history or []) + "\n" + text - ids = self.tokenizer.encode(full) - output_ids = self.runner.generate( - [ids], - max_new_tokens=512, - temperature=self.temperature, - top_p=self.top_p, - eos_token_id=self.tokenizer.eos_token_id, - ) - return self._parse_generation(output_ids) + pass async def generate_topk_per_token( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: - full = "\n".join(history or []) + "\n" + text - ids = self.tokenizer.encode(full) - *_, logits = self.runner.generate( - [ids], - max_new_tokens=1, - temperature=0, - output_logits=True, - ) - logits = logits[0, -1, :] - probs = np.softmax(logits) - top_idx = np.argpartition(probs, -self.topk)[-self.topk :] - top_idx = top_idx[np.argsort(probs[top_idx])[::-1]] - return [ - Token(self.tokenizer.decode([idx]), float(probs[idx])) for idx in top_idx - ] + pass async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: - full = "\n".join(history or []) + "\n" + text - ids = self.tokenizer.encode(full) - logprob_tokens = [] - for i in range(1, len(ids) + 1): - trunc = ids[: i - 1] + ids[i:] if i < len(ids) else ids[:-1] - *_, logits = self.runner.generate( - [trunc], - max_new_tokens=1, - temperature=0, - output_logits=True, - ) - logits = logits[0, -1, :] - probs = np.softmax(logits) - true_id = ids[i - 1] - logprob_tokens.append( - Token(self.tokenizer.decode([true_id]), float(probs[true_id])) - ) - return logprob_tokens + pass diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index b20c07ac..d3f6cfcc 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -25,7 +25,7 @@ def __init__( from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams except ImportError as exc: raise ImportError( - "VLLMWrapper requires vllm. Install it with: pip install vllm" + "VLLMWrapper requires vllm. Install it with: uv pip install vllm --torch-backend=auto" ) from exc self.SamplingParams = SamplingParams @@ -42,9 +42,6 @@ def __init__( self.top_p = top_p self.topk = topk - # ------------------------------------------------------------------ - # helper:把 history 拼成多轮格式(与 HFWrapper 保持一致) - # ------------------------------------------------------------------ @staticmethod def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: msgs = history or [] @@ -59,9 +56,6 @@ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str: lines.append(prompt) return "\n".join(lines) - # ------------------------------------------------------------------ - # 1. 常规生成 - # ------------------------------------------------------------------ async def generate_answer( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> str: @@ -73,28 +67,22 @@ async def generate_answer( max_tokens=extra.get("max_new_tokens", 512), ) - # vLLM 的异步接口 results = [] async for req_output in self.engine.generate( full_prompt, sp, request_id="graphgen_req" ): results = req_output.outputs - # 取最后一次返回 return results[-1].text - # ------------------------------------------------------------------ - # 2. 只生成 1 个新 token,返回 top-k 概率 - # ------------------------------------------------------------------ async def generate_topk_per_token( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: full_prompt = self._build_inputs(text, history) - # 强制 greedy(temperature=0)并返回 logprobs sp = self.SamplingParams( temperature=0, max_tokens=1, - logprobs=self.topk, # vLLM 会给出 top-k 的 logprob + logprobs=self.topk, ) results = [] @@ -102,20 +90,16 @@ async def generate_topk_per_token( full_prompt, sp, request_id="graphgen_topk" ): results = req_output.outputs - top_logprobs = results[-1].logprobs[0] # 第 1 个新生成 token 的 top-k + top_logprobs = results[-1].logprobs[0] tokens = [] for _, logprob_obj in top_logprobs.items(): tok_str = logprob_obj.decoded_token prob = float(logprob_obj.logprob.exp()) tokens.append(Token(tok_str, prob)) - # 按概率从高到低排序 tokens.sort(key=lambda x: -x.prob) return tokens - # ------------------------------------------------------------------ - # 3. 逐 token 计算“被模型预测到”的概率(与 HFWrapper 语义对齐) - # ------------------------------------------------------------------ async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: diff --git a/graphgen/operators/init/init_llm.py b/graphgen/operators/init/init_llm.py index 51bb1de8..f7e33356 100644 --- a/graphgen/operators/init/init_llm.py +++ b/graphgen/operators/init/init_llm.py @@ -43,10 +43,16 @@ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper: from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper return HuggingFaceWrapper(**config) - if backend == "sglang": - from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper + # if backend == "sglang": + # from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper + # + # return SGLangWrapper(**config) + + if backend == "vllm": + from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper + + return VLLMWrapper(**config) - return SGLangWrapper(**config) raise NotImplementedError(f"Backend {backend} is not implemented yet.") From 8c92ffdf9666734b8c5c0364ac3ca4236f9069df Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 29 Oct 2025 17:50:39 +0800 Subject: [PATCH 26/27] docs: update .env.example --- .env.example | 1 - 1 file changed, 1 deletion(-) diff --git a/.env.example b/.env.example index 835471f2..94064796 100644 --- a/.env.example +++ b/.env.example @@ -27,4 +27,3 @@ TRAINEE_API_KEY= # # TRAINEE_BACKEND=huggingface # TRAINEE_MODEL=Qwen/Qwen2.5-0.5B-Instruct - From 03e6d23bb1f019bd17875529f17aca716ec47091 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 29 Oct 2025 19:19:56 +0800 Subject: [PATCH 27/27] fix: fix parsing token_logprobs in sglang_wrapper --- graphgen/models/llm/api/http_client.py | 2 +- graphgen/models/llm/api/openai_client.py | 4 +- graphgen/models/llm/local/sglang_wrapper.py | 79 +++++++-------------- 3 files changed, 29 insertions(+), 56 deletions(-) diff --git a/graphgen/models/llm/api/http_client.py b/graphgen/models/llm/api/http_client.py index c49018a5..2c3b0acd 100644 --- a/graphgen/models/llm/api/http_client.py +++ b/graphgen/models/llm/api/http_client.py @@ -163,7 +163,7 @@ async def generate_topk_per_token( **extra: Any, ) -> List[Token]: body = self._build_body(text, history or []) - body["max_tokens"] = 5 + body["max_tokens"] = 1 if self.topk_per_token > 0: body["logprobs"] = True body["top_logprobs"] = self.topk_per_token diff --git a/graphgen/models/llm/api/openai_client.py b/graphgen/models/llm/api/openai_client.py index 18aece36..5f9c131a 100644 --- a/graphgen/models/llm/api/openai_client.py +++ b/graphgen/models/llm/api/openai_client.py @@ -105,8 +105,8 @@ async def generate_topk_per_token( kwargs["logprobs"] = True kwargs["top_logprobs"] = self.topk_per_token - # Limit max_tokens to 5 to avoid long completions - kwargs["max_tokens"] = 5 + # Limit max_tokens to 1 to avoid long completions + kwargs["max_tokens"] = 1 completion = await self.client.chat.completions.create( # pylint: disable=E1125 model=self.model_name, **kwargs diff --git a/graphgen/models/llm/local/sglang_wrapper.py b/graphgen/models/llm/local/sglang_wrapper.py index b01e4fba..3781af2c 100644 --- a/graphgen/models/llm/local/sglang_wrapper.py +++ b/graphgen/models/llm/local/sglang_wrapper.py @@ -5,7 +5,6 @@ from graphgen.bases.datatypes import Token -# TODO: implement SGLangWrapper methods class SGLangWrapper(BaseLLMWrapper): """ Async inference backend based on SGLang offline engine. @@ -59,43 +58,39 @@ def _build_sampling_params( params["logprobs"] = topk return params - def _prep_prompt(self, text: str, history: Optional[List[str]] = None) -> str: + def _prep_prompt(self, text: str, history: Optional[List[dict]] = None) -> str: """Convert raw text (+ optional history) into a single prompt string.""" parts = [] if self.system_prompt: parts.append(self.system_prompt) if history: assert len(history) % 2 == 0, "History must have even length (u/a turns)." - parts.extend(history) + parts.extend([item["content"] for item in history]) parts.append(text) return "\n".join(parts) def _tokens_from_output(self, output: Dict[str, Any]) -> List[Token]: - """ - Convert SGLang logprobs output into List[Token]. - SGLang returns: - output['logprobs'][t] -> { - "token": , - "logprob": , - "top_k_tokens": [...], - "top_k_logprobs": [...], - } - """ tokens: List[Token] = [] - if "logprobs" not in output or not output["logprobs"]: - return tokens - for entry in output["logprobs"]: - token_str = entry["token"] - logprob = entry["logprob"] - prob = math.exp(logprob) + meta = output.get("meta_info", {}) + logprobs = meta.get("output_token_logprobs", []) + topks = meta.get("output_top_logprobs", []) + + tokenizer = self.engine.tokenizer_manager.tokenizer + + for idx, (lp, tid, _) in enumerate(logprobs): + prob = math.exp(lp) + tok_str = tokenizer.decode([tid]) top_candidates = [] - if self.topk > 0 and "top_k_tokens" in entry: - for tok, lp in zip(entry["top_k_tokens"], entry["top_k_logprobs"]): - top_candidates.append(Token(tok, math.exp(lp))) + if self.topk > 0 and idx < len(topks): + for t_lp, t_tid, _ in topks[idx][: self.topk]: + top_candidates.append( + Token(text=tokenizer.decode([t_tid]), prob=math.exp(t_lp)) + ) + + tokens.append(Token(text=tok_str, prob=prob, top_candidates=top_candidates)) - tokens.append(Token(token_str, prob, top_candidates=top_candidates)) return tokens async def generate_answer( @@ -112,7 +107,7 @@ async def generate_answer( topk=0, # no logprobs needed for simple generation ) - outputs = self.engine.generate([prompt], sampling_params) + outputs = await self.engine.async_generate([prompt], sampling_params) return self.filter_think_tags(outputs[0]["text"]) async def generate_topk_per_token( @@ -125,45 +120,23 @@ async def generate_topk_per_token( sampling_params = self._build_sampling_params( temperature=self.temperature, top_p=self.top_p, - max_tokens=5, # keep short for token-level analysis + max_tokens=1, # keep short for token-level analysis topk=self.topk, - logprobs=True, ) - outputs = self.engine.generate([prompt], sampling_params) + outputs = await self.engine.async_generate( + [prompt], sampling_params, return_logprob=True, top_logprobs_num=5 + ) + print(outputs) return self._tokens_from_output(outputs[0]) async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: - """ - Return per-token probabilities for the *input* sequence. - SGLang offline engine does not expose this directly; we emulate by - generating 0 new tokens with logprobs enabled (returns prompt logprobs). - """ - prompt = self._prep_prompt(text, history) - sampling_params = self._build_sampling_params( - temperature=0.0, # deterministic - top_p=1.0, - max_tokens=0, # generate nothing - topk=self.topk, - logprobs=True, + raise NotImplementedError( + "SGLangWrapper does not support per-token logprobs yet." ) - outputs = self.engine.generate([prompt], sampling_params) - # SGLang returns prompt logprobs under key 'prompt_logprobs' when max_new_tokens=0 - prompt_logprobs = outputs[0].get("prompt_logprobs", []) - tokens: List[Token] = [] - for entry in prompt_logprobs: - tokens.append( - Token( - text=entry["token"], - prob=math.exp(entry["logprob"]), - top_candidates=[], # SGLang does not give top-k for prompt tokens - ) - ) - return tokens - def shutdown(self) -> None: """Gracefully shutdown the SGLang engine.""" if hasattr(self, "engine"):