Skip to content

Commit f79895d

Browse files
fix: fix lint problems
1 parent f2f9d7f commit f79895d

File tree

1 file changed

+45
-29
lines changed

1 file changed

+45
-29
lines changed

baselines/EntiGraph/entigraph.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@
5252
<Discussion on how the two entities interact within the article>
5353
"""
5454

55-
OPENAI_API_SYSTEM_QUALITY_QA_SFT = """You are an assistant to help read a article and then rephrase it in a question answering format. The user will provide you with an article with its content. You need to generate a paraphrase of the same article in question and answer format with multiple tags of "Question: ..." followed by "Answer: ...". Remember to keep the meaning and every content of the article intact.
55+
OPENAI_API_SYSTEM_QUALITY_QA_SFT = """You are an assistant to help read a article and then rephrase it in a question answering format.
56+
The user will provide you with an article with its content. You need to generate a paraphrase of the same article in question and answer format with multiple tags of "Question: ..." followed by "Answer: ...".
57+
Remember to keep the meaning and every content of the article intact.
5658
5759
Here is the format you should follow for your response:
5860
Question: <Question>
@@ -116,13 +118,13 @@ async def generate_entities(self, content: str) -> Dict:
116118
response_str = response_str.split("```json")[1].split("```")[0].strip()
117119
elif "```" in response_str:
118120
response_str = response_str.split("```")[1].split("```")[0].strip()
119-
121+
120122
# Find start and end of json
121123
start = response_str.find("{")
122124
end = response_str.rfind("}")
123125
if start != -1 and end != -1:
124126
response_str = response_str[start : end + 1]
125-
127+
126128
response = json.loads(response_str)
127129
if "entities" in response and "summary" in response:
128130
return response
@@ -146,11 +148,11 @@ async def generate_qa_sft(self, content: str) -> str:
146148
prompt = OPENAI_API_SYSTEM_QUALITY_QA_SFT.format(doc=content)
147149
return await self.client_qa.generate_answer(prompt)
148150

149-
def generate(self, docs: List[str]) -> List[dict]:
151+
def generate(self, input_docs: List[str]) -> List[dict]:
150152
loop = create_event_loop()
151-
return loop.run_until_complete(self.async_generate(docs))
153+
return loop.run_until_complete(self.async_generate(input_docs))
152154

153-
async def async_generate(self, docs: List[str]) -> dict:
155+
async def async_generate(self, input_docs: List[str]) -> dict:
154156
semaphore = asyncio.Semaphore(self.max_concurrent)
155157

156158
# 1. Generate Entities
@@ -167,8 +169,8 @@ async def process_entities(doc_text):
167169

168170
entities_results = []
169171
for result in tqdm_async(
170-
asyncio.as_completed([process_entities(doc) for doc in docs]),
171-
total=len(docs),
172+
asyncio.as_completed([process_entities(doc) for doc in input_docs]),
173+
total=len(input_docs),
172174
desc="Generating entities"
173175
):
174176
res = await result
@@ -213,7 +215,7 @@ async def process_relation(pair):
213215

214216
# 3. Generate QA SFT
215217
final_results = {}
216-
218+
217219
async def process_qa(text):
218220
async with semaphore:
219221
try:
@@ -255,31 +257,45 @@ def _post_process_synthetic_data(data: str) -> dict:
255257
return qas
256258

257259

260+
def load_from_json(file_obj) -> List[str]:
261+
"""Helper to load docs from a standard JSON list."""
262+
documents = []
263+
data = json.load(file_obj)
264+
if isinstance(data, list):
265+
for item in data:
266+
if isinstance(item, list):
267+
for chunk in item:
268+
if isinstance(chunk, dict) and "content" in chunk:
269+
documents.append(chunk["content"])
270+
elif isinstance(item, dict) and "content" in item:
271+
documents.append(item["content"])
272+
return documents
273+
274+
275+
def load_from_jsonl(file_obj) -> List[str]:
276+
"""Helper to load docs from a JSONL file."""
277+
documents = []
278+
file_obj.seek(0)
279+
for line in file_obj:
280+
if not line.strip():
281+
continue
282+
try:
283+
item = json.loads(line)
284+
if isinstance(item, dict) and "content" in item:
285+
documents.append(item["content"])
286+
except json.JSONDecodeError:
287+
continue
288+
return documents
289+
290+
258291
def load_and_dedup_data(input_file: str) -> List[str]:
259292
documents = []
260-
with open(input_file, "r", encoding="utf-8") as f:
293+
with open(input_file, "r", encoding="utf-8") as file_obj:
261294
try:
262-
data = json.load(f)
263-
if isinstance(data, list):
264-
for item in data:
265-
if isinstance(item, list):
266-
for chunk in item:
267-
if isinstance(chunk, dict) and "content" in chunk:
268-
documents.append(chunk["content"])
269-
elif isinstance(item, dict) and "content" in item:
270-
documents.append(item["content"])
295+
documents = load_from_json(file_obj)
271296
except json.JSONDecodeError:
272297
# Try JSONL
273-
f.seek(0)
274-
for line in f:
275-
if not line.strip():
276-
continue
277-
try:
278-
item = json.loads(line)
279-
if isinstance(item, dict) and "content" in item:
280-
documents.append(item["content"])
281-
except json.JSONDecodeError:
282-
continue
298+
documents = load_from_jsonl(file_obj)
283299

284300
# Dedup
285301
deduped = {}

0 commit comments

Comments
 (0)