From af5e7a237e45d1ec7d5bc1e3c4d5b15b366b6d55 Mon Sep 17 00:00:00 2001 From: CHERRY-ui8 <2693275288@qq.com> Date: Thu, 20 Nov 2025 10:52:40 +0800 Subject: [PATCH] fix: OpenAIClient parameter from model_name to model to resolve key mismatch --- baselines/Genie/genie.py | 2 +- baselines/LongForm/longform.py | 2 +- baselines/SELF-QA/self-qa.py | 2 +- baselines/Wrap/wrap.py | 2 +- graphgen/models/llm/api/openai_client.py | 8 ++++---- webui/app.py | 4 ++-- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/baselines/Genie/genie.py b/baselines/Genie/genie.py index 1dfd8e93..af542e49 100644 --- a/baselines/Genie/genie.py +++ b/baselines/Genie/genie.py @@ -122,7 +122,7 @@ async def process_chunk(content: str): load_dotenv() llm_client = OpenAIClient( - model_name=os.getenv("SYNTHESIZER_MODEL"), + model=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL"), ) diff --git a/baselines/LongForm/longform.py b/baselines/LongForm/longform.py index 8467556f..500483f3 100644 --- a/baselines/LongForm/longform.py +++ b/baselines/LongForm/longform.py @@ -89,7 +89,7 @@ async def process_chunk(content: str): load_dotenv() llm_client = OpenAIClient( - model_name=os.getenv("SYNTHESIZER_MODEL"), + model=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL"), ) diff --git a/baselines/SELF-QA/self-qa.py b/baselines/SELF-QA/self-qa.py index 1f96cff9..d0eac4b6 100644 --- a/baselines/SELF-QA/self-qa.py +++ b/baselines/SELF-QA/self-qa.py @@ -156,7 +156,7 @@ async def process_chunk(content: str): load_dotenv() llm_client = OpenAIClient( - model_name=os.getenv("SYNTHESIZER_MODEL"), + model=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL"), ) diff --git a/baselines/Wrap/wrap.py b/baselines/Wrap/wrap.py index cecbaadd..4d898a0e 100644 --- a/baselines/Wrap/wrap.py +++ b/baselines/Wrap/wrap.py @@ -109,7 +109,7 @@ async def process_chunk(content: str): load_dotenv() llm_client = OpenAIClient( - model_name=os.getenv("SYNTHESIZER_MODEL"), + model=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL"), ) diff --git a/graphgen/models/llm/api/openai_client.py b/graphgen/models/llm/api/openai_client.py index 5f9c131a..448d6625 100644 --- a/graphgen/models/llm/api/openai_client.py +++ b/graphgen/models/llm/api/openai_client.py @@ -32,7 +32,7 @@ class OpenAIClient(BaseLLMWrapper): def __init__( self, *, - model_name: str = "gpt-4o-mini", + model: str = "gpt-4o-mini", api_key: Optional[str] = None, base_url: Optional[str] = None, json_mode: bool = False, @@ -44,7 +44,7 @@ def __init__( **kwargs: Any, ): super().__init__(**kwargs) - self.model_name = model_name + self.model = model self.api_key = api_key self.base_url = base_url self.json_mode = json_mode @@ -109,7 +109,7 @@ async def generate_topk_per_token( kwargs["max_tokens"] = 1 completion = await self.client.chat.completions.create( # pylint: disable=E1125 - model=self.model_name, **kwargs + model=self.model, **kwargs ) tokens = get_top_response_tokens(completion) @@ -141,7 +141,7 @@ async def generate_answer( await self.tpm.wait(estimated_tokens, silent=True) completion = await self.client.chat.completions.create( # pylint: disable=E1125 - model=self.model_name, **kwargs + model=self.model, **kwargs ) if hasattr(completion, "usage"): self.token_usage.append( diff --git a/webui/app.py b/webui/app.py index 2e74e203..d0f45f9f 100644 --- a/webui/app.py +++ b/webui/app.py @@ -42,7 +42,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen: tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base")) synthesizer_llm_client = OpenAIClient( - model_name=env.get("SYNTHESIZER_MODEL", ""), + model=env.get("SYNTHESIZER_MODEL", ""), base_url=env.get("SYNTHESIZER_BASE_URL", ""), api_key=env.get("SYNTHESIZER_API_KEY", ""), request_limit=True, @@ -51,7 +51,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen: tokenizer=tokenizer_instance, ) trainee_llm_client = OpenAIClient( - model_name=env.get("TRAINEE_MODEL", ""), + model=env.get("TRAINEE_MODEL", ""), base_url=env.get("TRAINEE_BASE_URL", ""), api_key=env.get("TRAINEE_API_KEY", ""), request_limit=True,