diff --git a/graphgen/operators/build_kg/build_kg_service.py b/graphgen/operators/build_kg/build_kg_service.py index ec3c7cc1..532cfc79 100644 --- a/graphgen/operators/build_kg/build_kg_service.py +++ b/graphgen/operators/build_kg/build_kg_service.py @@ -12,12 +12,16 @@ class BuildKGService(BaseOperator): - def __init__(self, working_dir: str = "cache", graph_backend: str = "kuzu"): + def __init__( + self, working_dir: str = "cache", graph_backend: str = "kuzu", **build_kwargs + ): super().__init__(working_dir=working_dir, op_name="build_kg_service") self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.graph_storage: BaseGraphStorage = init_storage( backend=graph_backend, working_dir=working_dir, namespace="graph" ) + self.build_kwargs = build_kwargs + self.max_loop: int = int(self.build_kwargs.get("max_loop", 3)) def process(self, batch: pd.DataFrame) -> pd.DataFrame: docs = batch.to_dict(orient="records") @@ -46,6 +50,7 @@ def build_kg(self, chunks: List[Chunk]) -> None: llm_client=self.llm_client, kg_instance=self.graph_storage, chunks=text_chunks, + max_loop=self.max_loop, ) if len(mm_chunks) == 0: logger.info("All multi-modal chunks are already in the storage") diff --git a/graphgen/operators/build_kg/build_text_kg.py b/graphgen/operators/build_kg/build_text_kg.py index 1b5a8762..b599e5c2 100644 --- a/graphgen/operators/build_kg/build_text_kg.py +++ b/graphgen/operators/build_kg/build_text_kg.py @@ -12,15 +12,17 @@ def build_text_kg( llm_client: BaseLLMWrapper, kg_instance: BaseGraphStorage, chunks: List[Chunk], + max_loop: int = 3, ): """ :param llm_client: Synthesizer LLM model to extract entities and relationships :param kg_instance :param chunks + :param max_loop: Maximum number of loops for entity and relationship extraction :return: """ - kg_builder = LightRAGKGBuilder(llm_client=llm_client, max_loop=3) + kg_builder = LightRAGKGBuilder(llm_client=llm_client, max_loop=max_loop) results = run_concurrent( kg_builder.extract,