Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion baselines/Genie/genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ async def process_chunk(content: str):
load_dotenv()

llm_client = OpenAIClient(
model_name=os.getenv("SYNTHESIZER_MODEL"),
model=os.getenv("SYNTHESIZER_MODEL"),
api_key=os.getenv("SYNTHESIZER_API_KEY"),
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
)
Expand Down
2 changes: 1 addition & 1 deletion baselines/LongForm/longform.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def process_chunk(content: str):
load_dotenv()

llm_client = OpenAIClient(
model_name=os.getenv("SYNTHESIZER_MODEL"),
model=os.getenv("SYNTHESIZER_MODEL"),
api_key=os.getenv("SYNTHESIZER_API_KEY"),
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
)
Expand Down
2 changes: 1 addition & 1 deletion baselines/SELF-QA/self-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async def process_chunk(content: str):
load_dotenv()

llm_client = OpenAIClient(
model_name=os.getenv("SYNTHESIZER_MODEL"),
model=os.getenv("SYNTHESIZER_MODEL"),
api_key=os.getenv("SYNTHESIZER_API_KEY"),
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
)
Expand Down
2 changes: 1 addition & 1 deletion baselines/Wrap/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ async def process_chunk(content: str):
load_dotenv()

llm_client = OpenAIClient(
model_name=os.getenv("SYNTHESIZER_MODEL"),
model=os.getenv("SYNTHESIZER_MODEL"),
api_key=os.getenv("SYNTHESIZER_API_KEY"),
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
)
Expand Down
8 changes: 4 additions & 4 deletions graphgen/models/llm/api/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class OpenAIClient(BaseLLMWrapper):
def __init__(
self,
*,
model_name: str = "gpt-4o-mini",
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
json_mode: bool = False,
Expand All @@ -44,7 +44,7 @@ def __init__(
**kwargs: Any,
):
super().__init__(**kwargs)
self.model_name = model_name
self.model = model
self.api_key = api_key
self.base_url = base_url
self.json_mode = json_mode
Expand Down Expand Up @@ -109,7 +109,7 @@ async def generate_topk_per_token(
kwargs["max_tokens"] = 1

completion = await self.client.chat.completions.create( # pylint: disable=E1125
model=self.model_name, **kwargs
model=self.model, **kwargs
)

tokens = get_top_response_tokens(completion)
Expand Down Expand Up @@ -141,7 +141,7 @@ async def generate_answer(
await self.tpm.wait(estimated_tokens, silent=True)

completion = await self.client.chat.completions.create( # pylint: disable=E1125
model=self.model_name, **kwargs
model=self.model, **kwargs
)
if hasattr(completion, "usage"):
self.token_usage.append(
Expand Down
4 changes: 2 additions & 2 deletions webui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:

tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
synthesizer_llm_client = OpenAIClient(
model_name=env.get("SYNTHESIZER_MODEL", ""),
model=env.get("SYNTHESIZER_MODEL", ""),
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
api_key=env.get("SYNTHESIZER_API_KEY", ""),
request_limit=True,
Expand All @@ -51,7 +51,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
tokenizer=tokenizer_instance,
)
trainee_llm_client = OpenAIClient(
model_name=env.get("TRAINEE_MODEL", ""),
model=env.get("TRAINEE_MODEL", ""),
base_url=env.get("TRAINEE_BASE_URL", ""),
api_key=env.get("TRAINEE_API_KEY", ""),
request_limit=True,
Expand Down