diff --git a/.env.example b/.env.example index 94064796..5101b269 100644 --- a/.env.example +++ b/.env.example @@ -14,6 +14,14 @@ TRAINEE_MODEL=gpt-4o-mini TRAINEE_BASE_URL= TRAINEE_API_KEY= +# azure_openai_api +# SYNTHESIZER_BACKEND=azure_openai_api +# The following is the same as your "Deployment name" in Azure +# SYNTHESIZER_MODEL= +# SYNTHESIZER_BASE_URL=https://.openai.azure.com/openai/deployments//chat/completions +# SYNTHESIZER_API_KEY= +# SYNTHESIZER_API_VERSION= + # # ollama_api # SYNTHESIZER_BACKEND=ollama_api # SYNTHESIZER_MODEL=gemma3 diff --git a/graphgen/models/llm/api/openai_client.py b/graphgen/models/llm/api/openai_client.py index 448d6625..532b981c 100644 --- a/graphgen/models/llm/api/openai_client.py +++ b/graphgen/models/llm/api/openai_client.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional import openai -from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError +from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, AsyncAzureOpenAI, RateLimitError from tenacity import ( retry, retry_if_exception_type, @@ -35,17 +35,20 @@ def __init__( model: str = "gpt-4o-mini", api_key: Optional[str] = None, base_url: Optional[str] = None, + api_version: Optional[str] = None, json_mode: bool = False, seed: Optional[int] = None, topk_per_token: int = 5, # number of topk tokens to generate for each token request_limit: bool = False, rpm: Optional[RPM] = None, tpm: Optional[TPM] = None, + backend: str = "openai_api", **kwargs: Any, ): super().__init__(**kwargs) self.model = model self.api_key = api_key + self.api_version = api_version # required for Azure OpenAI self.base_url = base_url self.json_mode = json_mode self.seed = seed @@ -56,13 +59,32 @@ def __init__( self.rpm = rpm or RPM() self.tpm = tpm or TPM() + assert ( + backend in ("openai_api", "azure_openai_api") + ), f"Unsupported backend '{backend}'. Use 'openai_api' or 'azure_openai_api'." + self.backend = backend + self.__post_init__() def __post_init__(self): - assert self.api_key is not None, "Please provide api key to access openai api." - self.client = AsyncOpenAI( - api_key=self.api_key or "dummy", base_url=self.base_url - ) + + api_name = self.backend.replace("_", " ") + assert self.api_key is not None, f"Please provide api key to access {api_name}." + if self.backend == "openai_api": + self.client = AsyncOpenAI( + api_key=self.api_key or "dummy", base_url=self.base_url + ) + elif self.backend == "azure_openai_api": + assert self.api_version is not None, f"Please provide api_version for {api_name}." + assert self.base_url is not None, f"Please provide base_url for {api_name}." + self.client = AsyncAzureOpenAI( + api_key=self.api_key, + azure_endpoint=self.base_url, + api_version=self.api_version, + azure_deployment=self.model, + ) + else: + raise ValueError(f"Unsupported backend {self.backend}. Use 'openai_api' or 'azure_openai_api'.") def _pre_generate(self, text: str, history: List[str]) -> Dict: kwargs = { diff --git a/graphgen/operators/init/init_llm.py b/graphgen/operators/init/init_llm.py index 79a3618b..e294d2c3 100644 --- a/graphgen/operators/init/init_llm.py +++ b/graphgen/operators/init/init_llm.py @@ -27,10 +27,11 @@ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper: from graphgen.models.llm.api.http_client import HTTPClient return HTTPClient(**config) - if backend == "openai_api": + if backend in ("openai_api", "azure_openai_api"): from graphgen.models.llm.api.openai_client import OpenAIClient - - return OpenAIClient(**config) + # pass in concrete backend to the OpenAIClient so that internally we can distinguish + # between OpenAI and Azure OpenAI + return OpenAIClient(**config, backend=backend) if backend == "ollama_api": from graphgen.models.llm.api.ollama_client import OllamaClient