diff --git a/.pylintrc b/.pylintrc index 45c2b04b..ce85b1ba 100644 --- a/.pylintrc +++ b/.pylintrc @@ -452,6 +452,7 @@ disable=raw-checker-failed, R0917, # Too many positional arguments (6/5) (too-many-positional-arguments) C0103, E0401, + W0703, # Catching too general exception Exception # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/graphgen/bases/base_reader.py b/graphgen/bases/base_reader.py index 89778469..91d55fcd 100644 --- a/graphgen/bases/base_reader.py +++ b/graphgen/bases/base_reader.py @@ -1,8 +1,10 @@ import os from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Union +import pandas as pd import requests +from ray.data import Dataset class BaseReader(ABC): @@ -14,52 +16,65 @@ def __init__(self, text_column: str = "content"): self.text_column = text_column @abstractmethod - def read(self, file_path: str) -> List[Dict[str, Any]]: + def read(self, input_path: Union[str, List[str]]) -> Dataset: """ Read data from the specified file path. - :param file_path: Path to the input file. - :return: List of dictionaries containing the data. + :param input_path: Path to the input file or list of file paths. + :return: Ray Dataset containing the read data. """ - @staticmethod - def filter(data: List[dict]) -> List[dict]: + def _should_keep_item(self, item: Dict[str, Any]) -> bool: + """ + Determine whether to keep the given item based on the text column. + + :param item: Dictionary representing a data entry. + :return: True if the item should be kept, False otherwise. """ - Filter out entries with empty or missing text in the specified column. + item_type = item.get("type") + assert item_type in [ + "text", + "image", + "table", + "equation", + "protein", + ], f"Unsupported item type: {item_type}" + if item_type == "text": + content = item.get(self.text_column, "").strip() + return bool(content) + return True - :param data: List of dictionaries containing the data. - :return: Filtered list of dictionaries. + def _validate_batch(self, batch: pd.DataFrame) -> pd.DataFrame: + """ + Validate data format. """ + if "type" not in batch.columns: + raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}") - def _image_exists(path_or_url: str, timeout: int = 3) -> bool: - """ - Check if an image exists at the given local path or URL. - :param path_or_url: Local file path or remote URL of the image. - :param timeout: Timeout for remote URL requests in seconds. - :return: True if the image exists, False otherwise. - """ - if not path_or_url: - return False - if not path_or_url.startswith(("http://", "https://", "ftp://")): - path = path_or_url.replace("file://", "", 1) - path = os.path.abspath(path) - return os.path.isfile(path) - try: - resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout) - return resp.status_code == 200 - except requests.RequestException: - return False + if "text" in batch["type"].values: + if self.text_column not in batch.columns: + raise ValueError( + f"Missing '{self.text_column}' column for text documents" + ) - filtered_data = [] - for item in data: - if item.get("type") == "text": - content = item.get("content", "").strip() - if content: - filtered_data.append(item) - elif item.get("type") in ("image", "table", "equation"): - img_path = item.get("img_path") - if _image_exists(img_path): - filtered_data.append(item) - else: - filtered_data.append(item) - return filtered_data + return batch + + @staticmethod + def _image_exists(path_or_url: str, timeout: int = 3) -> bool: + """ + Check if an image exists at the given local path or URL. + :param path_or_url: Local file path or remote URL of the image. + :param timeout: Timeout for remote URL requests in seconds. + :return: True if the image exists, False otherwise. + """ + if not path_or_url: + return False + if not path_or_url.startswith(("http://", "https://", "ftp://")): + path = path_or_url.replace("file://", "", 1) + path = os.path.abspath(path) + return os.path.isfile(path) + try: + resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout) + return resp.status_code == 200 + except requests.RequestException: + return False diff --git a/graphgen/bases/base_splitter.py b/graphgen/bases/base_splitter.py index b2d1ad3a..e6b31b79 100644 --- a/graphgen/bases/base_splitter.py +++ b/graphgen/bases/base_splitter.py @@ -33,7 +33,7 @@ def split_text(self, text: str) -> List[str]: """ Split the input text into smaller chunks. - :param text: The input text to be split. + :param text: The input text to be chunk. :return: A list of text chunks. """ @@ -111,7 +111,7 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: def _split_text_with_regex( text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]] ) -> List[str]: - # Now that we have the separator, split the text + # Now that we have the separator, chunk the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. diff --git a/graphgen/configs/search_config.yaml b/graphgen/configs/search_config.yaml index 37e65818..ff110786 100644 --- a/graphgen/configs/search_config.yaml +++ b/graphgen/configs/search_config.yaml @@ -1,7 +1,7 @@ pipeline: - name: read params: - input_file: resources/input_examples/search_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples + input_file: resources/input_examples/search_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - name: search params: diff --git a/graphgen/engine.py b/graphgen/engine.py index dad75de5..6ddf565d 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -1,11 +1,24 @@ """ orchestration engine for GraphGen """ - +import inspect +import queue import threading import traceback +from enum import Enum, auto from functools import wraps -from typing import Any, Callable, List +from typing import Callable, Dict, List + + +class OpType(Enum): + STREAMING = auto() # once data from upstream arrives, process it immediately + BARRIER = auto() # wait for all upstream data to arrive before processing + BATCH = auto() # process data in batches when threshold is reached + + +# signals the end of a data stream +class EndOfStream: + pass class Context(dict): @@ -22,12 +35,21 @@ def get(self, k, default=None): class OpNode: def __init__( - self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any] + self, + name: str, + deps: List[str], + func: Callable, + op_type: OpType = OpType.BARRIER, # use barrier by default + batch_size: int = 32, # default batch size for BATCH operations ): - self.name, self.deps, self.func = name, deps, func + self.name = name + self.deps = deps + self.func = func + self.op_type = op_type + self.batch_size = batch_size -def op(name: str, deps=None): +def op(name: str, deps=None, op_type: OpType = OpType.BARRIER, batch_size: int = 32): deps = deps or [] def decorator(func): @@ -35,66 +57,197 @@ def decorator(func): def _wrapper(*args, **kwargs): return func(*args, **kwargs) - _wrapper.op_node = OpNode(name, deps, lambda self, ctx: func(self, **ctx)) + _wrapper.op_node = OpNode( + name, + deps, + func, + op_type=op_type, + batch_size=batch_size, + ) return _wrapper return decorator class Engine: - def __init__(self, max_workers: int = 4): - self.max_workers = max_workers + def __init__(self, queue_size: int = 100): + self.queue_size = queue_size - def run(self, ops: List[OpNode], ctx: Context): - name2op = {operation.name: operation for operation in ops} - - # topological sort - graph = {n: set(name2op[n].deps) for n in name2op} - topo = [] - q = [n for n, d in graph.items() if not d] - while q: - cur = q.pop(0) - topo.append(cur) - for child in [c for c, d in graph.items() if cur in d]: - graph[child].remove(cur) - if not graph[child]: - q.append(child) - - if len(topo) != len(ops): - raise ValueError( - "Cyclic dependencies detected among operations." - "Please check your configuration." - ) + @staticmethod + def _topo_sort(name2op: Dict[str, OpNode]) -> List[str]: + adj = {n: [] for n in name2op} + in_degree = {n: 0 for n in name2op} + + for name, operation in name2op.items(): + for dep_name in operation.deps: + if dep_name not in name2op: + raise ValueError(f"Dependency {dep_name} of {name} not found") + adj[dep_name].append(name) + in_degree[name] += 1 + + # Kahn's algorithm for topological sorting + queue_nodes = [n for n in name2op if in_degree[n] == 0] + topo_order = [] + + while queue_nodes: + u = queue_nodes.pop(0) + topo_order.append(u) + + for v in adj[u]: + in_degree[v] -= 1 + if in_degree[v] == 0: + queue_nodes.append(v) + + # cycle detection + if len(topo_order) != len(name2op): + cycle_nodes = set(name2op.keys()) - set(topo_order) + raise ValueError(f"Cyclic dependency detected among: {cycle_nodes}") + return topo_order + + def _build_channels(self, name2op): + """Return channels / consumers_of / producer_counts""" + channels, consumers_of, producer_counts = {}, {}, {n: 0 for n in name2op} + for name, operator in name2op.items(): + consumers_of[name] = [] + for dep in operator.deps: + if dep not in name2op: + raise ValueError(f"Dependency {dep} of {name} not found") + channels[(dep, name)] = queue.Queue(maxsize=self.queue_size) + consumers_of[dep].append(name) + producer_counts[name] += 1 + return channels, consumers_of, producer_counts - # semaphore for max_workers - sem = threading.Semaphore(self.max_workers) - done = {n: threading.Event() for n in name2op} - exc = {} - - def _exec(n: str): - with sem: - for d in name2op[n].deps: - done[d].wait() - if any(d in exc for d in name2op[n].deps): - exc[n] = Exception("Skipped due to failed dependencies") - done[n].set() - return - try: - name2op[n].func(name2op[n], ctx) - except Exception: # pylint: disable=broad-except - exc[n] = traceback.format_exc() - done[n].set() - - ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo] - for t in ts: + def _run_workers(self, ordered_ops, channels, consumers_of, producer_counts, ctx): + """Run worker threads for each operation node.""" + exceptions, threads = {}, [] + for node in ordered_ops: + t = threading.Thread( + target=self._worker_loop, + args=(node, channels, consumers_of, producer_counts, ctx, exceptions), + daemon=True, + ) t.start() - for t in ts: + threads.append(t) + for t in threads: t.join() - if exc: - raise RuntimeError( - "Some operations failed:\n" - + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items()) + return exceptions + + def _worker_loop( + self, node, channels, consumers_of, producer_counts, ctx, exceptions + ): + op_name = node.name + + def input_generator(): + # if no dependencies, yield None once + if not node.deps: + yield None + return + + active_producers = producer_counts[op_name] + # collect all queues + input_queues = [channels[(dep_name, op_name)] for dep_name in node.deps] + + # loop until all producers are done + while active_producers > 0: + got_data = False + for q in input_queues: + try: + item = q.get(timeout=0.1) + if isinstance(item, EndOfStream): + active_producers -= 1 + else: + yield item + got_data = True + except queue.Empty: + continue + + if not got_data and active_producers > 0: + # barrier wait on the first active queue + item = input_queues[0].get() + if isinstance(item, EndOfStream): + active_producers -= 1 + else: + yield item + + in_stream = input_generator() + + try: + # execute the operation + result_iter = [] + if node.op_type == OpType.BARRIER: + # consume all input + buffered_inputs = list(in_stream) + res = node.func(self, ctx, inputs=buffered_inputs) + if res is not None: + result_iter = res if isinstance(res, (list, tuple)) else [res] + + elif node.op_type == OpType.STREAMING: + # process input one by one + res = node.func(self, ctx, input_stream=in_stream) + if res is not None: + result_iter = res + + elif node.op_type == OpType.BATCH: + # accumulate inputs into batches and process + batch = [] + for item in in_stream: + batch.append(item) + if len(batch) >= node.batch_size: + res = node.func(self, ctx, inputs=batch) + if res is not None: + result_iter.extend( + res if isinstance(res, (list, tuple)) else [res] + ) + batch = [] + # process remaining items + if batch: + res = node.func(self, ctx, inputs=batch) + if res is not None: + result_iter.extend( + res if isinstance(res, (list, tuple)) else [res] + ) + + else: + raise ValueError(f"Unknown OpType {node.op_type} for {op_name}") + + # output dispatch, send results to downstream consumers + if result_iter: + for item in result_iter: + for consumer_name in consumers_of[op_name]: + channels[(op_name, consumer_name)].put(item) + + except Exception: # pylint: disable=broad-except + traceback.print_exc() + exceptions[op_name] = traceback.format_exc() + + finally: + # signal end of stream to downstream consumers + for consumer_name in consumers_of[op_name]: + channels[(op_name, consumer_name)].put(EndOfStream()) + + def run(self, ops: List[OpNode], ctx: Context): + name2op = {op.name: op for op in ops} + + # Step 1: topo sort and validate + sorted_op_names = self._topo_sort(name2op) + + # Step 2: build channels and tracking structures + channels, consumers_of, producer_counts = self._build_channels(name2op) + + # Step3: start worker threads using topo order + ordered_ops = [name2op[name] for name in sorted_op_names] + exceptions = self._run_workers( + ordered_ops, channels, consumers_of, producer_counts, ctx + ) + + if exceptions: + error_msgs = "\n".join( + [ + f"Operation {name} failed with error:\n{msg}" + for name, msg in exceptions.items() + ] ) + raise RuntimeError(f"Engine encountered exceptions:\n{error_msgs}") def collect_ops(config: dict, graph_gen) -> List[OpNode]: @@ -110,12 +263,67 @@ def collect_ops(config: dict, graph_gen) -> List[OpNode]: op_node = method.op_node # if there are runtime dependencies, override them - runtime_deps = stage.get("deps", op_node.deps) - op_node.deps = runtime_deps + deps = stage.get("deps", op_node.deps) + op_type = op_node.op_type + batch_size = stage.get("batch_size", op_node.batch_size) + + sig = inspect.signature(method) + accepts_input_stream = "input_stream" in sig.parameters + + if op_type == OpType.BARRIER: + if "params" in stage: + + def func(self, ctx, inputs, m=method, sc=stage): + return m(sc.get("params", {}), inputs=inputs) + + else: + + def func(self, ctx, inputs, m=method): + return m(inputs=inputs) + + elif op_type == OpType.STREAMING: + if "params" in stage: + if accepts_input_stream: + + def func(self, ctx, input_stream, m=method, sc=stage): + return m(sc.get("params", {}), input_stream=input_stream) + + else: + + def func(self, ctx, input_stream, m=method, sc=stage): + return m(sc.get("params", {})) + + else: + if accepts_input_stream: + + def func(self, ctx, input_stream, m=method): + return m(input_stream=input_stream) + + else: + + def func(self, ctx, input_stream, m=method): + return m() + + elif op_type == OpType.BATCH: + if "params" in stage: + + def func(self, ctx, inputs, m=method, sc=stage): + return m(sc.get("params", {}), inputs=inputs) + + else: + + def func(self, ctx, inputs, m=method): + return m(inputs=inputs) - if "params" in stage: - op_node.func = lambda self, ctx, m=method, sc=stage: m(sc.get("params", {})) else: - op_node.func = lambda self, ctx, m=method: m() - ops.append(op_node) + raise ValueError(f"Unknown OpType {op_type} for operation {name}") + + new_node = OpNode( + name=name, + deps=deps, + func=func, + op_type=op_type, + batch_size=batch_size, + ) + ops.append(new_node) return ops diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 1bfb35cb..17b8ac42 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -1,12 +1,12 @@ import os import time -from typing import Dict +from typing import Dict, Iterator, List import gradio as gr from graphgen.bases import BaseLLMWrapper from graphgen.bases.datatypes import Chunk -from graphgen.engine import op +from graphgen.engine import OpType, op from graphgen.models import ( JsonKVStorage, JsonListStorage, @@ -68,7 +68,8 @@ def __init__( self.working_dir, namespace="graph" ) self.search_storage: JsonKVStorage = JsonKVStorage( - self.working_dir, namespace="search" + os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), + namespace="search", ) self.rephrase_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="rephrase" @@ -88,80 +89,110 @@ def __init__( # webui self.progress_bar: gr.Progress = progress_bar - @op("read", deps=[]) + @op("read", deps=[], op_type=OpType.STREAMING) @async_to_sync_method async def read(self, read_config: Dict): """ read files from input sources """ - data = read_files(**read_config, cache_dir=self.working_dir) - if len(data) == 0: - logger.warning("No data to process") - return - - assert isinstance(data, list) and isinstance(data[0], dict) + count = 0 + for docs in read_files(**read_config, cache_dir=self.working_dir): + if not docs: + continue + new_docs = {compute_mm_hash(d, prefix="doc-"): d for d in docs} + _add_doc_keys = await self.full_docs_storage.filter_keys( + list(new_docs.keys()) + ) + new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} + + if new_docs: + await self.full_docs_storage.upsert(new_docs) + await self.full_docs_storage.index_done_callback() + for doc_id in new_docs.keys(): + yield doc_id + + count += len(new_docs) + logger.info( + "[Read] Yielded %d new documents, total %d", len(new_docs), count + ) + + if count == 0: + logger.warning("[Read] No new documents to process") # TODO: configurable whether to use coreference resolution - new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data} - _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys())) - new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} - - if len(new_docs) == 0: - logger.warning("All documents are already in the storage") - return - - await self.full_docs_storage.upsert(new_docs) - await self.full_docs_storage.index_done_callback() - - @op("chunk", deps=["read"]) + @op("chunk", deps=["read"], op_type=OpType.STREAMING) @async_to_sync_method - async def chunk(self, chunk_config: Dict): + async def chunk(self, chunk_config: Dict, input_stream: Iterator): """ chunk documents into smaller pieces from full_docs_storage if not already present + input_stream: document IDs from full_docs_storage + yield: chunk IDs inserted into chunks_storage """ - - new_docs = await self.meta_storage.get_new_data(self.full_docs_storage) - if len(new_docs) == 0: - logger.warning("All documents are already in the storage") - return - - inserting_chunks = await chunk_documents( - new_docs, - self.tokenizer_instance, - self.progress_bar, - **chunk_config, - ) - - _add_chunk_keys = await self.chunks_storage.filter_keys( - list(inserting_chunks.keys()) - ) - inserting_chunks = { - k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys - } - - if len(inserting_chunks) == 0: - logger.warning("All chunks are already in the storage") - return - - await self.chunks_storage.upsert(inserting_chunks) - await self.chunks_storage.index_done_callback() - await self.meta_storage.mark_done(self.full_docs_storage) - await self.meta_storage.index_done_callback() - - @op("build_kg", deps=["chunk"]) + count = 0 + for doc_id in input_stream: + doc = await self.full_docs_storage.get_by_id(doc_id) + if not doc: + logger.warning( + "[Chunk] Document %s not found in full_docs_storage", doc_id + ) + continue + + inserting_chunks = chunk_documents( + {doc_id: doc}, + self.tokenizer_instance, + self.progress_bar, + **chunk_config, + ) + + _add_chunk_keys = await self.chunks_storage.filter_keys( + list(inserting_chunks.keys()) + ) + inserting_chunks = { + k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys + } + + if inserting_chunks: + await self.chunks_storage.upsert(inserting_chunks) + await self.chunks_storage.index_done_callback() + count += len(inserting_chunks) + logger.info( + "[Chunk] Yielded %d new chunks for document %s, total %d", + len(inserting_chunks), + doc_id, + count, + ) + for _chunk_id in inserting_chunks.keys(): + yield _chunk_id + else: + logger.info( + "[Chunk] All chunks for document %s are already in the storage", + doc_id, + ) + if count == 0: + logger.warning("[Chunk] No new chunks to process") + + @op("build_kg", deps=["chunk"], op_type=OpType.BATCH, batch_size=32) @async_to_sync_method - async def build_kg(self): + async def build_kg(self, inputs: List): """ build knowledge graph from text chunks + inputs: chunk IDs from chunks_storage + return: None """ - # Step 1: get new chunks according to meta and chunks storage - inserting_chunks = await self.meta_storage.get_new_data(self.chunks_storage) - if len(inserting_chunks) == 0: - logger.warning("All chunks are already in the storage") - return + count = 0 + # Step 1: get chunks + inserting_chunks: Dict[str, Dict] = {} + for _chunk_id in inputs: + chunk = await self.chunks_storage.get_by_id(_chunk_id) + if chunk: + inserting_chunks[_chunk_id] = chunk + + count += len(inserting_chunks) + logger.info( + "[Build KG] Inserting %d chunks, total %d", len(inserting_chunks), count + ) - logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks)) # Step 2: build knowledge graph from new chunks _add_entities_and_relations = await build_kg( llm_client=self.synthesizer_llm_client, @@ -173,22 +204,26 @@ async def build_kg(self): logger.warning("No entities or relations extracted from text chunks") return - # Step 3: mark meta + # Step 3: store the new entities and relations await self.graph_storage.index_done_callback() - await self.meta_storage.mark_done(self.chunks_storage) - await self.meta_storage.index_done_callback() - return _add_entities_and_relations - - @op("search", deps=["read"]) + @op("search", deps=["read"], op_type=OpType.BATCH, batch_size=64) @async_to_sync_method - async def search(self, search_config: Dict): + async def search(self, search_config: Dict, inputs: List): + """ + search new documents from full_docs_storage + input_stream: document IDs from full_docs_storage + return: None + """ logger.info("[Search] %s ...", ", ".join(search_config["data_sources"])) - seeds = await self.meta_storage.get_new_data(self.full_docs_storage) - if len(seeds) == 0: - logger.warning("All documents are already been searched") - return + # Step 1: get documents + seeds = {} + for doc_id in inputs: + doc = await self.full_docs_storage.get_by_id(doc_id) + if doc: + seeds[doc_id] = doc + search_results = await search_all( seed_data=seeds, search_config=search_config, @@ -197,18 +232,17 @@ async def search(self, search_config: Dict): _add_search_keys = await self.search_storage.filter_keys( list(search_results.keys()) ) + search_results = { k: v for k, v in search_results.items() if k in _add_search_keys } if len(search_results) == 0: - logger.warning("All search results are already in the storage") - return + logger.warning("[Search] No new search results to add to storage") + await self.search_storage.upsert(search_results) await self.search_storage.index_done_callback() - await self.meta_storage.mark_done(self.full_docs_storage) - await self.meta_storage.index_done_callback() - @op("quiz_and_judge", deps=["build_kg"]) + @op("quiz_and_judge", deps=["build_kg"], op_type=OpType.BARRIER) @async_to_sync_method async def quiz_and_judge(self, quiz_and_judge_config: Dict): logger.warning( @@ -249,9 +283,10 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): logger.info("Restarting synthesizer LLM client.") self.synthesizer_llm_client.restart() - @op("partition", deps=["build_kg"]) + @op("partition", deps=["build_kg"], op_type=OpType.BARRIER) @async_to_sync_method async def partition(self, partition_config: Dict): + # TODO: partition 可以yield batches batches = await partition_kg( self.graph_storage, self.chunks_storage, @@ -261,11 +296,14 @@ async def partition(self, partition_config: Dict): await self.partition_storage.upsert(batches) return batches - @op("extract", deps=["chunk"]) + @op("extract", deps=["chunk"], op_type=OpType.STREAMING) @async_to_sync_method - async def extract(self, extract_config: Dict): - logger.info("Extracting information from given chunks...") - + async def extract(self, extract_config: Dict, input_stream: Iterator): + """ + Extract information from chunks in chunks_storage + input_stream: chunk IDs from chunks_storage + return: None + """ results = await extract_info( self.synthesizer_llm_client, self.chunks_storage, @@ -281,10 +319,10 @@ async def extract(self, extract_config: Dict): await self.meta_storage.mark_done(self.chunks_storage) await self.meta_storage.index_done_callback() - @op("generate", deps=["partition"]) + @op("generate", deps=["partition"], op_type=OpType.STREAMING) @async_to_sync_method - async def generate(self, generate_config: Dict): - + async def generate(self, generate_config: Dict, input_stream: Iterator): + # TODO: batches = self.partition_storage.data if not batches: logger.warning("No partitions found for QA generation") diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 68fd2a5d..e494631c 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -18,7 +18,6 @@ ) from .reader import ( CSVReader, - JSONLReader, JSONReader, ParquetReader, PDFReader, diff --git a/graphgen/models/reader/__init__.py b/graphgen/models/reader/__init__.py index 600ffb4a..220460c3 100644 --- a/graphgen/models/reader/__init__.py +++ b/graphgen/models/reader/__init__.py @@ -1,6 +1,5 @@ from .csv_reader import CSVReader from .json_reader import JSONReader -from .jsonl_reader import JSONLReader from .parquet_reader import ParquetReader from .pdf_reader import PDFReader from .pickle_reader import PickleReader diff --git a/graphgen/models/reader/csv_reader.py b/graphgen/models/reader/csv_reader.py index bc865a3b..99faa30e 100644 --- a/graphgen/models/reader/csv_reader.py +++ b/graphgen/models/reader/csv_reader.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List +from typing import List, Union -import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader @@ -13,13 +14,20 @@ class CSVReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = None, + ) -> Dataset: + """ + Read CSV files and return Ray Dataset. - df = pd.read_csv(file_path) - for _, row in df.iterrows(): - assert "type" in row, f"Missing 'type' column in document: {row.to_dict()}" - if row["type"] == "text" and self.text_column not in row: - raise ValueError( - f"Missing '{self.text_column}' in document: {row.to_dict()}" - ) - return self.filter(df.to_dict(orient="records")) + :param input_path: Path to CSV file or list of CSV files. + :param parallelism: Number of blocks for Ray Dataset reading. + :return: Ray Dataset containing validated and filtered data. + """ + + ds = ray.data.read_csv(input_path, override_num_blocks=parallelism) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/json_reader.py b/graphgen/models/reader/json_reader.py index 8253041c..1bcba4ea 100644 --- a/graphgen/models/reader/json_reader.py +++ b/graphgen/models/reader/json_reader.py @@ -1,26 +1,32 @@ -import json -from typing import Any, Dict, List +from typing import List, Union + +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader class JSONReader(BaseReader): """ - Reader for JSON files. + Reader for JSON and JSONL files. Columns: - type: The type of the document (e.g., "text", "image", etc.) - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "r", encoding="utf-8") as f: - data = json.load(f) - if isinstance(data, list): - for doc in data: - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError( - f"Missing '{self.text_column}' in document: {doc}" - ) - return self.filter(data) - raise ValueError("JSON file must contain a list of documents.") + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = 4, + ) -> Dataset: + """ + Read JSON file and return Ray Dataset. + :param input_path: Path to JSON/JSONL file or list of JSON/JSONL files. + :param parallelism: Number of parallel workers for reading files. + :return: Ray Dataset containing validated and filtered data. + """ + + ds = ray.data.read_json(input_path, override_num_blocks=parallelism) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/jsonl_reader.py b/graphgen/models/reader/jsonl_reader.py deleted file mode 100644 index 31bc3195..00000000 --- a/graphgen/models/reader/jsonl_reader.py +++ /dev/null @@ -1,30 +0,0 @@ -import json -from typing import Any, Dict, List - -from graphgen.bases.base_reader import BaseReader -from graphgen.utils import logger - - -class JSONLReader(BaseReader): - """ - Reader for JSONL files. - Columns: - - type: The type of the document (e.g., "text", "image", etc.) - - if type is "text", "content" column must be present. - """ - - def read(self, file_path: str) -> List[Dict[str, Any]]: - docs = [] - with open(file_path, "r", encoding="utf-8") as f: - for line in f: - try: - doc = json.loads(line) - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError( - f"Missing '{self.text_column}' in document: {doc}" - ) - docs.append(doc) - except json.JSONDecodeError as e: - logger.error("Error decoding JSON line: %s. Error: %s", line, e) - return self.filter(docs) diff --git a/graphgen/models/reader/parquet_reader.py b/graphgen/models/reader/parquet_reader.py index a325b876..5423643b 100644 --- a/graphgen/models/reader/parquet_reader.py +++ b/graphgen/models/reader/parquet_reader.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List +from typing import List, Union -import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader @@ -13,12 +14,22 @@ class ParquetReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - df = pd.read_parquet(file_path) - data: List[Dict[str, Any]] = df.to_dict(orient="records") + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = None, + ) -> Dataset: + """ + Read Parquet files using Ray Data. - for doc in data: - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError(f"Missing '{self.text_column}' in document: {doc}") - return self.filter(data) + :param input_path: Path to Parquet file or list of Parquet files. + :param parallelism: Number of blocks for Ray Dataset reading. + :return: Ray Dataset containing validated documents. + """ + if not ray.is_initialized(): + ray.init() + + ds = ray.data.read_parquet(input_path, override_num_blocks=parallelism) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/pdf_reader.py b/graphgen/models/reader/pdf_reader.py index 94562cb5..9d5c7c27 100644 --- a/graphgen/models/reader/pdf_reader.py +++ b/graphgen/models/reader/pdf_reader.py @@ -5,6 +5,9 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union +import ray +from ray.data import Dataset + from graphgen.bases.base_reader import BaseReader from graphgen.models.reader.txt_reader import TXTReader from graphgen.utils import logger, pick_device @@ -62,19 +65,32 @@ def __init__( self.parser = MinerUParser() self.txt_reader = TXTReader() - def read(self, file_path: str, **override) -> List[Dict[str, Any]]: - """ - file_path - **override: override MinerU parameters - """ - pdf_path = Path(file_path).expanduser().resolve() - if not pdf_path.is_file(): - raise FileNotFoundError(pdf_path) + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = 4, + **override, + ) -> Dataset: + + # Ensure input_path is a list + if isinstance(input_path, str): + input_path = [input_path] + + paths_ds = ray.data.from_items(input_path) + + def process_pdf(row: Dict[str, Any]) -> List[Dict[str, Any]]: + try: + pdf_path = row["item"] + kwargs = {**self._default_kwargs, **override} + return self._call_mineru(Path(pdf_path), kwargs) + except Exception as e: + logger.error("Failed to process %s: %s", row, e) + return [] - kwargs = {**self._default_kwargs, **override} + docs_ds = paths_ds.flat_map(process_pdf) + docs_ds = docs_ds.filter(self._should_keep_item) - mineru_result = self._call_mineru(pdf_path, kwargs) - return self.filter(mineru_result) + return docs_ds def _call_mineru( self, pdf_path: Path, kwargs: Dict[str, Any] @@ -161,18 +177,18 @@ def _try_load_cached_result( base = os.path.dirname(json_file) results = [] - for item in data: + for it in data: for key in ("img_path", "table_img_path", "equation_img_path"): - rel_path = item.get(key) + rel_path = it.get(key) if rel_path: - item[key] = str(Path(base).joinpath(rel_path).resolve()) - if item["type"] == "text": - item["content"] = item["text"] - del item["text"] + it[key] = str(Path(base).joinpath(rel_path).resolve()) + if it["type"] == "text": + it["content"] = it["text"] + del it["text"] for key in ("page_idx", "bbox", "text_level"): - if item.get(key) is not None: - del item[key] - results.append(item) + if it.get(key) is not None: + del it[key] + results.append(it) return results @staticmethod diff --git a/graphgen/models/reader/pickle_reader.py b/graphgen/models/reader/pickle_reader.py index 1a11dc11..0b0e5719 100644 --- a/graphgen/models/reader/pickle_reader.py +++ b/graphgen/models/reader/pickle_reader.py @@ -1,30 +1,82 @@ import pickle -from typing import Any, Dict, List +from typing import List, Union + +import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader +from graphgen.utils import logger class PickleReader(BaseReader): """ - Read pickle files, requiring the top-level object to be List[Dict[str, Any]]. - - Columns: + Read pickle files, requiring the schema to be restored to List[Dict[str, Any]]. + Each pickle file should contain a list of dictionaries with at least: - type: The type of the document (e.g., "text", "image", etc.) - if type is "text", "content" column must be present. + + Note: Uses ray.data.read_binary_files as ray.data.read_pickle is not available. + For Ray >= 2.5, consider using read_pickle if available in your version. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "rb") as f: - data = pickle.load(f) + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = None, + ) -> Dataset: + """ + Read Pickle files using Ray Data. + + :param input_path: Path to pickle file or list of pickle files. + :param parallelism: Number of blocks for Ray Dataset reading. + :return: Ray Dataset containing validated documents. + """ + if not ray.is_initialized(): + ray.init() + + # Use read_binary_files as a reliable alternative to read_pickle + ds = ray.data.read_binary_files( + input_path, override_num_blocks=parallelism, include_paths=True + ) + + # Deserialize pickle files and flatten into individual records + def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame: + all_records = [] + for _, row in batch.iterrows(): + try: + # Load pickle data from bytes + data = pickle.loads(row["bytes"]) + + # Validate structure + if not isinstance(data, list): + logger.error( + "Pickle file {row['path']} must contain a list, got {type(data)}" + ) + continue + + if not all(isinstance(item, dict) for item in data): + logger.error( + "Pickle file {row['path']} must contain a list of dictionaries" + ) + continue + + # Flatten: each dict in the list becomes a separate row + all_records.extend(data) + except Exception as e: + logger.error( + "Failed to deserialize pickle file %s: %s", row["path"], str(e) + ) + continue + + return pd.DataFrame(all_records) - if not isinstance(data, list): - raise ValueError("Pickle file must contain a list of documents.") + # Apply deserialization and flattening + ds = ds.map_batches(deserialize_batch, batch_format="pandas") - for doc in data: - if not isinstance(doc, dict): - raise ValueError("Every item in the list must be a dict.") - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError(f"Missing '{self.text_column}' in document: {doc}") + # Validate the schema + ds = ds.map_batches(self._validate_batch, batch_format="pandas") - return self.filter(data) + # Filter valid items + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/rdf_reader.py b/graphgen/models/reader/rdf_reader.py index cce167c1..406478f5 100644 --- a/graphgen/models/reader/rdf_reader.py +++ b/graphgen/models/reader/rdf_reader.py @@ -1,48 +1,130 @@ -from typing import Any, Dict, List +from pathlib import Path +from typing import Any, Dict, List, Union +import ray import rdflib +from ray.data import Dataset from rdflib import Literal from rdflib.util import guess_format from graphgen.bases.base_reader import BaseReader +from graphgen.utils import logger class RDFReader(BaseReader): """ Reader for RDF files that extracts triples and represents them as dictionaries. + + Uses Ray Data for distributed processing of multiple RDF files. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: + def __init__(self, *, text_column: str = "content", **kwargs): + """ + Initialize RDFReader. + + :param text_column: The column name for text content (default: "content"). + """ + super().__init__(**kwargs) + self.text_column = text_column + + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = 4, + ) -> Dataset: + """ + Read RDF file(s) using Ray Data. + + :param input_path: Path to RDF file or list of RDF files. + :param parallelism: Number of parallel workers for processing. + :return: Ray Dataset containing extracted documents. + """ + if not ray.is_initialized(): + ray.init() + + # Ensure input_path is a list to prevent Ray from splitting string into characters + if isinstance(input_path, str): + input_path = [input_path] + + # Create dataset from file paths + paths_ds = ray.data.from_items(input_path) + + def process_rdf(row: Dict[str, Any]) -> List[Dict[str, Any]]: + """Process a single RDF file and return list of documents.""" + try: + file_path = row["item"] + return self._parse_rdf_file(Path(file_path)) + except Exception as e: + logger.error( + "Failed to process RDF file %s: %s", row.get("item", "unknown"), e + ) + return [] + + # Process files in parallel and flatten results + docs_ds = paths_ds.flat_map(process_rdf) + + # Filter valid documents + docs_ds = docs_ds.filter(self._should_keep_item) + + return docs_ds + + def _parse_rdf_file(self, file_path: Path) -> List[Dict[str, Any]]: + """ + Parse a single RDF file and extract documents. + + :param file_path: Path to RDF file. + :return: List of document dictionaries. + """ + if not file_path.is_file(): + raise FileNotFoundError(f"RDF file not found: {file_path}") + g = rdflib.Graph() - fmt = guess_format(file_path) + fmt = guess_format(str(file_path)) + try: - g.parse(file_path, format=fmt) + g.parse(str(file_path), format=fmt) except Exception as e: raise ValueError(f"Cannot parse RDF file {file_path}: {e}") from e docs: List[Dict[str, Any]] = [] - text_col = self.text_column + # Process each unique subject in the RDF graph for subj in set(g.subjects()): literals = [] props = {} + + # Extract all triples for this subject for _, pred, obj in g.triples((subj, None, None)): pred_str = str(pred) + obj_str = str(obj) + + # Collect literal values as text content if isinstance(obj, Literal): - literals.append(str(obj)) - props.setdefault(pred_str, []).append(str(obj)) + literals.append(obj_str) + + # Store all properties (including non-literals) + props.setdefault(pred_str, []).append(obj_str) + # Join all literal values as the text content text = " ".join(literals).strip() if not text: - raise ValueError( - f"Subject {subj} has no literal values; " - f"missing '{text_col}' for text column." + logger.warning( + "Subject %s in %s has no literal values; document will have empty '%s' field.", + subj, + file_path, + self.text_column, ) - doc = {"id": str(subj), text_col: text, "properties": props} + # Create document dictionary + doc = { + "id": str(subj), + self.text_column: text, + "properties": props, + "source_file": str(file_path), + } docs.append(doc) if not docs: - raise ValueError("RDF file contains no valid documents.") + logger.warning("RDF file %s contains no valid documents.", file_path) - return self.filter(docs) + return docs diff --git a/graphgen/models/reader/txt_reader.py b/graphgen/models/reader/txt_reader.py index ec2ff747..bb6cce9e 100644 --- a/graphgen/models/reader/txt_reader.py +++ b/graphgen/models/reader/txt_reader.py @@ -1,10 +1,33 @@ -from typing import Any, Dict, List +from typing import List, Union + +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader class TXTReader(BaseReader): - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "r", encoding="utf-8") as f: - docs = [{"type": "text", self.text_column: f.read()}] - return self.filter(docs) + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = 4, + ) -> Dataset: + """ + Read text files from the specified input path. + :param input_path: Path to the input text file or list of text files. + :param parallelism: Number of blocks to override for Ray Dataset reading. + :return: Ray Dataset containing the read text data. + """ + docs_ds = ray.data.read_text( + input_path, encoding="utf-8", override_num_blocks=parallelism + ) + + docs_ds = docs_ds.map( + lambda row: { + "type": "text", + self.text_column: row["text"], + } + ) + + docs_ds = docs_ds.filter(self._should_keep_item) + return docs_ds diff --git a/graphgen/models/searcher/db/uniprot_searcher.py b/graphgen/models/searcher/db/uniprot_searcher.py index a74b623e..f5542f8c 100644 --- a/graphgen/models/searcher/db/uniprot_searcher.py +++ b/graphgen/models/searcher/db/uniprot_searcher.py @@ -27,12 +27,16 @@ def _get_pool(): return ThreadPoolExecutor(max_workers=10) +# ensure only one BLAST searcher at a time +_blast_lock = asyncio.Lock() + + class UniProtSearch(BaseSearcher): """ UniProt Search client to searcher with UniProt. 1) Get the protein by accession number. 2) Search with keywords or protein names (fuzzy searcher). - 3) Search with FASTA sequence (BLAST searcher). + 3) Search with FASTA sequence (BLAST searcher). Note that NCBIWWW does not support async. """ def __init__(self, use_local_blast: bool = False, local_blast_db: str = "sp_db"): @@ -230,22 +234,21 @@ async def search( if query.startswith(">") or re.fullmatch( r"[ACDEFGHIKLMNPQRSTVWY\s]+", query, re.I ): - coro = loop.run_in_executor( - _get_pool(), self.get_by_fasta, query, threshold - ) + async with _blast_lock: + result = await loop.run_in_executor( + _get_pool(), self.get_by_fasta, query, threshold + ) # check if accession number elif re.fullmatch(r"[A-NR-Z0-9]{6,10}", query, re.I): - coro = loop.run_in_executor(_get_pool(), self.get_by_accession, query) + result = await loop.run_in_executor( + _get_pool(), self.get_by_accession, query + ) else: # otherwise treat as keyword - coro = loop.run_in_executor(_get_pool(), self.get_best_hit, query) + result = await loop.run_in_executor(_get_pool(), self.get_best_hit, query) - result = await coro if result: result["_search_query"] = query return result - - -# TODO: use local UniProt database for large-scale searchs diff --git a/graphgen/models/storage/json_storage.py b/graphgen/models/storage/json_storage.py index fb37ee29..3f6a7a2b 100644 --- a/graphgen/models/storage/json_storage.py +++ b/graphgen/models/storage/json_storage.py @@ -24,7 +24,7 @@ async def all_keys(self) -> list[str]: async def index_done_callback(self): write_json(self._data, self._file_name) - async def get_by_id(self, id): + async def get_by_id(self, id) -> dict | None: return self._data.get(id, None) async def get_by_ids(self, ids, fields=None) -> list: diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 97f4b3c8..6b76de44 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,4 +1,5 @@ from .build_kg import build_kg +from .chunk import chunk_documents from .extract import extract_info from .generate import generate_qas from .init import init_llm @@ -6,4 +7,3 @@ from .quiz_and_judge import judge_statement, quiz from .read import read_files from .search import search_all -from .split import chunk_documents diff --git a/graphgen/operators/chunk/__init__.py b/graphgen/operators/chunk/__init__.py new file mode 100644 index 00000000..0f4c92ac --- /dev/null +++ b/graphgen/operators/chunk/__init__.py @@ -0,0 +1 @@ +from .chunk_documents import chunk_documents diff --git a/graphgen/operators/split/split_chunks.py b/graphgen/operators/chunk/chunk_documents.py similarity index 92% rename from graphgen/operators/split/split_chunks.py rename to graphgen/operators/chunk/chunk_documents.py index 3f728e00..280279de 100644 --- a/graphgen/operators/split/split_chunks.py +++ b/graphgen/operators/chunk/chunk_documents.py @@ -1,8 +1,6 @@ from functools import lru_cache from typing import Union -from tqdm.asyncio import tqdm as tqdm_async - from graphgen.models import ( ChineseRecursiveTextSplitter, RecursiveCharacterSplitter, @@ -38,7 +36,7 @@ def split_chunks(text: str, language: str = "en", **kwargs) -> list: return splitter.split_text(text) -async def chunk_documents( +def chunk_documents( new_docs: dict, tokenizer_instance: Tokenizer = None, progress_bar=None, @@ -47,9 +45,8 @@ async def chunk_documents( inserting_chunks = {} cur_index = 1 doc_number = len(new_docs) - async for doc_key, doc in tqdm_async( - new_docs.items(), desc="[1/4]Chunking documents", unit="doc" - ): + + for doc_key, doc in new_docs.items(): doc_type = doc.get("type") if doc_type == "text": doc_language = detect_main_language(doc["content"]) diff --git a/graphgen/operators/read/parallel_file_scanner.py b/graphgen/operators/read/parallel_file_scanner.py new file mode 100644 index 00000000..890a50a9 --- /dev/null +++ b/graphgen/operators/read/parallel_file_scanner.py @@ -0,0 +1,231 @@ +import os +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any, Dict, List, Set, Union + +from diskcache import Cache + +from graphgen.utils import logger + + +class ParallelFileScanner: + def __init__( + self, cache_dir: str, allowed_suffix, rescan: bool = False, max_workers: int = 4 + ): + self.cache = Cache(cache_dir) + self.allowed_suffix = set(allowed_suffix) if allowed_suffix else None + self.rescan = rescan + self.max_workers = max_workers + + def scan( + self, paths: Union[str, List[str]], recursive: bool = True + ) -> Dict[str, Any]: + if isinstance(paths, str): + paths = [paths] + + results = {} + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + future_to_path = {} + for p in paths: + if os.path.exists(p): + future = executor.submit( + self._scan_files, Path(p).resolve(), recursive, set() + ) + future_to_path[future] = p + else: + logger.warning("[READ] Path does not exist: %s", p) + + for future in as_completed(future_to_path): + path = future_to_path[future] + try: + results[path] = future.result() + except Exception as e: + logger.error("[READ] Error scanning path %s: %s", path, e) + results[path] = { + "error": str(e), + "files": [], + "dirs": [], + "stats": {}, + } + return results + + def _scan_files( + self, path: Path, recursive: bool, visited: Set[str] + ) -> Dict[str, Any]: + path_str = str(path) + + # Avoid cycles due to symlinks + if path_str in visited: + logger.warning("[READ] Skipping already visited path: %s", path_str) + return self._empty_result(path_str) + + # cache check + cache_key = f"scan::{path_str}::recursive::{recursive}" + cached = self.cache.get(cache_key) + if cached and not self.rescan: + logger.info("[READ] Using cached scan result for path: %s", path_str) + return cached["data"] + + logger.info("[READ] Scanning path: %s", path_str) + files, dirs = [], [] + stats = {"total_size": 0, "file_count": 0, "dir_count": 0, "errors": 0} + + try: + path_stat = path.stat() + if path.is_file(): + return self._scan_single_file(path, path_str, path_stat) + if path.is_dir(): + with os.scandir(path_str) as entries: + for entry in entries: + try: + entry_stat = entry.stat(follow_symlinks=False) + + if entry.is_dir(): + dirs.append( + { + "path": entry.path, + "name": entry.name, + "mtime": entry_stat.st_mtime, + } + ) + stats["dir_count"] += 1 + else: + # allowed suffix filter + if not self._is_allowed_file(Path(entry.path)): + continue + files.append( + { + "path": entry.path, + "name": entry.name, + "size": entry_stat.st_size, + "mtime": entry_stat.st_mtime, + } + ) + stats["total_size"] += entry_stat.st_size + stats["file_count"] += 1 + + except OSError: + stats["errors"] += 1 + + except (PermissionError, FileNotFoundError, OSError) as e: + logger.error("[READ] Failed to scan path %s: %s", path_str, e) + return {"error": str(e), "files": [], "dirs": [], "stats": stats} + + if recursive: + sub_visited = visited | {path_str} + sub_results = self._scan_subdirs(dirs, sub_visited) + + for sub_data in sub_results.values(): + files.extend(sub_data.get("files", [])) + stats["total_size"] += sub_data["stats"].get("total_size", 0) + stats["file_count"] += sub_data["stats"].get("file_count", 0) + + result = {"path": path_str, "files": files, "dirs": dirs, "stats": stats} + self._cache_result(cache_key, result, path) + return result + + def _scan_single_file( + self, path: Path, path_str: str, stat: os.stat_result + ) -> Dict[str, Any]: + """Scan a single file and return its metadata""" + if not self._is_allowed_file(path): + return self._empty_result(path_str) + + return { + "path": path_str, + "files": [ + { + "path": path_str, + "name": path.name, + "size": stat.st_size, + "mtime": stat.st_mtime, + } + ], + "dirs": [], + "stats": { + "total_size": stat.st_size, + "file_count": 1, + "dir_count": 0, + "errors": 0, + }, + } + + def _scan_subdirs(self, dir_list: List[Dict], visited: Set[str]) -> Dict[str, Any]: + """ + Parallel scan subdirectories + :param dir_list + :param visited + :return: + """ + results = {} + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(self._scan_files, Path(d["path"]), True, visited): d[ + "path" + ] + for d in dir_list + } + + for future in as_completed(futures): + path = futures[future] + try: + results[path] = future.result() + except Exception as e: + logger.error("[READ] Error scanning subdirectory %s: %s", path, e) + results[path] = { + "error": str(e), + "files": [], + "dirs": [], + "stats": {}, + } + + return results + + def _cache_result(self, key: str, result: Dict, path: Path): + """Cache the scan result""" + try: + self.cache.set( + key, + { + "data": result, + "dir_mtime": path.stat().st_mtime, + "cached_at": time.time(), + }, + ) + logger.info("[READ] Cached scan result for path: %s", path) + except OSError as e: + logger.error("[READ] Failed to cache scan result for path %s: %s", path, e) + + def _is_allowed_file(self, path: Path) -> bool: + """Check if the file has an allowed suffix""" + if self.allowed_suffix is None: + return True + suffix = path.suffix.lower().lstrip(".") + return suffix in self.allowed_suffix + + def invalidate(self, path: str): + """Invalidate cache for a specific path""" + path = Path(path).resolve() + keys = [k for k in self.cache if k.startswith(f"scan::{path}")] + for k in keys: + self.cache.delete(k) + logger.info("[READ] Invalidated cache for path: %s", path) + + def close(self): + self.cache.close() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + @staticmethod + def _empty_result(path: str) -> Dict[str, Any]: + return { + "path": path, + "files": [], + "dirs": [], + "stats": {"total_size": 0, "file_count": 0, "dir_count": 0, "errors": 0}, + } diff --git a/graphgen/operators/read/read_files.py b/graphgen/operators/read/read_files.py index b940b439..34ffee85 100644 --- a/graphgen/operators/read/read_files.py +++ b/graphgen/operators/read/read_files.py @@ -1,9 +1,10 @@ from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional, Union + +import ray from graphgen.models import ( CSVReader, - JSONLReader, JSONReader, ParquetReader, PDFReader, @@ -13,8 +14,10 @@ ) from graphgen.utils import logger +from .parallel_file_scanner import ParallelFileScanner + _MAPPING = { - "jsonl": JSONLReader, + "jsonl": JSONReader, "json": JSONReader, "txt": TXTReader, "csv": CSVReader, @@ -28,59 +31,93 @@ } -def _build_reader(suffix: str, cache_dir: str | None): +def _build_reader(suffix: str, cache_dir: str | None, **reader_kwargs): + """Factory function to build appropriate reader instance""" suffix = suffix.lower() - if suffix == "pdf" and cache_dir is not None: - return _MAPPING[suffix](output_dir=cache_dir) - return _MAPPING[suffix]() + reader_cls = _MAPPING.get(suffix) + if not reader_cls: + raise ValueError(f"Unsupported file suffix: {suffix}") + + # Special handling for PDFReader which needs output_dir + if suffix == "pdf": + if cache_dir is None: + raise ValueError("cache_dir must be provided for PDFReader") + return reader_cls(output_dir=cache_dir, **reader_kwargs) + + return reader_cls(**reader_kwargs) def read_files( - input_file: str, + input_path: Union[str, List[str]], allowed_suffix: Optional[List[str]] = None, cache_dir: Optional[str] = None, -) -> list[dict]: - path = Path(input_file).expanduser() - if not path.exists(): - raise FileNotFoundError(f"input_path not found: {input_file}") - - if allowed_suffix is None: - support_suffix = set(_MAPPING.keys()) - else: - support_suffix = {s.lower().lstrip(".") for s in allowed_suffix} - - # single file - if path.is_file(): - suffix = path.suffix.lstrip(".").lower() - if suffix not in support_suffix: - logger.warning( - "Skip file %s (suffix '%s' not in allowed_suffix %s)", - path, - suffix, - support_suffix, - ) - return [] - reader = _build_reader(suffix, cache_dir) - return reader.read(str(path)) - - # folder - files_to_read = [ - p for p in path.rglob("*") if p.suffix.lstrip(".").lower() in support_suffix - ] - logger.info( - "Found %d eligible file(s) under folder %s (allowed_suffix=%s)", - len(files_to_read), - input_file, - support_suffix, - ) - - all_docs: List[Dict[str, Any]] = [] - for p in files_to_read: - try: - suffix = p.suffix.lstrip(".").lower() - reader = _build_reader(suffix, cache_dir) - all_docs.extend(reader.read(str(p))) - except Exception as e: # pylint: disable=broad-except - logger.exception("Error reading %s: %s", p, e) - - return all_docs + parallelism: int = 4, + recursive: bool = True, + **reader_kwargs: Any, +) -> ray.data.Dataset: + """ + Unified entry point to read files of multiple types using Ray Data. + + :param input_path: File or directory path(s) to read from + :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) + :param cache_dir: Directory to cache intermediate files (PDF processing) + :param parallelism: Number of parallel workers + :param recursive: Whether to scan directories recursively + :param reader_kwargs: Additional kwargs passed to readers + :return: Ray Dataset containing all documents + """ + try: + # 1. Scan all paths to discover files + logger.info("[READ] Scanning paths: %s", input_path) + scanner = ParallelFileScanner( + cache_dir=cache_dir, + allowed_suffix=allowed_suffix, + rescan=False, + max_workers=parallelism if parallelism > 0 else 1, + ) + + all_files = [] + scan_results = scanner.scan(input_path, recursive=recursive) + + for result in scan_results.values(): + all_files.extend(result.get("files", [])) + + logger.info("[READ] Found %d files to process", len(all_files)) + + if not all_files: + return ray.data.from_items([]) + + # 2. Group files by suffix to use appropriate reader + files_by_suffix = {} + for file_info in all_files: + suffix = Path(file_info["path"]).suffix.lower().lstrip(".") + if allowed_suffix and suffix not in [ + s.lower().lstrip(".") for s in allowed_suffix + ]: + continue + files_by_suffix.setdefault(suffix, []).append(file_info["path"]) + + # 3. Create read tasks + read_tasks = [] + for suffix, file_paths in files_by_suffix.items(): + reader = _build_reader(suffix, cache_dir, **reader_kwargs) + ds = reader.read(file_paths, parallelism=parallelism) + read_tasks.append(ds) + + # 4. Combine all datasets + if not read_tasks: + logger.warning("[READ] No datasets created") + return ray.data.from_items([]) + + if len(read_tasks) == 1: + logger.info("[READ] Successfully read files from %s", input_path) + return read_tasks[0] + # len(read_tasks) > 1 + combined_ds = read_tasks[0].union(*read_tasks[1:]) + + logger.info("[READ] Successfully read files from %s", input_path) + return combined_ds + + except Exception as e: + logger.error("[READ] Failed to read files from %s: %s", input_path, e) + raise diff --git a/graphgen/operators/split/__init__.py b/graphgen/operators/split/__init__.py deleted file mode 100644 index 2afc738d..00000000 --- a/graphgen/operators/split/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .split_chunks import chunk_documents diff --git a/graphgen/run.py b/graphgen/run.py index c300a6aa..c710b3ca 100644 --- a/graphgen/run.py +++ b/graphgen/run.py @@ -73,7 +73,7 @@ def main(): ops = collect_ops(config, graph_gen) # run operations - Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx) + Engine(queue_size=config.get("queue_size", 100)).run(ops, ctx) save_config(os.path.join(output_path, "config.yaml"), config) logger.info("GraphGen completed successfully. Data saved to %s", output_path) diff --git a/graphgen/utils/wrap.py b/graphgen/utils/wrap.py index 57776f22..efcf320a 100644 --- a/graphgen/utils/wrap.py +++ b/graphgen/utils/wrap.py @@ -1,13 +1,43 @@ +import asyncio +import inspect from functools import wraps -from typing import Any, Callable -from .loop import create_event_loop +def async_to_sync_method(func): + """Convert async method to sync method, handling both coroutines and async generators.""" -def async_to_sync_method(func: Callable) -> Callable: @wraps(func) - def wrapper(self, *args, **kwargs) -> Any: - loop = create_event_loop() - return loop.run_until_complete(func(self, *args, **kwargs)) + def wrapper(*args, **kwargs): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = func(*args, **kwargs) + + # handle async generator (STREAMING operation) + if inspect.isasyncgen(result): + async_gen = result + + def sync_generator(): + try: + while True: + item = loop.run_until_complete(anext(async_gen)) + yield item + except StopAsyncIteration: + pass + finally: + loop.close() + + return sync_generator() + + # handle coroutine (BARRIER operation) + if inspect.iscoroutine(result): + try: + return loop.run_until_complete(result) + finally: + loop.close() + + else: + loop.close() + return result return wrapper diff --git a/requirements.txt b/requirements.txt index fd824606..881bf280 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,8 @@ requests fastapi trafilatura aiohttp +ray +diskcache leidenalg igraph diff --git a/scripts/search/search_uniprot.sh b/scripts/search/search_uniprot.sh new file mode 100644 index 00000000..642040af --- /dev/null +++ b/scripts/search/search_uniprot.sh @@ -0,0 +1,3 @@ +python3 -m graphgen.run \ +--config_file graphgen/configs/search_config.yaml \ +--output_dir cache/