Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ For any questions, please check [FAQ](https://github.com/open-sciencelab/GraphGe
```bash
python -m webui.app
```

For hot-reload during development, run
```bash
PYTHONPATH=. gradio webui/app.py
```


![ui](https://github.com/user-attachments/assets/3024e9bc-5d45-45f8-a4e6-b57bd2350d84)

Expand Down
7 changes: 7 additions & 0 deletions README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ GraphGen 首先根据源文本构建细粒度的知识图谱,然后利用期
```bash
python -m webui.app
```

如果在开发过程中需要热重载,请运行

```bash
PYTHONPATH=. gradio webui/app.py
```


![ui](https://github.com/user-attachments/assets/3024e9bc-5d45-45f8-a4e6-b57bd2350d84)

Expand Down
2 changes: 1 addition & 1 deletion graphgen/configs/aggregated_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ partition: # graph partition configuration
max_units_per_community: 20 # max nodes and edges per community
min_units_per_community: 5 # min nodes and edges per community
max_tokens_per_community: 10240 # max tokens per community
unit_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
generate:
mode: aggregated # atomic, aggregated, multi_hop, cot
data_format: ChatML # Alpaca, Sharegpt, ChatML
2 changes: 1 addition & 1 deletion graphgen/configs/multi_hop_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ partition: # graph partition configuration
max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3
min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3
max_tokens_per_community: 10240 # max tokens per community
unit_sampling: random # edge sampling strategy, support: random, max_loss, min_loss
unit_sampling: random # unit sampling strategy, support: random, max_loss, min_loss
generate:
mode: multi_hop # strategy for generating multi-hop QA pairs
data_format: ChatML # Alpaca, Sharegpt, ChatML
5 changes: 4 additions & 1 deletion graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,10 @@ async def generate(self, partition_config: Dict, generate_config: Dict):

# Step 2: generate QA pairs
results = await generate_qas(
self.synthesizer_llm_client, batches, generate_config
self.synthesizer_llm_client,
batches,
generate_config,
progress_bar=self.progress_bar,
)

if not results:
Expand Down
4 changes: 2 additions & 2 deletions graphgen/models/kg_builder/light_rag_kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def extract(

# step 2: initial glean
final_result = await self.llm_client.generate_answer(hint_prompt)
logger.debug("First extraction result: %s", final_result)
logger.info("First extraction result: %s", final_result)

# step3: iterative refinement
history = pack_history_conversations(hint_prompt, final_result)
Expand All @@ -57,7 +57,7 @@ async def extract(
glean_result = await self.llm_client.generate_answer(
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
)
logger.debug("Loop %s glean: %s", loop_idx + 1, glean_result)
logger.info("Loop %s glean: %s", loop_idx + 1, glean_result)

history += pack_history_conversations(
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
Expand Down
4 changes: 2 additions & 2 deletions graphgen/models/partitioner/ece_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
class ECEPartitioner(BFSPartitioner):
"""
ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE).
We calculate ECE for edges in KG (represented as 'comprehension loss')
and group edges with similar ECE values into the same community.
We calculate ECE for units in KG (represented as 'comprehension loss')
and group units with similar ECE values into the same community.
1. Select a sampling strategy.
2. Choose a unit based on the sampling strategy.
2. Expand the community using BFS.
Expand Down
3 changes: 3 additions & 0 deletions graphgen/operators/generate/generate_qas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ async def generate_qas(
]
],
generation_config: dict,
progress_bar=None,
) -> list[dict[str, Any]]:
"""
Generate question-answer pairs based on nodes and edges.
:param llm_client: LLM client
:param batches
:param generation_config
:param progress_bar
:return: QA pairs
"""
mode = generation_config["mode"]
Expand All @@ -45,6 +47,7 @@ async def generate_qas(
batches,
desc="[4/4]Generating QAs",
unit="batch",
progress_bar=progress_bar,
)

# format
Expand Down
120 changes: 104 additions & 16 deletions graphgen/utils/run_concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,77 @@
R = TypeVar("R")


# async def run_concurrent(
# coro_fn: Callable[[T], Awaitable[R]],
# items: List[T],
# *,
# desc: str = "processing",
# unit: str = "item",
# progress_bar: Optional[gr.Progress] = None,
# ) -> List[R]:
# tasks = [asyncio.create_task(coro_fn(it)) for it in items]
#
# results = []
# async for future in tqdm_async(
# tasks, desc=desc, unit=unit
# ):
# try:
# result = await future
# results.append(result)
# except Exception as e: # pylint: disable=broad-except
# logger.exception("Task failed: %s", e)
#
# if progress_bar is not None:
# progress_bar((len(results)) / len(items), desc=desc)
#
# if progress_bar is not None:
# progress_bar(1.0, desc=desc)
# return results

# results = await tqdm_async.gather(*tasks, desc=desc, unit=unit)
#
# ok_results = []
# for idx, res in enumerate(results):
# if isinstance(res, Exception):
# logger.exception("Task failed: %s", res)
# if progress_bar:
# progress_bar((idx + 1) / len(items), desc=desc)
# continue
# ok_results.append(res)
# if progress_bar:
# progress_bar((idx + 1) / len(items), desc=desc)
#
# if progress_bar:
# progress_bar(1.0, desc=desc)
# return ok_results

# async def run_concurrent(
# coro_fn: Callable[[T], Awaitable[R]],
# items: List[T],
# *,
# desc: str = "processing",
# unit: str = "item",
# progress_bar: Optional[gr.Progress] = None,
# ) -> List[R]:
# tasks = [asyncio.create_task(coro_fn(it)) for it in items]
#
# results = []
# # 使用同步方式更新进度条,避免异步冲突
# for i, task in enumerate(asyncio.as_completed(tasks)):
# try:
# result = await task
# results.append(result)
# # 同步更新进度条
# if progress_bar is not None:
# # 在同步上下文中更新进度
# progress_bar((i + 1) / len(items), desc=desc)
# except Exception as e:
# logger.exception("Task failed: %s", e)
# results.append(e)
#
# return results


async def run_concurrent(
coro_fn: Callable[[T], Awaitable[R]],
items: List[T],
Expand All @@ -20,19 +91,36 @@ async def run_concurrent(
) -> List[R]:
tasks = [asyncio.create_task(coro_fn(it)) for it in items]

results = await tqdm_async.gather(*tasks, desc=desc, unit=unit)

ok_results = []
for idx, res in enumerate(results):
if isinstance(res, Exception):
logger.exception("Task failed: %s", res)
if progress_bar:
progress_bar((idx + 1) / len(items), desc=desc)
continue
ok_results.append(res)
if progress_bar:
progress_bar((idx + 1) / len(items), desc=desc)

if progress_bar:
progress_bar(1.0, desc=desc)
return ok_results
completed_count = 0
results = []

pbar = tqdm_async(total=len(items), desc=desc, unit=unit)

if progress_bar is not None:
progress_bar(0.0, desc=f"{desc} (0/{len(items)})")

for future in asyncio.as_completed(tasks):
try:
result = await future
results.append(result)
except Exception as e: # pylint: disable=broad-except
logger.exception("Task failed: %s", e)
# even if failed, record it to keep results consistent with tasks
results.append(e)

completed_count += 1
pbar.update(1)

if progress_bar is not None:
progress = completed_count / len(items)
progress_bar(progress, desc=f"{desc} ({completed_count}/{len(items)})")

pbar.close()

if progress_bar is not None:
progress_bar(1.0, desc=f"{desc} (completed)")

# filter out exceptions
results = [res for res in results if not isinstance(res, Exception)]

return results
Loading