diff --git a/baselines/EntiGraph/entigraph.py b/baselines/EntiGraph/entigraph.py index d04546ef..9c985a29 100644 --- a/baselines/EntiGraph/entigraph.py +++ b/baselines/EntiGraph/entigraph.py @@ -3,225 +3,310 @@ import json import os import random -from hashlib import md5 +from dataclasses import dataclass +from typing import List, Dict +from dotenv import load_dotenv from tqdm.asyncio import tqdm as tqdm_async -from baselines.EntiGraph.inference.devapi import gptqa -from baselines.EntiGraph.tasks.baseline_task import BaselineTask - - -def compute_content_hash(content, prefix: str = ""): - return prefix + md5(content.encode()).hexdigest() - - -async def generate_entities( - document_content: str, system_message: str, openai_model: str -): - prompt = f""" - ### Document Content: - {document_content} - """ - can_read_entities = None - - max_tries = 5 - while not can_read_entities and max_tries > 0: - try: - completion = await gptqa( - prompt, openai_model, system_message, json_format=False - ) - completion = completion[completion.find("{") : completion.rfind("}") + 1] - response = json.loads(completion) - can_read_entities = response["entities"] - return response - except Exception as e: # pylint: disable=broad-except - print(f"Failed to generate entities: {str(e)}") +from graphgen.models import OpenAIClient, Tokenizer +from graphgen.utils import compute_content_hash, create_event_loop + +# Prompts from entigraph_utils/prompt_utils.py +OPENAI_API_SYSTEM_QUALITY_GENERATE_ENTITIES = """ +As a knowledge analyzer, your task is to dissect and understand an article provided by the user. You are required to perform the following steps: +1. Summarize the Article: Provide a concise summary of the entire article, capturing the main points and themes. +2. Extract Entities: Identify and list all significant "nouns" or entities mentioned within the article. These entities should include but not limited to: + * People: Any individuals mentioned in the article, using the names or references provided. + * Places: Both specific locations and abstract spaces relevant to the content. + * Object: Any concrete object that is referenced by the provided content. + * Concepts: Any significant abstract ideas or themes that are central to the article's discussion. + +Try to exhaust as many entities as possible. Your response should be structured in a JSON format to organize the information effectively. Ensure that the summary is brief yet comprehensive, and the list of entities is detailed and accurate. + +Here is the format you should use for your response: + +{ + "summary": "", + "entities": ["entity1", "entity2", ...] +} +""" + +OPENAI_API_SYSTEM_QUALITY_GENERATE_TWO_ENTITY_RELATIONS = """ +You will act as a knowledge analyzer tasked with dissecting an article provided by the user. Your role involves two main objectives: +1. Rephrasing Content: The user will identify two specific entities mentioned in the article. You are required to rephrase the content of the article twice: + * Once, emphasizing the first entity. + * Again, emphasizing the second entity. +2. Analyzing Interactions: Discuss how the two specified entities interact within the context of the article. + +Your responses should provide clear segregation between the rephrased content and the interaction analysis. Ensure each section of the output include sufficient context, ideally referencing the article's title to maintain clarity about the discussion's focus. +Here is the format you should follow for your response: + +### Discussion of in relation to <entity1> +<Rephrased content focusing on the first entity> + +### Discussion of <title> in relation to <entity2> +<Rephrased content focusing on the second entity> + +### Discussion of Interaction between <entity1> and <entity2> in context of <title> +<Discussion on how the two entities interact within the article> +""" + +OPENAI_API_SYSTEM_QUALITY_QA_SFT = """You are an assistant to help read \ +a article and then rephrase it in a question answering format. \ +The user will provide you with an article with its content. \ +You need to generate a paraphrase of the same article in question and answer format with \ +multiple tags of "Question: ..." followed by "Answer: ...". +Remember to keep the meaning and every content of the article intact. + +Here is the format you should follow for your response: +Question: <Question> +Answer: <Answer> + +Here is the article you need to rephrase: +{doc} +""" + + +@dataclass +class EntiGraph: + model_name: str + api_key: str + base_url: str + max_concurrent: int = 1000 + + def __post_init__(self): + self.tokenizer = Tokenizer() + + # Initialize specialized clients for different tasks to handle different system prompts and modes + self.client_entities = OpenAIClient( + model=self.model_name, + api_key=self.api_key, + base_url=self.base_url, + tokenizer=self.tokenizer, + system_prompt=OPENAI_API_SYSTEM_QUALITY_GENERATE_ENTITIES, + json_mode=True + ) + + self.client_relations = OpenAIClient( + model=self.model_name, + api_key=self.api_key, + base_url=self.base_url, + tokenizer=self.tokenizer, + system_prompt=OPENAI_API_SYSTEM_QUALITY_GENERATE_TWO_ENTITY_RELATIONS + ) + + self.client_qa = OpenAIClient( + model=self.model_name, + api_key=self.api_key, + base_url=self.base_url, + tokenizer=self.tokenizer, + system_prompt="You are an assistant to help read a article \ + and then rephrase it in a question answering format." + ) + + async def generate_entities(self, content: str) -> Dict: + prompt = f""" + ### Document Content: + {content} + """ + max_tries = 5 + while max_tries > 0: + try: + response_str = await self.client_entities.generate_answer(prompt) + if not response_str: + return None + # Clean up json string if needed (sometimes markdown code blocks) + if "```json" in response_str: + response_str = response_str.split("```json")[1].split("```")[0].strip() + elif "```" in response_str: + response_str = response_str.split("```")[1].split("```")[0].strip() + + # Find start and end of json + start = response_str.find("{") + end = response_str.rfind("}") + if start != -1 and end != -1: + response_str = response_str[start : end + 1] + + response = json.loads(response_str) + if "entities" in response and "summary" in response: + return response + except Exception as e: + print(f"Failed to generate entities: {e}") max_tries -= 1 + return None + + async def generate_two_entity_relations(self, document: str, entity1: str, entity2: str) -> str: + prompt = f""" + ### Document Content: + {document} + ### Entities: + - {entity1} + - {entity2} + """ + return await self.client_relations.generate_answer(prompt) + + async def generate_qa_sft(self, content: str) -> str: + # We format the prompt using the template logic + prompt = OPENAI_API_SYSTEM_QUALITY_QA_SFT.format(doc=content) + return await self.client_qa.generate_answer(prompt) + + def generate(self, input_docs: List[str]) -> List[dict]: + loop = create_event_loop() + return loop.run_until_complete(self.async_generate(input_docs)) + + async def async_generate(self, input_docs: List[str]) -> dict: + semaphore = asyncio.Semaphore(self.max_concurrent) + + # 1. Generate Entities + async def process_entities(doc_text): + async with semaphore: + res = await self.generate_entities(doc_text) + if res: + return { + "document": doc_text, + "entities": res["entities"], + "summary": res["summary"] + } + return None + entities_results = [] + for result in tqdm_async( + asyncio.as_completed([process_entities(doc) for doc in input_docs]), + total=len(input_docs), + desc="Generating entities" + ): + res = await result + if res: + entities_results.append(res) + + # 2. Generate Relations (Pairs) + pair_list = [] + random.seed(42) + for item in entities_results: + entities = item["entities"] + doc_text = item["document"] + temp_pairs = [] + for i, entity_i in enumerate(entities): + for j in range(i + 1, len(entities)): + temp_pairs.append((doc_text, entity_i, entities[j])) + + # Sample max 10 pairs per document + pair_list.extend(random.sample(temp_pairs, min(len(temp_pairs), 10))) + + async def process_relation(pair): + async with semaphore: + doc_text, e1, e2 = pair + try: + return await self.generate_two_entity_relations(doc_text, e1, e2) + except Exception as e: + print(f"Error generating relations: {e}") + return None -async def generate_two_entity_relations( - document_content: str, - entity1: str, - entity2: str, - system_message: str, - openai_model: str, -): - prompt = f""" - ### Document Content: - {document_content} - ### Entities: - - {entity1} - - {entity2} - """ - completion = await gptqa(prompt, openai_model, system_message) - return completion - - -async def generate_three_entity_relations( - document_content: str, - entity1: str, - entity2: str, - entity3: str, - system_message: str, - openai_model: str, -): - prompt = f""" - ### Document Content: - {document_content} - ### Entities: - - {entity1} - - {entity2} - - {entity3} - """ - completion = await gptqa(prompt, openai_model, system_message) - return completion - - -def _post_process_synthetic_data(data): + corpus = [] + for result in tqdm_async( + asyncio.as_completed([process_relation(pair) for pair in pair_list]), + total=len(pair_list), + desc="Generating relations" + ): + res = await result + if res: + corpus.append(res) + + # Combine summaries and relation discussions into the corpus for QA generation + full_corpus = [item["summary"] for item in entities_results] + corpus + + # 3. Generate QA SFT + final_results = {} + + async def process_qa(text): + async with semaphore: + try: + qa_text = await self.generate_qa_sft(text) + if qa_text: + return _post_process_synthetic_data(qa_text) + except Exception as e: + print(f"Error generating QA: {e}") + return {} + + for result in tqdm_async( + asyncio.as_completed([process_qa(text) for text in full_corpus]), + total=len(full_corpus), + desc="Generating QA SFT" + ): + qas = await result + if qas: + final_results.update(qas) + + return final_results + + +def _post_process_synthetic_data(data: str) -> dict: + # Logic from original code block = data.split("\n\n") qas = {} for line in block: if "Question: " in line and "Answer: " in line: - question = line.split("Question: ")[1].split("Answer: ")[0] - answer = line.split("Answer: ")[1] - qas[compute_content_hash(question)] = { - "question": question, - "answer": answer, - } - break + try: + question = line.split("Question: ")[1].split("Answer: ")[0].strip() + answer = line.split("Answer: ")[1].strip() + if question and answer: + qas[compute_content_hash(question)] = { + "question": question, + "answer": answer, + } + except IndexError: + continue return qas -async def generate_synthetic_data_for_document(input_file, data_type): - random.seed(42) - model_name = os.getenv("SYNTHESIZER_MODEL") - task = BaselineTask(input_file, data_type) - - max_concurrent = 1000 - semaphore = asyncio.Semaphore(max_concurrent) - - async def generate_document_entities(doc): - async with semaphore: - try: - entities = await generate_entities( - doc.text, task.openai_system_generate_entities, model_name - ) - if not entities: - return None - return { - "document": doc.text, - "entities": entities["entities"], - "summary": entities["summary"], - } - except Exception as e: # pylint: disable=broad-except - print(f"Error: {e}") - return None +def load_from_json(file_obj) -> List[str]: + """Helper to load docs from a standard JSON list.""" + documents = [] + data = json.load(file_obj) + if isinstance(data, list): + for item in data: + if isinstance(item, list): + for chunk in item: + if isinstance(chunk, dict) and "content" in chunk: + documents.append(chunk["content"]) + elif isinstance(item, dict) and "content" in item: + documents.append(item["content"]) + return documents + + +def load_from_jsonl(file_obj) -> List[str]: + """Helper to load docs from a JSONL file.""" + documents = [] + file_obj.seek(0) + for line in file_obj: + if not line.strip(): + continue + try: + item = json.loads(line) + if isinstance(item, dict) and "content" in item: + documents.append(item["content"]) + except json.JSONDecodeError: + continue + return documents - entities_list = [] - for result in tqdm_async( - asyncio.as_completed( - [generate_document_entities(doc) for doc in task.documents] - ), - total=len(task.documents), - desc="Generating entities", - ): - result = await result - if result: - entities_list.append(result) - - # iterate over triples of entities and generate relations - pair_list = [] - for doc in entities_list: - entities = doc["entities"] - temp = [] - for i, entity_i in enumerate(entities): - if i == len(entities) - 1: - break - for j in range(i + 1, len(entities)): - entity_j = entities[j] - pair = (doc["document"], entity_i, entity_j) - temp.append(pair) - - # Compute all possible combinations of entities is impractical, so we randomly sample 10 pairs - pair_list.extend(random.sample(temp, min(len(temp), 10))) - - async def process_two_entity_relations(pair): - async with semaphore: - try: - document, entity1, entity2 = pair - response = await generate_two_entity_relations( - document, - entity1, - entity2, - task.openai_system_generate_two_entity_relations, - model_name, - ) - return response - except Exception as e: # pylint: disable=broad-except - print(f"Error: {e}") - return None - corpus = [] - for result in tqdm_async( - asyncio.as_completed( - [process_two_entity_relations(pair) for pair in pair_list] - ), - total=len(pair_list), - desc="Generating two entity relations", - ): - result = await result - if result: - corpus.append(result) - - # triple_list = [] - # for doc in entities_list: - # entities = doc['entities'] - # for i in range(len(entities)): - # for j in range(i + 1, len(entities)): - # for k in range(j + 1, len(entities)): - # triple = (doc['document'], entities[i], entities[j], entities[k]) - # triple_list.append(triple) - # - # async def process_three_entity_relations(triple): - # async with semaphore: - # document, entity1, entity2, entity3 = triple - # response = await generate_three_entity_relations( - # document, entity1, entity2, entity3, - # task.openai_system_generate_three_entity_relations, - # model_name) - # return response - # - # for result in tqdm_async( - # asyncio.as_completed([process_three_entity_relations(triple) for triple in triple_list]), - # total=len(triple_list), - # desc="Generating three entity relations" - # ): - # corpus.append(await result) - - corpus = [doc["summary"] for doc in entities_list] + corpus - - qa_sft_results = {} - - async def generate_qa_sft(content): - async with semaphore: - completion = await gptqa( - content, model_name, task.openai_system_quality_qa_sft - ) - return completion - - for result in tqdm_async( - asyncio.as_completed([generate_qa_sft(content) for content in corpus]), - total=len(corpus), - desc="Generating QA SFT", - ): +def load_and_dedup_data(input_file: str) -> List[str]: + documents = [] + with open(input_file, "r", encoding="utf-8") as file_obj: try: - result = await result - if result: - qa_sft_results.update(_post_process_synthetic_data(result)) - except Exception as e: # pylint: disable=broad-except - print(f"Error: {e}") + documents = load_from_json(file_obj) + except json.JSONDecodeError: + # Try JSONL + documents = load_from_jsonl(file_obj) - return qa_sft_results + # Dedup + deduped = {} + for text in documents: + h = compute_content_hash(text) + if h not in deduped: + deduped[h] = text + return list(deduped.values()) if __name__ == "__main__": @@ -232,13 +317,6 @@ async def generate_qa_sft(content): default="resources/input_examples/json_demo.json", type=str, ) - parser.add_argument( - "--data_type", - help="Data type of input file. (Raw context or chunked context)", - choices=["raw", "chunked"], - default="raw", - type=str, - ) parser.add_argument( "--output_file", help="Output file path.", @@ -248,10 +326,23 @@ async def generate_qa_sft(content): args = parser.parse_args() - results = asyncio.run( - generate_synthetic_data_for_document(args.input_file, args.data_type) + load_dotenv() + + # Load data + docs = load_and_dedup_data(args.input_file) + + entigraph = EntiGraph( + model_name=os.getenv("SYNTHESIZER_MODEL"), + api_key=os.getenv("SYNTHESIZER_API_KEY"), + base_url=os.getenv("SYNTHESIZER_BASE_URL"), ) + results = entigraph.generate(docs) + # Save results + output_dir = os.path.dirname(args.output_file) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + with open(args.output_file, "w", encoding="utf-8") as f: json.dump(results, f, indent=4, ensure_ascii=False) diff --git a/baselines/EntiGraph/entigraph_utils/prompt_utils.py b/baselines/EntiGraph/entigraph_utils/prompt_utils.py deleted file mode 100644 index 0b5ae886..00000000 --- a/baselines/EntiGraph/entigraph_utils/prompt_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -# pylint: disable=C0301 - -QUALITY_FEW_SHOT_COT_PROMPT = """## Example 1 -### Question -In the context of "Les Misérables", written by Victor Hugo in 1862, what is the main setting of the novel? There is only one correct choice. -### Choices -A. London -B. Madrid -C. Paris -D. Rome -### Thought Process and Answer -Thought process: "Les Misérables" is primarily set in Paris, making C the correct choice. London, Madrid, and Rome are significant cities in other literary works but not in Victor Hugo's "Les Misérables". There is only one correct choice. -Answer: C. - -## Example 2 -### Question -In the context of "Brave New World", written by Aldous Huxley in 1932, what substance is widely used in the society to control citizens' happiness? There is only one correct choice. -### Choices -A. Gold -B. Soma -C. Silver -D. Iron -### Thought Process and Answer -Thought process: In Aldous Huxley's "Brave New World," Soma is used as a means to maintain social control by ensuring citizens' happiness, making B the correct choice. Gold, Silver, and Iron are not the substances used for this purpose in the book. -Answer: B. - -## Example 3 -### Question -In the context of "Romeo and Juliet", written by William Shakespeare in the early 1590s, what are the names of the two feuding families? There is only one correct choice. -Choices: -A. Montague and Capulet -B. Bennet and Darcy -C. Linton and Earnshaw -D. Bloom and Dedalus -### Thought Process and Answer -Thought process: In William Shakespeare's "Romeo and Juliet," the two feuding families are the Montagues and the Capulets, making A the correct choice. The Bennets and Darcys are in "Pride and Prejudice", the Lintons and Earnshaws in "Wuthering Heights", and Bloom and Dedalus in "Ulysses". -Answer: A. - -## Example 4 -### Question -In the context of "1984", written by George Orwell in 1949, what is the name of the totalitarian leader? There is only one correct choice. -### Choices -A. Big Brother -B. O'Brien -C. Winston Smith -D. Emmanuel Goldstein -### Thought Process and Answer -Thought process: In George Orwell's "1984," the totalitarian leader is known as Big Brother, making A the correct choice. O'Brien is a character in the novel, Winston Smith is the protagonist, and Emmanuel Goldstein is a rebel leader. -Answer: A. - -## Example 5 -### Question -In the context of "Moby-Dick", written by Herman Melville in 1851, what is the name of the ship's captain obsessed with hunting the titular whale? There is only one correct choice. -### Choices -A. Captain Hook -B. Captain Nemo -C. Captain Flint -D. Captain Ahab -### Thought Process and Answer -Thought process: In Herman Melville's "Moby-Dick," the ship's captain obsessed with hunting the whale is Captain Ahab, making D the correct choice. Captain Nemo is in "Twenty Thousand Leagues Under the Sea", Captain Flint in "Treasure Island", and Captain Hook in "Peter Pan". -Answer: D. - -## Example 6 -""" - -OPENAI_API_SYSTEM_QUALITY_GENERATE_ENTITIES = """ -As a knowledge analyzer, your task is to dissect and understand an article provided by the user. You are required to perform the following steps: -1. Summarize the Article: Provide a concise summary of the entire article, capturing the main points and themes. -2. Extract Entities: Identify and list all significant "nouns" or entities mentioned within the article. These entities should include but not limited to: - * People: Any individuals mentioned in the article, using the names or references provided. - * Places: Both specific locations and abstract spaces relevant to the content. - * Object: Any concrete object that is referenced by the provided content. - * Concepts: Any significant abstract ideas or themes that are central to the article's discussion. - -Try to exhaust as many entities as possible. Your response should be structured in a JSON format to organize the information effectively. Ensure that the summary is brief yet comprehensive, and the list of entities is detailed and accurate. - -Here is the format you should use for your response: - -{ - "summary": "<A concise summary of the article>", - "entities": ["entity1", "entity2", ...] -} -""" - -OPENAI_API_SYSTEM_QUALITY_GENERATE_TWO_ENTITY_RELATIONS = """ -You will act as a knowledge analyzer tasked with dissecting an article provided by the user. Your role involves two main objectives: -1. Rephrasing Content: The user will identify two specific entities mentioned in the article. You are required to rephrase the content of the article twice: - * Once, emphasizing the first entity. - * Again, emphasizing the second entity. -2. Analyzing Interactions: Discuss how the two specified entities interact within the context of the article. - -Your responses should provide clear segregation between the rephrased content and the interaction analysis. Ensure each section of the output include sufficient context, ideally referencing the article's title to maintain clarity about the discussion's focus. -Here is the format you should follow for your response: - -### Discussion of <title> in relation to <entity1> -<Rephrased content focusing on the first entity> - -### Discussion of <title> in relation to <entity2> -<Rephrased content focusing on the second entity> - -### Discussion of Interaction between <entity1> and <entity2> in context of <title> -<Discussion on how the two entities interact within the article> -""" - -OPENAI_API_SYSTEM_QUALITY_GENERATE_THREE_ENTITY_RELATIONS = """ -You will act as a knowledge analyzer tasked with dissecting an article provided by the user. Your role involves three main objectives: -1. Rephrasing Content: The user will identify three specific entities mentioned in the article. You are required to rephrase the content of the article three times: - * Once, emphasizing the first entity. - * Again, emphasizing the second entity. - * Lastly, emphasizing the third entity. -2. Analyzing Interactions: Discuss how these three specified entities interact within the context of the article. - -Your responses should provide clear segregation between the rephrased content and the interaction analysis. Ensure each section of the output include sufficient context, ideally referencing the article's title to maintain clarity about the discussion's focus. -Here is the format you should follow for your response: - -### Discussion of <title> in relation to <entity1> -<Rephrased content focusing on the first entity> - -### Discussion of <title> in relation to <entity2> -<Rephrased content focusing on the second entity> - -### Discussion of <title> in relation to <entity3> -<Rephrased content focusing on the third entity> - -### Discussion of Interaction between <entity1>, <entity2> and <entity3> in context of <title> -<Discussion on how the three entities interact within the article> -""" - -OPENAI_API_SYSTEM_QUALITY_QA_SFT = """You are an assistant to help read a article and then rephrase it in a question answering format. The user will provide you with an article with its content. You need to generate a paraphrase of the same article in question and answer format with multiple tags of "Question: ..." followed by "Answer: ...". Remember to keep the meaning and every content of the article intact. - -Here is the format you should follow for your response: -Question: <Question> -Answer: <Answer> - -Here is the article you need to rephrase: -{doc} -""" diff --git a/baselines/EntiGraph/inference/devapi.py b/baselines/EntiGraph/inference/devapi.py deleted file mode 100644 index 617ffbef..00000000 --- a/baselines/EntiGraph/inference/devapi.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -import dotenv -from openai import AsyncOpenAI - -dotenv.load_dotenv() - -async def gptqa(prompt: str, - openai_model_name: str, - system_message: str, - json_format: bool = False, - temp: float = 1.0): - client = AsyncOpenAI( - api_key=os.getenv("SYNTHESIZER_API_KEY"), - base_url=os.getenv("SYNTHESIZER_BASE_URL") - ) - openai_model_name = openai_model_name or os.getenv("SYNTHESIZER_MODEL") - - if json_format: - completion = await client.chat.completions.create( - model=openai_model_name, - temperature=temp, - response_format={ "type": "json_object" }, - messages=[ - {"role": "system", - "content": system_message}, - {"role": "user", - "content": prompt}, - ]) - else: - completion = await client.chat.completions.create( - model=openai_model_name, - temperature=temp, - messages=[ - {"role": "system", - "content": system_message}, - {"role": "user", - "content": prompt}, - ]) - return completion.choices[0].message.content diff --git a/baselines/EntiGraph/tasks/__init.py b/baselines/EntiGraph/tasks/__init.py deleted file mode 100644 index e69de29b..00000000 diff --git a/baselines/EntiGraph/tasks/baseline_task.py b/baselines/EntiGraph/tasks/baseline_task.py deleted file mode 100644 index 846ad1cf..00000000 --- a/baselines/EntiGraph/tasks/baseline_task.py +++ /dev/null @@ -1,66 +0,0 @@ -# Rewrite from https://github.com/ZitongYang/Synthetic_Continued_Pretraining/blob/main/tasks/quality.py - -import json -from hashlib import md5 - -from baselines.EntiGraph.tasks.task_abc import Document, Task -from baselines.EntiGraph.entigraph_utils.prompt_utils import ( - OPENAI_API_SYSTEM_QUALITY_GENERATE_ENTITIES, - OPENAI_API_SYSTEM_QUALITY_GENERATE_TWO_ENTITY_RELATIONS, - OPENAI_API_SYSTEM_QUALITY_GENERATE_THREE_ENTITY_RELATIONS, - QUALITY_FEW_SHOT_COT_PROMPT, OPENAI_API_SYSTEM_QUALITY_QA_SFT) - -class BaselineTask(Task): - openai_system_generate_entities = OPENAI_API_SYSTEM_QUALITY_GENERATE_ENTITIES - openai_system_generate_two_entity_relations = OPENAI_API_SYSTEM_QUALITY_GENERATE_TWO_ENTITY_RELATIONS - openai_system_generate_three_entity_relations = OPENAI_API_SYSTEM_QUALITY_GENERATE_THREE_ENTITY_RELATIONS - openai_system_quality_qa_sft = OPENAI_API_SYSTEM_QUALITY_QA_SFT - llama_cot_prompt = QUALITY_FEW_SHOT_COT_PROMPT - - def __init__(self, input_file: str, data_type: str): - self._data = self._load_split(input_file, data_type) - self._create_documents() - self._dedup() - - @staticmethod - def _load_split(input_file: str, data_type: str): - if data_type == 'raw': - with open(input_file, "r", encoding='utf-8') as f: - data = [json.loads(line) for line in f] - data = [[chunk] for chunk in data] - elif data_type == 'chunked': - with open(input_file, "r", encoding='utf-8') as f: - data = json.load(f) - - documents = [] - for doc in data: - for chunk in doc: - documents.append(chunk) - return documents - - def _create_documents(self): - documents = [] - for adict in self._data: - document = Document(text=adict['content'], questions=[]) - documents.append(document) - super().__init__('baseline', documents) - - def _dedup(self): - deuped_documents = {} - for document in self.documents: - key = compute_content_hash(document.text) - if key not in deuped_documents: - deuped_documents[key] = document - - self.documents = list(deuped_documents.values()) - - - def performance_stats(self): - pass - - def load_attempts_json(self, file_path: str): - pass - - -def compute_content_hash(content, prefix: str = ""): - return prefix + md5(content.encode()).hexdigest() diff --git a/baselines/EntiGraph/tasks/task_abc.py b/baselines/EntiGraph/tasks/task_abc.py deleted file mode 100644 index 35c21577..00000000 --- a/baselines/EntiGraph/tasks/task_abc.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import List, Dict -from abc import abstractmethod - - -class Question: - def __init__(self, statement: Dict, answer: str, attempts: List[Dict], formatted_prompt: str = ""): - self.statement = statement - self.answer = answer - self.attempts = attempts - self.formatted_prompt = formatted_prompt - - @abstractmethod - def prompt(self): - pass - - @abstractmethod - def iscorrect(self, attempt_index: int = 0): - pass - - @abstractmethod - def asdict(self): - pass - - @abstractmethod - def llama_parse_answer(self): - pass - - -class Document: - def __init__(self, text: str, questions: List[Dict]): - self.text = text - self.questions = questions - - @property - @abstractmethod - def uid(self): - pass - - @property - @abstractmethod - def content(self): - pass - - @abstractmethod - def question_prompts(self, add_document_context: bool, add_thought_process: bool, sep_after_question: str): - pass - - @abstractmethod - def asdict(self): - pass - - def majority_vote(self, n_samples): - for question in self.questions: - question.majority_vote(n_samples) - - -class Task: - openai_system_generate_entities: str - openai_system_generate_two_entity_relations: str - openai_system_generate_three_entity_relations: str - llama_cot_prompt: str - - def __init__(self, name, documents: List[Document]): - self.name = name - self.documents = documents - - @abstractmethod - def load_attempts_json(self, file_path: str): - pass - - @abstractmethod - def performance_stats(self): - pass - - def all_questions(self, add_document_context: bool, add_thought_process: bool, sep_after_question: str): - prompts = [] - for document in self.documents: - prompts += document.question_prompts(add_document_context, add_thought_process, sep_after_question) - - return prompts - - @property - def all_document_contents(self): - return '\n'.join([document.content for document in self.documents]) - - def asdict(self): - return [document.asdict() for document in self.documents] - - def majority_vote(self, n_samples: int = 1): - for document in self.documents: - document.majority_vote(n_samples) diff --git a/baselines/Genie/genie.py b/baselines/Genie/genie.py index 972b20fe..bfd219f1 100644 --- a/baselines/Genie/genie.py +++ b/baselines/Genie/genie.py @@ -8,7 +8,7 @@ from dotenv import load_dotenv from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import OpenAIClient +from graphgen.models import OpenAIClient, Tokenizer from graphgen.utils import compute_content_hash, create_event_loop PROMPT_TEMPLATE = """Instruction: Given the next [document], create a [question] and [answer] pair that are grounded \ @@ -75,14 +75,20 @@ async def process_chunk(content: str): tasks = [] for doc in docs: - for chunk in doc: - tasks.append(process_chunk(chunk["content"])) + if isinstance(doc, list): + for chunk in doc: + tasks.append(process_chunk(chunk["content"])) + elif isinstance(doc, dict): + tasks.append(process_chunk(doc["content"])) for result in tqdm_async( asyncio.as_completed(tasks), total=len(tasks), desc="Generating using Genie" ): try: - question, answer = _post_process(await result) + response = await result + if response is None: + continue + question, answer = _post_process(response) if question and answer: final_results[compute_content_hash(question)] = { "question": question, @@ -101,13 +107,6 @@ async def process_chunk(content: str): default="resources/input_examples/json_demo.json", type=str, ) - parser.add_argument( - "--data_type", - help="Data type of input file. (Raw context or chunked context)", - choices=["raw", "chunked"], - default="raw", - type=str, - ) parser.add_argument( "--output_file", help="Output file path.", @@ -123,20 +122,20 @@ async def process_chunk(content: str): model=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL"), + tokenizer=Tokenizer(model_name=os.getenv("TOKENIZER_MODEL")), ) genie = Genie(llm_client=llm_client) - if args.data_type == "raw": - with open(args.input_file, "r", encoding="utf-8") as f: - data = [json.loads(line) for line in f] - data = [[chunk] for chunk in data] - elif args.data_type == "chunked": - with open(args.input_file, "r", encoding="utf-8") as f: - data = json.load(f) + with open(args.input_file, "r", encoding="utf-8") as f: + data = json.load(f) results = genie.generate(data) # Save results + output_dir = os.path.dirname(args.output_file) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + with open(args.output_file, "w", encoding="utf-8") as f: json.dump(results, f, indent=4, ensure_ascii=False) diff --git a/baselines/LongForm/longform.py b/baselines/LongForm/longform.py index abf4fd3e..83b22e8c 100644 --- a/baselines/LongForm/longform.py +++ b/baselines/LongForm/longform.py @@ -8,7 +8,7 @@ from dotenv import load_dotenv from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import OpenAIClient +from graphgen.models import OpenAIClient, Tokenizer from graphgen.utils import compute_content_hash, create_event_loop PROMPT_TEMPLATE = """Instruction: X @@ -33,7 +33,10 @@ async def async_generate(self, docs: List[List[dict]]) -> dict: async def process_chunk(content: str): async with semaphore: - question = await self.llm_client.generate_answer(content) + prompt = PROMPT_TEMPLATE.format(doc=content) + question = await self.llm_client.generate_answer(prompt) + if question is None: + return {} return { compute_content_hash(question): { "question": question, @@ -43,8 +46,11 @@ async def process_chunk(content: str): tasks = [] for doc in docs: - for chunk in doc: - tasks.append(process_chunk(chunk["content"])) + if isinstance(doc, list): + for chunk in doc: + tasks.append(process_chunk(chunk["content"])) + elif isinstance(doc, dict): + tasks.append(process_chunk(doc["content"])) for result in tqdm_async( asyncio.as_completed(tasks), @@ -53,7 +59,8 @@ async def process_chunk(content: str): ): try: qa = await result - final_results.update(qa) + if qa: + final_results.update(qa) except Exception as e: # pylint: disable=broad-except print(f"Error: {e}") return final_results @@ -67,13 +74,6 @@ async def process_chunk(content: str): default="resources/input_examples/json_demo.json", type=str, ) - parser.add_argument( - "--data_type", - help="Data type of input file. (Raw context or chunked context)", - choices=["raw", "chunked"], - default="raw", - type=str, - ) parser.add_argument( "--output_file", help="Output file path.", @@ -89,20 +89,20 @@ async def process_chunk(content: str): model=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL"), + tokenizer=Tokenizer(model_name=os.getenv("TOKENIZER_MODEL")), ) longform = LongForm(llm_client=llm_client) - if args.data_type == "raw": - with open(args.input_file, "r", encoding="utf-8") as f: - data = [json.loads(line) for line in f] - data = [[chunk] for chunk in data] - elif args.data_type == "chunked": - with open(args.input_file, "r", encoding="utf-8") as f: - data = json.load(f) + with open(args.input_file, "r", encoding="utf-8") as f: + data = json.load(f) results = longform.generate(data) # Save results + output_dir = os.path.dirname(args.output_file) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + with open(args.output_file, "w", encoding="utf-8") as f: json.dump(results, f, indent=4, ensure_ascii=False) diff --git a/baselines/SELF-QA/self-qa.py b/baselines/SELF-QA/self-qa.py index b222d970..7f2b71cd 100644 --- a/baselines/SELF-QA/self-qa.py +++ b/baselines/SELF-QA/self-qa.py @@ -8,7 +8,7 @@ from dotenv import load_dotenv from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import OpenAIClient +from graphgen.models import OpenAIClient, Tokenizer from graphgen.utils import compute_content_hash, create_event_loop INSTRUCTION_GENERATION_PROMPT = """The background knowledge is: @@ -71,6 +71,8 @@ async def process_chunk(content: str): async with semaphore: prompt = INSTRUCTION_GENERATION_PROMPT.format(doc=content) response = await self.llm_client.generate_answer(prompt) + if response is None: + return [] try: instruction_questions = _post_process_instructions(response) @@ -90,7 +92,10 @@ async def process_chunk(content: str): desc="Generating QAs", ): try: - question, answer = _post_process_answers(await qa) + qa_response = await qa + if qa_response is None: + continue + question, answer = _post_process_answers(qa_response) if question and answer: qas.append( { @@ -110,8 +115,11 @@ async def process_chunk(content: str): tasks = [] for doc in docs: - for chunk in doc: - tasks.append(process_chunk(chunk["content"])) + if isinstance(doc, list): + for chunk in doc: + tasks.append(process_chunk(chunk["content"])) + elif isinstance(doc, dict): + tasks.append(process_chunk(doc["content"])) for result in tqdm_async( asyncio.as_completed(tasks), @@ -120,8 +128,9 @@ async def process_chunk(content: str): ): try: qas = await result - for qa in qas: - final_results.update(qa) + if qas: + for qa in qas: + final_results.update(qa) except Exception as e: # pylint: disable=broad-except print(f"Error: {e}") return final_results @@ -135,13 +144,6 @@ async def process_chunk(content: str): default="resources/input_examples/json_demo.json", type=str, ) - parser.add_argument( - "--data_type", - help="Data type of input file. (Raw context or chunked context)", - choices=["raw", "chunked"], - default="raw", - type=str, - ) parser.add_argument( "--output_file", help="Output file path.", @@ -157,20 +159,20 @@ async def process_chunk(content: str): model=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL"), + tokenizer=Tokenizer(model_name=os.getenv("TOKENIZER_MODEL")), ) self_qa = SelfQA(llm_client=llm_client) - if args.data_type == "raw": - with open(args.input_file, "r", encoding="utf-8") as f: - data = [json.loads(line) for line in f] - data = [[chunk] for chunk in data] - elif args.data_type == "chunked": - with open(args.input_file, "r", encoding="utf-8") as f: - data = json.load(f) + with open(args.input_file, "r", encoding="utf-8") as f: + data = json.load(f) results = self_qa.generate(data) # Save results + output_dir = os.path.dirname(args.output_file) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + with open(args.output_file, "w", encoding="utf-8") as f: json.dump(results, f, indent=4, ensure_ascii=False) diff --git a/baselines/Wrap/wrap.py b/baselines/Wrap/wrap.py index 90ce2a46..91c17540 100644 --- a/baselines/Wrap/wrap.py +++ b/baselines/Wrap/wrap.py @@ -8,7 +8,7 @@ from dotenv import load_dotenv from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import OpenAIClient +from graphgen.models import OpenAIClient, Tokenizer from graphgen.utils import compute_content_hash, create_event_loop PROMPT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. @@ -36,7 +36,7 @@ def _post_process(content: str) -> list: question = item.split("Question:")[1].split("Answer:")[0].strip() answer = item.split("Answer:")[1].strip() qas.append((question, answer)) - except Exception as e: # pylint: disable=broad-except + except Exception as e: print(f"Error: {e}") continue return qas @@ -62,8 +62,11 @@ async def process_chunk(content: str): tasks = [] for doc in docs: - for chunk in doc: - tasks.append(process_chunk(chunk["content"])) + if isinstance(doc, list): + for chunk in doc: + tasks.append(process_chunk(chunk["content"])) + elif isinstance(doc, dict): + tasks.append(process_chunk(doc["content"])) for result in tqdm_async( asyncio.as_completed(tasks), total=len(tasks), desc="Generating using Wrap" @@ -75,7 +78,7 @@ async def process_chunk(content: str): "question": qa[0], "answer": qa[1], } - except Exception as e: # pylint: disable=broad-except + except Exception as e: print(f"Error: {e}") return final_results @@ -88,13 +91,6 @@ async def process_chunk(content: str): default="resources/input_examples/json_demo.json", type=str, ) - parser.add_argument( - "--data_type", - help="Data type of input file. (Raw context or chunked context)", - choices=["raw", "chunked"], - default="raw", - type=str, - ) parser.add_argument( "--output_file", help="Output file path.", @@ -110,20 +106,20 @@ async def process_chunk(content: str): model=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL"), + tokenizer=Tokenizer(model_name=os.getenv("TOKENIZER_MODEL")), ) wrap = Wrap(llm_client=llm_client) - if args.data_type == "raw": - with open(args.input_file, "r", encoding="utf-8") as f: - data = [json.loads(line) for line in f] - data = [[chunk] for chunk in data] - elif args.data_type == "chunked": - with open(args.input_file, "r", encoding="utf-8") as f: - data = json.load(f) + with open(args.input_file, "r", encoding="utf-8") as f: + data = json.load(f) results = wrap.generate(data) # Save results + output_dir = os.path.dirname(args.output_file) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + with open(args.output_file, "w", encoding="utf-8") as f: json.dump(results, f, indent=4, ensure_ascii=False) diff --git a/examples/baselines/generate_all_baselines.sh b/examples/baselines/generate_all_baselines.sh deleted file mode 100644 index 8536978e..00000000 --- a/examples/baselines/generate_all_baselines.sh +++ /dev/null @@ -1,7 +0,0 @@ -# generate all baselines at one go - -bash scripts/baselines/generate_wrap.sh -bash scripts/baselines/generate_selfqa.sh -bash scripts/baselines/generate_longform.sh -bash scripts/baselines/generate_genie.sh -bash scripts/baselines/generate_entigraph.sh \ No newline at end of file diff --git a/examples/baselines/generate_bds.sh b/examples/baselines/generate_bds.sh index d4bd9e8c..ef87ac05 100644 --- a/examples/baselines/generate_bds.sh +++ b/examples/baselines/generate_bds.sh @@ -1,2 +1 @@ -python3 -m baselines.BDS.bds --input_file resources/input_examples/graphml_demo.graphml \ - --output_file cache/data/bds.json \ +python3 -m baselines.BDS.bds --input_file resources/input_examples/graphml_demo.graphml --output_file cache/data/bds.json diff --git a/examples/baselines/generate_entigraph.sh b/examples/baselines/generate_entigraph.sh index 8474c96f..fa17f992 100644 --- a/examples/baselines/generate_entigraph.sh +++ b/examples/baselines/generate_entigraph.sh @@ -1,3 +1 @@ -python3 -m baselines.EntiGraph.entigraph --input_file resources/input_examples/raw_demo.jsonl \ - --data_type raw \ - --output_file cache/data/entigraph.json \ +python3 -m baselines.EntiGraph.entigraph --input_file examples/input_examples/json_demo.json --output_file cache/data/entigraph.json \ No newline at end of file diff --git a/examples/baselines/generate_genie.sh b/examples/baselines/generate_genie.sh index 3a06de10..d871e01e 100644 --- a/examples/baselines/generate_genie.sh +++ b/examples/baselines/generate_genie.sh @@ -1,3 +1 @@ -python3 -m baselines.Genie.genie --input_file resources/input_examples/raw_demo.jsonl \ - --data_type raw \ - --output_file cache/data/genie.json \ +python3 -m baselines.Genie.genie --input_file examples/input_examples/json_demo.json --output_file cache/data/genie.json \ No newline at end of file diff --git a/examples/baselines/generate_longform.sh b/examples/baselines/generate_longform.sh index 62de848c..03281d68 100644 --- a/examples/baselines/generate_longform.sh +++ b/examples/baselines/generate_longform.sh @@ -1,3 +1 @@ -python3 -m baselines.LongForm.longform --input_file resources/input_examples/raw_demo.jsonl \ - --data_type raw \ - --output_file cache/data/longform.json \ +python3 -m baselines.LongForm.longform --input_file examples/input_examples/json_demo.json --output_file cache/data/longform.json \ No newline at end of file diff --git a/examples/baselines/generate_selfqa.sh b/examples/baselines/generate_selfqa.sh index ef13e721..565fa502 100644 --- a/examples/baselines/generate_selfqa.sh +++ b/examples/baselines/generate_selfqa.sh @@ -1,3 +1 @@ -python3 -m baselines.SELF-QA.self-qa --input_file resources/input_examples/raw_demo.jsonl \ - --data_type raw \ - --output_file cache/data/self-qa.json \ +python3 -m baselines.SELF-QA.self-qa --input_file examples/input_examples/json_demo.json --output_file cache/data/self-qa.json \ No newline at end of file diff --git a/examples/baselines/generate_wrap.sh b/examples/baselines/generate_wrap.sh index fcaf3933..ad4b37a9 100644 --- a/examples/baselines/generate_wrap.sh +++ b/examples/baselines/generate_wrap.sh @@ -1,3 +1 @@ -python3 -m baselines.Wrap.wrap --input_file resources/input_examples/raw_demo.jsonl \ - --data_type raw \ - --output_file cache/data/wrap.json \ +python3 -m baselines.Wrap.wrap --input_file examples/input_examples/json_demo.json --output_file cache/data/wrap.json \ No newline at end of file