Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ba48ee5
chore: delete duplicate model
ChenZiHong-Gavin Oct 27, 2025
d1e0af5
feat: add huggingface wrapper
ChenZiHong-Gavin Oct 27, 2025
a3a0c60
refactor: change file name
ChenZiHong-Gavin Oct 27, 2025
61766b5
feat: add ds_wrapper, trt_wrapper, sglang_wrapper
ChenZiHong-Gavin Oct 27, 2025
e6f4502
refactor: refactor graphgen
ChenZiHong-Gavin Oct 27, 2025
5a1fc4d
wip: add azure client
ChenZiHong-Gavin Oct 27, 2025
0276402
Merge branch 'main' of https://github.com/open-sciencelab/GraphGen in…
ChenZiHong-Gavin Oct 27, 2025
276f881
feat: add ollama_wrapper, tgi_wrapper
ChenZiHong-Gavin Oct 27, 2025
846b924
feat: add azure_client, http_client, ollama_client
ChenZiHong-Gavin Oct 27, 2025
f6bdaf6
delete azure_client
ChenZiHong-Gavin Oct 27, 2025
ee2d35e
tests: add http_client test
ChenZiHong-Gavin Oct 28, 2025
d02e5a2
docs: update .env example
ChenZiHong-Gavin Oct 28, 2025
614283f
feat: switch llm backend(http_api)
ChenZiHong-Gavin Oct 28, 2025
abc8dc2
fix: fix ollama_client
ChenZiHong-Gavin Oct 28, 2025
ebf9d1c
wip: fix ollama_client
ChenZiHong-Gavin Oct 28, 2025
fac9997
tests: add ollama_client tests
ChenZiHong-Gavin Oct 28, 2025
f5a4594
fix: fix generate_topk_per_token in ollmam_client
ChenZiHong-Gavin Oct 28, 2025
d4beb52
fix: delete useless tests
ChenZiHong-Gavin Oct 28, 2025
c8055f1
fix: fix transformers warning not using GenerationConfig
ChenZiHong-Gavin Oct 28, 2025
ae9b28b
fix: fix _build_inputs in hf_wrapper
ChenZiHong-Gavin Oct 28, 2025
292d986
fix: fix gen_kwargs
ChenZiHong-Gavin Oct 28, 2025
f562ee2
chore: delete ds_wrapper
ChenZiHong-Gavin Oct 28, 2025
399ef45
feat: add vllm_wrapper
ChenZiHong-Gavin Oct 28, 2025
569db1d
wip:sglang backend
ChenZiHong-Gavin Oct 29, 2025
e86cbbf
fix: change llm_wrapper type
ChenZiHong-Gavin Oct 29, 2025
095d18a
wip: add sglang_wrapper
ChenZiHong-Gavin Oct 29, 2025
8c92ffd
docs: update .env.example
ChenZiHong-Gavin Oct 29, 2025
03e6d23
fix: fix parsing token_logprobs in sglang_wrapper
ChenZiHong-Gavin Oct 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,7 +1,29 @@
# Tokenizer
TOKENIZER_MODEL=
SYNTHESIZER_MODEL=

# LLM
# Support different backends: http_api, openai_api, ollama_api, ollama, huggingface, tgi, sglang, tensorrt

# http_api / openai_api
SYNTHESIZER_BACKEND=openai_api
SYNTHESIZER_MODEL=gpt-4o-mini
SYNTHESIZER_BASE_URL=
SYNTHESIZER_API_KEY=
TRAINEE_MODEL=
TRAINEE_BACKEND=openai_api
TRAINEE_MODEL=gpt-4o-mini
TRAINEE_BASE_URL=
TRAINEE_API_KEY=

# # ollama_api
# SYNTHESIZER_BACKEND=ollama_api
# SYNTHESIZER_MODEL=gemma3
# SYNTHESIZER_BASE_URL=http://localhost:11434
#
# Note: TRAINEE with ollama_api backend is not supported yet as ollama_api does not support logprobs.

# # huggingface
# SYNTHESIZER_BACKEND=huggingface
# SYNTHESIZER_MODEL=Qwen/Qwen2.5-0.5B-Instruct
#
# TRAINEE_BACKEND=huggingface
# TRAINEE_MODEL=Qwen/Qwen2.5-0.5B-Instruct
2 changes: 1 addition & 1 deletion graphgen/bases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base_generator import BaseGenerator
from .base_kg_builder import BaseKGBuilder
from .base_llm_client import BaseLLMClient
from .base_llm_wrapper import BaseLLMWrapper
from .base_partitioner import BasePartitioner
from .base_reader import BaseReader
from .base_splitter import BaseSplitter
Expand Down
4 changes: 2 additions & 2 deletions graphgen/bases/base_generator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from abc import ABC, abstractmethod
from typing import Any

from graphgen.bases.base_llm_client import BaseLLMClient
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper


class BaseGenerator(ABC):
"""
Generate QAs based on given prompts.
"""

def __init__(self, llm_client: BaseLLMClient):
def __init__(self, llm_client: BaseLLMWrapper):
self.llm_client = llm_client

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions graphgen/bases/base_kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from collections import defaultdict
from typing import Dict, List, Tuple

from graphgen.bases.base_llm_client import BaseLLMClient
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.base_storage import BaseGraphStorage
from graphgen.bases.datatypes import Chunk


class BaseKGBuilder(ABC):
def __init__(self, llm_client: BaseLLMClient):
def __init__(self, llm_client: BaseLLMWrapper):
self.llm_client = llm_client
self._nodes: Dict[str, List[dict]] = defaultdict(list)
self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from graphgen.bases.datatypes import Token


class BaseLLMClient(abc.ABC):
class BaseLLMWrapper(abc.ABC):
"""
LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...).
"""
Expand Down Expand Up @@ -66,3 +66,9 @@ def filter_think_tags(text: str, think_tag: str = "think") -> str:
think_pattern = re.compile(rf"<{think_tag}>.*?</{think_tag}>", re.DOTALL)
filtered_text = think_pattern.sub("", text).strip()
return filtered_text if filtered_text else text.strip()

def shutdown(self) -> None:
"""Shutdown the LLM engine if applicable."""

def restart(self) -> None:
"""Reinitialize the LLM engine if applicable."""
65 changes: 35 additions & 30 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import os
import time
from dataclasses import dataclass
from typing import Dict, cast

import gradio as gr

from graphgen.bases import BaseLLMWrapper
from graphgen.bases.base_storage import StorageNameSpace
from graphgen.bases.datatypes import Chunk
from graphgen.models import (
Expand All @@ -20,6 +20,7 @@
build_text_kg,
chunk_documents,
generate_qas,
init_llm,
judge_statement,
partition_kg,
quiz,
Expand All @@ -31,40 +32,28 @@
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))


@dataclass
class GraphGen:
unique_id: int = int(time.time())
working_dir: str = os.path.join(sys_path, "cache")

# llm
tokenizer_instance: Tokenizer = None
synthesizer_llm_client: OpenAIClient = None
trainee_llm_client: OpenAIClient = None

# webui
progress_bar: gr.Progress = None

def __post_init__(self):
self.tokenizer_instance: Tokenizer = self.tokenizer_instance or Tokenizer(
def __init__(
self,
unique_id: int = int(time.time()),
working_dir: str = os.path.join(sys_path, "cache"),
tokenizer_instance: Tokenizer = None,
synthesizer_llm_client: OpenAIClient = None,
trainee_llm_client: OpenAIClient = None,
progress_bar: gr.Progress = None,
):
self.unique_id: int = unique_id
self.working_dir: str = working_dir

# llm
self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer(
model_name=os.getenv("TOKENIZER_MODEL")
)

self.synthesizer_llm_client: OpenAIClient = (
self.synthesizer_llm_client
or OpenAIClient(
model_name=os.getenv("SYNTHESIZER_MODEL"),
api_key=os.getenv("SYNTHESIZER_API_KEY"),
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
tokenizer=self.tokenizer_instance,
)
)

self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient(
model_name=os.getenv("TRAINEE_MODEL"),
api_key=os.getenv("TRAINEE_API_KEY"),
base_url=os.getenv("TRAINEE_BASE_URL"),
tokenizer=self.tokenizer_instance,
self.synthesizer_llm_client: BaseLLMWrapper = (
synthesizer_llm_client or init_llm("synthesizer")
)
self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client

self.full_docs_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="full_docs"
Expand All @@ -86,6 +75,9 @@ def __post_init__(self):
namespace="qa",
)

# webui
self.progress_bar: gr.Progress = progress_bar

@async_to_sync_method
async def insert(self, read_config: Dict, split_config: Dict):
"""
Expand Down Expand Up @@ -272,16 +264,29 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
)

# TODO: assert trainee_llm_client is valid before judge
if not self.trainee_llm_client:
# TODO: shutdown existing synthesizer_llm_client properly
logger.info("No trainee LLM client provided, initializing a new one.")
self.synthesizer_llm_client.shutdown()
self.trainee_llm_client = init_llm("trainee")

re_judge = quiz_and_judge_config["re_judge"]
_update_relations = await judge_statement(
self.trainee_llm_client,
self.graph_storage,
self.rephrase_storage,
re_judge,
)

await self.rephrase_storage.index_done_callback()
await _update_relations.index_done_callback()

logger.info("Shutting down trainee LLM client.")
self.trainee_llm_client.shutdown()
self.trainee_llm_client = None
logger.info("Restarting synthesizer LLM client.")
self.synthesizer_llm_client.restart()

@async_to_sync_method
async def generate(self, partition_config: Dict, generate_config: Dict):
# Step 1: partition the graph
Expand Down
3 changes: 1 addition & 2 deletions graphgen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
VQAGenerator,
)
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
from .llm.openai_client import OpenAIClient
from .llm.topk_token_model import TopkTokenModel
from .llm import HTTPClient, OllamaClient, OpenAIClient
from .partitioner import (
AnchorBFSPartitioner,
BFSPartitioner,
Expand Down
4 changes: 2 additions & 2 deletions graphgen/models/kg_builder/light_rag_kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import Counter, defaultdict
from typing import Dict, List, Tuple

from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMWrapper, Chunk
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
from graphgen.utils import (
detect_main_language,
Expand All @@ -15,7 +15,7 @@


class LightRAGKGBuilder(BaseKGBuilder):
def __init__(self, llm_client: BaseLLMClient, max_loop: int = 3):
def __init__(self, llm_client: BaseLLMWrapper, max_loop: int = 3):
super().__init__(llm_client)
self.max_loop = max_loop

Expand Down
4 changes: 4 additions & 0 deletions graphgen/models/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .api.http_client import HTTPClient
from .api.ollama_client import OllamaClient
from .api.openai_client import OpenAIClient
from .local.hf_wrapper import HuggingFaceWrapper
Empty file.
Loading