From 899cf01fc734c9520c1d341a3ecf06da75747d74 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 20 Nov 2025 10:38:04 +0800 Subject: [PATCH 01/26] feat: add bucket for map and all-reduce --- graphgen/engine.py | 82 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 79 insertions(+), 3 deletions(-) diff --git a/graphgen/engine.py b/graphgen/engine.py index dad75de5..f2d788c1 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -4,8 +4,14 @@ import threading import traceback +from enum import Enum, auto from functools import wraps -from typing import Any, Callable, List +from typing import Any, Callable, List, Optional + + +class AggMode(Enum): + MAP = auto() # whenever upstream produces a result, run + ALL_REDUCE = auto() # wait for all upstream results, then run class Context(dict): @@ -22,9 +28,18 @@ def get(self, k, default=None): class OpNode: def __init__( - self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any] + self, + name: str, + deps: List[str], + compute_func: Callable[["OpNode", Context], Any], + callback_func: Optional[Callable[["OpNode", Context, List[Any]], None]] = None, + agg_mode: AggMode = AggMode.ALL_REDUCE, ): - self.name, self.deps, self.func = name, deps, func + self.name = name + self.deps = deps + self.compute_func = compute_func + self.callback_func = callback_func or (lambda self, ctx, results: None) + self.agg_mode = agg_mode def op(name: str, deps=None): @@ -97,6 +112,67 @@ def _exec(n: str): ) +class Bucket: + """ + Bucket for a single operation, collecting computation results and triggering downstream ops + """ + + def __init__( + self, name: str, size: int, mode: AggMode, callback: Callable[[List[Any]], None] + ): + self.name = name + self.size = size + self.mode = mode + self.callback = callback + self._lock = threading.Lock() + self._results: List[Any] = [] + self._done = False + + def put(self, result: Any): + with self._lock: + if self._done: + return + self._results.append(result) + + if self.mode == AggMode.MAP or len(self._results) == self.size: + self._fire() + + def _fire(self): + self._done = True + threading.Thread(target=self._callback_wrapper, daemon=True).start() + + def _callback_wrapper(self): + try: + self.callback(self._results) + except Exception: # pylint: disable=broad-except + traceback.print_exc() + + +class BucketManager: + def __init__(self): + self._buckets: dict[str, Bucket] = {} + self._lock = threading.Lock() + + def register( + self, + node_name: str, + bucket_size: int, + mode: AggMode, + callback: Callable[[List[Any]], None], + ): + with self._lock: + if node_name in self._buckets: + raise RuntimeError(f"Bucket {node_name} already registered") + self._buckets[node_name] = Bucket( + name=node_name, size=bucket_size, mode=mode, callback=callback + ) + return self._buckets[node_name] + + def get(self, node_name: str) -> Optional[Bucket]: + with self._lock: + return self._buckets.get(node_name) + + def collect_ops(config: dict, graph_gen) -> List[OpNode]: """ build operation nodes from yaml config From 8d7f2c52880672bd6bcd30e61a65cb3b61ea0747 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 20 Nov 2025 11:59:20 +0800 Subject: [PATCH 02/26] feat: stream reading files --- graphgen/operators/read/read_files.py | 46 ++++++++++++--------------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/graphgen/operators/read/read_files.py b/graphgen/operators/read/read_files.py index b940b439..c82ca2a5 100644 --- a/graphgen/operators/read/read_files.py +++ b/graphgen/operators/read/read_files.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Iterator, List, Optional from graphgen.models import ( CSVReader, @@ -39,10 +39,10 @@ def read_files( input_file: str, allowed_suffix: Optional[List[str]] = None, cache_dir: Optional[str] = None, -) -> list[dict]: +) -> Iterator[list[dict]]: path = Path(input_file).expanduser() if not path.exists(): - raise FileNotFoundError(f"input_path not found: {input_file}") + raise FileNotFoundError(f"[Reader] input_path not found: {input_file}") if allowed_suffix is None: support_suffix = set(_MAPPING.keys()) @@ -54,33 +54,27 @@ def read_files( suffix = path.suffix.lstrip(".").lower() if suffix not in support_suffix: logger.warning( - "Skip file %s (suffix '%s' not in allowed_suffix %s)", + "[Reader] Skip file %s (suffix '%s' not in allowed_suffix %s)", path, suffix, support_suffix, ) - return [] + return reader = _build_reader(suffix, cache_dir) - return reader.read(str(path)) + logger.info("[Reader] Reading file %s", path) + yield reader.read(str(path)) + return # folder - files_to_read = [ - p for p in path.rglob("*") if p.suffix.lstrip(".").lower() in support_suffix - ] - logger.info( - "Found %d eligible file(s) under folder %s (allowed_suffix=%s)", - len(files_to_read), - input_file, - support_suffix, - ) - - all_docs: List[Dict[str, Any]] = [] - for p in files_to_read: - try: - suffix = p.suffix.lstrip(".").lower() - reader = _build_reader(suffix, cache_dir) - all_docs.extend(reader.read(str(p))) - except Exception as e: # pylint: disable=broad-except - logger.exception("Error reading %s: %s", p, e) - - return all_docs + logger.info("[Reader] Streaming directory %s", path) + for p in path.rglob("*"): + if p.is_file() and p.suffix.lstrip(".").lower() in support_suffix: + try: + suffix = p.suffix.lstrip(".").lower() + reader = _build_reader(suffix, cache_dir) + logger.info("[Reader] Reading file %s", p) + docs = reader.read(str(p)) + if docs: + yield docs + except Exception: # pylint: disable=broad-except + logger.exception("[Reader] Error reading %s", p) From 39272390675a73e9ac2a637b4a87b685bf3c2738 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 20 Nov 2025 12:33:53 +0800 Subject: [PATCH 03/26] fix: fix params in collect_ops --- graphgen/engine.py | 88 +++++++++++++++++++++++++++++++--------------- 1 file changed, 60 insertions(+), 28 deletions(-) diff --git a/graphgen/engine.py b/graphgen/engine.py index f2d788c1..aad4f1ea 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -42,7 +42,7 @@ def __init__( self.agg_mode = agg_mode -def op(name: str, deps=None): +def op(name: str, deps=None, agg_mode: AggMode = AggMode.ALL_REDUCE): deps = deps or [] def decorator(func): @@ -50,7 +50,13 @@ def decorator(func): def _wrapper(*args, **kwargs): return func(*args, **kwargs) - _wrapper.op_node = OpNode(name, deps, lambda self, ctx: func(self, **ctx)) + _wrapper.op_node = OpNode( + name=name, + deps=deps, + compute_func=lambda self, ctx: func(self), + callback_func=lambda self, ctx, results: None, + agg_mode=agg_mode, + ) return _wrapper return decorator @@ -59,48 +65,45 @@ def _wrapper(*args, **kwargs): class Engine: def __init__(self, max_workers: int = 4): self.max_workers = max_workers + self.bucket_mgr = BucketManager() def run(self, ops: List[OpNode], ctx: Context): - name2op = {operation.name: operation for operation in ops} - - # topological sort - graph = {n: set(name2op[n].deps) for n in name2op} - topo = [] - q = [n for n, d in graph.items() if not d] - while q: - cur = q.pop(0) - topo.append(cur) - for child in [c for c, d in graph.items() if cur in d]: - graph[child].remove(cur) - if not graph[child]: - q.append(child) + name2op = {op.name: op for op in ops} + topo_names = [op.name for op in self._topo_sort(ops)] - if len(topo) != len(ops): - raise ValueError( - "Cyclic dependencies detected among operations." - "Please check your configuration." - ) - - # semaphore for max_workers sem = threading.Semaphore(self.max_workers) done = {n: threading.Event() for n in name2op} exc = {} + for node in ops: + bucket_size = ctx.get(f"_bucket_size_{node.name}", 1) + self.bucket_mgr.register( + node.name, + bucket_size, + node.agg_mode, + lambda results, n=node: self._callback_wrapper(n, ctx, results), + ) + def _exec(n: str): with sem: for d in name2op[n].deps: done[d].wait() if any(d in exc for d in name2op[n].deps): - exc[n] = Exception("Skipped due to failed dependencies") + exc[n] = "Skipped due to failed dependencies" done[n].set() return + try: - name2op[n].func(name2op[n], ctx) + name2op[n].compute_func(name2op[n], ctx) except Exception: # pylint: disable=broad-except exc[n] = traceback.format_exc() - done[n].set() + finally: + done[n].set() - ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo] + ts = [ + threading.Thread(target=_exec, args=(name,), daemon=True) + for name in topo_names + ] for t in ts: t.start() for t in ts: @@ -111,6 +114,34 @@ def _exec(n: str): + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items()) ) + @staticmethod + def _callback_wrapper(node: OpNode, ctx: Context, results: List[Any]): + try: + node.callback_func(node, ctx, results) + except Exception: # pylint: disable=broad-except + traceback.print_exc() + + @staticmethod + def _topo_sort(ops: List[OpNode]) -> List[OpNode]: + name2op = {operation.name: operation for operation in ops} + graph = {n: set(name2op[n].deps) for n in name2op} + topo = [] + q = [n for n, d in graph.items() if not d] + while q: + cur = q.pop(0) + topo.append(name2op[cur]) + for child in [c for c, d in graph.items() if cur in d]: + graph[child].remove(cur) + if not graph[child]: + q.append(child) + + if len(topo) != len(ops): + raise ValueError( + "Cyclic dependencies detected among operations." + "Please check your configuration." + ) + return topo + class Bucket: """ @@ -190,8 +221,9 @@ def collect_ops(config: dict, graph_gen) -> List[OpNode]: op_node.deps = runtime_deps if "params" in stage: - op_node.func = lambda self, ctx, m=method, sc=stage: m(sc.get("params", {})) + params = stage["params"] + op_node.compute_func = lambda self, ctx, m=method, p=params: m(p) else: - op_node.func = lambda self, ctx, m=method: m() + op_node.compute_func = lambda self, ctx, m=method: m() ops.append(op_node) return ops From 166ecaf8218a77b23ab68e40123b644072714611 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 20 Nov 2025 15:50:31 +0800 Subject: [PATCH 04/26] refactor: refactor engine to dataflow orchestration --- graphgen/engine.py | 337 +++++++++++++++++++++++++-------------------- 1 file changed, 190 insertions(+), 147 deletions(-) diff --git a/graphgen/engine.py b/graphgen/engine.py index aad4f1ea..a7fd723c 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -1,17 +1,25 @@ """ orchestration engine for GraphGen """ - +import queue import threading import traceback from enum import Enum, auto from functools import wraps -from typing import Any, Callable, List, Optional +from typing import Callable, Dict, List + + +class OpType(Enum): + STREAMING = auto() # once data from upstream arrives, process it immediately + BARRIER = auto() # wait for all upstream data to arrive before processing + + # TODO: implement batch processing + # BATCH = auto() # process data in batches -class AggMode(Enum): - MAP = auto() # whenever upstream produces a result, run - ALL_REDUCE = auto() # wait for all upstream results, then run +# signals the end of a data stream +class EndOfStream: + pass class Context(dict): @@ -31,18 +39,16 @@ def __init__( self, name: str, deps: List[str], - compute_func: Callable[["OpNode", Context], Any], - callback_func: Optional[Callable[["OpNode", Context, List[Any]], None]] = None, - agg_mode: AggMode = AggMode.ALL_REDUCE, + func: Callable, + op_type: OpType = OpType.BARRIER, # use barrier by default ): self.name = name self.deps = deps - self.compute_func = compute_func - self.callback_func = callback_func or (lambda self, ctx, results: None) - self.agg_mode = agg_mode + self.func = func + self.op_type = op_type -def op(name: str, deps=None, agg_mode: AggMode = AggMode.ALL_REDUCE): +def op(name: str, deps=None, op_type: OpType = OpType.BARRIER): deps = deps or [] def decorator(func): @@ -51,11 +57,10 @@ def _wrapper(*args, **kwargs): return func(*args, **kwargs) _wrapper.op_node = OpNode( - name=name, - deps=deps, - compute_func=lambda self, ctx: func(self), - callback_func=lambda self, ctx, results: None, - agg_mode=agg_mode, + name, + deps, + func, + op_type=op_type, ) return _wrapper @@ -63,145 +68,157 @@ def _wrapper(*args, **kwargs): class Engine: - def __init__(self, max_workers: int = 4): - self.max_workers = max_workers - self.bucket_mgr = BucketManager() + def __init__(self, queue_size: int = 100): + self.queue_size = queue_size - def run(self, ops: List[OpNode], ctx: Context): - name2op = {op.name: op for op in ops} - topo_names = [op.name for op in self._topo_sort(ops)] - - sem = threading.Semaphore(self.max_workers) - done = {n: threading.Event() for n in name2op} - exc = {} - - for node in ops: - bucket_size = ctx.get(f"_bucket_size_{node.name}", 1) - self.bucket_mgr.register( - node.name, - bucket_size, - node.agg_mode, - lambda results, n=node: self._callback_wrapper(n, ctx, results), + @staticmethod + def _topo_sort(name2op: Dict[str, OpNode]) -> List[str]: + adj = {n: [] for n in name2op} + in_degree = {n: 0 for n in name2op} + + for name, operation in name2op.items(): + for dep_name in operation.deps: + if dep_name not in name2op: + raise ValueError(f"Dependency {dep_name} of {name} not found") + adj[dep_name].append(name) + in_degree[name] += 1 + + # Kahn's algorithm for topological sorting + queue_nodes = [n for n in name2op if in_degree[n] == 0] + topo_order = [] + + while queue_nodes: + u = queue_nodes.pop(0) + topo_order.append(u) + + for v in adj[u]: + in_degree[v] -= 1 + if in_degree[v] == 0: + queue_nodes.append(v) + + # cycle detection + if len(topo_order) != len(name2op): + cycle_nodes = set(name2op.keys()) - set(topo_order) + raise ValueError(f"Cyclic dependency detected among: {cycle_nodes}") + return topo_order + + def _build_channels(self, name2op): + """Return channels / consumers_of / producer_counts""" + channels, consumers_of, producer_counts = {}, {}, {n: 0 for n in name2op} + for name, operator in name2op.items(): + consumers_of[name] = [] + for dep in operator.deps: + if dep not in name2op: + raise ValueError(f"Dependency {dep} of {name} not found") + channels[(dep, name)] = queue.Queue(maxsize=self.queue_size) + consumers_of[dep].append(name) + producer_counts[name] += 1 + return channels, consumers_of, producer_counts + + def _run_workers(self, ordered_ops, channels, consumers_of, producer_counts, ctx): + """Run worker threads for each operation node.""" + exceptions, threads = {}, [] + for node in ordered_ops: + t = threading.Thread( + target=self._worker_loop, + args=(node, channels, consumers_of, producer_counts, ctx, exceptions), + daemon=True, ) - - def _exec(n: str): - with sem: - for d in name2op[n].deps: - done[d].wait() - if any(d in exc for d in name2op[n].deps): - exc[n] = "Skipped due to failed dependencies" - done[n].set() - return - - try: - name2op[n].compute_func(name2op[n], ctx) - except Exception: # pylint: disable=broad-except - exc[n] = traceback.format_exc() - finally: - done[n].set() - - ts = [ - threading.Thread(target=_exec, args=(name,), daemon=True) - for name in topo_names - ] - for t in ts: t.start() - for t in ts: + threads.append(t) + for t in threads: t.join() - if exc: - raise RuntimeError( - "Some operations failed:\n" - + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items()) - ) - - @staticmethod - def _callback_wrapper(node: OpNode, ctx: Context, results: List[Any]): - try: - node.callback_func(node, ctx, results) - except Exception: # pylint: disable=broad-except - traceback.print_exc() - - @staticmethod - def _topo_sort(ops: List[OpNode]) -> List[OpNode]: - name2op = {operation.name: operation for operation in ops} - graph = {n: set(name2op[n].deps) for n in name2op} - topo = [] - q = [n for n, d in graph.items() if not d] - while q: - cur = q.pop(0) - topo.append(name2op[cur]) - for child in [c for c, d in graph.items() if cur in d]: - graph[child].remove(cur) - if not graph[child]: - q.append(child) - - if len(topo) != len(ops): - raise ValueError( - "Cyclic dependencies detected among operations." - "Please check your configuration." - ) - return topo - + return exceptions -class Bucket: - """ - Bucket for a single operation, collecting computation results and triggering downstream ops - """ - - def __init__( - self, name: str, size: int, mode: AggMode, callback: Callable[[List[Any]], None] + def _worker_loop( + self, node, channels, consumers_of, producer_counts, ctx, exceptions ): - self.name = name - self.size = size - self.mode = mode - self.callback = callback - self._lock = threading.Lock() - self._results: List[Any] = [] - self._done = False - - def put(self, result: Any): - with self._lock: - if self._done: - return - self._results.append(result) + op_name = node.name - if self.mode == AggMode.MAP or len(self._results) == self.size: - self._fire() + def input_generator(): + # if no dependencies, yield None once + if not node.deps: + yield None + return - def _fire(self): - self._done = True - threading.Thread(target=self._callback_wrapper, daemon=True).start() + active_producers = producer_counts[op_name] + # collect all queues + input_queues = [channels[(dep_name, op_name)] for dep_name in node.deps] + + # loop until all producers are done + while active_producers > 0: + got_data = False + for q in input_queues: + try: + item = q.get(timeout=0.1) + if isinstance(item, EndOfStream): + active_producers -= 1 + else: + yield item + got_data = True + except queue.Empty: + continue + + if not got_data and active_producers > 0: + # barrier wait on the first active queue + item = input_queues[0].get() + if isinstance(item, EndOfStream): + active_producers -= 1 + else: + yield item + + in_stream = input_generator() - def _callback_wrapper(self): try: - self.callback(self._results) + # execute the operation + result_iter = [] + if node.op_type == OpType.BARRIER: + # consume all input + buffered_inputs = list(in_stream) + res = node.func(self, ctx, inputs=buffered_inputs) + if res is not None: + result_iter = res if isinstance(res, (list, tuple)) else [res] + + elif node.op_type == OpType.STREAMING: + # process input one by one + res = node.func(self, ctx, input_stream=in_stream) + if res is not None: + result_iter = res + else: + raise ValueError(f"Unknown OpType {node.op_type} for {op_name}") + + # output dispatch, send results to downstream consumers + if result_iter: + for item in result_iter: + for consumer_name in consumers_of[op_name]: + channels[(op_name, consumer_name)].put(item) + except Exception: # pylint: disable=broad-except traceback.print_exc() + exceptions[op_name] = traceback.format_exc() + finally: + # signal end of stream to downstream consumers + for consumer_name in consumers_of[op_name]: + channels[(op_name, consumer_name)].put(EndOfStream()) -class BucketManager: - def __init__(self): - self._buckets: dict[str, Bucket] = {} - self._lock = threading.Lock() + def run(self, ops: List[OpNode], ctx: Context): + name2op = {op.name: op for op in ops} - def register( - self, - node_name: str, - bucket_size: int, - mode: AggMode, - callback: Callable[[List[Any]], None], - ): - with self._lock: - if node_name in self._buckets: - raise RuntimeError(f"Bucket {node_name} already registered") - self._buckets[node_name] = Bucket( - name=node_name, size=bucket_size, mode=mode, callback=callback - ) - return self._buckets[node_name] + # Step 1: topo sort and validate + sorted_op_names = self._topo_sort(name2op) - def get(self, node_name: str) -> Optional[Bucket]: - with self._lock: - return self._buckets.get(node_name) + # Step 2: build channels and tracking structures + channels, consumers_of, producer_counts = self._build_channels(name2op) + + # Step3: start worker threads using topo order + ordered_ops = [name2op[name] for name in sorted_op_names] + exceptions = self._run_workers( + ordered_ops, channels, consumers_of, producer_counts, ctx + ) + + if exceptions: + raise RuntimeError(f"Engine encountered exceptions: {exceptions}") def collect_ops(config: dict, graph_gen) -> List[OpNode]: @@ -217,13 +234,39 @@ def collect_ops(config: dict, graph_gen) -> List[OpNode]: op_node = method.op_node # if there are runtime dependencies, override them - runtime_deps = stage.get("deps", op_node.deps) - op_node.deps = runtime_deps + deps = stage.get("deps", op_node.deps) + op_type = op_node.op_type + + if op_type == OpType.BARRIER: + if "params" in stage: + + def func(self, ctx, inputs, m=method, sc=stage): + return m(sc.get("params", {}), inputs=inputs) + + else: + + def func(self, ctx, inputs, m=method): + return m(inputs=inputs) + + elif op_type == OpType.STREAMING: + if "params" in stage: + + def func(self, ctx, input_stream, m=method, sc=stage): + return m(sc.get("params", {}), input_stream=input_stream) + + else: + + def func(self, ctx, input_stream, m=method): + return m(input_stream=input_stream) - if "params" in stage: - params = stage["params"] - op_node.compute_func = lambda self, ctx, m=method, p=params: m(p) else: - op_node.compute_func = lambda self, ctx, m=method: m() - ops.append(op_node) + raise ValueError(f"Unknown OpType {op_type} for operation {name}") + + new_node = OpNode( + name=name, + deps=deps, + func=func, + op_type=op_type, + ) + ops.append(new_node) return ops From fa23d32c87df6f37483a6a96fd5f8451f1772bd2 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 20 Nov 2025 17:16:09 +0800 Subject: [PATCH 05/26] fix: use default if input_stream is not givin --- graphgen/engine.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/graphgen/engine.py b/graphgen/engine.py index a7fd723c..589dd66d 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -1,6 +1,7 @@ """ orchestration engine for GraphGen """ +import inspect import queue import threading import traceback @@ -218,7 +219,13 @@ def run(self, ops: List[OpNode], ctx: Context): ) if exceptions: - raise RuntimeError(f"Engine encountered exceptions: {exceptions}") + error_msgs = "\n".join( + [ + f"Operation {name} failed with error:\n{msg}" + for name, msg in exceptions.items() + ] + ) + raise RuntimeError(f"Engine encountered exceptions:\n{error_msgs}") def collect_ops(config: dict, graph_gen) -> List[OpNode]: @@ -237,6 +244,9 @@ def collect_ops(config: dict, graph_gen) -> List[OpNode]: deps = stage.get("deps", op_node.deps) op_type = op_node.op_type + sig = inspect.signature(method) + accepts_input_stream = "input_stream" in sig.parameters + if op_type == OpType.BARRIER: if "params" in stage: @@ -250,14 +260,26 @@ def func(self, ctx, inputs, m=method): elif op_type == OpType.STREAMING: if "params" in stage: + if accepts_input_stream: + + def func(self, ctx, input_stream, m=method, sc=stage): + return m(sc.get("params", {}), input_stream=input_stream) - def func(self, ctx, input_stream, m=method, sc=stage): - return m(sc.get("params", {}), input_stream=input_stream) + else: + + def func(self, ctx, input_stream, m=method, sc=stage): + return m(sc.get("params", {})) else: + if accepts_input_stream: + + def func(self, ctx, input_stream, m=method): + return m(input_stream=input_stream) + + else: - def func(self, ctx, input_stream, m=method): - return m(input_stream=input_stream) + def func(self, ctx, input_stream, m=method): + return m() else: raise ValueError(f"Unknown OpType {op_type} for operation {name}") From 543c5d90973b76ffedbd3eab49d98f7fb42282e1 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 20 Nov 2025 17:25:06 +0800 Subject: [PATCH 06/26] refactor: rename chunk_document.py --- graphgen/bases/base_splitter.py | 4 ++-- graphgen/operators/__init__.py | 2 +- graphgen/operators/chunk/__init__.py | 1 + .../{split/split_chunks.py => chunk/chunk_documents.py} | 9 +++------ graphgen/operators/split/__init__.py | 1 - 5 files changed, 7 insertions(+), 10 deletions(-) create mode 100644 graphgen/operators/chunk/__init__.py rename graphgen/operators/{split/split_chunks.py => chunk/chunk_documents.py} (92%) delete mode 100644 graphgen/operators/split/__init__.py diff --git a/graphgen/bases/base_splitter.py b/graphgen/bases/base_splitter.py index b2d1ad3a..e6b31b79 100644 --- a/graphgen/bases/base_splitter.py +++ b/graphgen/bases/base_splitter.py @@ -33,7 +33,7 @@ def split_text(self, text: str) -> List[str]: """ Split the input text into smaller chunks. - :param text: The input text to be split. + :param text: The input text to be chunk. :return: A list of text chunks. """ @@ -111,7 +111,7 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: def _split_text_with_regex( text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]] ) -> List[str]: - # Now that we have the separator, split the text + # Now that we have the separator, chunk the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index a9ce24cd..be969d71 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,4 +1,5 @@ from .build_kg import build_kg +from .chunk import chunk_documents from .extract import extract_info from .generate import generate_qas from .init import init_llm @@ -7,4 +8,3 @@ from .quiz import quiz from .read import read_files from .search import search_all -from .split import chunk_documents diff --git a/graphgen/operators/chunk/__init__.py b/graphgen/operators/chunk/__init__.py new file mode 100644 index 00000000..0f4c92ac --- /dev/null +++ b/graphgen/operators/chunk/__init__.py @@ -0,0 +1 @@ +from .chunk_documents import chunk_documents diff --git a/graphgen/operators/split/split_chunks.py b/graphgen/operators/chunk/chunk_documents.py similarity index 92% rename from graphgen/operators/split/split_chunks.py rename to graphgen/operators/chunk/chunk_documents.py index 3f728e00..280279de 100644 --- a/graphgen/operators/split/split_chunks.py +++ b/graphgen/operators/chunk/chunk_documents.py @@ -1,8 +1,6 @@ from functools import lru_cache from typing import Union -from tqdm.asyncio import tqdm as tqdm_async - from graphgen.models import ( ChineseRecursiveTextSplitter, RecursiveCharacterSplitter, @@ -38,7 +36,7 @@ def split_chunks(text: str, language: str = "en", **kwargs) -> list: return splitter.split_text(text) -async def chunk_documents( +def chunk_documents( new_docs: dict, tokenizer_instance: Tokenizer = None, progress_bar=None, @@ -47,9 +45,8 @@ async def chunk_documents( inserting_chunks = {} cur_index = 1 doc_number = len(new_docs) - async for doc_key, doc in tqdm_async( - new_docs.items(), desc="[1/4]Chunking documents", unit="doc" - ): + + for doc_key, doc in new_docs.items(): doc_type = doc.get("type") if doc_type == "text": doc_language = detect_main_language(doc["content"]) diff --git a/graphgen/operators/split/__init__.py b/graphgen/operators/split/__init__.py deleted file mode 100644 index 2afc738d..00000000 --- a/graphgen/operators/split/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .split_chunks import chunk_documents From e18fe7a56f5eea799300b6cb9b9b8bf8f1f05c34 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 20 Nov 2025 18:58:34 +0800 Subject: [PATCH 07/26] feat: handle async generator --- graphgen/utils/wrap.py | 42 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/graphgen/utils/wrap.py b/graphgen/utils/wrap.py index 57776f22..efcf320a 100644 --- a/graphgen/utils/wrap.py +++ b/graphgen/utils/wrap.py @@ -1,13 +1,43 @@ +import asyncio +import inspect from functools import wraps -from typing import Any, Callable -from .loop import create_event_loop +def async_to_sync_method(func): + """Convert async method to sync method, handling both coroutines and async generators.""" -def async_to_sync_method(func: Callable) -> Callable: @wraps(func) - def wrapper(self, *args, **kwargs) -> Any: - loop = create_event_loop() - return loop.run_until_complete(func(self, *args, **kwargs)) + def wrapper(*args, **kwargs): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = func(*args, **kwargs) + + # handle async generator (STREAMING operation) + if inspect.isasyncgen(result): + async_gen = result + + def sync_generator(): + try: + while True: + item = loop.run_until_complete(anext(async_gen)) + yield item + except StopAsyncIteration: + pass + finally: + loop.close() + + return sync_generator() + + # handle coroutine (BARRIER operation) + if inspect.iscoroutine(result): + try: + return loop.run_until_complete(result) + finally: + loop.close() + + else: + loop.close() + return result return wrapper From 9cc14793f8555bce994498b9ca5d3e03fa01d575 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 20 Nov 2025 19:02:50 +0800 Subject: [PATCH 08/26] wip --- graphgen/engine.py | 43 +++++++++++-- graphgen/graphgen.py | 149 ++++++++++++++++++++++++------------------- graphgen/run.py | 2 +- 3 files changed, 125 insertions(+), 69 deletions(-) diff --git a/graphgen/engine.py b/graphgen/engine.py index 589dd66d..6ddf565d 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -13,9 +13,7 @@ class OpType(Enum): STREAMING = auto() # once data from upstream arrives, process it immediately BARRIER = auto() # wait for all upstream data to arrive before processing - - # TODO: implement batch processing - # BATCH = auto() # process data in batches + BATCH = auto() # process data in batches when threshold is reached # signals the end of a data stream @@ -42,14 +40,16 @@ def __init__( deps: List[str], func: Callable, op_type: OpType = OpType.BARRIER, # use barrier by default + batch_size: int = 32, # default batch size for BATCH operations ): self.name = name self.deps = deps self.func = func self.op_type = op_type + self.batch_size = batch_size -def op(name: str, deps=None, op_type: OpType = OpType.BARRIER): +def op(name: str, deps=None, op_type: OpType = OpType.BARRIER, batch_size: int = 32): deps = deps or [] def decorator(func): @@ -62,6 +62,7 @@ def _wrapper(*args, **kwargs): deps, func, op_type=op_type, + batch_size=batch_size, ) return _wrapper @@ -185,6 +186,27 @@ def input_generator(): res = node.func(self, ctx, input_stream=in_stream) if res is not None: result_iter = res + + elif node.op_type == OpType.BATCH: + # accumulate inputs into batches and process + batch = [] + for item in in_stream: + batch.append(item) + if len(batch) >= node.batch_size: + res = node.func(self, ctx, inputs=batch) + if res is not None: + result_iter.extend( + res if isinstance(res, (list, tuple)) else [res] + ) + batch = [] + # process remaining items + if batch: + res = node.func(self, ctx, inputs=batch) + if res is not None: + result_iter.extend( + res if isinstance(res, (list, tuple)) else [res] + ) + else: raise ValueError(f"Unknown OpType {node.op_type} for {op_name}") @@ -243,6 +265,7 @@ def collect_ops(config: dict, graph_gen) -> List[OpNode]: # if there are runtime dependencies, override them deps = stage.get("deps", op_node.deps) op_type = op_node.op_type + batch_size = stage.get("batch_size", op_node.batch_size) sig = inspect.signature(method) accepts_input_stream = "input_stream" in sig.parameters @@ -281,6 +304,17 @@ def func(self, ctx, input_stream, m=method): def func(self, ctx, input_stream, m=method): return m() + elif op_type == OpType.BATCH: + if "params" in stage: + + def func(self, ctx, inputs, m=method, sc=stage): + return m(sc.get("params", {}), inputs=inputs) + + else: + + def func(self, ctx, inputs, m=method): + return m(inputs=inputs) + else: raise ValueError(f"Unknown OpType {op_type} for operation {name}") @@ -289,6 +323,7 @@ def func(self, ctx, input_stream, m=method): deps=deps, func=func, op_type=op_type, + batch_size=batch_size, ) ops.append(new_node) return ops diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 108d8795..b78dc77d 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -1,12 +1,12 @@ import os import time -from typing import Dict +from typing import Dict, Iterator, List import gradio as gr from graphgen.bases import BaseLLMWrapper from graphgen.bases.datatypes import Chunk -from graphgen.engine import op +from graphgen.engine import OpType, op from graphgen.models import ( JsonKVStorage, JsonListStorage, @@ -88,74 +88,95 @@ def __init__( # webui self.progress_bar: gr.Progress = progress_bar - @op("read", deps=[]) + @op("read", deps=[], op_type=OpType.STREAMING) @async_to_sync_method async def read(self, read_config: Dict): """ read files from input sources """ - data = read_files(**read_config, cache_dir=self.working_dir) - if len(data) == 0: - logger.warning("No data to process") - return - - assert isinstance(data, list) and isinstance(data[0], dict) + count = 0 + for docs in read_files(**read_config, cache_dir=self.working_dir): + if not docs: + continue + new_docs = {compute_mm_hash(d, prefix="doc-"): d for d in docs} + _add_doc_keys = await self.full_docs_storage.filter_keys( + list(new_docs.keys()) + ) + new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} + + if new_docs: + await self.full_docs_storage.upsert(new_docs) + await self.full_docs_storage.index_done_callback() + for doc_id in new_docs.keys(): + yield doc_id + + count += len(new_docs) + logger.info( + "[Read] Yielded %d new documents, total %d", len(new_docs), count + ) + + if count == 0: + logger.warning("[Read] No new documents to process") # TODO: configurable whether to use coreference resolution - new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data} - _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys())) - new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} - - if len(new_docs) == 0: - logger.warning("All documents are already in the storage") - return - - await self.full_docs_storage.upsert(new_docs) - await self.full_docs_storage.index_done_callback() - - @op("chunk", deps=["read"]) + @op("chunk", deps=["read"], op_type=OpType.STREAMING) @async_to_sync_method - async def chunk(self, chunk_config: Dict): + async def chunk(self, chunk_config: Dict, input_stream: Iterator): """ chunk documents into smaller pieces from full_docs_storage if not already present """ - - new_docs = await self.meta_storage.get_new_data(self.full_docs_storage) - if len(new_docs) == 0: - logger.warning("All documents are already in the storage") - return - - inserting_chunks = await chunk_documents( - new_docs, - self.tokenizer_instance, - self.progress_bar, - **chunk_config, - ) - - _add_chunk_keys = await self.chunks_storage.filter_keys( - list(inserting_chunks.keys()) - ) - inserting_chunks = { - k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys - } - - if len(inserting_chunks) == 0: - logger.warning("All chunks are already in the storage") - return - - await self.chunks_storage.upsert(inserting_chunks) - await self.chunks_storage.index_done_callback() - await self.meta_storage.mark_done(self.full_docs_storage) - await self.meta_storage.index_done_callback() - - @op("build_kg", deps=["chunk"]) + count = 0 + for doc_id in input_stream: + doc = await self.full_docs_storage.get_by_id(doc_id) + if not doc: + logger.warning( + "[Chunk] Document %s not found in full_docs_storage", doc_id + ) + continue + + inserting_chunks = chunk_documents( + {doc_id: doc}, + self.tokenizer_instance, + self.progress_bar, + **chunk_config, + ) + + _add_chunk_keys = await self.chunks_storage.filter_keys( + list(inserting_chunks.keys()) + ) + inserting_chunks = { + k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys + } + + if inserting_chunks: + await self.chunks_storage.upsert(inserting_chunks) + await self.chunks_storage.index_done_callback() + count += len(inserting_chunks) + logger.info( + "[Chunk] Yielded %d new chunks for document %s, total %d", + len(inserting_chunks), + doc_id, + count, + ) + for _chunk_id in inserting_chunks.keys(): + yield _chunk_id + else: + logger.info( + "[Chunk] All chunks for document %s are already in the storage", + doc_id, + ) + if count == 0: + logger.warning("[Chunk] No new chunks to process") + + @op("build_kg", deps=["chunk"], op_type=OpType.BATCH, batch_size=32) @async_to_sync_method - async def build_kg(self): + async def build_kg(self, inputs: List): """ build knowledge graph from text chunks """ - # Step 1: get new chunks according to meta and chunks storage + # Step 1: get chunks + inserting_chunks = await self.meta_storage.get_new_data(self.chunks_storage) if len(inserting_chunks) == 0: logger.warning("All chunks are already in the storage") @@ -180,9 +201,9 @@ async def build_kg(self): return _add_entities_and_relations - @op("search", deps=["read"]) + @op("search", deps=["read"], op_type=OpType.STREAMING) @async_to_sync_method - async def search(self, search_config: Dict): + async def search(self, search_config: Dict, input_stream: Iterator): logger.info("[Search] %s ...", ", ".join(search_config["data_sources"])) seeds = await self.meta_storage.get_new_data(self.full_docs_storage) @@ -208,9 +229,9 @@ async def search(self, search_config: Dict): await self.meta_storage.mark_done(self.full_docs_storage) await self.meta_storage.index_done_callback() - @op("quiz_and_judge", deps=["build_kg"]) + @op("quiz_and_judge", deps=["build_kg"], op_type=OpType.BARRIER) @async_to_sync_method - async def quiz_and_judge(self, quiz_and_judge_config: Dict): + async def quiz_and_judge(self, quiz_and_judge_config: Dict, inputs: None): logger.warning( "Quiz and Judge operation needs trainee LLM client." " Make sure to provide one." @@ -247,9 +268,9 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): logger.info("Restarting synthesizer LLM client.") self.synthesizer_llm_client.restart() - @op("partition", deps=["build_kg"]) + @op("partition", deps=["build_kg"], op_type=OpType.BARRIER) @async_to_sync_method - async def partition(self, partition_config: Dict): + async def partition(self, partition_config: Dict, inputs: None): batches = await partition_kg( self.graph_storage, self.chunks_storage, @@ -259,9 +280,9 @@ async def partition(self, partition_config: Dict): await self.partition_storage.upsert(batches) return batches - @op("extract", deps=["chunk"]) + @op("extract", deps=["chunk"], op_type=OpType.STREAMING) @async_to_sync_method - async def extract(self, extract_config: Dict): + async def extract(self, extract_config: Dict, input_stream: Iterator): logger.info("Extracting information from given chunks...") results = await extract_info( @@ -279,9 +300,9 @@ async def extract(self, extract_config: Dict): await self.meta_storage.mark_done(self.chunks_storage) await self.meta_storage.index_done_callback() - @op("generate", deps=["partition"]) + @op("generate", deps=["partition"], op_type=OpType.BARRIER) @async_to_sync_method - async def generate(self, generate_config: Dict): + async def generate(self, generate_config: Dict, inputs: None): batches = self.partition_storage.data if not batches: diff --git a/graphgen/run.py b/graphgen/run.py index c300a6aa..c710b3ca 100644 --- a/graphgen/run.py +++ b/graphgen/run.py @@ -73,7 +73,7 @@ def main(): ops = collect_ops(config, graph_gen) # run operations - Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx) + Engine(queue_size=config.get("queue_size", 100)).run(ops, ctx) save_config(os.path.join(output_path, "config.yaml"), config) logger.info("GraphGen completed successfully. Data saved to %s", output_path) From 4b3d9d90e3334edfc8f8c819fc26eecbf196e6dc Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 20 Nov 2025 19:56:14 +0800 Subject: [PATCH 09/26] feat: adapt read, chunk, build_kg operators to new optypes --- graphgen/graphgen.py | 38 +++++++++++++++---------- graphgen/models/storage/json_storage.py | 2 +- graphgen/operators/read/read_files.py | 8 +++--- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index b78dc77d..27bfad82 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -125,6 +125,8 @@ async def read(self, read_config: Dict): async def chunk(self, chunk_config: Dict, input_stream: Iterator): """ chunk documents into smaller pieces from full_docs_storage if not already present + input_stream: document IDs from full_docs_storage + yield: chunk IDs inserted into chunks_storage """ count = 0 for doc_id in input_stream: @@ -174,15 +176,22 @@ async def chunk(self, chunk_config: Dict, input_stream: Iterator): async def build_kg(self, inputs: List): """ build knowledge graph from text chunks + inputs: chunk IDs from chunks_storage + return: None """ + count = 0 # Step 1: get chunks + inserting_chunks: Dict[str, Dict] = {} + for _chunk_id in inputs: + chunk = await self.chunks_storage.get_by_id(_chunk_id) + if chunk: + inserting_chunks[_chunk_id] = chunk + + count += len(inserting_chunks) + logger.info( + "[Build KG] Inserting %d chunks, total %d", len(inserting_chunks), count + ) - inserting_chunks = await self.meta_storage.get_new_data(self.chunks_storage) - if len(inserting_chunks) == 0: - logger.warning("All chunks are already in the storage") - return - - logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks)) # Step 2: build knowledge graph from new chunks _add_entities_and_relations = await build_kg( llm_client=self.synthesizer_llm_client, @@ -194,12 +203,8 @@ async def build_kg(self, inputs: List): logger.warning("No entities or relations extracted from text chunks") return - # Step 3: mark meta + # Step 3: store the new entities and relations await self.graph_storage.index_done_callback() - await self.meta_storage.mark_done(self.chunks_storage) - await self.meta_storage.index_done_callback() - - return _add_entities_and_relations @op("search", deps=["read"], op_type=OpType.STREAMING) @async_to_sync_method @@ -231,7 +236,7 @@ async def search(self, search_config: Dict, input_stream: Iterator): @op("quiz_and_judge", deps=["build_kg"], op_type=OpType.BARRIER) @async_to_sync_method - async def quiz_and_judge(self, quiz_and_judge_config: Dict, inputs: None): + async def quiz_and_judge(self, quiz_and_judge_config: Dict): logger.warning( "Quiz and Judge operation needs trainee LLM client." " Make sure to provide one." @@ -270,7 +275,7 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict, inputs: None): @op("partition", deps=["build_kg"], op_type=OpType.BARRIER) @async_to_sync_method - async def partition(self, partition_config: Dict, inputs: None): + async def partition(self, partition_config: Dict): batches = await partition_kg( self.graph_storage, self.chunks_storage, @@ -283,8 +288,11 @@ async def partition(self, partition_config: Dict, inputs: None): @op("extract", deps=["chunk"], op_type=OpType.STREAMING) @async_to_sync_method async def extract(self, extract_config: Dict, input_stream: Iterator): - logger.info("Extracting information from given chunks...") - + """ + Extract information from chunks in chunks_storage + input_stream: chunk IDs from chunks_storage + return: None + """ results = await extract_info( self.synthesizer_llm_client, self.chunks_storage, diff --git a/graphgen/models/storage/json_storage.py b/graphgen/models/storage/json_storage.py index fb37ee29..3f6a7a2b 100644 --- a/graphgen/models/storage/json_storage.py +++ b/graphgen/models/storage/json_storage.py @@ -24,7 +24,7 @@ async def all_keys(self) -> list[str]: async def index_done_callback(self): write_json(self._data, self._file_name) - async def get_by_id(self, id): + async def get_by_id(self, id) -> dict | None: return self._data.get(id, None) async def get_by_ids(self, ids, fields=None) -> list: diff --git a/graphgen/operators/read/read_files.py b/graphgen/operators/read/read_files.py index c82ca2a5..0fc1c2a4 100644 --- a/graphgen/operators/read/read_files.py +++ b/graphgen/operators/read/read_files.py @@ -42,7 +42,7 @@ def read_files( ) -> Iterator[list[dict]]: path = Path(input_file).expanduser() if not path.exists(): - raise FileNotFoundError(f"[Reader] input_path not found: {input_file}") + raise FileNotFoundError(f"[Read] input_path not found: {input_file}") if allowed_suffix is None: support_suffix = set(_MAPPING.keys()) @@ -54,19 +54,19 @@ def read_files( suffix = path.suffix.lstrip(".").lower() if suffix not in support_suffix: logger.warning( - "[Reader] Skip file %s (suffix '%s' not in allowed_suffix %s)", + "[Read] Skip file %s (suffix '%s' not in allowed_suffix %s)", path, suffix, support_suffix, ) return reader = _build_reader(suffix, cache_dir) - logger.info("[Reader] Reading file %s", path) + logger.info("[Read] Reading file %s", path) yield reader.read(str(path)) return # folder - logger.info("[Reader] Streaming directory %s", path) + logger.info("[Read] Streaming directory %s", path) for p in path.rglob("*"): if p.is_file() and p.suffix.lstrip(".").lower() in support_suffix: try: From 8eebad1884a7fec9397e1e0f839d0448ed248499 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 20 Nov 2025 21:22:14 +0800 Subject: [PATCH 10/26] fix: async_lock when using blast search --- .../models/searcher/db/uniprot_searcher.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/graphgen/models/searcher/db/uniprot_searcher.py b/graphgen/models/searcher/db/uniprot_searcher.py index a74b623e..f5542f8c 100644 --- a/graphgen/models/searcher/db/uniprot_searcher.py +++ b/graphgen/models/searcher/db/uniprot_searcher.py @@ -27,12 +27,16 @@ def _get_pool(): return ThreadPoolExecutor(max_workers=10) +# ensure only one BLAST searcher at a time +_blast_lock = asyncio.Lock() + + class UniProtSearch(BaseSearcher): """ UniProt Search client to searcher with UniProt. 1) Get the protein by accession number. 2) Search with keywords or protein names (fuzzy searcher). - 3) Search with FASTA sequence (BLAST searcher). + 3) Search with FASTA sequence (BLAST searcher). Note that NCBIWWW does not support async. """ def __init__(self, use_local_blast: bool = False, local_blast_db: str = "sp_db"): @@ -230,22 +234,21 @@ async def search( if query.startswith(">") or re.fullmatch( r"[ACDEFGHIKLMNPQRSTVWY\s]+", query, re.I ): - coro = loop.run_in_executor( - _get_pool(), self.get_by_fasta, query, threshold - ) + async with _blast_lock: + result = await loop.run_in_executor( + _get_pool(), self.get_by_fasta, query, threshold + ) # check if accession number elif re.fullmatch(r"[A-NR-Z0-9]{6,10}", query, re.I): - coro = loop.run_in_executor(_get_pool(), self.get_by_accession, query) + result = await loop.run_in_executor( + _get_pool(), self.get_by_accession, query + ) else: # otherwise treat as keyword - coro = loop.run_in_executor(_get_pool(), self.get_best_hit, query) + result = await loop.run_in_executor(_get_pool(), self.get_best_hit, query) - result = await coro if result: result["_search_query"] = query return result - - -# TODO: use local UniProt database for large-scale searchs From 1e69b0eff80f56fd23a87fb38f524f64651eb730 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 20 Nov 2025 21:22:46 +0800 Subject: [PATCH 11/26] feat: add search config --- graphgen/configs/search_config.yaml | 2 +- graphgen/graphgen.py | 37 ++++++++++++++++++----------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/graphgen/configs/search_config.yaml b/graphgen/configs/search_config.yaml index 37e65818..ff110786 100644 --- a/graphgen/configs/search_config.yaml +++ b/graphgen/configs/search_config.yaml @@ -1,7 +1,7 @@ pipeline: - name: read params: - input_file: resources/input_examples/search_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples + input_file: resources/input_examples/search_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - name: search params: diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 27bfad82..5b548c5b 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -68,7 +68,8 @@ def __init__( self.working_dir, namespace="graph" ) self.search_storage: JsonKVStorage = JsonKVStorage( - self.working_dir, namespace="search" + os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), + namespace="search", ) self.rephrase_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="rephrase" @@ -206,15 +207,23 @@ async def build_kg(self, inputs: List): # Step 3: store the new entities and relations await self.graph_storage.index_done_callback() - @op("search", deps=["read"], op_type=OpType.STREAMING) + @op("search", deps=["read"], op_type=OpType.BATCH, batch_size=64) @async_to_sync_method - async def search(self, search_config: Dict, input_stream: Iterator): + async def search(self, search_config: Dict, inputs: List): + """ + search new documents from full_docs_storage + input_stream: document IDs from full_docs_storage + return: None + """ logger.info("[Search] %s ...", ", ".join(search_config["data_sources"])) - seeds = await self.meta_storage.get_new_data(self.full_docs_storage) - if len(seeds) == 0: - logger.warning("All documents are already been searched") - return + # Step 1: get documents + seeds = {} + for doc_id in inputs: + doc = await self.full_docs_storage.get_by_id(doc_id) + if doc: + seeds[doc_id] = doc + search_results = await search_all( seed_data=seeds, search_config=search_config, @@ -223,16 +232,15 @@ async def search(self, search_config: Dict, input_stream: Iterator): _add_search_keys = await self.search_storage.filter_keys( list(search_results.keys()) ) + search_results = { k: v for k, v in search_results.items() if k in _add_search_keys } if len(search_results) == 0: - logger.warning("All search results are already in the storage") - return + logger.warning("[Search] No new search results to add to storage") + await self.search_storage.upsert(search_results) await self.search_storage.index_done_callback() - await self.meta_storage.mark_done(self.full_docs_storage) - await self.meta_storage.index_done_callback() @op("quiz_and_judge", deps=["build_kg"], op_type=OpType.BARRIER) @async_to_sync_method @@ -276,6 +284,7 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): @op("partition", deps=["build_kg"], op_type=OpType.BARRIER) @async_to_sync_method async def partition(self, partition_config: Dict): + # TODO: partition 可以yield batches batches = await partition_kg( self.graph_storage, self.chunks_storage, @@ -308,10 +317,10 @@ async def extract(self, extract_config: Dict, input_stream: Iterator): await self.meta_storage.mark_done(self.chunks_storage) await self.meta_storage.index_done_callback() - @op("generate", deps=["partition"], op_type=OpType.BARRIER) + @op("generate", deps=["partition"], op_type=OpType.STREAMING) @async_to_sync_method - async def generate(self, generate_config: Dict, inputs: None): - + async def generate(self, generate_config: Dict, input_stream: Iterator): + # TODO: batches = self.partition_storage.data if not batches: logger.warning("No partitions found for QA generation") From 6112d838bf2d7152888d1d8a147f0ce7db1645e5 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 21 Nov 2025 14:22:35 +0800 Subject: [PATCH 12/26] feat: add search scripts --- scripts/search/search_uniprot.sh | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 scripts/search/search_uniprot.sh diff --git a/scripts/search/search_uniprot.sh b/scripts/search/search_uniprot.sh new file mode 100644 index 00000000..642040af --- /dev/null +++ b/scripts/search/search_uniprot.sh @@ -0,0 +1,3 @@ +python3 -m graphgen.run \ +--config_file graphgen/configs/search_config.yaml \ +--output_dir cache/ From a7a01553c6b0b94c3cfeac75ed7273772583e4b3 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 21 Nov 2025 18:41:03 +0800 Subject: [PATCH 13/26] style: diable pylint error for catching too general exceptions --- .pylintrc | 1 + 1 file changed, 1 insertion(+) diff --git a/.pylintrc b/.pylintrc index 45c2b04b..ce85b1ba 100644 --- a/.pylintrc +++ b/.pylintrc @@ -452,6 +452,7 @@ disable=raw-checker-failed, R0917, # Too many positional arguments (6/5) (too-many-positional-arguments) C0103, E0401, + W0703, # Catching too general exception Exception # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option From bc487fb95d9117e0e2c42b8a0b65123599a9f74b Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 21 Nov 2025 18:49:55 +0800 Subject: [PATCH 14/26] feat: add parallel scan_files --- graphgen/operators/read/read_files.py | 114 +++++++++++------ graphgen/operators/read/scan_files.py | 170 ++++++++++++++++++++++++++ 2 files changed, 246 insertions(+), 38 deletions(-) create mode 100644 graphgen/operators/read/scan_files.py diff --git a/graphgen/operators/read/read_files.py b/graphgen/operators/read/read_files.py index 0fc1c2a4..0b33e9db 100644 --- a/graphgen/operators/read/read_files.py +++ b/graphgen/operators/read/read_files.py @@ -1,6 +1,13 @@ -from pathlib import Path +from ray.data.datasource import ( + Datasource, ReadTask +) +import pyarrow as pa +from typing import List, Dict, Any, Optional, Union + from typing import Iterator, List, Optional +import ray + from graphgen.models import ( CSVReader, JSONLReader, @@ -34,47 +41,78 @@ def _build_reader(suffix: str, cache_dir: str | None): return _MAPPING[suffix](output_dir=cache_dir) return _MAPPING[suffix]() +class UnifiedFileDatasource(Datasource): + pass + def read_files( - input_file: str, + input_path: str, allowed_suffix: Optional[List[str]] = None, cache_dir: Optional[str] = None, -) -> Iterator[list[dict]]: - path = Path(input_file).expanduser() - if not path.exists(): - raise FileNotFoundError(f"[Read] input_path not found: {input_file}") + parallelism: int = 4, + **ray_kwargs, +) -> ray.data.Dataset: + """ + Reads files from the specified input path, filtering by allowed suffixes, + and returns a Ray Dataset containing the read documents. + :param input_path: input file or directory path + :param allowed_suffix: list of allowed file suffixes (e.g., ['pdf', 'txt']) + :param cache_dir: directory to cache intermediate files (used for PDF reading) + :param parallelism: number of parallel workers for reading files + :param ray_kwargs: additional keyword arguments for Ray Dataset reading + :return: Ray Dataset containing the read documents + """ + + if not ray.is_initialized(): + ray.init() + - if allowed_suffix is None: - support_suffix = set(_MAPPING.keys()) - else: - support_suffix = {s.lower().lstrip(".") for s in allowed_suffix} + return ray.data.read_datasource( + UnifiedFileDatasource( + paths=[input_path], + allowed_suffix=allowed_suffix, + cache_dir=cache_dir, + **ray_kwargs, # Pass additional Ray kwargs here + ), + parallelism=parallelism, + ) - # single file - if path.is_file(): - suffix = path.suffix.lstrip(".").lower() - if suffix not in support_suffix: - logger.warning( - "[Read] Skip file %s (suffix '%s' not in allowed_suffix %s)", - path, - suffix, - support_suffix, - ) - return - reader = _build_reader(suffix, cache_dir) - logger.info("[Read] Reading file %s", path) - yield reader.read(str(path)) - return - # folder - logger.info("[Read] Streaming directory %s", path) - for p in path.rglob("*"): - if p.is_file() and p.suffix.lstrip(".").lower() in support_suffix: - try: - suffix = p.suffix.lstrip(".").lower() - reader = _build_reader(suffix, cache_dir) - logger.info("[Reader] Reading file %s", p) - docs = reader.read(str(p)) - if docs: - yield docs - except Exception: # pylint: disable=broad-except - logger.exception("[Reader] Error reading %s", p) + # path = Path(input_file).expanduser() + # if not path.exists(): + # raise FileNotFoundError(f"[Read] input_path not found: {input_file}") + # + # if allowed_suffix is None: + # support_suffix = set(_MAPPING.keys()) + # else: + # support_suffix = {s.lower().lstrip(".") for s in allowed_suffix} + # + # # single file + # if path.is_file(): + # suffix = path.suffix.lstrip(".").lower() + # if suffix not in support_suffix: + # logger.warning( + # "[Read] Skip file %s (suffix '%s' not in allowed_suffix %s)", + # path, + # suffix, + # support_suffix, + # ) + # return + # reader = _build_reader(suffix, cache_dir) + # logger.info("[Read] Reading file %s", path) + # yield reader.read(str(path)) + # return + # + # # folder + # logger.info("[Read] Streaming directory %s", path) + # for p in path.rglob("*"): + # if p.is_file() and p.suffix.lstrip(".").lower() in support_suffix: + # try: + # suffix = p.suffix.lstrip(".").lower() + # reader = _build_reader(suffix, cache_dir) + # logger.info("[Reader] Reading file %s", p) + # docs = reader.read(str(p)) + # if docs: + # yield docs + # except Exception: # pylint: disable=broad-except + # logger.exception("[Reader] Error reading %s", p) diff --git a/graphgen/operators/read/scan_files.py b/graphgen/operators/read/scan_files.py new file mode 100644 index 00000000..2da0a60d --- /dev/null +++ b/graphgen/operators/read/scan_files.py @@ -0,0 +1,170 @@ +import os +import time +from typing import List, Dict, Any, Set, Union +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed +from diskcache import Cache +from graphgen.utils import logger + +class ParallelDirScanner: + def __init__(self, + cache_dir: str, + allowed_suffix, + rescan: bool = False, + max_workers: int = 4 + ): + self.cache = Cache(cache_dir) + self.allowed_suffix = set(allowed_suffix) if allowed_suffix else None + self.rescan = rescan + self.max_workers = max_workers + + def scan(self, paths: Union[str, List[str]], recursive: bool = True) -> Dict[str, Any]: + if isinstance(paths, str): + paths = [paths] + + results = {} + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + future_to_path = { + executor.submit(self._scan_dir, Path(p).resolve(), recursive, set()): p + for p in paths if os.path.exists(p) + } + + for future in as_completed(future_to_path): + path = future_to_path[future] + try: + results[path] = future.result() + except Exception as e: + logger.error("Error scanning path %s: %s", path, e) + results[path] = {'error': str(e), 'files': [], 'dirs': [], 'stats': {}} + + return results + + def _scan_dir(self, path: Path, recursive: bool, visited: Set[str]) -> Dict[str, Any]: + path_str = str(path) + + # Avoid cycles due to symlinks + if path_str in visited: + logger.warning("Skipping already visited path: %s", path_str) + return self._empty_result(path_str) + + # cache check + cache_key = f"scan::{path_str}::recursive::{recursive}" + cached = self.cache.get(cache_key) + if cached and not self.rescan: + logger.info("Using cached scan result for path: %s", path_str) + return cached['data'] + + logger.info("Scanning path: %s", path_str) + files, dirs = [], [] + stats = {'total_size': 0, 'file_count': 0, 'dir_count': 0, 'errors': 0} + + try: + with os.scandir(path_str) as entries: + for entry in entries: + try: + entry_stat = entry.stat(follow_symlinks=False) + + if entry.is_dir(): + dirs.append({ + 'path': entry.path, + 'name': entry.name, + 'mtime': entry_stat.st_mtime + }) + stats['dir_count'] += 1 + else: + # allowed suffix filter + if self.allowed_suffix: + suffix = Path(entry.name).suffix.lower() + if suffix not in self.allowed_suffix: + continue + + files.append({ + 'path': entry.path, + 'name': entry.name, + 'size': entry_stat.st_size, + 'mtime': entry_stat.st_mtime + }) + stats['total_size'] += entry_stat.st_size + stats['file_count'] += 1 + + except OSError: + stats['errors'] += 1 + + except (PermissionError, FileNotFoundError, OSError) as e: + logger.error("Failed to scan directory %s: %s", path_str, e) + return {'error': str(e), 'files': [], 'dirs': [], 'stats': stats} + + if recursive: + sub_visited = visited | {path_str} + sub_results = self._scan_subdirs(dirs, sub_visited) + + for sub_data in sub_results.values(): + files.extend(sub_data.get('files', [])) + stats['total_size'] += sub_data['stats'].get('total_size', 0) + stats['file_count'] += sub_data['stats'].get('file_count', 0) + + result = {'path': path_str, 'files': files, 'dirs': dirs, 'stats': stats} + self._cache_result(cache_key, result, path) + return result + + def _scan_subdirs(self, dir_list: List[Dict], visited: Set[str]) -> Dict[str, Any]: + """ + Parallel scan subdirectories + :param dir_list + :param visited + :return: + """ + results = {} + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(self._scan_dir, Path(d['path']), True, visited): d['path'] + for d in dir_list + } + + for future in as_completed(futures): + path = futures[future] + try: + results[path] = future.result() + except Exception as e: + logger.error("Error scanning subdirectory %s: %s", path, e) + results[path] = {'error': str(e), 'files': [], 'dirs': [], 'stats': {}} + + return results + + def _cache_result(self, key: str, result: Dict, path: Path): + """Cache the scan result""" + try: + self.cache.set(key, { + 'data': result, + 'dir_mtime': path.stat().st_mtime, + 'cached_at': time.time() + }) + logger.info(f"Cached scan result for: {path}") + except OSError: + pass + + def invalidate(self, path: str): + """Invalidate cache for a specific path""" + path = Path(path).resolve() + keys = [k for k in self.cache if k.startswith(f"scan:{path}")] + for k in keys: + self.cache.delete(k) + logger.info(f"Invalidated cache for path: {path}") + + def close(self): + self.cache.close() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + @staticmethod + def _empty_result(path: str) -> Dict[str, Any]: + return { + 'path': path, + 'files': [], + 'dirs': [], + 'stats': {'total_size': 0, 'file_count': 0, 'dir_count': 0, 'errors': 0} + } From bd2f7c471b1604605c39451494788e93889e19c2 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 21 Nov 2025 20:26:32 +0800 Subject: [PATCH 15/26] refactor: refactor txt_reader using ray data --- graphgen/bases/base_reader.py | 81 ++++++++++++++-------------- graphgen/models/reader/txt_reader.py | 33 ++++++++++-- 2 files changed, 68 insertions(+), 46 deletions(-) diff --git a/graphgen/bases/base_reader.py b/graphgen/bases/base_reader.py index 89778469..37e78414 100644 --- a/graphgen/bases/base_reader.py +++ b/graphgen/bases/base_reader.py @@ -1,8 +1,9 @@ import os from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import requests +from ray.data import Dataset class BaseReader(ABC): @@ -14,52 +15,50 @@ def __init__(self, text_column: str = "content"): self.text_column = text_column @abstractmethod - def read(self, file_path: str) -> List[Dict[str, Any]]: + def read(self, input_path: Union[str, List[str]]) -> Dataset: """ Read data from the specified file path. - :param file_path: Path to the input file. - :return: List of dictionaries containing the data. + :param input_path: Path to the input file or list of file paths. + :return: Ray Dataset containing the read data. """ - @staticmethod - def filter(data: List[dict]) -> List[dict]: + def _should_keep_item(self, item: Dict[str, Any]) -> bool: """ - Filter out entries with empty or missing text in the specified column. + Determine whether to keep the given item based on the text column. - :param data: List of dictionaries containing the data. - :return: Filtered list of dictionaries. + :param item: Dictionary representing a data entry. + :return: True if the item should be kept, False otherwise. """ + item_type = item.get("type") + assert item_type in [ + "text", + "image", + "table", + "equation", + "protein", + ], f"Unsupported item type: {item_type}" + if item_type == "text": + content = item.get(self.text_column, "").strip() + return bool(content) + return True - def _image_exists(path_or_url: str, timeout: int = 3) -> bool: - """ - Check if an image exists at the given local path or URL. - :param path_or_url: Local file path or remote URL of the image. - :param timeout: Timeout for remote URL requests in seconds. - :return: True if the image exists, False otherwise. - """ - if not path_or_url: - return False - if not path_or_url.startswith(("http://", "https://", "ftp://")): - path = path_or_url.replace("file://", "", 1) - path = os.path.abspath(path) - return os.path.isfile(path) - try: - resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout) - return resp.status_code == 200 - except requests.RequestException: - return False - - filtered_data = [] - for item in data: - if item.get("type") == "text": - content = item.get("content", "").strip() - if content: - filtered_data.append(item) - elif item.get("type") in ("image", "table", "equation"): - img_path = item.get("img_path") - if _image_exists(img_path): - filtered_data.append(item) - else: - filtered_data.append(item) - return filtered_data + @staticmethod + def _image_exists(path_or_url: str, timeout: int = 3) -> bool: + """ + Check if an image exists at the given local path or URL. + :param path_or_url: Local file path or remote URL of the image. + :param timeout: Timeout for remote URL requests in seconds. + :return: True if the image exists, False otherwise. + """ + if not path_or_url: + return False + if not path_or_url.startswith(("http://", "https://", "ftp://")): + path = path_or_url.replace("file://", "", 1) + path = os.path.abspath(path) + return os.path.isfile(path) + try: + resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout) + return resp.status_code == 200 + except requests.RequestException: + return False diff --git a/graphgen/models/reader/txt_reader.py b/graphgen/models/reader/txt_reader.py index ec2ff747..3ad33a73 100644 --- a/graphgen/models/reader/txt_reader.py +++ b/graphgen/models/reader/txt_reader.py @@ -1,10 +1,33 @@ -from typing import Any, Dict, List +from typing import List, Union + +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader class TXTReader(BaseReader): - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "r", encoding="utf-8") as f: - docs = [{"type": "text", self.text_column: f.read()}] - return self.filter(docs) + def read( + self, + input_path: Union[str, List[str]], + override_num_blocks: int = 4, + ) -> Dataset: + """ + Read text files from the specified input path. + :param input_path: Path to the input text file or list of text files. + :param override_num_blocks: Number of blocks to override for Ray Dataset reading. + :return: Ray Dataset containing the read text data. + """ + docs_ds = ray.data.read_text( + input_path, encoding="utf-8", override_num_blocks=override_num_blocks + ) + + docs_ds = docs_ds.map( + lambda row: { + "type": "text", + self.text_column: row["text"], + } + ) + + docs_ds = docs_ds.filter(self._should_keep_item) + return docs_ds From 0422bd02b9a7fe32b135b06b9df69fa5aed6c9ab Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 21 Nov 2025 20:43:46 +0800 Subject: [PATCH 16/26] refactor: refactor csv_reader using ray data --- graphgen/bases/base_reader.py | 16 ++++++++++++++ graphgen/models/reader/csv_reader.py | 33 ++++++++++++++++++---------- requirements.txt | 2 ++ 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/graphgen/bases/base_reader.py b/graphgen/bases/base_reader.py index 37e78414..91d55fcd 100644 --- a/graphgen/bases/base_reader.py +++ b/graphgen/bases/base_reader.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Union +import pandas as pd import requests from ray.data import Dataset @@ -43,6 +44,21 @@ def _should_keep_item(self, item: Dict[str, Any]) -> bool: return bool(content) return True + def _validate_batch(self, batch: pd.DataFrame) -> pd.DataFrame: + """ + Validate data format. + """ + if "type" not in batch.columns: + raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}") + + if "text" in batch["type"].values: + if self.text_column not in batch.columns: + raise ValueError( + f"Missing '{self.text_column}' column for text documents" + ) + + return batch + @staticmethod def _image_exists(path_or_url: str, timeout: int = 3) -> bool: """ diff --git a/graphgen/models/reader/csv_reader.py b/graphgen/models/reader/csv_reader.py index bc865a3b..e1e14ffe 100644 --- a/graphgen/models/reader/csv_reader.py +++ b/graphgen/models/reader/csv_reader.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List +from typing import List, Union -import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader @@ -13,13 +14,23 @@ class CSVReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: + def read( + self, + input_path: Union[str, List[str]], + override_num_blocks: int = None, + ) -> Dataset: + """ + Read CSV files and return Ray Dataset. - df = pd.read_csv(file_path) - for _, row in df.iterrows(): - assert "type" in row, f"Missing 'type' column in document: {row.to_dict()}" - if row["type"] == "text" and self.text_column not in row: - raise ValueError( - f"Missing '{self.text_column}' in document: {row.to_dict()}" - ) - return self.filter(df.to_dict(orient="records")) + :param input_path: Path to CSV file or list of CSV files. + :param override_num_blocks: Number of blocks for Ray Dataset reading. + :return: Ray Dataset containing validated and filtered data. + """ + + ds = ray.data.read_csv(input_path, override_num_blocks=override_num_blocks) + + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + + ds = ds.filter(self._should_keep_item) + + return ds diff --git a/requirements.txt b/requirements.txt index fd824606..881bf280 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,8 @@ requests fastapi trafilatura aiohttp +ray +diskcache leidenalg igraph From 36e80ef8b5c9ebd1d3c52fd57aa376b34a1c321b Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 21 Nov 2025 21:01:11 +0800 Subject: [PATCH 17/26] refactor: refactor json_reader using ray data --- graphgen/models/reader/csv_reader.py | 3 --- graphgen/models/reader/json_reader.py | 34 ++++++++++++++++----------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/graphgen/models/reader/csv_reader.py b/graphgen/models/reader/csv_reader.py index e1e14ffe..ccd136ad 100644 --- a/graphgen/models/reader/csv_reader.py +++ b/graphgen/models/reader/csv_reader.py @@ -28,9 +28,6 @@ def read( """ ds = ray.data.read_csv(input_path, override_num_blocks=override_num_blocks) - ds = ds.map_batches(self._validate_batch, batch_format="pandas") - ds = ds.filter(self._should_keep_item) - return ds diff --git a/graphgen/models/reader/json_reader.py b/graphgen/models/reader/json_reader.py index 8253041c..a8fe4bc5 100644 --- a/graphgen/models/reader/json_reader.py +++ b/graphgen/models/reader/json_reader.py @@ -1,5 +1,7 @@ -import json -from typing import Any, Dict, List +from typing import List, Union + +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader @@ -12,15 +14,19 @@ class JSONReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "r", encoding="utf-8") as f: - data = json.load(f) - if isinstance(data, list): - for doc in data: - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError( - f"Missing '{self.text_column}' in document: {doc}" - ) - return self.filter(data) - raise ValueError("JSON file must contain a list of documents.") + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = 4, + ) -> Dataset: + """ + Read JSON file and return Ray Dataset. + :param input_path: Path to JSON file or list of JSON files. + :param parallelism: Number of parallel workers for reading files. + :return: Ray Dataset containing validated and filtered data. + """ + + ds = ray.data.read_json(input_path, parallelism=parallelism) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds From db8252ca23eb10f3e78379607f5d11287ba81fed Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 21 Nov 2025 21:07:28 +0800 Subject: [PATCH 18/26] refactor: refactor parquet_reader using ray data --- graphgen/models/reader/json_reader.py | 4 +-- graphgen/models/reader/jsonl_reader.py | 30 ----------------------- graphgen/models/reader/parquet_reader.py | 31 ++++++++++++++++-------- 3 files changed, 23 insertions(+), 42 deletions(-) delete mode 100644 graphgen/models/reader/jsonl_reader.py diff --git a/graphgen/models/reader/json_reader.py b/graphgen/models/reader/json_reader.py index a8fe4bc5..cd18e4f5 100644 --- a/graphgen/models/reader/json_reader.py +++ b/graphgen/models/reader/json_reader.py @@ -8,7 +8,7 @@ class JSONReader(BaseReader): """ - Reader for JSON files. + Reader for JSON and JSONL files. Columns: - type: The type of the document (e.g., "text", "image", etc.) - if type is "text", "content" column must be present. @@ -21,7 +21,7 @@ def read( ) -> Dataset: """ Read JSON file and return Ray Dataset. - :param input_path: Path to JSON file or list of JSON files. + :param input_path: Path to JSON/JSONL file or list of JSON/JSONL files. :param parallelism: Number of parallel workers for reading files. :return: Ray Dataset containing validated and filtered data. """ diff --git a/graphgen/models/reader/jsonl_reader.py b/graphgen/models/reader/jsonl_reader.py deleted file mode 100644 index 31bc3195..00000000 --- a/graphgen/models/reader/jsonl_reader.py +++ /dev/null @@ -1,30 +0,0 @@ -import json -from typing import Any, Dict, List - -from graphgen.bases.base_reader import BaseReader -from graphgen.utils import logger - - -class JSONLReader(BaseReader): - """ - Reader for JSONL files. - Columns: - - type: The type of the document (e.g., "text", "image", etc.) - - if type is "text", "content" column must be present. - """ - - def read(self, file_path: str) -> List[Dict[str, Any]]: - docs = [] - with open(file_path, "r", encoding="utf-8") as f: - for line in f: - try: - doc = json.loads(line) - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError( - f"Missing '{self.text_column}' in document: {doc}" - ) - docs.append(doc) - except json.JSONDecodeError as e: - logger.error("Error decoding JSON line: %s. Error: %s", line, e) - return self.filter(docs) diff --git a/graphgen/models/reader/parquet_reader.py b/graphgen/models/reader/parquet_reader.py index a325b876..e521e6d8 100644 --- a/graphgen/models/reader/parquet_reader.py +++ b/graphgen/models/reader/parquet_reader.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List +from typing import List, Union -import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader @@ -13,12 +14,22 @@ class ParquetReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - df = pd.read_parquet(file_path) - data: List[Dict[str, Any]] = df.to_dict(orient="records") + def read( + self, + input_path: Union[str, List[str]], + override_num_blocks: int = None, + ) -> Dataset: + """ + Read Parquet files using Ray Data. - for doc in data: - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError(f"Missing '{self.text_column}' in document: {doc}") - return self.filter(data) + :param input_path: Path to Parquet file or list of Parquet files. + :param override_num_blocks: Number of blocks for Ray Dataset reading. + :return: Ray Dataset containing validated documents. + """ + if not ray.is_initialized(): + ray.init() + + ds = ray.data.read_parquet(input_path, override_num_blocks=override_num_blocks) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds From 97f7e75de70088dfdc987de91743ec5a11fec36f Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 21 Nov 2025 23:01:29 +0800 Subject: [PATCH 19/26] refactor: refactor pdf_reader using ray data --- graphgen/models/__init__.py | 1 - graphgen/models/reader/__init__.py | 1 - graphgen/models/reader/pdf_reader.py | 75 ++++++++++++++++++++-------- 3 files changed, 55 insertions(+), 22 deletions(-) diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 68fd2a5d..e494631c 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -18,7 +18,6 @@ ) from .reader import ( CSVReader, - JSONLReader, JSONReader, ParquetReader, PDFReader, diff --git a/graphgen/models/reader/__init__.py b/graphgen/models/reader/__init__.py index 600ffb4a..220460c3 100644 --- a/graphgen/models/reader/__init__.py +++ b/graphgen/models/reader/__init__.py @@ -1,6 +1,5 @@ from .csv_reader import CSVReader from .json_reader import JSONReader -from .jsonl_reader import JSONLReader from .parquet_reader import ParquetReader from .pdf_reader import PDFReader from .pickle_reader import PickleReader diff --git a/graphgen/models/reader/pdf_reader.py b/graphgen/models/reader/pdf_reader.py index 94562cb5..dd685475 100644 --- a/graphgen/models/reader/pdf_reader.py +++ b/graphgen/models/reader/pdf_reader.py @@ -5,6 +5,9 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union +import ray +from ray.data import Dataset + from graphgen.bases.base_reader import BaseReader from graphgen.models.reader.txt_reader import TXTReader from graphgen.utils import logger, pick_device @@ -62,19 +65,32 @@ def __init__( self.parser = MinerUParser() self.txt_reader = TXTReader() - def read(self, file_path: str, **override) -> List[Dict[str, Any]]: - """ - file_path - **override: override MinerU parameters - """ - pdf_path = Path(file_path).expanduser().resolve() - if not pdf_path.is_file(): - raise FileNotFoundError(pdf_path) + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = 4, + **override, + ) -> Dataset: + + # Ensure input_path is a list + if isinstance(input_path, str): + input_path = [input_path] - kwargs = {**self._default_kwargs, **override} + paths_ds = ray.data.from_items(input_path) - mineru_result = self._call_mineru(pdf_path, kwargs) - return self.filter(mineru_result) + def process_pdf(row: Dict[str, Any]) -> List[Dict[str, Any]]: + try: + pdf_path = row["item"] + kwargs = {**self._default_kwargs, **override} + return self._call_mineru(Path(pdf_path), kwargs) + except Exception as e: + logger.error("Failed to process %s: %s", row, e) + return [] + + docs_ds = paths_ds.flat_map(process_pdf) + docs_ds = docs_ds.filter(self._should_keep_item) + + return docs_ds def _call_mineru( self, pdf_path: Path, kwargs: Dict[str, Any] @@ -161,18 +177,18 @@ def _try_load_cached_result( base = os.path.dirname(json_file) results = [] - for item in data: + for it in data: for key in ("img_path", "table_img_path", "equation_img_path"): - rel_path = item.get(key) + rel_path = it.get(key) if rel_path: - item[key] = str(Path(base).joinpath(rel_path).resolve()) - if item["type"] == "text": - item["content"] = item["text"] - del item["text"] + it[key] = str(Path(base).joinpath(rel_path).resolve()) + if it["type"] == "text": + it["content"] = it["text"] + del it["text"] for key in ("page_idx", "bbox", "text_level"): - if item.get(key) is not None: - del item[key] - results.append(item) + if it.get(key) is not None: + del it[key] + results.append(it) return results @staticmethod @@ -231,3 +247,22 @@ def _check_bin() -> None: "MinerU is not installed or not found in PATH. Please install it from pip: \n" "pip install -U 'mineru[core]'" ) from exc + + +if __name__ == "__main__": + reader = PDFReader( + output_dir="./output", + method="auto", + backend="pipeline", + device="cpu", + lang="en", + formula=True, + table=True, + ) + dataset = reader.read( + "/home/PJLAB/chenzihong/Project/graphgen/resources/input_examples/pdf_demo.pdf", + parallelism=2, + ) + + for item in dataset.take_all(): + print(item) From ac99aa857ae4e751cd8c4f1aef2d84be3e069065 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Sat, 22 Nov 2025 14:21:36 +0800 Subject: [PATCH 20/26] refactor: refactor pickle_reader using ray data --- graphgen/models/reader/pdf_reader.py | 19 ------ graphgen/models/reader/pickle_reader.py | 84 ++++++++++++++++++++----- 2 files changed, 68 insertions(+), 35 deletions(-) diff --git a/graphgen/models/reader/pdf_reader.py b/graphgen/models/reader/pdf_reader.py index dd685475..9d5c7c27 100644 --- a/graphgen/models/reader/pdf_reader.py +++ b/graphgen/models/reader/pdf_reader.py @@ -247,22 +247,3 @@ def _check_bin() -> None: "MinerU is not installed or not found in PATH. Please install it from pip: \n" "pip install -U 'mineru[core]'" ) from exc - - -if __name__ == "__main__": - reader = PDFReader( - output_dir="./output", - method="auto", - backend="pipeline", - device="cpu", - lang="en", - formula=True, - table=True, - ) - dataset = reader.read( - "/home/PJLAB/chenzihong/Project/graphgen/resources/input_examples/pdf_demo.pdf", - parallelism=2, - ) - - for item in dataset.take_all(): - print(item) diff --git a/graphgen/models/reader/pickle_reader.py b/graphgen/models/reader/pickle_reader.py index 1a11dc11..1af7ca77 100644 --- a/graphgen/models/reader/pickle_reader.py +++ b/graphgen/models/reader/pickle_reader.py @@ -1,30 +1,82 @@ import pickle -from typing import Any, Dict, List +from typing import List, Union + +import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader +from graphgen.utils import logger class PickleReader(BaseReader): """ - Read pickle files, requiring the top-level object to be List[Dict[str, Any]]. - - Columns: + Read pickle files, requiring the schema to be restored to List[Dict[str, Any]]. + Each pickle file should contain a list of dictionaries with at least: - type: The type of the document (e.g., "text", "image", etc.) - if type is "text", "content" column must be present. + + Note: Uses ray.data.read_binary_files as ray.data.read_pickle is not available. + For Ray >= 2.5, consider using read_pickle if available in your version. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "rb") as f: - data = pickle.load(f) + def read( + self, + input_path: Union[str, List[str]], + override_num_blocks: int = None, + ) -> Dataset: + """ + Read Pickle files using Ray Data. + + :param input_path: Path to pickle file or list of pickle files. + :param override_num_blocks: Number of blocks for Ray Dataset reading. + :return: Ray Dataset containing validated documents. + """ + if not ray.is_initialized(): + ray.init() + + # Use read_binary_files as a reliable alternative to read_pickle + ds = ray.data.read_binary_files( + input_path, override_num_blocks=override_num_blocks, include_paths=True + ) + + # Deserialize pickle files and flatten into individual records + def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame: + all_records = [] + for _, row in batch.iterrows(): + try: + # Load pickle data from bytes + data = pickle.loads(row["bytes"]) + + # Validate structure + if not isinstance(data, list): + logger.error( + "Pickle file {row['path']} must contain a list, got {type(data)}" + ) + continue + + if not all(isinstance(item, dict) for item in data): + logger.error( + "Pickle file {row['path']} must contain a list of dictionaries" + ) + continue + + # Flatten: each dict in the list becomes a separate row + all_records.extend(data) + except Exception as e: + logger.error( + "Failed to deserialize pickle file %s: %s", row["path"], str(e) + ) + continue + + return pd.DataFrame(all_records) - if not isinstance(data, list): - raise ValueError("Pickle file must contain a list of documents.") + # Apply deserialization and flattening + ds = ds.map_batches(deserialize_batch, batch_format="pandas") - for doc in data: - if not isinstance(doc, dict): - raise ValueError("Every item in the list must be a dict.") - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError(f"Missing '{self.text_column}' in document: {doc}") + # Validate the schema + ds = ds.map_batches(self._validate_batch, batch_format="pandas") - return self.filter(data) + # Filter valid items + ds = ds.filter(self._should_keep_item) + return ds From 3d9185a70f9f6448651f07382411f5238370b2a9 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Sat, 22 Nov 2025 14:53:12 +0800 Subject: [PATCH 21/26] refactor: refactor rdf_reader using ray data --- graphgen/models/reader/rdf_reader.py | 108 +++++++++++++++++++++++---- 1 file changed, 95 insertions(+), 13 deletions(-) diff --git a/graphgen/models/reader/rdf_reader.py b/graphgen/models/reader/rdf_reader.py index cce167c1..406478f5 100644 --- a/graphgen/models/reader/rdf_reader.py +++ b/graphgen/models/reader/rdf_reader.py @@ -1,48 +1,130 @@ -from typing import Any, Dict, List +from pathlib import Path +from typing import Any, Dict, List, Union +import ray import rdflib +from ray.data import Dataset from rdflib import Literal from rdflib.util import guess_format from graphgen.bases.base_reader import BaseReader +from graphgen.utils import logger class RDFReader(BaseReader): """ Reader for RDF files that extracts triples and represents them as dictionaries. + + Uses Ray Data for distributed processing of multiple RDF files. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: + def __init__(self, *, text_column: str = "content", **kwargs): + """ + Initialize RDFReader. + + :param text_column: The column name for text content (default: "content"). + """ + super().__init__(**kwargs) + self.text_column = text_column + + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = 4, + ) -> Dataset: + """ + Read RDF file(s) using Ray Data. + + :param input_path: Path to RDF file or list of RDF files. + :param parallelism: Number of parallel workers for processing. + :return: Ray Dataset containing extracted documents. + """ + if not ray.is_initialized(): + ray.init() + + # Ensure input_path is a list to prevent Ray from splitting string into characters + if isinstance(input_path, str): + input_path = [input_path] + + # Create dataset from file paths + paths_ds = ray.data.from_items(input_path) + + def process_rdf(row: Dict[str, Any]) -> List[Dict[str, Any]]: + """Process a single RDF file and return list of documents.""" + try: + file_path = row["item"] + return self._parse_rdf_file(Path(file_path)) + except Exception as e: + logger.error( + "Failed to process RDF file %s: %s", row.get("item", "unknown"), e + ) + return [] + + # Process files in parallel and flatten results + docs_ds = paths_ds.flat_map(process_rdf) + + # Filter valid documents + docs_ds = docs_ds.filter(self._should_keep_item) + + return docs_ds + + def _parse_rdf_file(self, file_path: Path) -> List[Dict[str, Any]]: + """ + Parse a single RDF file and extract documents. + + :param file_path: Path to RDF file. + :return: List of document dictionaries. + """ + if not file_path.is_file(): + raise FileNotFoundError(f"RDF file not found: {file_path}") + g = rdflib.Graph() - fmt = guess_format(file_path) + fmt = guess_format(str(file_path)) + try: - g.parse(file_path, format=fmt) + g.parse(str(file_path), format=fmt) except Exception as e: raise ValueError(f"Cannot parse RDF file {file_path}: {e}") from e docs: List[Dict[str, Any]] = [] - text_col = self.text_column + # Process each unique subject in the RDF graph for subj in set(g.subjects()): literals = [] props = {} + + # Extract all triples for this subject for _, pred, obj in g.triples((subj, None, None)): pred_str = str(pred) + obj_str = str(obj) + + # Collect literal values as text content if isinstance(obj, Literal): - literals.append(str(obj)) - props.setdefault(pred_str, []).append(str(obj)) + literals.append(obj_str) + + # Store all properties (including non-literals) + props.setdefault(pred_str, []).append(obj_str) + # Join all literal values as the text content text = " ".join(literals).strip() if not text: - raise ValueError( - f"Subject {subj} has no literal values; " - f"missing '{text_col}' for text column." + logger.warning( + "Subject %s in %s has no literal values; document will have empty '%s' field.", + subj, + file_path, + self.text_column, ) - doc = {"id": str(subj), text_col: text, "properties": props} + # Create document dictionary + doc = { + "id": str(subj), + self.text_column: text, + "properties": props, + "source_file": str(file_path), + } docs.append(doc) if not docs: - raise ValueError("RDF file contains no valid documents.") + logger.warning("RDF file %s contains no valid documents.", file_path) - return self.filter(docs) + return docs From d5924f0d2ed8bb82675a6c10d77156f23c59f3c2 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 24 Nov 2025 13:48:27 +0800 Subject: [PATCH 22/26] fix: fix scanning file path --- .../operators/read/parallel_file_scanner.py | 231 ++++++++++++++++++ graphgen/operators/read/scan_files.py | 170 ------------- 2 files changed, 231 insertions(+), 170 deletions(-) create mode 100644 graphgen/operators/read/parallel_file_scanner.py delete mode 100644 graphgen/operators/read/scan_files.py diff --git a/graphgen/operators/read/parallel_file_scanner.py b/graphgen/operators/read/parallel_file_scanner.py new file mode 100644 index 00000000..499cc224 --- /dev/null +++ b/graphgen/operators/read/parallel_file_scanner.py @@ -0,0 +1,231 @@ +import os +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any, Dict, List, Set, Union + +from diskcache import Cache + +from graphgen.utils import logger + + +class ParallelFileScanner: + def __init__( + self, cache_dir: str, allowed_suffix, rescan: bool = False, max_workers: int = 4 + ): + self.cache = Cache(cache_dir) + self.allowed_suffix = set(allowed_suffix) if allowed_suffix else None + self.rescan = rescan + self.max_workers = max_workers + + def scan( + self, paths: Union[str, List[str]], recursive: bool = True + ) -> Dict[str, Any]: + if isinstance(paths, str): + paths = [paths] + + results = {} + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + future_to_path = {} + for p in paths: + if os.path.exists(p): + future = executor.submit( + self._scan_files, Path(p).resolve(), recursive, set() + ) + future_to_path[future] = p + else: + logger.warning("[READ] Path does not exist: %s", p) + + for future in as_completed(future_to_path): + path = future_to_path[future] + try: + results[path] = future.result() + except Exception as e: + logger.error("[READ] Error scanning path %s: %s", path, e) + results[path] = { + "error": str(e), + "files": [], + "dirs": [], + "stats": {}, + } + return results + + def _scan_files( + self, path: Path, recursive: bool, visited: Set[str] + ) -> Dict[str, Any]: + path_str = str(path) + + # Avoid cycles due to symlinks + if path_str in visited: + logger.warning("[READ] Skipping already visited path: %s", path_str) + return self._empty_result(path_str) + + # cache check + cache_key = f"scan::{path_str}::recursive::{recursive}" + cached = self.cache.get(cache_key) + if cached and not self.rescan: + logger.info("[READ] Using cached scan result for path: %s", path_str) + return cached["data"] + + logger.info("[READ] Scanning path: %s", path_str) + files, dirs = [], [] + stats = {"total_size": 0, "file_count": 0, "dir_count": 0, "errors": 0} + + try: + path_stat = path.stat() + if path.is_file(): + return self._scan_single_file(path, path_str, path_stat) + if path.is_dir(): + with os.scandir(path_str) as entries: + for entry in entries: + try: + entry_stat = entry.stat(follow_symlinks=False) + + if entry.is_dir(): + dirs.append( + { + "path": entry.path, + "name": entry.name, + "mtime": entry_stat.st_mtime, + } + ) + stats["dir_count"] += 1 + else: + # allowed suffix filter + if not self._is_allowed_file(Path(entry.path)): + continue + files.append( + { + "path": entry.path, + "name": entry.name, + "size": entry_stat.st_size, + "mtime": entry_stat.st_mtime, + } + ) + stats["total_size"] += entry_stat.st_size + stats["file_count"] += 1 + + except OSError: + stats["errors"] += 1 + + except (PermissionError, FileNotFoundError, OSError) as e: + logger.error("[READ] Failed to scan path %s: %s", path_str, e) + return {"error": str(e), "files": [], "dirs": [], "stats": stats} + + if recursive: + sub_visited = visited | {path_str} + sub_results = self._scan_subdirs(dirs, sub_visited) + + for sub_data in sub_results.values(): + files.extend(sub_data.get("files", [])) + stats["total_size"] += sub_data["stats"].get("total_size", 0) + stats["file_count"] += sub_data["stats"].get("file_count", 0) + + result = {"path": path_str, "files": files, "dirs": dirs, "stats": stats} + self._cache_result(cache_key, result, path) + return result + + def _scan_single_file( + self, path: Path, path_str: str, stat: os.stat_result + ) -> Dict[str, Any]: + """Scan a single file and return its metadata""" + if not self._is_allowed_file(path): + return self._empty_result(path_str) + + return { + "path": path_str, + "files": [ + { + "path": path_str, + "name": path.name, + "size": stat.st_size, + "mtime": stat.st_mtime, + } + ], + "dirs": [], + "stats": { + "total_size": stat.st_size, + "file_count": 1, + "dir_count": 0, + "errors": 0, + }, + } + + def _scan_subdirs(self, dir_list: List[Dict], visited: Set[str]) -> Dict[str, Any]: + """ + Parallel scan subdirectories + :param dir_list + :param visited + :return: + """ + results = {} + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(self._scan_files, Path(d["path"]), True, visited): d[ + "path" + ] + for d in dir_list + } + + for future in as_completed(futures): + path = futures[future] + try: + results[path] = future.result() + except Exception as e: + logger.error("[READ] Error scanning subdirectory %s: %s", path, e) + results[path] = { + "error": str(e), + "files": [], + "dirs": [], + "stats": {}, + } + + return results + + def _cache_result(self, key: str, result: Dict, path: Path): + """Cache the scan result""" + try: + self.cache.set( + key, + { + "data": result, + "dir_mtime": path.stat().st_mtime, + "cached_at": time.time(), + }, + ) + logger.info("[READ] Cached scan result for path: %s", path) + except OSError: + pass + + def _is_allowed_file(self, path: Path) -> bool: + """Check if the file has an allowed suffix""" + if self.allowed_suffix is None: + return True + suffix = path.suffix.lower().lstrip(".") + return suffix in self.allowed_suffix + + def invalidate(self, path: str): + """Invalidate cache for a specific path""" + path = Path(path).resolve() + keys = [k for k in self.cache if k.startswith(f"scan::{path}")] + for k in keys: + self.cache.delete(k) + logger.info("[READ] Invalidated cache for path: %s", path) + + def close(self): + self.cache.close() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + @staticmethod + def _empty_result(path: str) -> Dict[str, Any]: + return { + "path": path, + "files": [], + "dirs": [], + "stats": {"total_size": 0, "file_count": 0, "dir_count": 0, "errors": 0}, + } diff --git a/graphgen/operators/read/scan_files.py b/graphgen/operators/read/scan_files.py deleted file mode 100644 index 2da0a60d..00000000 --- a/graphgen/operators/read/scan_files.py +++ /dev/null @@ -1,170 +0,0 @@ -import os -import time -from typing import List, Dict, Any, Set, Union -from pathlib import Path -from concurrent.futures import ThreadPoolExecutor, as_completed -from diskcache import Cache -from graphgen.utils import logger - -class ParallelDirScanner: - def __init__(self, - cache_dir: str, - allowed_suffix, - rescan: bool = False, - max_workers: int = 4 - ): - self.cache = Cache(cache_dir) - self.allowed_suffix = set(allowed_suffix) if allowed_suffix else None - self.rescan = rescan - self.max_workers = max_workers - - def scan(self, paths: Union[str, List[str]], recursive: bool = True) -> Dict[str, Any]: - if isinstance(paths, str): - paths = [paths] - - results = {} - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - future_to_path = { - executor.submit(self._scan_dir, Path(p).resolve(), recursive, set()): p - for p in paths if os.path.exists(p) - } - - for future in as_completed(future_to_path): - path = future_to_path[future] - try: - results[path] = future.result() - except Exception as e: - logger.error("Error scanning path %s: %s", path, e) - results[path] = {'error': str(e), 'files': [], 'dirs': [], 'stats': {}} - - return results - - def _scan_dir(self, path: Path, recursive: bool, visited: Set[str]) -> Dict[str, Any]: - path_str = str(path) - - # Avoid cycles due to symlinks - if path_str in visited: - logger.warning("Skipping already visited path: %s", path_str) - return self._empty_result(path_str) - - # cache check - cache_key = f"scan::{path_str}::recursive::{recursive}" - cached = self.cache.get(cache_key) - if cached and not self.rescan: - logger.info("Using cached scan result for path: %s", path_str) - return cached['data'] - - logger.info("Scanning path: %s", path_str) - files, dirs = [], [] - stats = {'total_size': 0, 'file_count': 0, 'dir_count': 0, 'errors': 0} - - try: - with os.scandir(path_str) as entries: - for entry in entries: - try: - entry_stat = entry.stat(follow_symlinks=False) - - if entry.is_dir(): - dirs.append({ - 'path': entry.path, - 'name': entry.name, - 'mtime': entry_stat.st_mtime - }) - stats['dir_count'] += 1 - else: - # allowed suffix filter - if self.allowed_suffix: - suffix = Path(entry.name).suffix.lower() - if suffix not in self.allowed_suffix: - continue - - files.append({ - 'path': entry.path, - 'name': entry.name, - 'size': entry_stat.st_size, - 'mtime': entry_stat.st_mtime - }) - stats['total_size'] += entry_stat.st_size - stats['file_count'] += 1 - - except OSError: - stats['errors'] += 1 - - except (PermissionError, FileNotFoundError, OSError) as e: - logger.error("Failed to scan directory %s: %s", path_str, e) - return {'error': str(e), 'files': [], 'dirs': [], 'stats': stats} - - if recursive: - sub_visited = visited | {path_str} - sub_results = self._scan_subdirs(dirs, sub_visited) - - for sub_data in sub_results.values(): - files.extend(sub_data.get('files', [])) - stats['total_size'] += sub_data['stats'].get('total_size', 0) - stats['file_count'] += sub_data['stats'].get('file_count', 0) - - result = {'path': path_str, 'files': files, 'dirs': dirs, 'stats': stats} - self._cache_result(cache_key, result, path) - return result - - def _scan_subdirs(self, dir_list: List[Dict], visited: Set[str]) -> Dict[str, Any]: - """ - Parallel scan subdirectories - :param dir_list - :param visited - :return: - """ - results = {} - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - futures = { - executor.submit(self._scan_dir, Path(d['path']), True, visited): d['path'] - for d in dir_list - } - - for future in as_completed(futures): - path = futures[future] - try: - results[path] = future.result() - except Exception as e: - logger.error("Error scanning subdirectory %s: %s", path, e) - results[path] = {'error': str(e), 'files': [], 'dirs': [], 'stats': {}} - - return results - - def _cache_result(self, key: str, result: Dict, path: Path): - """Cache the scan result""" - try: - self.cache.set(key, { - 'data': result, - 'dir_mtime': path.stat().st_mtime, - 'cached_at': time.time() - }) - logger.info(f"Cached scan result for: {path}") - except OSError: - pass - - def invalidate(self, path: str): - """Invalidate cache for a specific path""" - path = Path(path).resolve() - keys = [k for k in self.cache if k.startswith(f"scan:{path}")] - for k in keys: - self.cache.delete(k) - logger.info(f"Invalidated cache for path: {path}") - - def close(self): - self.cache.close() - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - - @staticmethod - def _empty_result(path: str) -> Dict[str, Any]: - return { - 'path': path, - 'files': [], - 'dirs': [], - 'stats': {'total_size': 0, 'file_count': 0, 'dir_count': 0, 'errors': 0} - } From 1e7108061bd0448765930a22c45fe2ce7cff9080 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 24 Nov 2025 14:32:54 +0800 Subject: [PATCH 23/26] fix: fix read_files --- graphgen/models/reader/csv_reader.py | 6 +- graphgen/models/reader/json_reader.py | 2 +- graphgen/models/reader/parquet_reader.py | 6 +- graphgen/models/reader/pickle_reader.py | 6 +- graphgen/models/reader/txt_reader.py | 6 +- graphgen/operators/read/read_files.py | 308 +++++++++++++++++------ 6 files changed, 247 insertions(+), 87 deletions(-) diff --git a/graphgen/models/reader/csv_reader.py b/graphgen/models/reader/csv_reader.py index ccd136ad..99faa30e 100644 --- a/graphgen/models/reader/csv_reader.py +++ b/graphgen/models/reader/csv_reader.py @@ -17,17 +17,17 @@ class CSVReader(BaseReader): def read( self, input_path: Union[str, List[str]], - override_num_blocks: int = None, + parallelism: int = None, ) -> Dataset: """ Read CSV files and return Ray Dataset. :param input_path: Path to CSV file or list of CSV files. - :param override_num_blocks: Number of blocks for Ray Dataset reading. + :param parallelism: Number of blocks for Ray Dataset reading. :return: Ray Dataset containing validated and filtered data. """ - ds = ray.data.read_csv(input_path, override_num_blocks=override_num_blocks) + ds = ray.data.read_csv(input_path, override_num_blocks=parallelism) ds = ds.map_batches(self._validate_batch, batch_format="pandas") ds = ds.filter(self._should_keep_item) return ds diff --git a/graphgen/models/reader/json_reader.py b/graphgen/models/reader/json_reader.py index cd18e4f5..1bcba4ea 100644 --- a/graphgen/models/reader/json_reader.py +++ b/graphgen/models/reader/json_reader.py @@ -26,7 +26,7 @@ def read( :return: Ray Dataset containing validated and filtered data. """ - ds = ray.data.read_json(input_path, parallelism=parallelism) + ds = ray.data.read_json(input_path, override_num_blocks=parallelism) ds = ds.map_batches(self._validate_batch, batch_format="pandas") ds = ds.filter(self._should_keep_item) return ds diff --git a/graphgen/models/reader/parquet_reader.py b/graphgen/models/reader/parquet_reader.py index e521e6d8..5423643b 100644 --- a/graphgen/models/reader/parquet_reader.py +++ b/graphgen/models/reader/parquet_reader.py @@ -17,19 +17,19 @@ class ParquetReader(BaseReader): def read( self, input_path: Union[str, List[str]], - override_num_blocks: int = None, + parallelism: int = None, ) -> Dataset: """ Read Parquet files using Ray Data. :param input_path: Path to Parquet file or list of Parquet files. - :param override_num_blocks: Number of blocks for Ray Dataset reading. + :param parallelism: Number of blocks for Ray Dataset reading. :return: Ray Dataset containing validated documents. """ if not ray.is_initialized(): ray.init() - ds = ray.data.read_parquet(input_path, override_num_blocks=override_num_blocks) + ds = ray.data.read_parquet(input_path, override_num_blocks=parallelism) ds = ds.map_batches(self._validate_batch, batch_format="pandas") ds = ds.filter(self._should_keep_item) return ds diff --git a/graphgen/models/reader/pickle_reader.py b/graphgen/models/reader/pickle_reader.py index 1af7ca77..0b0e5719 100644 --- a/graphgen/models/reader/pickle_reader.py +++ b/graphgen/models/reader/pickle_reader.py @@ -23,13 +23,13 @@ class PickleReader(BaseReader): def read( self, input_path: Union[str, List[str]], - override_num_blocks: int = None, + parallelism: int = None, ) -> Dataset: """ Read Pickle files using Ray Data. :param input_path: Path to pickle file or list of pickle files. - :param override_num_blocks: Number of blocks for Ray Dataset reading. + :param parallelism: Number of blocks for Ray Dataset reading. :return: Ray Dataset containing validated documents. """ if not ray.is_initialized(): @@ -37,7 +37,7 @@ def read( # Use read_binary_files as a reliable alternative to read_pickle ds = ray.data.read_binary_files( - input_path, override_num_blocks=override_num_blocks, include_paths=True + input_path, override_num_blocks=parallelism, include_paths=True ) # Deserialize pickle files and flatten into individual records diff --git a/graphgen/models/reader/txt_reader.py b/graphgen/models/reader/txt_reader.py index 3ad33a73..bb6cce9e 100644 --- a/graphgen/models/reader/txt_reader.py +++ b/graphgen/models/reader/txt_reader.py @@ -10,16 +10,16 @@ class TXTReader(BaseReader): def read( self, input_path: Union[str, List[str]], - override_num_blocks: int = 4, + parallelism: int = 4, ) -> Dataset: """ Read text files from the specified input path. :param input_path: Path to the input text file or list of text files. - :param override_num_blocks: Number of blocks to override for Ray Dataset reading. + :param parallelism: Number of blocks to override for Ray Dataset reading. :return: Ray Dataset containing the read text data. """ docs_ds = ray.data.read_text( - input_path, encoding="utf-8", override_num_blocks=override_num_blocks + input_path, encoding="utf-8", override_num_blocks=parallelism ) docs_ds = docs_ds.map( diff --git a/graphgen/operators/read/read_files.py b/graphgen/operators/read/read_files.py index 0b33e9db..c4933ce4 100644 --- a/graphgen/operators/read/read_files.py +++ b/graphgen/operators/read/read_files.py @@ -1,16 +1,13 @@ -from ray.data.datasource import ( - Datasource, ReadTask -) -import pyarrow as pa -from typing import List, Dict, Any, Optional, Union - -from typing import Iterator, List, Optional +from pathlib import Path +from typing import Any, Iterable, List, Optional, Union +import pyarrow as pa import ray +from ray.data.block import Block, BlockMetadata +from ray.data.datasource import Datasource, ReadTask from graphgen.models import ( CSVReader, - JSONLReader, JSONReader, ParquetReader, PDFReader, @@ -20,8 +17,10 @@ ) from graphgen.utils import logger +from .parallel_file_scanner import ParallelFileScanner + _MAPPING = { - "jsonl": JSONLReader, + "jsonl": JSONReader, "json": JSONReader, "txt": TXTReader, "csv": CSVReader, @@ -35,84 +34,245 @@ } -def _build_reader(suffix: str, cache_dir: str | None): +def _build_reader(suffix: str, cache_dir: str | None, **reader_kwargs): + """Factory function to build appropriate reader instance""" suffix = suffix.lower() - if suffix == "pdf" and cache_dir is not None: - return _MAPPING[suffix](output_dir=cache_dir) - return _MAPPING[suffix]() + reader_cls = _MAPPING.get(suffix) + if not reader_cls: + raise ValueError(f"Unsupported file suffix: {suffix}") + # Special handling for PDFReader which needs output_dir + if suffix == "pdf": + if cache_dir is None: + raise ValueError("cache_dir must be provided for PDFReader") + return reader_cls(output_dir=cache_dir, **reader_kwargs) + + return reader_cls(**reader_kwargs) + + +# pylint: disable=abstract-method class UnifiedFileDatasource(Datasource): - pass + """ + A unified Ray DataSource that can read multiple file types + and automatically route to the appropriate reader. + """ + + def __init__( + self, + paths: Union[str, List[str]], + allowed_suffix: Optional[List[str]] = None, + cache_dir: Optional[str] = None, + recursive: bool = True, + **reader_kwargs, + ): + """ + Initialize the datasource. + + :param paths: File or directory paths to read + :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) + :param cache_dir: Directory to cache intermediate files (used for PDF processing) + :param recursive: Whether to scan directories recursively + :param reader_kwargs: Additional kwargs passed to readers + """ + self.paths = [paths] if isinstance(paths, str) else paths + self.allowed_suffix = ( + [s.lower().lstrip(".") for s in allowed_suffix] + if allowed_suffix + else list(_MAPPING.keys()) + ) + self.cache_dir = cache_dir + self.recursive = recursive + self.reader_kwargs = reader_kwargs + + # Validate allowed suffixes + unsupported = set(self.allowed_suffix) - set(_MAPPING.keys()) + if unsupported: + raise ValueError(f"Unsupported file suffixes: {unsupported}") + + def get_read_tasks( + self, parallelism: int, per_task_row_limit: Optional[int] = None + ) -> List[ReadTask]: + """ + Create read tasks for all discovered files. + + :param parallelism: Number of parallel workers + :param per_task_row_limit: Optional limit on rows per task + :return: List of ReadTask objects + """ + # 1. Scan all paths to discover files + logger.info("[READ] Scanning paths: %s", self.paths) + scanner = ParallelFileScanner( + cache_dir=self.cache_dir, + allowed_suffix=self.allowed_suffix, + rescan=False, + max_workers=parallelism if parallelism > 0 else 1, + ) + + all_files = [] + scan_results = scanner.scan(self.paths, recursive=self.recursive) + + for result in scan_results.values(): + all_files.extend(result.get("files", [])) + + logger.info("[READ] Found %d files to process", len(all_files)) + if not all_files: + return [] + + # 2. Group files by suffix to use appropriate reader + files_by_suffix = {} + for file_info in all_files: + suffix = Path(file_info["path"]).suffix.lower().lstrip(".") + if suffix not in self.allowed_suffix: + continue + files_by_suffix.setdefault(suffix, []).append(file_info["path"]) + + # 3. Create read tasks + read_tasks = [] + + for suffix, file_paths in files_by_suffix.items(): + # Split files into chunks for parallel processing + num_chunks = min(parallelism, len(file_paths)) + if num_chunks == 0: + continue + + chunks = [[] for _ in range(num_chunks)] + for i, path in enumerate(file_paths): + chunks[i % num_chunks].append(path) + + # Create a task for each chunk + for chunk in chunks: + if not chunk: + continue + + # Use factory function to avoid mutable default argument issue + def make_read_fn( + file_paths_chunk, suffix_val, reader_kwargs_val, cache_dir_val + ): + def _read_fn() -> Iterable[Block]: + """ + Read a chunk of files and return blocks. + This function runs in a Ray worker. + """ + all_records = [] + + for file_path in file_paths_chunk: + try: + # Build reader for this file + reader = _build_reader( + suffix_val, cache_dir_val, **reader_kwargs_val + ) + + # Read the file - readers return Dataset + ds = reader.read(file_path, parallelism=parallelism) + + # Convert Dataset to list of dicts + records = ds.take_all() + all_records.extend(records) + + except Exception as e: + logger.error( + "[READ] Error reading file %s: %s", file_path, e + ) + continue + + # Convert list of dicts to PyArrow Table (Block) + if all_records: + # Create PyArrow Table from records + table = pa.Table.from_pylist(all_records) + yield table + + return _read_fn + + # Create closure with current loop variables + read_fn = make_read_fn( + chunk, suffix, self.reader_kwargs, self.cache_dir + ) + + # Calculate metadata for this task + total_bytes = sum( + Path(fp).stat().st_size for fp in chunk if Path(fp).exists() + ) + + # input_files must be Optional[str], not List[str] + # Use first file as representative or None if empty + first_file = chunk[0] if chunk else None + + metadata = BlockMetadata( + num_rows=None, # Unknown until read + size_bytes=total_bytes, + input_files=first_file, + exec_stats=None, + ) + + read_tasks.append( + ReadTask( + read_fn=read_fn, + metadata=metadata, + schema=None, # Will be inferred + per_task_row_limit=per_task_row_limit, + ) + ) + + logger.info("[READ] Created %d read tasks", len(read_tasks)) + return read_tasks + + def estimate_inmemory_data_size(self) -> Optional[int]: + """ + Estimate the total size of data in memory. + This helps Ray optimize task scheduling. + """ + try: + total_size = 0 + for path in self.paths: + scan_results = ParallelFileScanner( + cache_dir=self.cache_dir, + allowed_suffix=self.allowed_suffix, + rescan=False, + max_workers=1, + ).scan(path, recursive=self.recursive) + + for result in scan_results.values(): + total_size += result.get("stats", {}).get("total_size", 0) + return total_size + except Exception: + # Return None if estimation fails + return None def read_files( - input_path: str, + input_path: Union[str, List[str]], allowed_suffix: Optional[List[str]] = None, cache_dir: Optional[str] = None, parallelism: int = 4, - **ray_kwargs, + recursive: bool = True, + **reader_kwargs: Any, ) -> ray.data.Dataset: """ - Reads files from the specified input path, filtering by allowed suffixes, - and returns a Ray Dataset containing the read documents. - :param input_path: input file or directory path - :param allowed_suffix: list of allowed file suffixes (e.g., ['pdf', 'txt']) - :param cache_dir: directory to cache intermediate files (used for PDF reading) - :param parallelism: number of parallel workers for reading files - :param ray_kwargs: additional keyword arguments for Ray Dataset reading - :return: Ray Dataset containing the read documents + Unified entry point to read files of multiple types using Ray Data. + + :param input_path: File or directory path(s) to read from + :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) + :param cache_dir: Directory to cache intermediate files (PDF processing) + :param parallelism: Number of parallel workers + :param recursive: Whether to scan directories recursively + :param reader_kwargs: Additional kwargs passed to readers + :return: Ray Dataset containing all documents """ if not ray.is_initialized(): ray.init() - - return ray.data.read_datasource( - UnifiedFileDatasource( - paths=[input_path], - allowed_suffix=allowed_suffix, - cache_dir=cache_dir, - **ray_kwargs, # Pass additional Ray kwargs here - ), - parallelism=parallelism, - ) - - - # path = Path(input_file).expanduser() - # if not path.exists(): - # raise FileNotFoundError(f"[Read] input_path not found: {input_file}") - # - # if allowed_suffix is None: - # support_suffix = set(_MAPPING.keys()) - # else: - # support_suffix = {s.lower().lstrip(".") for s in allowed_suffix} - # - # # single file - # if path.is_file(): - # suffix = path.suffix.lstrip(".").lower() - # if suffix not in support_suffix: - # logger.warning( - # "[Read] Skip file %s (suffix '%s' not in allowed_suffix %s)", - # path, - # suffix, - # support_suffix, - # ) - # return - # reader = _build_reader(suffix, cache_dir) - # logger.info("[Read] Reading file %s", path) - # yield reader.read(str(path)) - # return - # - # # folder - # logger.info("[Read] Streaming directory %s", path) - # for p in path.rglob("*"): - # if p.is_file() and p.suffix.lstrip(".").lower() in support_suffix: - # try: - # suffix = p.suffix.lstrip(".").lower() - # reader = _build_reader(suffix, cache_dir) - # logger.info("[Reader] Reading file %s", p) - # docs = reader.read(str(p)) - # if docs: - # yield docs - # except Exception: # pylint: disable=broad-except - # logger.exception("[Reader] Error reading %s", p) + try: + return ray.data.read_datasource( + UnifiedFileDatasource( + paths=input_path, + allowed_suffix=allowed_suffix, + cache_dir=cache_dir, + recursive=recursive, + **reader_kwargs, + ), + parallelism=parallelism, + ) + except Exception as e: + logger.error("[READ] Failed to read files from %s: %s", input_path, e) + raise From 00551e31dbbccb7dfbd979016370eaa6e5e94692 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 24 Nov 2025 14:38:58 +0800 Subject: [PATCH 24/26] fix: fix pylint problems --- graphgen/operators/read/read_files.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graphgen/operators/read/read_files.py b/graphgen/operators/read/read_files.py index c4933ce4..df330f48 100644 --- a/graphgen/operators/read/read_files.py +++ b/graphgen/operators/read/read_files.py @@ -178,7 +178,8 @@ def _read_fn() -> Iterable[Block]: # Convert list of dicts to PyArrow Table (Block) if all_records: # Create PyArrow Table from records - table = pa.Table.from_pylist(all_records) + # pylint: disable=no-value-for-parameter + table = pa.Table.from_pylist(mapping=all_records) yield table return _read_fn From cb2833c056d2bfa8c604cbf59160aba7d8a58f22 Mon Sep 17 00:00:00 2001 From: chenzihong <58508660+ChenZiHong-Gavin@users.noreply.github.com> Date: Mon, 24 Nov 2025 14:44:07 +0800 Subject: [PATCH 25/26] fix: fix for pull request finding 'Empty except' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --- graphgen/operators/read/parallel_file_scanner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphgen/operators/read/parallel_file_scanner.py b/graphgen/operators/read/parallel_file_scanner.py index 499cc224..890a50a9 100644 --- a/graphgen/operators/read/parallel_file_scanner.py +++ b/graphgen/operators/read/parallel_file_scanner.py @@ -194,8 +194,8 @@ def _cache_result(self, key: str, result: Dict, path: Path): }, ) logger.info("[READ] Cached scan result for path: %s", path) - except OSError: - pass + except OSError as e: + logger.error("[READ] Failed to cache scan result for path %s: %s", path, e) def _is_allowed_file(self, path: Path) -> bool: """Check if the file has an allowed suffix""" From f391c24f074d11cd19771d2e9d4f8022fc63c7b4 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Mon, 24 Nov 2025 19:22:11 +0800 Subject: [PATCH 26/26] perf: optimize read_files.py by deleting implementation of ray.data.DataSource --- graphgen/operators/read/read_files.py | 240 +++++--------------------- 1 file changed, 42 insertions(+), 198 deletions(-) diff --git a/graphgen/operators/read/read_files.py b/graphgen/operators/read/read_files.py index df330f48..34ffee85 100644 --- a/graphgen/operators/read/read_files.py +++ b/graphgen/operators/read/read_files.py @@ -1,10 +1,7 @@ from pathlib import Path -from typing import Any, Iterable, List, Optional, Union +from typing import Any, List, Optional, Union -import pyarrow as pa import ray -from ray.data.block import Block, BlockMetadata -from ray.data.datasource import Datasource, ReadTask from graphgen.models import ( CSVReader, @@ -50,230 +47,77 @@ def _build_reader(suffix: str, cache_dir: str | None, **reader_kwargs): return reader_cls(**reader_kwargs) -# pylint: disable=abstract-method -class UnifiedFileDatasource(Datasource): - """ - A unified Ray DataSource that can read multiple file types - and automatically route to the appropriate reader. +def read_files( + input_path: Union[str, List[str]], + allowed_suffix: Optional[List[str]] = None, + cache_dir: Optional[str] = None, + parallelism: int = 4, + recursive: bool = True, + **reader_kwargs: Any, +) -> ray.data.Dataset: """ + Unified entry point to read files of multiple types using Ray Data. - def __init__( - self, - paths: Union[str, List[str]], - allowed_suffix: Optional[List[str]] = None, - cache_dir: Optional[str] = None, - recursive: bool = True, - **reader_kwargs, - ): - """ - Initialize the datasource. - - :param paths: File or directory paths to read - :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) - :param cache_dir: Directory to cache intermediate files (used for PDF processing) - :param recursive: Whether to scan directories recursively - :param reader_kwargs: Additional kwargs passed to readers - """ - self.paths = [paths] if isinstance(paths, str) else paths - self.allowed_suffix = ( - [s.lower().lstrip(".") for s in allowed_suffix] - if allowed_suffix - else list(_MAPPING.keys()) - ) - self.cache_dir = cache_dir - self.recursive = recursive - self.reader_kwargs = reader_kwargs - - # Validate allowed suffixes - unsupported = set(self.allowed_suffix) - set(_MAPPING.keys()) - if unsupported: - raise ValueError(f"Unsupported file suffixes: {unsupported}") - - def get_read_tasks( - self, parallelism: int, per_task_row_limit: Optional[int] = None - ) -> List[ReadTask]: - """ - Create read tasks for all discovered files. - - :param parallelism: Number of parallel workers - :param per_task_row_limit: Optional limit on rows per task - :return: List of ReadTask objects - """ + :param input_path: File or directory path(s) to read from + :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) + :param cache_dir: Directory to cache intermediate files (PDF processing) + :param parallelism: Number of parallel workers + :param recursive: Whether to scan directories recursively + :param reader_kwargs: Additional kwargs passed to readers + :return: Ray Dataset containing all documents + """ + try: # 1. Scan all paths to discover files - logger.info("[READ] Scanning paths: %s", self.paths) + logger.info("[READ] Scanning paths: %s", input_path) scanner = ParallelFileScanner( - cache_dir=self.cache_dir, - allowed_suffix=self.allowed_suffix, + cache_dir=cache_dir, + allowed_suffix=allowed_suffix, rescan=False, max_workers=parallelism if parallelism > 0 else 1, ) all_files = [] - scan_results = scanner.scan(self.paths, recursive=self.recursive) + scan_results = scanner.scan(input_path, recursive=recursive) for result in scan_results.values(): all_files.extend(result.get("files", [])) logger.info("[READ] Found %d files to process", len(all_files)) + if not all_files: - return [] + return ray.data.from_items([]) # 2. Group files by suffix to use appropriate reader files_by_suffix = {} for file_info in all_files: suffix = Path(file_info["path"]).suffix.lower().lstrip(".") - if suffix not in self.allowed_suffix: + if allowed_suffix and suffix not in [ + s.lower().lstrip(".") for s in allowed_suffix + ]: continue files_by_suffix.setdefault(suffix, []).append(file_info["path"]) # 3. Create read tasks read_tasks = [] - for suffix, file_paths in files_by_suffix.items(): - # Split files into chunks for parallel processing - num_chunks = min(parallelism, len(file_paths)) - if num_chunks == 0: - continue - - chunks = [[] for _ in range(num_chunks)] - for i, path in enumerate(file_paths): - chunks[i % num_chunks].append(path) - - # Create a task for each chunk - for chunk in chunks: - if not chunk: - continue - - # Use factory function to avoid mutable default argument issue - def make_read_fn( - file_paths_chunk, suffix_val, reader_kwargs_val, cache_dir_val - ): - def _read_fn() -> Iterable[Block]: - """ - Read a chunk of files and return blocks. - This function runs in a Ray worker. - """ - all_records = [] - - for file_path in file_paths_chunk: - try: - # Build reader for this file - reader = _build_reader( - suffix_val, cache_dir_val, **reader_kwargs_val - ) - - # Read the file - readers return Dataset - ds = reader.read(file_path, parallelism=parallelism) - - # Convert Dataset to list of dicts - records = ds.take_all() - all_records.extend(records) + reader = _build_reader(suffix, cache_dir, **reader_kwargs) + ds = reader.read(file_paths, parallelism=parallelism) + read_tasks.append(ds) - except Exception as e: - logger.error( - "[READ] Error reading file %s: %s", file_path, e - ) - continue + # 4. Combine all datasets + if not read_tasks: + logger.warning("[READ] No datasets created") + return ray.data.from_items([]) - # Convert list of dicts to PyArrow Table (Block) - if all_records: - # Create PyArrow Table from records - # pylint: disable=no-value-for-parameter - table = pa.Table.from_pylist(mapping=all_records) - yield table + if len(read_tasks) == 1: + logger.info("[READ] Successfully read files from %s", input_path) + return read_tasks[0] + # len(read_tasks) > 1 + combined_ds = read_tasks[0].union(*read_tasks[1:]) - return _read_fn + logger.info("[READ] Successfully read files from %s", input_path) + return combined_ds - # Create closure with current loop variables - read_fn = make_read_fn( - chunk, suffix, self.reader_kwargs, self.cache_dir - ) - - # Calculate metadata for this task - total_bytes = sum( - Path(fp).stat().st_size for fp in chunk if Path(fp).exists() - ) - - # input_files must be Optional[str], not List[str] - # Use first file as representative or None if empty - first_file = chunk[0] if chunk else None - - metadata = BlockMetadata( - num_rows=None, # Unknown until read - size_bytes=total_bytes, - input_files=first_file, - exec_stats=None, - ) - - read_tasks.append( - ReadTask( - read_fn=read_fn, - metadata=metadata, - schema=None, # Will be inferred - per_task_row_limit=per_task_row_limit, - ) - ) - - logger.info("[READ] Created %d read tasks", len(read_tasks)) - return read_tasks - - def estimate_inmemory_data_size(self) -> Optional[int]: - """ - Estimate the total size of data in memory. - This helps Ray optimize task scheduling. - """ - try: - total_size = 0 - for path in self.paths: - scan_results = ParallelFileScanner( - cache_dir=self.cache_dir, - allowed_suffix=self.allowed_suffix, - rescan=False, - max_workers=1, - ).scan(path, recursive=self.recursive) - - for result in scan_results.values(): - total_size += result.get("stats", {}).get("total_size", 0) - return total_size - except Exception: - # Return None if estimation fails - return None - - -def read_files( - input_path: Union[str, List[str]], - allowed_suffix: Optional[List[str]] = None, - cache_dir: Optional[str] = None, - parallelism: int = 4, - recursive: bool = True, - **reader_kwargs: Any, -) -> ray.data.Dataset: - """ - Unified entry point to read files of multiple types using Ray Data. - - :param input_path: File or directory path(s) to read from - :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) - :param cache_dir: Directory to cache intermediate files (PDF processing) - :param parallelism: Number of parallel workers - :param recursive: Whether to scan directories recursively - :param reader_kwargs: Additional kwargs passed to readers - :return: Ray Dataset containing all documents - """ - - if not ray.is_initialized(): - ray.init() - - try: - return ray.data.read_datasource( - UnifiedFileDatasource( - paths=input_path, - allowed_suffix=allowed_suffix, - cache_dir=cache_dir, - recursive=recursive, - **reader_kwargs, - ), - parallelism=parallelism, - ) except Exception as e: logger.error("[READ] Failed to read files from %s: %s", input_path, e) raise