Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ TRAINEE_MODEL=gpt-4o-mini
TRAINEE_BASE_URL=
TRAINEE_API_KEY=

# azure_openai_api
# SYNTHESIZER_BACKEND=azure_openai_api
# The following is the same as your "Deployment name" in Azure
# SYNTHESIZER_MODEL=<your-deployment-name>
# SYNTHESIZER_BASE_URL=https://<your-resource-name>.openai.azure.com/openai/deployments/<your-deployment-name>/chat/completions
# SYNTHESIZER_API_KEY=
# SYNTHESIZER_API_VERSION=<api-version>

# # ollama_api
# SYNTHESIZER_BACKEND=ollama_api
# SYNTHESIZER_MODEL=gemma3
Expand Down
32 changes: 27 additions & 5 deletions graphgen/models/llm/api/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Optional

import openai
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, AsyncAzureOpenAI, RateLimitError
from tenacity import (
retry,
retry_if_exception_type,
Expand Down Expand Up @@ -35,17 +35,20 @@ def __init__(
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
api_version: Optional[str] = None,
json_mode: bool = False,
seed: Optional[int] = None,
topk_per_token: int = 5, # number of topk tokens to generate for each token
request_limit: bool = False,
rpm: Optional[RPM] = None,
tpm: Optional[TPM] = None,
backend: str = "openai_api",
**kwargs: Any,
):
super().__init__(**kwargs)
self.model = model
self.api_key = api_key
self.api_version = api_version # required for Azure OpenAI
self.base_url = base_url
self.json_mode = json_mode
self.seed = seed
Expand All @@ -56,13 +59,32 @@ def __init__(
self.rpm = rpm or RPM()
self.tpm = tpm or TPM()

assert (
backend in ("openai_api", "azure_openai_api")
), f"Unsupported backend '{backend}'. Use 'openai_api' or 'azure_openai_api'."
self.backend = backend

self.__post_init__()

def __post_init__(self):
assert self.api_key is not None, "Please provide api key to access openai api."
self.client = AsyncOpenAI(
api_key=self.api_key or "dummy", base_url=self.base_url
)

api_name = self.backend.replace("_", " ")
assert self.api_key is not None, f"Please provide api key to access {api_name}."
if self.backend == "openai_api":
self.client = AsyncOpenAI(
api_key=self.api_key or "dummy", base_url=self.base_url
)
elif self.backend == "azure_openai_api":
assert self.api_version is not None, f"Please provide api_version for {api_name}."
assert self.base_url is not None, f"Please provide base_url for {api_name}."
self.client = AsyncAzureOpenAI(
api_key=self.api_key,
azure_endpoint=self.base_url,
api_version=self.api_version,
azure_deployment=self.model,
)
else:
raise ValueError(f"Unsupported backend {self.backend}. Use 'openai_api' or 'azure_openai_api'.")

def _pre_generate(self, text: str, history: List[str]) -> Dict:
kwargs = {
Expand Down
7 changes: 4 additions & 3 deletions graphgen/operators/init/init_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
from graphgen.models.llm.api.http_client import HTTPClient

return HTTPClient(**config)
if backend == "openai_api":
if backend in ("openai_api", "azure_openai_api"):
from graphgen.models.llm.api.openai_client import OpenAIClient

return OpenAIClient(**config)
# pass in concrete backend to the OpenAIClient so that internally we can distinguish
# between OpenAI and Azure OpenAI
return OpenAIClient(**config, backend=backend)
if backend == "ollama_api":
from graphgen.models.llm.api.ollama_client import OllamaClient

Expand Down