Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ disable=raw-checker-failed,
R0917, # Too many positional arguments (6/5) (too-many-positional-arguments)
C0103,
E0401,
W0718, # Catching too general exception Exception (broad-except)

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
28 changes: 15 additions & 13 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,16 @@ def __init__(
self.graph_storage: NetworkXStorage = NetworkXStorage(
self.working_dir, namespace="graph"
)
self.search_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="search"
)
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="rephrase"
)
self.partition_storage: JsonListStorage = JsonListStorage(
self.working_dir, namespace="partition"
)
self.search_storage: JsonKVStorage = JsonKVStorage(
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
namespace="search",
)
self.qa_storage: JsonListStorage = JsonListStorage(
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
namespace="qa",
Expand All @@ -94,23 +95,24 @@ async def read(self, read_config: Dict):
"""
read files from input sources
"""
data = read_files(**read_config, cache_dir=self.working_dir)
if len(data) == 0:
logger.warning("No data to process")
return
doc_stream = read_files(**read_config, cache_dir=self.working_dir)

assert isinstance(data, list) and isinstance(data[0], dict)
batch = {}
for doc in doc_stream:
doc_id = compute_mm_hash(doc, prefix="doc-")

# TODO: configurable whether to use coreference resolution
batch[doc_id] = doc
if batch:
self.full_docs_storage.upsert(batch)
self.full_docs_storage.index_done_callback()

new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
_add_doc_keys = self.full_docs_storage.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
# TODO: configurable whether to use coreference resolution

_add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys()))
new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys}
if len(new_docs) == 0:
logger.warning("All documents are already in the storage")
return

self.full_docs_storage.upsert(new_docs)
self.full_docs_storage.index_done_callback()

Expand Down
231 changes: 231 additions & 0 deletions graphgen/operators/read/parallel_file_scanner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, List, Set, Union

from diskcache import Cache

from graphgen.utils import logger


class ParallelFileScanner:
def __init__(
self, cache_dir: str, allowed_suffix, rescan: bool = False, max_workers: int = 4
):
self.cache = Cache(cache_dir)
self.allowed_suffix = set(allowed_suffix) if allowed_suffix else None
self.rescan = rescan
self.max_workers = max_workers

def scan(
self, paths: Union[str, List[str]], recursive: bool = True
) -> Dict[str, Any]:
if isinstance(paths, str):
paths = [paths]

results = {}
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_path = {}
for p in paths:
if os.path.exists(p):
future = executor.submit(
self._scan_files, Path(p).resolve(), recursive, set()
)
future_to_path[future] = p
else:
logger.warning("[READ] Path does not exist: %s", p)

for future in as_completed(future_to_path):
path = future_to_path[future]
try:
results[path] = future.result()
except Exception as e:
logger.error("[READ] Error scanning path %s: %s", path, e)
results[path] = {
"error": str(e),
"files": [],
"dirs": [],
"stats": {},
}
return results

def _scan_files(
self, path: Path, recursive: bool, visited: Set[str]
) -> Dict[str, Any]:
path_str = str(path)

# Avoid cycles due to symlinks
if path_str in visited:
logger.warning("[READ] Skipping already visited path: %s", path_str)
return self._empty_result(path_str)

# cache check
cache_key = f"scan::{path_str}::recursive::{recursive}"
cached = self.cache.get(cache_key)
if cached and not self.rescan:
logger.info("[READ] Using cached scan result for path: %s", path_str)
return cached["data"]

logger.info("[READ] Scanning path: %s", path_str)
files, dirs = [], []
stats = {"total_size": 0, "file_count": 0, "dir_count": 0, "errors": 0}

try:
path_stat = path.stat()
if path.is_file():
return self._scan_single_file(path, path_str, path_stat)
if path.is_dir():
with os.scandir(path_str) as entries:
for entry in entries:
try:
entry_stat = entry.stat(follow_symlinks=False)

if entry.is_dir():
dirs.append(
{
"path": entry.path,
"name": entry.name,
"mtime": entry_stat.st_mtime,
}
)
stats["dir_count"] += 1
else:
# allowed suffix filter
if not self._is_allowed_file(Path(entry.path)):
continue
files.append(
{
"path": entry.path,
"name": entry.name,
"size": entry_stat.st_size,
"mtime": entry_stat.st_mtime,
}
)
stats["total_size"] += entry_stat.st_size
stats["file_count"] += 1

except OSError:
stats["errors"] += 1

except (PermissionError, FileNotFoundError, OSError) as e:
logger.error("[READ] Failed to scan path %s: %s", path_str, e)
return {"error": str(e), "files": [], "dirs": [], "stats": stats}

if recursive:
sub_visited = visited | {path_str}
sub_results = self._scan_subdirs(dirs, sub_visited)

for sub_data in sub_results.values():
files.extend(sub_data.get("files", []))
stats["total_size"] += sub_data["stats"].get("total_size", 0)
stats["file_count"] += sub_data["stats"].get("file_count", 0)

result = {"path": path_str, "files": files, "dirs": dirs, "stats": stats}
self._cache_result(cache_key, result, path)
return result

def _scan_single_file(
self, path: Path, path_str: str, stat: os.stat_result
) -> Dict[str, Any]:
"""Scan a single file and return its metadata"""
if not self._is_allowed_file(path):
return self._empty_result(path_str)

return {
"path": path_str,
"files": [
{
"path": path_str,
"name": path.name,
"size": stat.st_size,
"mtime": stat.st_mtime,
}
],
"dirs": [],
"stats": {
"total_size": stat.st_size,
"file_count": 1,
"dir_count": 0,
"errors": 0,
},
}

def _scan_subdirs(self, dir_list: List[Dict], visited: Set[str]) -> Dict[str, Any]:
"""
Parallel scan subdirectories
:param dir_list
:param visited
:return:
"""
results = {}
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = {
executor.submit(self._scan_files, Path(d["path"]), True, visited): d[
"path"
]
for d in dir_list
}

for future in as_completed(futures):
path = futures[future]
try:
results[path] = future.result()
except Exception as e:
logger.error("[READ] Error scanning subdirectory %s: %s", path, e)
results[path] = {
"error": str(e),
"files": [],
"dirs": [],
"stats": {},
}

return results

def _cache_result(self, key: str, result: Dict, path: Path):
"""Cache the scan result"""
try:
self.cache.set(
key,
{
"data": result,
"dir_mtime": path.stat().st_mtime,
"cached_at": time.time(),
},
)
logger.info("[READ] Cached scan result for path: %s", path)
except OSError as e:
logger.error("[READ] Failed to cache scan result for path %s: %s", path, e)

def _is_allowed_file(self, path: Path) -> bool:
"""Check if the file has an allowed suffix"""
if self.allowed_suffix is None:
return True
suffix = path.suffix.lower().lstrip(".")
return suffix in self.allowed_suffix

def invalidate(self, path: str):
"""Invalidate cache for a specific path"""
path = Path(path).resolve()
keys = [k for k in self.cache if k.startswith(f"scan::{path}")]
for k in keys:
self.cache.delete(k)
logger.info("[READ] Invalidated cache for path: %s", path)

def close(self):
self.cache.close()

def __enter__(self):
return self

def __exit__(self, *args):
self.close()

@staticmethod
def _empty_result(path: str) -> Dict[str, Any]:
return {
"path": path,
"files": [],
"dirs": [],
"stats": {"total_size": 0, "file_count": 0, "dir_count": 0, "errors": 0},
}
67 changes: 40 additions & 27 deletions graphgen/operators/read/read_files.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Iterator, List, Optional

from graphgen.models import (
CSVReader,
Expand All @@ -13,6 +13,8 @@
)
from graphgen.utils import logger

from .parallel_file_scanner import ParallelFileScanner

_MAPPING = {
"jsonl": JSONLReader,
"json": JSONReader,
Expand All @@ -39,7 +41,20 @@ def read_files(
input_file: str,
allowed_suffix: Optional[List[str]] = None,
cache_dir: Optional[str] = None,
) -> list[dict]:
max_workers: int = 4,
rescan: bool = False,
) -> Iterator[Dict[str, Any]]:
"""
Read files from a path using parallel scanning and appropriate readers.

Args:
input_file: Path to a file or directory
allowed_suffix: List of file suffixes to read. If None, uses all supported types
cache_dir: Directory for caching PDF extraction and scan results
max_workers: Number of workers for parallel scanning
rescan: Whether to force rescan even if cached results exist
"""

path = Path(input_file).expanduser()
if not path.exists():
raise FileNotFoundError(f"input_path not found: {input_file}")
Expand All @@ -49,38 +64,36 @@ def read_files(
else:
support_suffix = {s.lower().lstrip(".") for s in allowed_suffix}

# single file
if path.is_file():
suffix = path.suffix.lstrip(".").lower()
if suffix not in support_suffix:
logger.warning(
"Skip file %s (suffix '%s' not in allowed_suffix %s)",
path,
suffix,
support_suffix,
)
return []
reader = _build_reader(suffix, cache_dir)
return reader.read(str(path))

# folder
files_to_read = [
p for p in path.rglob("*") if p.suffix.lstrip(".").lower() in support_suffix
]
with ParallelFileScanner(
cache_dir=cache_dir or "cache",
allowed_suffix=support_suffix,
rescan=rescan,
max_workers=max_workers,
) as scanner:
scan_results = scanner.scan(str(path), recursive=True)

# Extract files from scan results
files_to_read = []
for path_result in scan_results.values():
if "error" in path_result:
logger.warning("Error scanning %s: %s", path_result.path, path_result.error)
continue
files_to_read.extend(path_result.get("files", []))

logger.info(
"Found %d eligible file(s) under folder %s (allowed_suffix=%s)",
len(files_to_read),
input_file,
support_suffix,
)

all_docs: List[Dict[str, Any]] = []
for p in files_to_read:
for file_info in files_to_read:
try:
suffix = p.suffix.lstrip(".").lower()
file_path = file_info["path"]
suffix = Path(file_path).suffix.lstrip(".").lower()
reader = _build_reader(suffix, cache_dir)
all_docs.extend(reader.read(str(p)))
except Exception as e: # pylint: disable=broad-except
logger.exception("Error reading %s: %s", p, e)

return all_docs
yield from reader.read(file_path)

except Exception as e: # pylint: disable=broad-except
logger.exception("Error reading %s: %s", file_info.get("path"), e)
Loading