Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,9 @@ Experience it on the [OpenXLab Application Center](https://openxlab.org.cn/apps/

### Workflow
![workflow](resources/images/flow.png)


## 🍀 Acknowledgements
- [SiliconCloud](https://siliconflow.cn) Abundant LLM API, some models are free
- [LightRAG](https://github.com/HKUDS/LightRAG) Simple and efficient graph retrieval solution
- [ROGRAG](https://github.com/tpoisonooo/ROGRAG) ROGRAG: A Robustly Optimized GraphRAG Framework
6 changes: 3 additions & 3 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ async def async_split_chunks(self, data: Union[List[list], List[dict]], data_typ

cur_index = 1
doc_number = len(new_docs)
for doc_key, doc in tqdm_async(
new_docs.items(), desc="Chunking documents", unit="doc"
async for doc_key, doc in tqdm_async(
new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
):
chunks = {
compute_content_hash(dp["content"], prefix="chunk-"): {
Expand Down Expand Up @@ -117,7 +117,7 @@ async def async_split_chunks(self, data: Union[List[list], List[dict]], data_typ
logger.warning("All docs are already in the storage")
return {}
logger.info("[New Docs] inserting %d docs", len(new_docs))
for doc in tqdm_async(data, desc="Chunking documents", unit="doc"):
async for doc in tqdm_async(data, desc="[1/4]Chunking documents", unit="doc"):
doc_str = "".join([chunk['content'] for chunk in doc])
for chunk in doc:
chunk_key = compute_content_hash(chunk['content'], prefix="chunk-")
Expand Down
6 changes: 3 additions & 3 deletions graphgen/operators/extract_kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,16 @@ async def _process_single_content(chunk: Chunk, max_loop: int = 3):

results = []
chunk_number = len(chunks)
for result in tqdm_async(
async for result in tqdm_async(
asyncio.as_completed([_process_single_content(c) for c in chunks]),
total=len(chunks),
desc="Extracting entities and relationships from chunks",
desc="[3/4]Extracting entities and relationships from chunks",
unit="chunk",
):
try:
results.append(await result)
if progress_bar is not None:
progress_bar(len(results) / chunk_number, desc="Extracting entities and relationships from chunks")
progress_bar(len(results) / chunk_number, desc="[3/4]Extracting entities and relationships from chunks")
except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while extracting entities and relationships from chunks: %s", e)

Expand Down
14 changes: 7 additions & 7 deletions graphgen/operators/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,11 @@ async def _process_single_batch(

for result in tqdm_async(asyncio.as_completed(
[_process_single_batch(batch) for batch in processing_batches]
), total=len(processing_batches), desc="Generating QAs"):
), total=len(processing_batches), desc="[4/4]Generating QAs"):
try:
results.update(await result)
if progress_bar is not None:
progress_bar(len(results) / len(processing_batches), desc="Generating QAs")
progress_bar(len(results) / len(processing_batches), desc="[4/4]Generating QAs")
except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while generating QA: %s", e)

Expand Down Expand Up @@ -398,12 +398,12 @@ async def _generate_question(
for result in tqdm_async(
asyncio.as_completed([_generate_question(task) for task in tasks]),
total=len(tasks),
desc="Generating QAs"
desc="[4/4]Generating QAs"
):
try:
results.update(await result)
if progress_bar is not None:
progress_bar(len(results) / len(tasks), desc="Generating QAs")
progress_bar(len(results) / len(tasks), desc="[4/4]Generating QAs")
except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while generating QA: %s", e)
return results
Expand Down Expand Up @@ -507,15 +507,15 @@ async def _process_single_batch(
logger.error("Error occurred while processing batch: %s", e)
return {}

for result in tqdm_async(
async for result in tqdm_async(
asyncio.as_completed([_process_single_batch(batch) for batch in processing_batches]),
total=len(processing_batches),
desc="Generating QAs"
desc="[4/4]Generating QAs"
):
try:
results.update(await result)
if progress_bar is not None:
progress_bar(len(results) / len(processing_batches), desc="Generating QAs")
progress_bar(len(results) / len(processing_batches), desc="[4/4]Generating QAs")
except Exception as e: # pylint: disable=broad-except
logger.error("Error occurred while generating QA: %s", e)
return results
19 changes: 7 additions & 12 deletions webui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,10 @@ def sum_tokens(client):
# Test API connection
test_api_connection(env["SYNTHESIZER_BASE_URL"],
env["SYNTHESIZER_API_KEY"], env["SYNTHESIZER_MODEL"])
progress(0.1, "API Connection Successful")

# Initialize GraphGen
graph_gen = init_graph_gen(config, env)
graph_gen.clear()
progress(0.2, "Model Initialized")

graph_gen.progress_bar = progress

Expand Down Expand Up @@ -157,21 +155,17 @@ def sum_tokens(client):

# Process the data
graph_gen.insert(data, data_type)
progress(0.4, "Data inserted")

if config['if_trainee_model']:
# Generate quiz
graph_gen.quiz(max_samples=quiz_samples)
progress(0.6, "Quiz generated")

# Judge statements
graph_gen.judge()
progress(0.8, "Statements judged")
else:
graph_gen.traverse_strategy.edge_sampling = "random"
# Skip judge statements
graph_gen.judge(skip=True)
progress(0.8, "Statements judged")

# Traverse graph
graph_gen.traverse()
Expand Down Expand Up @@ -212,7 +206,6 @@ def sum_tokens(client):
except Exception as e:
raise gr.Error(f"DataFrame operation error: {str(e)}")

progress(1.0, "Graph traversed")
return output_file, gr.DataFrame(label='Token Stats',
headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
datatype=["str", "str", "str"],
Expand Down Expand Up @@ -378,7 +371,7 @@ def sum_tokens(client):
with gr.Column():
rpm = gr.Slider(
label="RPM",
minimum=500,
minimum=10,
maximum=10000,
value=1000,
step=100,
Expand All @@ -388,7 +381,7 @@ def sum_tokens(client):
tpm = gr.Slider(
label="TPM",
minimum=5000,
maximum=100000,
maximum=5000000,
value=50000,
step=1000,
interactive=True,
Expand Down Expand Up @@ -435,9 +428,11 @@ def sum_tokens(client):
test_api_connection,
inputs=[base_url, api_key, synthesizer_model],
outputs=[])
test_connection_btn.click(test_api_connection,
inputs=[base_url, api_key, trainee_model],
outputs=[])

if if_trainee_model.value:
test_connection_btn.click(test_api_connection,
inputs=[base_url, api_key, trainee_model],
outputs=[])

expand_method.change(lambda method:
(gr.update(visible=method == "max_width"),
Expand Down
2 changes: 1 addition & 1 deletion webui/translation.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
},
"zh": {
"Title": "✨开箱即用的LLM训练数据生成框架✨",
"Intro": "是一个基于知识图谱的合成数据生成框架,旨在解决知识密集型问答生成的挑战。\n\n 上传你的文本块(如农业、医疗、海洋知识),填写 LLM api key,即可在线生成 **[LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)**、**[xtuner](https://github.com/InternLM/xtuner)** 所需训练数据。结束后我们将自动删除用户信息。",
"Intro": "是一个基于知识图谱的数据合成框架,旨在知识密集型任务中生成问答。\n\n 上传你的文本块(如农业、医疗、海洋知识),填写 LLM api key,即可在线生成 **[LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)**、**[xtuner](https://github.com/InternLM/xtuner)** 所需训练数据。结束后我们将自动删除用户信息。",
"Use Trainee Model": "使用Trainee Model来识别知识盲区,使用硅基流动时请保持禁用",
"Base URL Info": "调用模型API的URL,默认使用硅基流动",
"Synthesizer Model Info": "用于构建知识图谱和生成问答的模型",
Expand Down