Skip to content

Commit 8024193

Browse files
Merge pull request #15 from Tendo33/main
refactor(graphgen): update imports and adapt GraphGen class
2 parents 14667ba + c824d92 commit 8024193

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

graphgen/graphgen.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,35 @@
11
# Adapt from https://github.com/HKUDS/LightRAG
22

3-
import os
43
import asyncio
4+
import os
55
import time
6-
from typing import List, cast, Union
7-
from dataclasses import dataclass
6+
from dataclasses import dataclass, field
7+
from typing import List, Union, cast
88

9-
from tqdm.asyncio import tqdm as tqdm_async
109
import gradio as gr
10+
from tqdm.asyncio import tqdm as tqdm_async
1111

12-
from .models import Chunk, JsonKVStorage, OpenAIModel, NetworkXStorage, WikiSearch, Tokenizer, TraverseStrategy
12+
from .models import (
13+
Chunk,
14+
JsonKVStorage,
15+
NetworkXStorage,
16+
OpenAIModel,
17+
Tokenizer,
18+
TraverseStrategy,
19+
WikiSearch,
20+
)
1321
from .models.storage.base_storage import StorageNameSpace
14-
from .utils import create_event_loop, logger, compute_content_hash
15-
from .operators import (extract_kg, search_wikipedia, quiz, judge_statement,
16-
skip_judge_statement, traverse_graph_by_edge,
17-
traverse_graph_atomically, traverse_graph_for_multi_hop)
18-
22+
from .operators import (
23+
extract_kg,
24+
judge_statement,
25+
quiz,
26+
search_wikipedia,
27+
skip_judge_statement,
28+
traverse_graph_atomically,
29+
traverse_graph_by_edge,
30+
traverse_graph_for_multi_hop,
31+
)
32+
from .utils import compute_content_hash, create_event_loop, logger
1933

2034
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
2135

@@ -35,10 +49,10 @@ class GraphGen:
3549

3650
# web search
3751
if_web_search: bool = False
38-
wiki_client: WikiSearch = WikiSearch()
52+
wiki_client: WikiSearch = field(default_factory=WikiSearch)
3953

4054
# traverse strategy
41-
traverse_strategy: TraverseStrategy = TraverseStrategy()
55+
traverse_strategy: TraverseStrategy = field(default_factory=TraverseStrategy)
4256

4357
# webui
4458
progress_bar: gr.Progress = None

0 commit comments

Comments
 (0)