diff --git a/baselines/BDS/bds.py b/baselines/BDS/bds.py index 33c36718..419fdcc3 100644 --- a/baselines/BDS/bds.py +++ b/baselines/BDS/bds.py @@ -1,15 +1,15 @@ import argparse import asyncio import json -import os -from dataclasses import dataclass from typing import List import networkx as nx from dotenv import load_dotenv from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import NetworkXStorage, OpenAIClient, Tokenizer +from graphgen.bases import BaseLLMWrapper +from graphgen.models import NetworkXStorage +from graphgen.operators import init_llm from graphgen.utils import create_event_loop QA_GENERATION_PROMPT = """ @@ -52,10 +52,12 @@ def _post_process(text: str) -> dict: return {} -@dataclass class BDS: - llm_client: OpenAIClient = None - max_concurrent: int = 1000 + def __init__(self, llm_client: BaseLLMWrapper = None, max_concurrent: int = 1000): + self.llm_client: BaseLLMWrapper = llm_client or init_llm( + "synthesizer" + ) + self.max_concurrent: int = max_concurrent def generate(self, tasks: List[dict]) -> List[dict]: loop = create_event_loop() @@ -102,16 +104,7 @@ async def job(item): load_dotenv() - tokenizer_instance: Tokenizer = Tokenizer( - model_name=os.getenv("TOKENIZER_MODEL", "cl100k_base") - ) - llm_client = OpenAIClient( - model_name=os.getenv("SYNTHESIZER_MODEL"), - api_key=os.getenv("SYNTHESIZER_API_KEY"), - base_url=os.getenv("SYNTHESIZER_BASE_URL"), - tokenizer_instance=tokenizer_instance, - ) - bds = BDS(llm_client=llm_client) + bds = BDS() graph = NetworkXStorage.load_nx_graph(args.input_file)