diff --git a/docs/en/concepts/09-transaction.md b/docs/en/concepts/09-transaction.md new file mode 100644 index 000000000..3d7ced04a --- /dev/null +++ b/docs/en/concepts/09-transaction.md @@ -0,0 +1,374 @@ +# Transaction Mechanism + +OpenViking's transaction mechanism protects the consistency of core write operations (`rm`, `mv`, `add_resource`, `session.commit`), ensuring that VikingFS, VectorDB, and QueueManager remain consistent even when failures occur. + +## Design Philosophy + +OpenViking is a context database where FS is the source of truth and VectorDB is a derived index. A lost index can be rebuilt from source data, but lost source data is unrecoverable. Therefore: + +> **Better to miss a search result than to return a bad one.** + +## Design Principles + +1. **Transactions cover synchronous operations only**: FS + VectorDB operations run inside transactions; SemanticQueue/EmbeddingQueue enqueue runs after commit (as post_actions) — they are idempotent and retriable +2. **On by default**: All data operations automatically use transactions; no extra configuration needed +3. **Write-exclusive**: Path locks ensure only one write transaction can operate on a path at a time +4. **Undo Log model**: Record reverse operations before each change; replay them in reverse order on failure +5. **Persistent journal**: Each transaction writes a journal file to AGFS for crash recovery + +## Architecture + +``` +Service Layer (rm / mv / add_resource / session.commit) + | + v ++--[TransactionContext async context manager]--+ +| | +| 1. Create transaction + write journal | +| 2. Acquire path lock (poll + timeout) | +| 3. Execute operations (FS + VectorDB) | +| 4. Record Undo Log (mark completed) | +| 5. Commit / Rollback | +| 6. Execute post_actions (enqueue etc) | +| 7. Release lock + clean up journal | +| | +| On exception: reverse Undo Log + unlock | ++----------------------------------------------+ + | + v +Storage Layer (VikingFS, VectorDB, QueueManager) +``` + +## Consistency Issues and Solutions + +### rm(uri) + +| Problem | Solution | +|---------|----------| +| Delete file first, then index -> file gone but index remains -> search returns non-existent file | **Reverse order**: delete index first, then file. Index deletion failure -> both file and index intact | + +Transaction flow: + +``` +1. Begin transaction, acquire lock (lock_mode="subtree") +2. Snapshot VectorDB records (for rollback recovery) +3. Delete VectorDB index -> immediately invisible to search +4. Delete FS file +5. Commit -> release lock -> delete journal +``` + +Rollback: Step 4 fails -> restore VectorDB records from snapshot. + +### mv(old_uri, new_uri) + +| Problem | Solution | +|---------|----------| +| File moved to new path but index points to old path -> search returns old path (doesn't exist) | Transaction wrapper; rollback on failure | + +Transaction flow: + +``` +1. Begin transaction, acquire lock (lock_mode="mv", SUBTREE on both source and destination for directories) +2. Move FS file +3. Update VectorDB URIs +4. Commit -> release lock -> delete journal +``` + +Rollback: Step 3 fails -> move file back to original location. + +### add_resource + +| Problem | Solution | +|---------|----------| +| File moved from temp to final directory, then crash -> file exists but never searchable | Two separate paths for first-time add vs incremental update | + +First-time add and incremental update are two independent paths: + +**First-time add** (target does not exist) — handled in `ResourceProcessor.process_resource` Phase 3.5: + +``` +1. Begin transaction, lock parent_path of final_uri (lock_mode="point") +2. Record undo: fs_write_new (uri=dst_path) +3. agfs.mv temp directory -> final location +4. Commit -> release lock -> delete journal +5. Clean up temp directory +6. Enqueue SemanticMsg(uri=final, target_uri=None) -> DAG runs on final, no callback +``` + +Crash recovery: Undo deletes the incomplete dst_path; re-run `add_resource` to retry. + +**Incremental update** (target already exists) — temp stays in place: + +``` +1. Enqueue SemanticMsg(uri=temp, target_uri=final) -> DAG runs on temp +2. DAG completion triggers sync_diff_callback or move_temp_to_target_callback +3. Each VikingFS.rm / VikingFS.mv inside callbacks creates its own independent transaction +``` + +Note: DAG callbacks do NOT wrap operations in an outer TransactionContext. Each `VikingFS.rm` and `VikingFS.mv` has its own transaction internally. An outer lock would conflict with these inner locks (e.g. outer POINT lock on target_path vs inner SUBTREE lock from `rm`) causing deadlock. + +### session.commit() + +| Problem | Solution | +|---------|----------| +| Messages cleared but archive not written -> conversation data lost | Phase 1 without transaction (incomplete archive has no side effects) + Phase 2 with redo transaction | + +LLM calls have unpredictable latency (5s~60s+) and cannot be inside a lock-holding transaction. The design splits into two phases: + +``` +Phase 1 — Archive (no transaction, no lock): + 1. Generate archive summary (LLM) + 2. Write archive (history/archive_N/messages.jsonl + summaries) + 3. Clear messages.jsonl + 4. Clear in-memory message list + +Phase 2 — Memory extraction + write (transaction, lock_mode="none", redo semantics): + 1. Record init_info (archive_uri, session_uri, user identity) + 2. Extract memories from archived messages (LLM) + 3. Write current message state + 4. Write relations + 5. Register post_action: enqueue SemanticQueue + 6. Commit +``` + +**Redo semantics**: Phase 2 does not register undo log entries. On crash recovery, memory extraction and writing are re-executed from the archive (`_redo_session_memory`) instead of being rolled back. + +**Crash recovery analysis**: + +| Crash point | State | Recovery action | +|------------|-------|----------------| +| During Phase 1 archive write | No transaction | Incomplete archive; next commit scans history/ for index, unaffected | +| Phase 1 archive complete but messages not cleared | No transaction | Archive complete + messages still present = redundant but safe | +| During Phase 2 memory extraction/write | Journal EXEC | On startup: `_redo_session_memory` redoes extraction + write + enqueue from archive | +| After Phase 2 commit | Journal COMMIT | On startup: replay `post_action("enqueue_semantic")` | + +## TransactionContext + +`TransactionContext` is an **async** context manager that encapsulates the full transaction lifecycle: + +```python +from openviking.storage.transaction import TransactionContext, get_transaction_manager + +tx_manager = get_transaction_manager() + +async with TransactionContext(tx_manager, "rm", [path], lock_mode="subtree") as tx: + # Record undo (call before making changes) + seq = tx.record_undo("vectordb_delete", {"record_ids": ids, "records_snapshot": snapshot}) + # Execute change + delete_from_vector_store(uris) + # Mark completed + tx.mark_completed(seq) + + # Register post-commit action (optional) + tx.add_post_action("enqueue_semantic", {"uri": uri, ...}) + + # Commit + await tx.commit() +# Auto-rollback if commit() not called +``` + +**Lock modes**: + +| lock_mode | Use case | Behavior | +|-----------|----------|----------| +| `point` | Write operations | Lock the specified path; conflicts with any lock on the same path and any SUBTREE lock on ancestors | +| `subtree` | Delete operations | Lock the subtree root; conflicts with any lock on the same path, any lock on descendants, and any SUBTREE lock on ancestors | +| `mv` | Move operations | Directory move: SUBTREE lock on both source and destination; File move: POINT lock on source parent and destination (controlled by `src_is_dir`) | +| `none` | Lock-free operations | Skip lock acquisition, transition directly to EXEC status. Used for session.commit Phase 2 and other scenarios that don't require path mutual exclusion | + +## Lock Types (POINT vs SUBTREE) + +The lock mechanism uses two lock types to handle different conflict patterns: + +| | POINT on same path | SUBTREE on same path | POINT on descendant | SUBTREE on ancestor | +|---|---|---|---|---| +| **POINT** | Conflict | Conflict | — | Conflict | +| **SUBTREE** | Conflict | Conflict | Conflict | Conflict | + +- **POINT (P)**: Used for write and semantic-processing operations. Only locks a single directory. Blocks if any ancestor holds a SUBTREE lock. +- **SUBTREE (S)**: Used for rm and mv operations. Logically covers the entire subtree but only writes **one lock file** at the root. Before acquiring, scans all descendants and ancestor directories for conflicting locks. + +## Undo Log + +Each transaction maintains an Undo Log recording the reverse action for each step: + +| op_type | Forward operation | Rollback action | +|---------|-------------------|-----------------| +| `fs_mv` | Move file | Move back | +| `fs_rm` | Delete file | Skip (irreversible; rm is always the last step by design) | +| `fs_write_new` | Create new file/directory | Delete | +| `fs_mkdir` | Create directory | Delete | +| `vectordb_delete` | Delete index records | Restore from snapshot | +| `vectordb_upsert` | Insert index records | Delete | +| `vectordb_update_uri` | Update URI | Restore old value | + +Rollback rules: Only entries with `completed=True` are rolled back, in **reverse order**. Each step has independent try-catch (best-effort). During crash recovery, `recover_all=True` also reverses uncompleted entries to clean up partial operations. + +### Context Reconstruction + +VectorDB rollback operations require a `RequestContext` (containing account_id, user_id, agent_id, role). Since the original context is unavailable during crash recovery, `_ctx_*` fields are serialized into undo params when calling record_undo: + +- `_ctx_account_id`: Account ID +- `_ctx_user_id`: User ID +- `_ctx_agent_id`: Agent ID +- `_ctx_role`: Role + +During rollback, `_reconstruct_ctx()` rebuilds the context from these fields. If reconstruction fails (missing fields), the VectorDB rollback step is skipped with a warning. + +## Lock Mechanism + +### Lock Protocol + +Lock file path: `{path}/.path.ovlock` + +Lock file content (Fencing Token): +``` +{transaction_id}:{time_ns}:{lock_type} +``` + +Where `lock_type` is `P` (POINT) or `S` (SUBTREE). + +### Lock Acquisition (POINT mode) + +``` +loop until timeout (poll interval: 200ms): + 1. Check target directory exists + 2. Check if target directory is locked by another transaction + - Stale lock? -> remove and retry + - Active lock? -> wait + 3. Check all ancestor directories for SUBTREE locks + - Stale lock? -> remove and retry + - Active lock? -> wait + 4. Write POINT (P) lock file + 5. TOCTOU double-check: re-scan ancestors for SUBTREE locks + - Conflict found: compare (timestamp, tx_id) + - Later one (larger timestamp/tx_id) backs off (removes own lock) to prevent livelock + - Wait and retry + 6. Verify lock file ownership (fencing token matches) + 7. Success + +Timeout (default 0 = no-wait) raises LockAcquisitionError +``` + +### Lock Acquisition (SUBTREE mode) + +``` +loop until timeout (poll interval: 200ms): + 1. Check target directory exists + 2. Check if target directory is locked by another transaction + - Stale lock? -> remove and retry + - Active lock? -> wait + 3. Check all ancestor directories for SUBTREE locks + - Stale lock? -> remove and retry + - Active lock? -> wait + 4. Scan all descendant directories for any locks by other transactions + - Stale lock? -> remove and retry + - Active lock? -> wait + 5. Write SUBTREE (S) lock file (only one file, at the root path) + 6. TOCTOU double-check: re-scan descendants and ancestors + - Conflict found: compare (timestamp, tx_id) + - Later one (larger timestamp/tx_id) backs off (removes own lock) to prevent livelock + - Wait and retry + 7. Verify lock file ownership (fencing token matches) + 8. Success + +Timeout (default 0 = no-wait) raises LockAcquisitionError +``` + +### Lock Expiry Cleanup + +**Stale lock detection**: PathLock checks the fencing token timestamp. Locks older than `lock_expire` (default 300s) are considered stale and are removed automatically during acquisition. + +**Transaction timeout**: TransactionManager checks active transactions every 60 seconds. Transactions with `updated_at` exceeding the transaction timeout (default 3600s) are rolled back. + +## Transaction Journal + +Each transaction persists a journal in AGFS: + +``` +/local/_system/transactions/{tx_id}/journal.json +``` + +Contains: transaction ID, status, lock paths, init_info, undo_log, post_actions. + +### Lifecycle + +``` +Create transaction -> write journal (INIT) +Acquire lock -> update journal (ACQUIRE -> EXEC) +Execute changes -> update journal per step (mark undo entry completed) +Commit -> update journal (COMMIT + post_actions) + -> execute post_actions -> release locks -> delete journal +Rollback -> execute undo log -> release locks -> delete journal +``` + +## Crash Recovery + +`TransactionManager.start()` automatically scans for residual journals on startup: + +| Journal status at crash | Recovery action | +|------------------------|----------------| +| `COMMIT` + non-empty post_actions | Replay post_actions -> release locks -> delete journal | +| `COMMIT` + empty post_actions / `RELEASED` | Release locks -> delete journal | +| `EXEC` / `FAIL` / `RELEASING` (`session_memory` operation) | Redo memory extraction + write from archive (`_redo_session_memory`) -> release locks -> delete journal | +| `EXEC` / `FAIL` / `RELEASING` (all undo entries completed) | Roll forward (treat as committed, replay post_actions) -> release locks -> delete journal | +| `EXEC` / `FAIL` / `RELEASING` (other) | Execute undo log rollback (`recover_all=True`) -> release locks -> delete journal | +| `INIT` / `ACQUIRE` | Clean up orphan locks (using init_info.lock_paths) -> delete journal (no changes were made) | + +### Defense Summary + +| Failure scenario | Defense | Recovery timing | +|-----------------|--------|-----------------| +| Crash during transaction | Journal + undo log rollback | On restart | +| Crash after commit, before enqueue | Journal post_actions replay | On restart | +| Crash after enqueue, before worker processes | QueueFS SQLite persistence | Worker auto-pulls after restart | +| Crash during session.commit Phase 2 | Journal + redo (re-extract memories from archive) | On restart | +| Orphan index | Cleaned on L2 on-demand load | When user accesses | +| Crash between lock creation and journal update | init_info records intended lock paths; recovery checks and cleans orphan locks | On restart | + +## Transaction State Machine + +``` +INIT -> ACQUIRE -> EXEC -> COMMIT -> RELEASING -> RELEASED + | + FAIL -> RELEASING -> RELEASED +``` + +- `INIT`: Transaction created, waiting for lock +- `ACQUIRE`: Acquiring lock +- `EXEC`: Transaction operations executing +- `COMMIT`: Committed, post_actions may be pending +- `FAIL`: Execution failed, entering rollback +- `RELEASING`: Releasing locks +- `RELEASED`: Locks released, transaction complete + +## Configuration + +The transaction mechanism is enabled by default with no extra configuration needed. **The default behavior is no-wait**: if the path is locked, `LockAcquisitionError` is raised immediately. To allow wait/retry, configure the `storage.transaction` section: + +```json +{ + "storage": { + "transaction": { + "lock_timeout": 5.0, + "lock_expire": 300.0 + } + } +} +``` + +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `lock_timeout` | float | Lock acquisition timeout (seconds). `0` = fail immediately if locked (default). `> 0` = wait/retry up to this many seconds. | `0.0` | +| `lock_expire` | float | Stale lock expiry threshold (seconds). Locks held longer than this by a crashed process are force-released. | `300.0` | + +### QueueFS Persistence + +The transaction mechanism relies on QueueFS using the SQLite backend to ensure enqueued tasks survive process restarts. This is the default configuration and requires no manual setup. + +## Related Documentation + +- [Architecture](./01-architecture.md) - System architecture overview +- [Storage](./05-storage.md) - AGFS and vector store +- [Session Management](./08-session.md) - Session and memory management +- [Configuration](../guides/01-configuration.md) - Configuration reference diff --git a/docs/en/guides/01-configuration.md b/docs/en/guides/01-configuration.md index 7c2779b60..d9ddd5477 100644 --- a/docs/en/guides/01-configuration.md +++ b/docs/en/guides/01-configuration.md @@ -515,7 +515,6 @@ Supports S3 storage in VirtualHostStyle mode, such as TOS. - #### vectordb Vector database storage configuration @@ -639,6 +638,30 @@ When `root_api_key` is configured, the server enables multi-tenant authenticatio For startup and deployment details see [Deployment](./03-deployment.md), for authentication see [Authentication](./04-authentication.md). +## storage.transaction Section + +The transaction mechanism is enabled by default and usually requires no configuration. **The default behavior is no-wait**: if the target path is already locked by another transaction, the operation fails immediately with `LockAcquisitionError`. Set `lock_timeout` to a positive value to allow polling/retry. + +```json +{ + "storage": { + "transaction": { + "lock_timeout": 5.0, + "lock_expire": 300.0, + "max_parallel_locks": 8 + } + } +} +``` + +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `lock_timeout` | float | Path lock acquisition timeout (seconds). `0` = fail immediately if locked (default). `> 0` = wait/retry up to this many seconds, then raise `LockAcquisitionError`. | `0.0` | +| `lock_expire` | float | Stale lock expiry threshold (seconds). Locks held longer than this by a crashed process are force-released. | `300.0` | +| `max_parallel_locks` | int | Max parallel locks during recursive locking for rm/mv operations | `8` | + +For details on the transaction mechanism, see [Transaction Mechanism](../concepts/09-transaction.md). + ## Full Schema ```json @@ -673,6 +696,11 @@ For startup and deployment details see [Deployment](./03-deployment.md), for aut "url": "string", "timeout": 10 }, + "transaction": { + "lock_timeout": 0.0, + "lock_expire": 300.0, + "max_parallel_locks": 8 + }, "vectordb": { "backend": "local|remote", "url": "string", diff --git a/docs/zh/concepts/09-transaction.md b/docs/zh/concepts/09-transaction.md index 503ed683f..2d42815a6 100644 --- a/docs/zh/concepts/09-transaction.md +++ b/docs/zh/concepts/09-transaction.md @@ -1,167 +1,374 @@ # 事务机制 -OpenViking 的事务机制为 AI Agent 上下文数据库提供可靠的操作保障,解决数据一致性、并发控制和错误恢复等核心问题。 +OpenViking 的事务机制保护核心写操作(`rm`、`mv`、`add_resource`、`session.commit`)的一致性,确保 VikingFS、VectorDB、QueueManager 三个子系统在故障时不会出现数据不一致。 -## 概览 +## 设计哲学 +OpenViking 是上下文数据库,FS 是源数据,VectorDB 是派生索引。索引丢了可从源数据重建,源数据丢失不可恢复。因此: + +> **宁可搜不到,不要搜到坏结果。** + +## 设计原则 + +1. **事务只覆盖同步部分**:FS + VectorDB 操作在事务内;SemanticQueue/EmbeddingQueue 的 enqueue 在事务提交后执行(post_actions),它们是幂等的,失败可重试 +2. **默认生效**:所有数据操作命令自动开启事务机制,用户无需额外配置 +3. **写互斥**:通过路径锁保证同一路径同一时间只有一个写事务 +4. **Undo Log 模型**:变更前记录反向操作,失败时反序执行回滚 +5. **事务日志持久化**:每个事务在 AGFS 中写入 journal 文件,支持崩溃恢复 + +## 架构 + +``` +Service Layer (rm / mv / add_resource / session.commit) + │ + ▼ +┌──[TransactionContext 异步上下文管理器]──┐ +│ │ +│ 1. 创建事务 + 写 journal │ +│ 2. 获取路径锁(轮询 + 超时) │ +│ 3. 执行操作(FS + VectorDB) │ +│ 4. 记录 Undo Log(每步完成后标记) │ +│ 5. Commit / Rollback │ +│ 6. 执行 post_actions(enqueue 等) │ +│ 7. 释放锁 + 清理 journal │ +│ │ +│ 异常时:反序执行 Undo Log → 释放锁 │ +└─────────────────────────────────────────┘ + │ + ▼ +Storage Layer (VikingFS, VectorDB, QueueManager) ``` -操作请求 → TransactionManager → 锁保护 → 执行操作 → 状态更新 - ↓ ↓ ↓ - 事务ID分配和事务状态管理 路径锁校验和加锁 +## 一致性问题与解决方案 + +### rm(uri) + +| 问题 | 方案 | +|------|------| +| 先删文件再删索引 → 文件已删但索引残留 → 搜索返回不存在的文件 | **调换顺序**:先删索引再删文件。索引删除失败 → 文件和索引都在,搜索正常 | + +事务流程: -事务生命周期:开始操作 → 创建事务 → 锁保护生效 → 文件系统同步操作 → 摘要和索引异步操作 → 移除锁保护 → 事务结束 +``` +1. 开始事务,加锁(lock_mode="subtree") +2. 快照 VectorDB 中受影响的记录(用于回滚恢复) +3. 删除 VectorDB 索引 → 搜索立刻不可见 +4. 删除 FS 文件 +5. 提交 → 删锁 → 删 journal ``` -**设计原则**: -1. **最小化锁粒度**:仅支持路径锁机制,不实现复杂的 MVCC 等 -2. **写互斥优先**:暂不实现读锁(共享锁),先承诺写操作的互斥性 -3. **渐进式扩展**:避免过度设计,聚焦核心需求,未来需要时再添加更复杂的锁机制 -4. **默认生效**:所有数据操作命令均开启事务机制,用户无需额外配置 +回滚:第 4 步失败 → 从快照恢复 VectorDB 记录,文件和索引都在。 -## 核心需求分析 +### mv(old_uri, new_uri) -OpenViking 的数据操作命令(如 `add_resource`、`rm`、`mv` 等)存在以下无保护操作问题: +| 问题 | 方案 | +|------|------| +| 文件移到新路径但索引指向旧路径 → 搜索返回旧路径(不存在) | 事务包装,移动失败则回滚 | -1. **并发冲突**:多个用户同时操作同一目录可能导致数据不一致 -2. **无原子性**:`add_resource` 多阶段操作中,某个阶段失败可能留下中间状态 -3. **无可观测性**:操作结果无法预测,用户无法直接观察到正在操作的状态 +事务流程: -## 系统一致性要求 +``` +1. 开始事务,加锁(lock_mode="mv",目录移动时源和目标均 SUBTREE) +2. 移动 FS 文件 +3. 更新 VectorDB 中的 URI +4. 提交 → 删锁 → 删 journal +``` + +回滚:第 3 步失败 → 把文件移回原位。 -从系统分析的角度,OpenViking 要求实现组件间的分布式一致性: +### add_resource -1. **向量索引的最终一致**:所有上下文数据的向量表征依托独立的向量数据库或向量索引实现,要求确保在任何操作序列下,向量表示的更新都能实现最终一致 -2. **文件系统的读写一致性**:所有上下文数据的文件系统表示依托 VikingFS 实现,底层为 AGFS 桥接的分布式文件系统,要求确保在任何操作序列下,文件系统的更新都能保证数据不会损坏或丢失 -3. **队列和异步数据处理的一致性**:所有上下文数据的异步操作依托队列实现,要求确保在任何操作序列下,队列中的数据都能实现最终一致,即队列中的数据会最终被处理,不会丢失或重复 +| 问题 | 方案 | +|------|------| +| 文件从临时目录移到正式目录后崩溃 → 文件存在但永远搜不到 | 首次添加与增量更新分离为两条独立路径 | -## TransactionManager(事务管理器) +首次添加和增量更新是两条独立路径: -TransactionManager 是全局单例,负责管理事务生命周期和锁机制实现。 +**首次添加**(target 不存在)— 在 `ResourceProcessor.process_resource` Phase 3.5 中处理: -### 核心职责 +``` +1. 开始事务,锁 final_uri 的父目录(lock_mode="point") +2. 记录 undo: fs_write_new(uri=dst_path) +3. agfs.mv 临时目录 → 正式位置 +4. 提交 → 删锁 → 删 journal +5. 清理临时目录 +6. 入队 SemanticMsg(uri=final, target_uri=None) → DAG 在 final 上跑,无 callback +``` -- 分配事务ID -- 管理事务生命周期(开始、提交、回滚) -- 提供事务的锁机制实现接口,防止死锁 +崩溃恢复:undo 删除不完整的 dst_path;重新执行 `add_resource` 即可重试。 -### 关键特性 +**增量更新**(target 已存在)— temp 保持不动: ``` -路径锁 + 写互斥 = 并发冲突防护 +1. 入队 SemanticMsg(uri=temp, target_uri=final) → DAG 在 temp 上跑 +2. DAG 完成后触发 sync_diff_callback 或 move_temp_to_target_callback +3. callback 内的每个 VikingFS.rm / VikingFS.mv 各自创建独立事务 ``` -- **路径锁**:锁定目标目录,防止并发的目录级操作如目录删除、目录移动等 -- **写互斥**:同一时间只允许一个事务写操作,路径锁机制确保所有写操作的互斥性 -- **事务结束状态**:事务有明确的结束状态,包括完成、失败丢弃等 +注意:DAG callback 不在外层包裹 TransactionContext。每个 `VikingFS.rm` 和 `VikingFS.mv` 内部各自有独立事务保护。外层锁会与内部锁冲突(如外层 POINT lock on target_path 与内部 `rm` 的 SUBTREE lock 冲突)导致死锁。 + +### session.commit() -### 事务状态机 +| 问题 | 方案 | +|------|------| +| 消息已清空但 archive 未写入 → 对话数据丢失 | Phase 1 无事务(archive 不完整无副作用)+ Phase 2 redo 事务 | + +LLM 调用耗时不可控(5s~60s+),不能放在持锁事务内。设计拆为两个阶段: ``` -INIT → AQUIRE → EXEC → COMMIT/FAIL → RELEASING → RELEASED +Phase 1 — 归档(无事务、无锁): + 1. 生成归档摘要(LLM) + 2. 写 archive(history/archive_N/messages.jsonl + 摘要) + 3. 清空 messages.jsonl + 4. 清空内存中的消息列表 + +Phase 2 — 记忆提取 + 写入(事务,lock_mode="none",redo 语义): + 1. 记录 init_info(archive_uri、session_uri、用户身份信息) + 2. 从归档消息提取 memories(LLM) + 3. 写当前消息状态 + 4. 写 relations + 5. 注册 post_action: enqueue SemanticQueue + 6. 提交 ``` -**状态说明**: -- `INIT`:事务初始化完成,等待锁获取 -- `AQUIRE`:正在获取锁资源 -- `EXEC`:事务操作正在执行 -- `COMMIT/FAIL`:事务执行完成,进入最终状态 -- `RELEASING`:正在释放锁资源 -- `RELEASED`:锁资源已完全释放,事务结束 +**Redo 语义**:Phase 2 不注册 undo log。崩溃恢复时从 archive 重新执行记忆提取和写入(`_redo_session_memory`),而非回滚。 + +**崩溃恢复分析**: -### 事务记录属性 +| 崩溃时间点 | 状态 | 恢复动作 | +|-----------|------|---------| +| Phase 1 写 archive 中途 | 无事务 | archive 不完整,下次 commit 从 history/ 扫描 index,不受影响 | +| Phase 1 archive 完成但 messages 未清空 | 无事务 | archive 完整 + messages 仍在 = 数据冗余但安全 | +| Phase 2 记忆提取/写入中途 | journal EXEC | 启动恢复:`_redo_session_memory` 从 archive 重做提取+写入+入队 | +| Phase 2 commit 后 | journal COMMIT | 启动恢复:重放 `post_action("enqueue_semantic")` | + +## TransactionContext + +`TransactionContext` 是**异步**上下文管理器,封装事务的完整生命周期: ```python -TransactionRecord( - id: str, # 事务ID,采用 uuid 格式,唯一标识一个事务 - locks: List[str], # 锁列表 - status: str, # 当前状态 - init_info: Dict, # 事务初始化信息 - rollback_info: Dict, # 回滚信息 - created_at: float, # 创建时间 - updated_at: float, # 更新时间 -) +from openviking.storage.transaction import TransactionContext, get_transaction_manager + +tx_manager = get_transaction_manager() + +async with TransactionContext(tx_manager, "rm", [path], lock_mode="subtree") as tx: + # 记录 undo(变更前调用) + seq = tx.record_undo("vectordb_delete", {"record_ids": ids, "records_snapshot": snapshot}) + # 执行变更 + delete_from_vector_store(uris) + # 标记完成 + tx.mark_completed(seq) + + # 注册提交后动作(可选) + tx.add_post_action("enqueue_semantic", {"uri": uri, ...}) + + # 提交 + await tx.commit() +# 未 commit 时自动回滚 ``` -### 设计决策 +**锁模式**: -- 暂不实现共享锁(读锁),简化设计 -- 锁粒度仅限目录,不实现范围锁机制 -- 不实现复杂的死锁检测,通过超时机制防止死锁,事务超时后自动释放所有锁 -- 支持可选的自下而上并行加锁模式,提升大型目录树操作的性能和一致性 -- 事务状态机增加AQUIRE+RELEASING状态,明确跟踪锁释放过程,提高系统可观测性 +| lock_mode | 用途 | 行为 | +|-----------|------|------| +| `point` | 写操作 | 锁定指定路径;与同路径的任何锁和祖先目录的 SUBTREE 锁冲突 | +| `subtree` | 删除操作 | 锁定子树根节点;与同路径的任何锁、后代目录的任何锁和祖先目录的 SUBTREE 锁冲突 | +| `mv` | 移动操作 | 目录移动:源和目标均加 SUBTREE 锁;文件移动:源父目录和目标均加 POINT 锁(通过 `src_is_dir` 控制) | +| `none` | 无锁操作 | 跳过锁获取,直接进入 EXEC 状态。用于 session.commit Phase 2 等不需要路径互斥的场景 | -## 锁机制 +## 锁类型(POINT vs SUBTREE) + +锁机制使用两种锁类型来处理不同的冲突场景: + +| | 同路径 POINT | 同路径 SUBTREE | 后代 POINT | 祖先 SUBTREE | +|---|---|---|---|---| +| **POINT** | 冲突 | 冲突 | — | 冲突 | +| **SUBTREE** | 冲突 | 冲突 | 冲突 | 冲突 | + +- **POINT (P)**:用于写操作和语义处理。只锁单个目录。若祖先目录持有 SUBTREE 锁则阻塞。 +- **SUBTREE (S)**:用于删除和移动操作。逻辑上覆盖整个子树,但只在根目录写**一个锁文件**。获取前扫描所有后代和祖先目录确认无冲突锁。 -锁机制是事务管理的核心组件,当前只提供路径锁类型。 +## Undo Log -### 锁类型 +每个事务维护一个 Undo Log,记录每步操作的反向动作: -| 锁类型 | 作用范围 | 用例 | -|--------|----------|------| -| 路径锁 | 整个目录 | 用于阻止目录被意外整体移动或删除,确保事务操作过程的路径合法性 +| op_type | 正向操作 | 回滚动作 | +|---------|---------|---------| +| `fs_mv` | 移动文件 | 移回原位 | +| `fs_rm` | 删除文件 | 跳过(不可逆,设计上 rm 是最后一步) | +| `fs_write_new` | 创建新文件/目录 | 删除 | +| `fs_mkdir` | 创建目录 | 删除 | +| `vectordb_delete` | 删除索引记录 | 从快照恢复 | +| `vectordb_upsert` | 插入索引记录 | 删除 | +| `vectordb_update_uri` | 更新 URI | 恢复旧值 | + +回滚规则:只回滚 `completed=True` 的条目,**反序执行**。每步独立 try-catch(best-effort)。崩溃恢复时使用 `recover_all=True`,也会回滚未完成的条目以清理部分操作残留。 + +### 上下文重建 + +VectorDB 回滚操作需要 `RequestContext`(包含 account_id、user_id、agent_id、role)。由于崩溃恢复时原始上下文不可用,record_undo 时在 undo params 中序列化 `_ctx_*` 字段: + +- `_ctx_account_id`:账户 ID +- `_ctx_user_id`:用户 ID +- `_ctx_agent_id`:代理 ID +- `_ctx_role`:角色 + +回滚时通过 `_reconstruct_ctx()` 从这些字段重建上下文。若重建失败(字段缺失),该 VectorDB 回滚步骤将被跳过并记录警告。 + +## 锁机制 ### 锁协议 +锁文件路径:`{path}/.path.ovlock` + +锁文件内容(Fencing Token): ``` -viking://resources/github/volcengine/OpenViking/.path.ovlock +{transaction_id}:{time_ns}:{lock_type} ``` -- 锁文件存在即表示已加锁 -- 文件内容为事务ID,用于标识当前事务 -- 事务操作完成后,删除锁文件以释放锁 +其中 `lock_type` 为 `P`(POINT)或 `S`(SUBTREE)。 + +### 获取锁流程(POINT 模式) -### 加锁流程 +``` +循环直到超时(轮询间隔:200ms): + 1. 检查目标目录存在 + 2. 检查目标路径是否被其他事务锁定 + - 陈旧锁? → 移除后重试 + - 活跃锁? → 等待 + 3. 检查所有祖先目录是否有 SUBTREE 锁 + - 陈旧锁? → 移除后重试 + - 活跃锁? → 等待 + 4. 写入 POINT (P) 锁文件 + 5. TOCTOU 双重检查:重新扫描祖先目录的 SUBTREE 锁 + - 发现冲突:比较 (timestamp, tx_id) + - 后到者(更大的 timestamp/tx_id)主动让步(删除自己的锁),防止活锁 + - 等待后重试 + 6. 验证锁文件归属(fencing token 匹配) + 7. 成功 + +超时(默认 0 = 不等待)抛出 LockAcquisitionError +``` -#### 普通操作加锁流程 +### 获取锁流程(SUBTREE 模式) ``` -1. 检查目标目录是否存在 -2. 检查目标目录是否已被其他事务锁定 -3. 检查目标目录的父目录是否已被其他事务锁定 -4. 创建 .path.ovlock 文件,文件内容为事务ID -5. 再次检查目标目录的父目录是否已被其他事务锁定 -6. 读取刚创建的 .path.ovlock 文件内容,确认为当前事务ID -7. 一切正常,则返回加锁成功 +循环直到超时(轮询间隔:200ms): + 1. 检查目标目录存在 + 2. 检查目标路径是否被其他事务锁定 + - 陈旧锁? → 移除后重试 + - 活跃锁? → 等待 + 3. 检查所有祖先目录是否有 SUBTREE 锁 + - 陈旧锁? → 移除后重试 + - 活跃锁? → 等待 + 4. 扫描所有后代目录,检查是否有其他事务持有的锁 + - 陈旧锁? → 移除后重试 + - 活跃锁? → 等待 + 5. 写入 SUBTREE (S) 锁文件(只写一个文件,在根路径) + 6. TOCTOU 双重检查:重新扫描后代目录和祖先目录 + - 发现冲突:比较 (timestamp, tx_id) + - 后到者(更大的 timestamp/tx_id)主动让步(删除自己的锁),防止活锁 + - 等待后重试 + 7. 验证锁文件归属(fencing token 匹配) + 8. 成功 + +超时(默认 0 = 不等待)抛出 LockAcquisitionError ``` -#### rm 操作加锁流程 +### 锁过期清理 + +**陈旧锁检测**:PathLock 检查 fencing token 中的时间戳。超过 `lock_expire`(默认 300s)的锁被视为陈旧锁,在加锁过程中自动移除。 + +**事务超时**:TransactionManager 每 60 秒检查活跃事务,`updated_at` 超过事务超时时间(默认 3600s)的事务强制回滚。 + +## 事务日志(Journal) +每个事务在 AGFS 持久化一份 journal: + +``` +/local/_system/transactions/{tx_id}/journal.json ``` -# 传统串行模式:存在更大的竞态条件窗口 -1. 检查目标目录是否存在 -2. 检查目标目录是否已被其他事务锁定 -3. 检查目标目录的父目录是否已被其他事务锁定 -4. 在目标目录下创建 .path.ovlock 文件,文件内容为事务ID -5. 递归地在目标目录的所有子目录下创建 .path.ovlock 文件 -6. 如果发生加锁失败,移除所有已经创建的 .path.ovlock 文件 -7. 一切正常,则返回加锁成功 -# 自下而上并行模式 -1. 并行遍历整个目录树,收集所有子目录路径 -2. 按照目录层级从深到浅排序,从最深层子目录开始 -3. 以有限并行度(默认最大8)批量创建 .path.ovlock 文件 -4. 最后锁定目标目录 -5. 如果任一位置加锁失败,逆序移除所有已经创建的 .path.ovlock 文件 +内容包含:事务 ID、状态、锁路径、init_info、undo_log、post_actions。 + +### 生命周期 + ``` +创建事务 → 写 journal(INIT) +获取锁 → 更新 journal(ACQUIRE → EXEC) +执行变更 → 每步更新 journal(标记 undo entry completed) +提交 → 更新 journal(COMMIT + post_actions) + → 执行 post_actions → 删锁 → 删 journal +回滚 → 执行 undo log → 删锁 → 删 journal +``` + +## 崩溃恢复 + +`TransactionManager.start()` 启动时自动扫描残留 journal: -#### mv 操作加锁流程 +| 崩溃时 journal 状态 | 恢复方式 | +|---------------------|---------| +| `COMMIT` + post_actions 非空 | 重放 post_actions → 删锁 → 删 journal | +| `COMMIT` + post_actions 为空 / `RELEASED` | 删锁 → 删 journal | +| `EXEC` / `FAIL` / `RELEASING`(`session_memory` 操作) | 从 archive 重做记忆提取+写入(`_redo_session_memory`) → 删锁 → 删 journal | +| `EXEC` / `FAIL` / `RELEASING`(所有 undo 均 completed) | 前滚(视为已提交,重放 post_actions) → 删锁 → 删 journal | +| `EXEC` / `FAIL` / `RELEASING`(其他) | 执行 undo log 回滚(`recover_all=True`) → 删锁 → 删 journal | +| `INIT` / `ACQUIRE` | 通过 init_info.lock_paths 清理孤儿锁 → 删 journal(变更未执行) | + +### 防线总结 + +| 异常场景 | 防线 | 恢复时机 | +|---------|------|---------| +| 事务内崩溃 | journal + undo log 回滚 | 重启时 | +| 提交后 enqueue 前崩溃 | journal post_actions 重放 | 重启时 | +| enqueue 后 worker 处理前崩溃 | QueueFS SQLite 持久化 | worker 重启后自动拉取 | +| session.commit Phase 2 中崩溃 | journal + redo(从 archive 重做记忆提取) | 重启时 | +| 孤儿索引 | L2 按需加载时清理 | 用户访问时 | +| 加锁后 journal 更新前崩溃 | init_info 记录预期锁路径,恢复时检查并清理孤儿锁 | 重启时 | + +## 事务状态机 ``` -1. 先参照 rm 操作对原目录进行加锁 -2. 再参照普通操作过程对新目录进行加锁 +INIT → ACQUIRE → EXEC → COMMIT → RELEASING → RELEASED + ↓ + FAIL → RELEASING → RELEASED ``` -### 锁机制性能分析 +- `INIT`:事务已创建,等待锁获取 +- `ACQUIRE`:正在获取锁 +- `EXEC`:事务操作执行中 +- `COMMIT`:已提交,可能有 post_actions 待执行 +- `FAIL`:执行失败,进入回滚 +- `RELEASING`:正在释放锁 +- `RELEASED`:锁已释放,事务结束 + +## 配置 + +事务机制默认启用,无需额外配置。**默认不等待**:若路径被锁定则立即抛出 `LockAcquisitionError`。如需允许等待重试,可通过 `storage.transaction` 段配置: + +```json +{ + "storage": { + "transaction": { + "lock_timeout": 5.0, + "lock_expire": 300.0 + } + } +} +``` + +| 参数 | 类型 | 说明 | 默认值 | +|------|------|------|--------| +| `lock_timeout` | float | 获取锁的等待超时(秒)。`0` = 立即失败(默认);`> 0` = 最多等待此时间 | `0.0` | +| `lock_expire` | float | 锁过期时间(秒),超过此时间的事务锁将被视为陈旧锁并强制释放 | `300.0` | + +### QueueFS 持久化 -- 并行遍历采用广度优先策略,同时处理同一层级的所有目录 -- 并行加锁从最深层开始,逐层向上锁定,确保整个目录树的一致性 -- 有限并行度(默认最大8)避免AGFS服务过载 -- 加锁失败时采用逆序回滚,确保所有已加锁目录都能正确释放 -- 事务状态机明确区分锁管理过程(AQUIRE+RELEASING状态),提高系统可观测性和调试效率 +事务机制依赖 QueueFS 使用 SQLite 后端,确保 enqueue 的任务在进程重启后可恢复。这是默认配置,无需手动设置。 ## 相关文档 - [架构概述](./01-architecture.md) - 系统整体架构 - [存储架构](./05-storage.md) - AGFS 和向量库 -- [会话管理](./08-session.md) - 会话和记忆管理 \ No newline at end of file +- [会话管理](./08-session.md) - 会话和记忆管理 +- [配置](../guides/01-configuration.md) - 配置文件说明 diff --git a/docs/zh/guides/01-configuration.md b/docs/zh/guides/01-configuration.md index 889c737ab..e4befcde9 100644 --- a/docs/zh/guides/01-configuration.md +++ b/docs/zh/guides/01-configuration.md @@ -489,10 +489,9 @@ AST 提取支持:Python、JavaScript/TypeScript、Rust、Go、Java、C/C++。 - #### vectordb -向量库存储的配置 +向量库存储的配置 | 参数 | 类型 | 说明 | 默认值 | |------|------|------|--------| @@ -614,6 +613,30 @@ HTTP 客户端(`SyncHTTPClient` / `AsyncHTTPClient`)和 CLI 工具连接远 启动方式和部署详情见 [服务部署](./03-deployment.md),认证详情见 [认证](./04-authentication.md)。 +## storage.transaction 段 + +事务机制默认启用,通常无需配置。**默认行为是不等待**:若目标路径已被其他事务锁定,操作立即失败并抛出 `LockAcquisitionError`。若需要等待重试,请将 `lock_timeout` 设为正数。 + +```json +{ + "storage": { + "transaction": { + "lock_timeout": 5.0, + "lock_expire": 300.0, + "max_parallel_locks": 8 + } + } +} +``` + +| 参数 | 类型 | 说明 | 默认值 | +|------|------|------|--------| +| `lock_timeout` | float | 获取路径锁的等待超时(秒)。`0` = 立即失败(默认);`> 0` = 最多等待此时间后抛出 `LockAcquisitionError` | `0.0` | +| `lock_expire` | float | 锁过期时间(秒)。超过此时间的事务锁将被视为崩溃进程遗留的陈旧锁并强制释放 | `300.0` | +| `max_parallel_locks` | int | rm/mv 操作递归加锁时的最大并行数 | `8` | + +事务机制的详细说明见 [事务机制](../concepts/09-transaction.md)。 + ## 完整 Schema ```json @@ -648,6 +671,11 @@ HTTP 客户端(`SyncHTTPClient` / `AsyncHTTPClient`)和 CLI 工具连接远 "url": "string", "timeout": 10 }, + "transaction": { + "lock_timeout": 0.0, + "lock_expire": 300.0, + "max_parallel_locks": 8 + }, "vectordb": { "backend": "local|remote", "url": "string", diff --git a/openviking/agfs_manager.py b/openviking/agfs_manager.py index 14ed124ae..9ae796f24 100644 --- a/openviking/agfs_manager.py +++ b/openviking/agfs_manager.py @@ -133,9 +133,23 @@ def _generate_config(self) -> Path: "version": "1.0.0", }, }, + # TODO(multi-node): SQLite backend is single-node only. Each AGFS instance + # gets its own isolated queue.db under its own data_path, so messages + # enqueued on node A are invisible to node B. For multi-node deployments, + # switch backend to "tidb" or "mysql" so all nodes share the same queue. + # + # Additionally, the TiDB backend currently uses immediate soft-delete on + # Dequeue (no two-phase status='processing' transition), meaning there is + # no at-least-once guarantee: a worker crash loses the in-flight message. + # The TiDB backend's Ack() and RecoverStale() are both no-ops and must be + # implemented before it can be used safely in production. "queuefs": { "enabled": True, "path": "/queue", + "config": { + "backend": "sqlite", + "db_path": str(self.data_path / "_system" / "queue" / "queue.db"), + }, }, }, } @@ -196,6 +210,7 @@ def start(self) -> None: self._check_port_available() self.vikingfs_path.mkdir(parents=True, exist_ok=True) + (self.data_path / "_system" / "queue").mkdir(parents=True, exist_ok=True) # NOTICE: should use viking://temp/ instead of self.vikingfs_path / "temp" # Create temp directory for Parser use # (self.vikingfs_path / "temp").mkdir(exist_ok=True) diff --git a/openviking/async_client.py b/openviking/async_client.py index 294d15daf..680b6ee88 100644 --- a/openviking/async_client.py +++ b/openviking/async_client.py @@ -97,6 +97,11 @@ async def reset(cls) -> None: await cls._instance.close() cls._instance = None + # Also reset transaction manager singleton + from openviking.storage.transaction import reset_transaction_manager + + reset_transaction_manager() + # ============= Session methods ============= def session(self, session_id: Optional[str] = None, must_exist: bool = False) -> Session: diff --git a/openviking/eval/ragas/__init__.py b/openviking/eval/ragas/__init__.py index 03336bc73..df2952105 100644 --- a/openviking/eval/ragas/__init__.py +++ b/openviking/eval/ragas/__init__.py @@ -111,8 +111,8 @@ def _create_ragas_llm_from_config() -> Optional[Any]: RAGAS LLM instance or None if VLM is not configured. """ try: - from openai import OpenAI - from ragas.llms import llm_factory + from langchain_openai import ChatOpenAI + from ragas.llms import LangchainLLMWrapper except ImportError: return None @@ -124,11 +124,12 @@ def _create_ragas_llm_from_config() -> Optional[Any]: logger.info(f"Using RAGAS LLM from environment: model={model_name}, base_url={api_base}") - client = OpenAI( + openai_model = ChatOpenAI( + model=model_name, api_key=api_key, base_url=api_base, ) - return llm_factory(model_name, client=client) + return LangchainLLMWrapper(openai_model) try: from openviking_cli.utils.config import get_openviking_config @@ -151,13 +152,13 @@ def _create_ragas_llm_from_config() -> Optional[Any]: ) return None - client = OpenAI( + model_name = vlm_config.model or "gpt-4o-mini" + openai_model = ChatOpenAI( + model=model_name, api_key=vlm_config.api_key, base_url=vlm_config.api_base, ) - - model_name = vlm_config.model or "gpt-4o-mini" - return llm_factory(model_name, client=client) + return LangchainLLMWrapper(openai_model) class RagasEvaluator(BaseEvaluator): diff --git a/openviking/models/embedder/jina_embedders.py b/openviking/models/embedder/jina_embedders.py index 1d8420709..2792ab888 100644 --- a/openviking/models/embedder/jina_embedders.py +++ b/openviking/models/embedder/jina_embedders.py @@ -61,6 +61,7 @@ def __init__( document_param: str = "retrieval.passage", late_chunking: Optional[bool] = None, config: Optional[Dict[str, Any]] = None, + task: Optional[str] = None, ): """Initialize Jina AI Dense Embedder @@ -89,8 +90,11 @@ def __init__( self.api_key = api_key self.api_base = api_base or "https://api.jina.ai/v1" self.dimension = dimension - if context == "query": - self.task: Optional[str] = query_param + # Direct task overrides context-based logic + if task is not None: + self.task: Optional[str] = task + elif context == "query": + self.task = query_param elif context == "document": self.task = document_param else: diff --git a/openviking/models/embedder/openai_embedders.py b/openviking/models/embedder/openai_embedders.py index d0a0fe7f2..c92e9e350 100644 --- a/openviking/models/embedder/openai_embedders.py +++ b/openviking/models/embedder/openai_embedders.py @@ -62,6 +62,7 @@ def __init__( document_param: Optional[str] = None, config: Optional[Dict[str, Any]] = None, max_tokens: Optional[int] = None, + input_type: Optional[str] = None, ): """Initialize OpenAI-Compatible Dense Embedder @@ -104,22 +105,24 @@ def __init__( self.api_base = api_base self.dimension = dimension - # Symmetric by default: only activate input_type if user explicitly sets either value - non_symmetric = query_param is not None or document_param is not None - if not non_symmetric: - self.input_type: Optional[str] = None - elif context == "query": - self.input_type = query_param if query_param is not None else "query" - elif context == "document": - self.input_type = document_param if document_param is not None else "passage" + # Direct input_type overrides context-based logic + if input_type is not None: + self.input_type: Optional[str] = input_type else: - self.input_type = None + # Symmetric by default: only activate input_type if user explicitly sets either value + non_symmetric = query_param is not None or document_param is not None + if not non_symmetric: + self.input_type = None + elif context == "query": + self.input_type = query_param if query_param is not None else "query" + elif context == "document": + self.input_type = document_param if document_param is not None else "passage" + else: + self.input_type = None - if not self.api_key: - raise ValueError("api_key is required") # Allow missing api_key when api_base is set (e.g. local OpenAI-compatible servers) if not self.api_key and not self.api_base: - raise ValueError("api_key is required (or set api_base for local servers)") + raise ValueError("api_key is required") # Initialize OpenAI client # Use a placeholder api_key when not provided (for local OpenAI-compatible servers) diff --git a/openviking/parse/tree_builder.py b/openviking/parse/tree_builder.py index 4260065e4..5f070cee4 100644 --- a/openviking/parse/tree_builder.py +++ b/openviking/parse/tree_builder.py @@ -21,7 +21,7 @@ """ import logging -from typing import TYPE_CHECKING, Optional +from typing import Optional from openviking.core.building_tree import BuildingTree from openviking.core.context import Context @@ -31,9 +31,6 @@ from openviking.utils import parse_code_hosting_url from openviking_cli.utils.uri import VikingURI -if TYPE_CHECKING: - pass - logger = logging.getLogger(__name__) @@ -78,6 +75,31 @@ def _get_base_uri( # Agent scope return "viking://agent" + async def _resolve_unique_uri(self, uri: str, max_attempts: int = 100) -> str: + """Return a URI that does not collide with an existing resource. + + If *uri* is free, return it unchanged. Otherwise append ``_1``, + ``_2``, ... until a free name is found. + """ + viking_fs = get_viking_fs() + + async def _exists(u: str) -> bool: + try: + await viking_fs.stat(u) + return True + except Exception: + return False + + if not await _exists(uri): + return uri + + for i in range(1, max_attempts + 1): + candidate = f"{uri}_{i}" + if not await _exists(candidate): + return candidate + + raise FileExistsError(f"Cannot resolve unique name for {uri} after {max_attempts} attempts") + # ============================================================================ # v5.0 Methods (temporary directory + SemanticQueue architecture) # ============================================================================ @@ -145,7 +167,10 @@ async def finalize_from_temp( raise ValueError(f"Parent URI is not a directory: {parent_uri}") candidate_uri = VikingURI(base_uri).join(final_doc_name).uri - final_uri = candidate_uri + if to_uri: + final_uri = candidate_uri + else: + final_uri = await self._resolve_unique_uri(candidate_uri) tree = BuildingTree( source_path=source_path, diff --git a/openviking/service/core.py b/openviking/service/core.py index b07bdb840..4c5a2670c 100644 --- a/openviking/service/core.py +++ b/openviking/service/core.py @@ -132,16 +132,27 @@ def _init_storage( logger.warning("AGFS client not initialized, skipping queue manager") # Initialize VikingDBManager with QueueManager - self._vikingdb_manager = VikingDBManager(vectordb_config=config.vectordb, queue_manager=self._queue_manager) + self._vikingdb_manager = VikingDBManager( + vectordb_config=config.vectordb, queue_manager=self._queue_manager + ) - # Configure queues if QueueManager is available + # Configure queues if QueueManager is available. + # Workers are NOT started here — start() is called after VikingFS is initialized + # in initialize(), so that recovered tasks don't race against VikingFS init. if self._queue_manager: - self._queue_manager.setup_standard_queues(self._vikingdb_manager) + self._queue_manager.setup_standard_queues(self._vikingdb_manager, start=False) # Initialize TransactionManager (fail-fast if AGFS missing) if self._agfs_client is None: raise RuntimeError("AGFS client not initialized for TransactionManager") - self._transaction_manager = init_transaction_manager(agfs=self._agfs_client) + tx_cfg = config.transaction + self._transaction_manager = init_transaction_manager( + agfs=self._agfs_client, + max_parallel_locks=tx_cfg.max_parallel_locks, + lock_timeout=tx_cfg.lock_timeout, + lock_expire=tx_cfg.lock_expire, + vector_store=self._vikingdb_manager, + ) @property def _agfs(self) -> Any: @@ -255,6 +266,14 @@ async def initialize(self) -> None: if enable_recorder: logger.info("VikingFS IO Recorder enabled") + # Start queue workers now that VikingFS is ready. + # Doing it here (rather than in _init_storage) ensures that any tasks + # recovered from a previous crash are not processed before VikingFS is + # initialized, which would cause "VikingFS not initialized" errors. + if self._queue_manager: + self._queue_manager.start() + logger.info("QueueManager workers started") + # Initialize directories directory_initializer = DirectoryInitializer(vikingdb=self._vikingdb_manager) self._directory_initializer = directory_initializer @@ -306,7 +325,7 @@ async def initialize(self) -> None: async def close(self) -> None: """Close OpenViking and release resources.""" if self._transaction_manager: - self._transaction_manager.stop() + await self._transaction_manager.stop() self._transaction_manager = None if self._vikingdb_manager: diff --git a/openviking/service/debug_service.py b/openviking/service/debug_service.py index 9c3cf39b3..7dffff65c 100644 --- a/openviking/service/debug_service.py +++ b/openviking/service/debug_service.py @@ -138,13 +138,14 @@ def vlm(self) -> ComponentStatus: @property def transaction(self) -> ComponentStatus: """Get transaction status.""" - transaction_manager = get_transaction_manager() - if transaction_manager is None: + try: + transaction_manager = get_transaction_manager() + except Exception: return ComponentStatus( name="transaction", is_healthy=False, has_errors=True, - status="Transaction manager not initialized.", + status="Not initialized", ) observer = TransactionObserver(transaction_manager) return ComponentStatus( diff --git a/openviking/session/session.py b/openviking/session/session.py index 554180272..5726b8e32 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -220,85 +220,17 @@ def update_tool_part( self._update_message_in_jsonl() def commit(self) -> Dict[str, Any]: - """Commit session: create archive, extract memories, persist.""" - result = { - "session_id": self.session_id, - "status": "committed", - "memories_extracted": 0, - "active_count_updated": 0, - "archived": False, - "stats": None, - } - if not self._messages: - get_current_telemetry().set("memory.extracted", 0) - return result - - # 1. Archive current messages - self._compression.compression_index += 1 - messages_to_archive = self._messages.copy() - - summary = self._generate_archive_summary(messages_to_archive) - archive_abstract = self._extract_abstract_from_summary(summary) - archive_overview = summary - - self._write_archive( - index=self._compression.compression_index, - messages=messages_to_archive, - abstract=archive_abstract, - overview=archive_overview, - ) - - self._compression.original_count += len(messages_to_archive) - result["archived"] = True - - self._messages.clear() - logger.info( - f"Archived: {len(messages_to_archive)} messages → history/archive_{self._compression.compression_index:03d}/" - ) - - # 2. Extract long-term memories - if self._session_compressor: - logger.info( - f"Starting memory extraction from {len(messages_to_archive)} archived messages" - ) - memories = run_async( - self._session_compressor.extract_long_term_memories( - messages=messages_to_archive, - user=self.user, - session_id=self.session_id, - ctx=self.ctx, - ) - ) - logger.info(f"Extracted {len(memories)} memories") - result["memories_extracted"] = len(memories) - self._stats.memories_extracted += len(memories) - get_current_telemetry().set("memory.extracted", len(memories)) - - # 3. Write current messages to AGFS - self._write_to_agfs(self._messages) - - # 4. Create relations - self._write_relations() - - # 5. Update active_count - active_count_updated = self._update_active_counts() - result["active_count_updated"] = active_count_updated + """Sync wrapper for commit_async().""" + return run_async(self.commit_async()) - # 6. Update statistics - self._stats.compression_count = self._compression.compression_index - result["stats"] = { - "total_turns": self._stats.total_turns, - "contexts_used": self._stats.contexts_used, - "skills_used": self._stats.skills_used, - "memories_extracted": self._stats.memories_extracted, - } + async def commit_async(self) -> Dict[str, Any]: + """Async commit session: two-phase approach. - self._stats.total_tokens = 0 - logger.info(f"Session {self.session_id} committed") - return result + Phase 1 (Archive, no transaction): Write archive, clear messages. + Phase 2 (Memory, transaction with redo semantics): Extract memories, write, enqueue. + """ + from openviking.storage.transaction import TransactionContext, get_transaction_manager - async def commit_async(self) -> Dict[str, Any]: - """Async commit session: create archive, extract memories, persist.""" result = { "session_id": self.session_id, "status": "committed", @@ -311,7 +243,9 @@ async def commit_async(self) -> Dict[str, Any]: get_current_telemetry().set("memory.extracted", 0) return result - # 1. Archive current messages + tx_manager = get_transaction_manager() + + # ===== Preparation (no transaction) ===== self._compression.compression_index += 1 messages_to_archive = self._messages.copy() @@ -319,48 +253,80 @@ async def commit_async(self) -> Dict[str, Any]: archive_abstract = self._extract_abstract_from_summary(summary) archive_overview = summary + # ===== Phase 1: Archive (no transaction, no lock) ===== + archive_uri = ( + f"{self._session_uri}/history/archive_{self._compression.compression_index:03d}" + ) await self._write_archive_async( index=self._compression.compression_index, messages=messages_to_archive, abstract=archive_abstract, overview=archive_overview, ) + await self._write_to_agfs_async(messages=[]) + self._messages.clear() self._compression.original_count += len(messages_to_archive) result["archived"] = True - - self._messages.clear() logger.info( - f"Archived: {len(messages_to_archive)} messages → history/archive_{self._compression.compression_index:03d}/" + f"Archived: {len(messages_to_archive)} messages → " + f"history/archive_{self._compression.compression_index:03d}/" ) - # 2. Extract long-term memories - if self._session_compressor: - logger.info( - f"Starting memory extraction from {len(messages_to_archive)} archived messages" - ) - memories = await self._session_compressor.extract_long_term_memories( - messages=messages_to_archive, - user=self.user, - session_id=self.session_id, - ctx=self.ctx, + # ===== Phase 2: Memory extraction + write (transaction, redo semantics) ===== + async with TransactionContext( + tx_manager, + "session_memory", + [], + lock_mode="none", + ) as tx: + # Store redo info so _recover_one can redo from archive on crash + tx.record.init_info.update( + { + "archive_uri": archive_uri, + "session_uri": self._session_uri, + "account_id": self.ctx.account_id, + "user_id": self.ctx.user.user_id, + "agent_id": self.ctx.user.agent_id, + "role": self.ctx.role.value, + } ) - logger.info(f"Extracted {len(memories)} memories") - result["memories_extracted"] = len(memories) - self._stats.memories_extracted += len(memories) - get_current_telemetry().set("memory.extracted", len(memories)) - - # 3. Write current messages to AGFS - await self._write_to_agfs_async(self._messages) - # 4. Create relations - await self._write_relations_async() + if self._session_compressor: + logger.info( + f"Starting memory extraction from {len(messages_to_archive)} archived messages" + ) + memories = await self._session_compressor.extract_long_term_memories( + messages=messages_to_archive, + user=self.user, + session_id=self.session_id, + ctx=self.ctx, + ) + logger.info(f"Extracted {len(memories)} memories") + result["memories_extracted"] = len(memories) + self._stats.memories_extracted += len(memories) + get_current_telemetry().set("memory.extracted", len(memories)) + + await self._write_to_agfs_async(self._messages) + await self._write_relations_async() + tx.add_post_action( + "enqueue_semantic", + { + "uri": self._session_uri, + "context_type": "memory", + "account_id": self.ctx.account_id, + "user_id": self.ctx.user.user_id, + "agent_id": self.ctx.user.agent_id, + "role": self.ctx.role.value, + }, + ) + await tx.commit() - # 5. Update active_count + # Update active_count active_count_updated = await self._update_active_counts_async() result["active_count_updated"] = active_count_updated - # 6. Update statistics + # Update statistics self._stats.compression_count = self._compression.compression_index result["stats"] = { "total_turns": self._stats.total_turns, @@ -762,6 +728,56 @@ def _write_relations(self) -> None: except Exception as e: logger.warning(f"Failed to create relation to {usage.uri}: {e}") + def _write_checkpoint(self, data: Dict[str, Any]) -> None: + """Write a commit checkpoint file for crash recovery.""" + if not self._viking_fs: + return + + checkpoint = { + **data, + "session_id": self.session_id, + "compression_index": self._compression.compression_index, + "timestamp": get_current_timestamp(), + } + run_async( + self._viking_fs.write_file( + f"{self._session_uri}/.commit_checkpoint.json", + json.dumps(checkpoint, ensure_ascii=False), + ctx=self.ctx, + ) + ) + + async def _write_checkpoint_async(self, data: Dict[str, Any]) -> None: + """Write a commit checkpoint file for crash recovery (async).""" + if not self._viking_fs: + return + + checkpoint = { + **data, + "session_id": self.session_id, + "compression_index": self._compression.compression_index, + "timestamp": get_current_timestamp(), + } + await self._viking_fs.write_file( + f"{self._session_uri}/.commit_checkpoint.json", + json.dumps(checkpoint, ensure_ascii=False), + ctx=self.ctx, + ) + + def _read_checkpoint(self) -> Optional[Dict[str, Any]]: + """Read commit checkpoint file if it exists.""" + if not self._viking_fs: + return None + try: + content = run_async( + self._viking_fs.read_file( + f"{self._session_uri}/.commit_checkpoint.json", ctx=self.ctx + ) + ) + return json.loads(content) + except Exception: + return None + async def _write_relations_async(self) -> None: """Create relations to used contexts/tools (async).""" if not self._viking_fs: diff --git a/openviking/storage/errors.py b/openviking/storage/errors.py index bc3e36be2..7f6a483b2 100644 --- a/openviking/storage/errors.py +++ b/openviking/storage/errors.py @@ -29,3 +29,15 @@ class ConnectionError(StorageException): class SchemaError(StorageException): """Raised when schema validation fails.""" + + +class TransactionError(VikingDBException): + """Raised when a transaction operation fails.""" + + +class LockAcquisitionError(TransactionError): + """Raised when lock acquisition fails.""" + + +class TransactionRollbackError(TransactionError): + """Raised when transaction rollback fails.""" diff --git a/openviking/storage/local_fs.py b/openviking/storage/local_fs.py index c4683dd72..3d23566dc 100644 --- a/openviking/storage/local_fs.py +++ b/openviking/storage/local_fs.py @@ -11,6 +11,7 @@ from openviking.server.identity import RequestContext from openviking.storage.queuefs import EmbeddingQueue, get_queue_manager from openviking.storage.queuefs.embedding_msg_converter import EmbeddingMsgConverter +from openviking_cli.exceptions import NotFoundError from openviking_cli.utils.logger import get_logger from openviking_cli.utils.uri import VikingURI @@ -178,7 +179,7 @@ async def import_ovpack( f"Resource already exists at {root_uri}. Use force=True to overwrite." ) logger.info(f"[local_fs] Overwriting existing resource at {root_uri}") - except FileNotFoundError: + except NotFoundError: # Path does not exist, safe to import pass @@ -204,9 +205,10 @@ async def import_ovpack( if not zip_path: continue - # Normalize path separators to handle Windows-created ZIPs - zip_path = zip_path.replace("\\", "/") + # Validate before normalization so backslash paths are rejected safe_zip_path = _validate_ovpack_member_path(zip_path, base_name) + # Normalize path separators to handle Windows-created ZIPs + safe_zip_path = safe_zip_path.replace("\\", "/") # Handle directory entries if safe_zip_path.endswith("/"): diff --git a/openviking/storage/observers/transaction_observer.py b/openviking/storage/observers/transaction_observer.py index dce4555d5..e29b76656 100644 --- a/openviking/storage/observers/transaction_observer.py +++ b/openviking/storage/observers/transaction_observer.py @@ -81,7 +81,7 @@ def _format_status_as_table(self, transactions: Dict[str, Any]) -> str: # Group transactions by status status_counts = { TransactionStatus.INIT: 0, - TransactionStatus.AQUIRE: 0, + TransactionStatus.ACQUIRE: 0, TransactionStatus.EXEC: 0, TransactionStatus.COMMIT: 0, TransactionStatus.FAIL: 0, @@ -107,7 +107,7 @@ def _format_status_as_table(self, transactions: Dict[str, Any]) -> str: status_priority = { TransactionStatus.EXEC: 0, - TransactionStatus.AQUIRE: 1, + TransactionStatus.ACQUIRE: 1, TransactionStatus.RELEASING: 2, TransactionStatus.INIT: 3, TransactionStatus.COMMIT: 4, @@ -206,7 +206,7 @@ def get_status_summary(self) -> Dict[str, int]: summary = { "INIT": 0, - "AQUIRE": 0, + "ACQUIRE": 0, "EXEC": 0, "COMMIT": 0, "FAIL": 0, diff --git a/openviking/storage/queuefs/named_queue.py b/openviking/storage/queuefs/named_queue.py index ad79202b0..ce7705b79 100644 --- a/openviking/storage/queuefs/named_queue.py +++ b/openviking/storage/queuefs/named_queue.py @@ -198,6 +198,21 @@ async def enqueue(self, data: Union[str, Dict[str, Any]]) -> str: msg_id = self._agfs.write(enqueue_file, data.encode("utf-8")) return msg_id if isinstance(msg_id, str) else str(msg_id) + async def ack(self, msg_id: str) -> None: + """Acknowledge successful processing of a message (deletes it from persistent storage). + + Must be called after the dequeue handler finishes processing a message. + If not called (e.g. process crashes), the message will be automatically + re-queued on the next startup via RecoverStale. + """ + if not msg_id: + return + ack_file = f"{self.path}/ack" + try: + self._agfs.write(ack_file, msg_id.encode("utf-8")) + except Exception as e: + logger.warning(f"[NamedQueue] Ack failed for {self.name} msg_id={msg_id}: {e}") + def _read_queue_message(self) -> Optional[Dict[str, Any]]: """Read and remove one message from the AGFS queue; return parsed dict or None. @@ -217,15 +232,30 @@ def _read_queue_message(self) -> Optional[Dict[str, Any]]: return json.loads(raw.decode("utf-8")) async def dequeue(self) -> Optional[Dict[str, Any]]: - """Get and remove message from queue, then invoke the dequeue handler.""" + """Dequeue a message, process it, then ack to confirm deletion. + + Flow (at-least-once delivery): + 1. Read from /dequeue → backend marks message as 'processing' (not deleted yet) + 2. Call on_dequeue() → actual processing + 3. Call ack() → backend deletes the message permanently + + If the process crashes between steps 1 and 3, the backend's RecoverStale + on the next startup resets the message back to 'pending' for retry. + """ await self._ensure_initialized() try: data = self._read_queue_message() if data is None: return None + # Capture message ID before passing data to handler (handler may modify it) + msg_id = data.get("id", "") if isinstance(data, dict) else "" if self._dequeue_handler: self._on_dequeue_start() data = await self._dequeue_handler.on_dequeue(data) + # Ack unconditionally after handler returns (success or handled error). + # If on_dequeue raises, the exception propagates and ack is skipped — + # the message will be recovered on next startup. + await self.ack(msg_id) return data except Exception as e: logger.debug(f"[NamedQueue] Dequeue failed for {self.name}: {e}") diff --git a/openviking/storage/queuefs/queue_manager.py b/openviking/storage/queuefs/queue_manager.py index 95e9aeb20..52b42476f 100644 --- a/openviking/storage/queuefs/queue_manager.py +++ b/openviking/storage/queuefs/queue_manager.py @@ -107,16 +107,16 @@ def start(self) -> None: logger.info("[QueueManager] Started") - def setup_standard_queues(self, vector_store: Any) -> None: + def setup_standard_queues(self, vector_store: Any, start: bool = True) -> None: """ Setup standard queues (Embedding and Semantic) with their handlers. - This method initializes the EmbeddingQueue with TextEmbeddingHandler - and the SemanticQueue with SemanticProcessor, then ensures the - queue manager is started. - Args: vector_store: Vector store instance for handlers to write results. + start: Whether to start worker threads immediately (default True). + Pass False when the consumer depends on resources that are + not yet initialized (e.g. VikingFS); call start() manually + after those resources are ready. """ # Import handlers here to avoid circular dependencies from openviking.storage.collection_schemas import TextEmbeddingHandler @@ -140,8 +140,8 @@ def setup_standard_queues(self, vector_store: Any) -> None: ) logger.info("Semantic queue initialized with SemanticProcessor") - # Start QueueManager processing - self.start() + if start: + self.start() def _start_queue_worker(self, queue: NamedQueue) -> None: """Start a dedicated worker thread for a queue if not already running.""" @@ -207,10 +207,14 @@ async def _worker_async_concurrent( async def process_one(data: Dict[str, Any]) -> None: async with sem: + msg_id = data.get("id", "") if isinstance(data, dict) else "" try: await queue.process_dequeued(data) + # Ack after successful processing (delete from persistent storage). + await queue.ack(msg_id) except Exception as e: # Handler did not call report_error; decrement in_progress manually. + # Do NOT ack — let RecoverStale re-queue on next startup. queue._on_process_error(str(e), data) logger.error(f"[QueueManager] Concurrent worker error for {queue.name}: {e}") @@ -241,9 +245,21 @@ async def process_one(data: Dict[str, Any]) -> None: await asyncio.sleep(self._poll_interval) - # Drain remaining in-flight tasks on shutdown + # Drain remaining in-flight tasks on shutdown (with timeout) if active_tasks: - await asyncio.gather(*active_tasks, return_exceptions=True) + try: + await asyncio.wait_for( + asyncio.gather(*active_tasks, return_exceptions=True), + timeout=5.0, + ) + except asyncio.TimeoutError: + logger.warning( + f"[QueueManager] Drain timeout for {queue.name}, " + f"cancelling {len(active_tasks)} in-flight task(s)" + ) + for t in active_tasks: + t.cancel() + await asyncio.gather(*active_tasks, return_exceptions=True) def stop(self) -> None: """Stop QueueManager and release resources.""" @@ -254,8 +270,10 @@ def stop(self) -> None: # Stop queue workers for stop_event in self._queue_stop_events.values(): stop_event.set() - for thread in self._queue_threads.values(): - thread.join() + for name, thread in self._queue_threads.items(): + thread.join(timeout=10.0) + if thread.is_alive(): + logger.warning(f"[QueueManager] Worker thread {name} did not exit in time") self._queue_threads.clear() self._queue_stop_events.clear() @@ -280,9 +298,6 @@ def get_queue( allow_create: bool = False, ) -> NamedQueue: """Get or create a named queue object.""" - if not self._started: - self.start() - if name not in self._queues: if not allow_create: raise RuntimeError(f"Queue {name} does not exist and allow_create is False") diff --git a/openviking/storage/queuefs/semantic_dag.py b/openviking/storage/queuefs/semantic_dag.py index 9ed037b6d..7495a170e 100644 --- a/openviking/storage/queuefs/semantic_dag.py +++ b/openviking/storage/queuefs/semantic_dag.py @@ -505,11 +505,15 @@ def _finalize_children_abstracts(self, node: DirNode) -> List[Dict[str, str]]: return results async def _overview_task(self, dir_uri: str) -> None: + from openviking.storage.errors import LockAcquisitionError + from openviking.storage.transaction import TransactionContext, get_transaction_manager + node = self._nodes.get(dir_uri) if not node: return need_vectorize = True children_changed = True + abstract = "" try: overview = None abstract = None @@ -531,11 +535,22 @@ async def _overview_task(self, dir_uri: str) -> None: ) abstract = self._processor._extract_abstract_from_overview(overview) + dir_path = self._viking_fs._uri_to_path(dir_uri, ctx=self._ctx) try: - await self._viking_fs.write_file(f"{dir_uri}/.overview.md", overview, ctx=self._ctx) - await self._viking_fs.write_file(f"{dir_uri}/.abstract.md", abstract, ctx=self._ctx) - except Exception as e: - logger.warning(f"Failed to write overview/abstract for {dir_uri}: {e}") + # No undo entries recorded: semantic files (.overview.md / .abstract.md) are + # regenerable, so residual writes after a crash are acceptable. + async with TransactionContext( + get_transaction_manager(), "semantic_dag", [dir_path], lock_mode="point" + ) as tx: + await self._viking_fs.write_file( + f"{dir_uri}/.overview.md", overview, ctx=self._ctx + ) + await self._viking_fs.write_file( + f"{dir_uri}/.abstract.md", abstract, ctx=self._ctx + ) + await tx.commit() + except LockAcquisitionError: + logger.info(f"[SemanticDag] {dir_uri} does not exist or is locked, skipping") try: if need_vectorize: @@ -554,7 +569,6 @@ async def _overview_task(self, dir_uri: str) -> None: except Exception as e: logger.error(f"Failed to generate overview for {dir_uri}: {e}", exc_info=True) - abstract = "" finally: self._stats.done_nodes += 1 self._stats.in_progress_nodes = max(0, self._stats.in_progress_nodes - 1) diff --git a/openviking/storage/queuefs/semantic_processor.py b/openviking/storage/queuefs/semantic_processor.py index 60dc06e64..0f5c0642c 100644 --- a/openviking/storage/queuefs/semantic_processor.py +++ b/openviking/storage/queuefs/semantic_processor.py @@ -276,6 +276,70 @@ def get_dag_stats(self) -> Optional["DagStats"]: return None return self._dag_executor.get_stats() + async def _process_single_directory( + self, + uri: str, + context_type: str, + children_uris: List[str], + file_paths: List[str], + ) -> None: + """Process single directory, generate .abstract.md and .overview.md.""" + from openviking.storage.errors import LockAcquisitionError + from openviking.storage.transaction import TransactionContext, get_transaction_manager + + viking_fs = get_viking_fs() + dir_path = viking_fs._uri_to_path(uri, ctx=self._current_ctx) + + try: + # No undo entries recorded: semantic files (.overview.md / .abstract.md) are + # regenerable, so residual writes after a crash are acceptable. + async with TransactionContext( + get_transaction_manager(), "semantic", [dir_path], lock_mode="point" + ) as tx: + # 1. Collect .abstract.md from subdirectories + children_abstracts = await self._collect_children_abstracts(children_uris) + + # 2. Concurrently generate summaries for files in directory + tasks = [ + self._generate_single_file_summary(fp, ctx=self._current_ctx) + for fp in file_paths + ] + file_summaries = await asyncio.gather(*tasks) + + # 3. Generate .overview.md + overview = await self._generate_overview(uri, file_summaries, children_abstracts) + + # 4. Extract abstract from overview + abstract = self._extract_abstract_from_overview(overview) + + # 5. Write files + await viking_fs.write_file(f"{uri}/.overview.md", overview, ctx=self._current_ctx) + await viking_fs.write_file(f"{uri}/.abstract.md", abstract, ctx=self._current_ctx) + + logger.debug(f"Generated overview and abstract for {uri}") + + # 6. Vectorize directory and files concurrently + vectorize_tasks = [ + self._vectorize_directory_simple(uri, context_type, abstract, overview), + *( + self._vectorize_single_file( + parent_uri=uri, + context_type=context_type, + file_path=fp, + summary_dict=summary, + ) + for fp, summary in zip(file_paths, file_summaries) + ), + ] + results = await asyncio.gather(*vectorize_tasks, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + logger.error(f"Vectorization failed: {result}", exc_info=True) + + await tx.commit() + except LockAcquisitionError: + logger.info(f"[SemanticProcessor] {uri} does not exist or is locked, skipping") + async def _process_memory_directory(self, msg: SemanticMsg) -> None: """Process a memory directory with special handling. diff --git a/openviking/storage/transaction/__init__.py b/openviking/storage/transaction/__init__.py index b6c06d6e5..afbc3e1e1 100644 --- a/openviking/storage/transaction/__init__.py +++ b/openviking/storage/transaction/__init__.py @@ -6,22 +6,31 @@ Provides transaction management and lock mechanisms for data operations. """ +from openviking.storage.transaction.context_manager import TransactionContext +from openviking.storage.transaction.journal import TransactionJournal from openviking.storage.transaction.path_lock import PathLock from openviking.storage.transaction.transaction_manager import ( TransactionManager, get_transaction_manager, init_transaction_manager, + reset_transaction_manager, ) from openviking.storage.transaction.transaction_record import ( TransactionRecord, TransactionStatus, ) +from openviking.storage.transaction.undo import UndoEntry, execute_rollback __all__ = [ "PathLock", + "TransactionContext", + "TransactionJournal", "TransactionManager", "TransactionRecord", "TransactionStatus", - "init_transaction_manager", + "UndoEntry", + "execute_rollback", "get_transaction_manager", + "init_transaction_manager", + "reset_transaction_manager", ] diff --git a/openviking/storage/transaction/context_manager.py b/openviking/storage/transaction/context_manager.py new file mode 100644 index 000000000..09697e10f --- /dev/null +++ b/openviking/storage/transaction/context_manager.py @@ -0,0 +1,159 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +""" +Transaction context manager for OpenViking. + +Provides an async context manager that wraps a set of operations in a +transaction with automatic rollback on failure. +""" + +from typing import Any, Dict, List, Optional + +from openviking.storage.errors import LockAcquisitionError, TransactionError +from openviking.storage.transaction.transaction_record import TransactionRecord +from openviking.storage.transaction.undo import UndoEntry +from openviking_cli.utils.logger import get_logger + +logger = get_logger(__name__) + + +class TransactionContext: + """Async context manager for transactional operations. + + Usage:: + + async with TransactionContext(tx_manager, "rm", [path], lock_mode="subtree") as tx: + seq = tx.record_undo("fs_rm", {"uri": uri}) + # ... do work ... + tx.mark_completed(seq) + await tx.commit() + """ + + def __init__( + self, + tx_manager: Any, + operation: str, + lock_paths: List[str], + lock_mode: str = "point", + mv_dst_path: Optional[str] = None, + src_is_dir: bool = True, + ): + self._tx_manager = tx_manager + self._operation = operation + self._lock_paths = lock_paths + self._lock_mode = lock_mode + self._mv_dst_path = mv_dst_path + self._src_is_dir = src_is_dir + self._record: Optional[TransactionRecord] = None + self._committed = False + self._sequence = 0 + + @property + def record(self) -> TransactionRecord: + if self._record is None: + raise TransactionError("Transaction not started") + return self._record + + async def __aenter__(self) -> "TransactionContext": + self._record = self._tx_manager.create_transaction( + init_info={ + "operation": self._operation, + "lock_paths": self._lock_paths, + "lock_mode": self._lock_mode, + "mv_dst_path": self._mv_dst_path, + } + ) + tx_id = self._record.id + + # Write journal BEFORE acquiring locks so that crash recovery can + # find orphan locks via init_info even if the process dies between + # lock creation and journal update. + try: + self._tx_manager.journal.write(self._record.to_journal()) + except Exception as e: + logger.warning(f"[Transaction] Failed to write journal for {tx_id}: {e}") + + success = False + if self._lock_mode == "none": + # No lock acquisition — transition directly to EXEC status + tx = self._tx_manager.get_transaction(tx_id) + if tx: + from openviking.storage.transaction.transaction_record import TransactionStatus + + tx.update_status(TransactionStatus.EXEC) + success = True + elif self._lock_mode == "subtree": + for path in self._lock_paths: + success = await self._tx_manager.acquire_lock_subtree(tx_id, path) + if not success: + break + elif self._lock_mode == "mv": + if len(self._lock_paths) < 1 or not self._mv_dst_path: + raise TransactionError("mv lock mode requires lock_paths[0] and mv_dst_path") + success = await self._tx_manager.acquire_lock_mv( + tx_id, + self._lock_paths[0], + self._mv_dst_path, + src_is_dir=self._src_is_dir, + ) + else: + # "point" mode (default) + for path in self._lock_paths: + success = await self._tx_manager.acquire_lock_point(tx_id, path) + if not success: + break + + if not success: + await self._tx_manager.rollback(tx_id) + raise LockAcquisitionError( + f"Failed to acquire {self._lock_mode} lock for {self._lock_paths}" + ) + + # Update journal with actual lock paths now populated in the record. + try: + self._tx_manager.journal.update(self._record.to_journal()) + except Exception as e: + logger.warning(f"[Transaction] Failed to update journal for {tx_id}: {e}") + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if not self._committed: + try: + await self._tx_manager.rollback(self._record.id) + except Exception as e: + logger.error(f"Rollback failed during __aexit__: {e}") + return False + + def record_undo(self, op_type: str, params: Dict[str, Any]) -> int: + seq = self._sequence + self._sequence += 1 + entry = UndoEntry(sequence=seq, op_type=op_type, params=params) + self.record.undo_log.append(entry) + + try: + self._tx_manager.journal.update(self.record.to_journal()) + except Exception as e: + logger.debug(f"[Transaction] Failed to persist journal: {e}") + + return seq + + def mark_completed(self, sequence: int) -> None: + for entry in self.record.undo_log: + if entry.sequence == sequence: + entry.completed = True + break + + try: + self._tx_manager.journal.update(self.record.to_journal()) + except Exception as e: + logger.debug(f"[Transaction] Failed to persist journal: {e}") + + def add_post_action(self, action_type: str, params: Dict[str, Any]) -> None: + self.record.post_actions.append({"type": action_type, "params": params}) + + async def commit(self) -> None: + success = await self._tx_manager.commit(self._record.id) + if not success: + raise TransactionError(f"Failed to commit transaction {self._record.id}") + self._committed = True diff --git a/openviking/storage/transaction/journal.py b/openviking/storage/transaction/journal.py new file mode 100644 index 000000000..6cb144749 --- /dev/null +++ b/openviking/storage/transaction/journal.py @@ -0,0 +1,113 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +""" +Transaction journal for crash recovery. + +Persists transaction state to AGFS so that incomplete transactions can be +detected and recovered after a process restart. +""" + +import json +from typing import Any, Dict, List + +from openviking.pyagfs import AGFSClient +from openviking_cli.utils.logger import get_logger + +logger = get_logger(__name__) + +# Journal root path (global, not behind VikingFS URI mapping) +_JOURNAL_ROOT = "/local/_system/transactions" + + +class TransactionJournal: + """Persists transaction records to AGFS for crash recovery. + + Journal files live at ``/local/_system/transactions/{tx_id}/journal.json``. + """ + + def __init__(self, agfs: AGFSClient): + self._agfs = agfs + + def _tx_dir(self, tx_id: str) -> str: + return f"{_JOURNAL_ROOT}/{tx_id}" + + def _journal_path(self, tx_id: str) -> str: + return f"{_JOURNAL_ROOT}/{tx_id}/journal.json" + + def _ensure_dir(self, path: str) -> None: + """Create directory, ignoring already-exists errors.""" + try: + self._agfs.mkdir(path) + except Exception as e: + logger.warning(f"[Journal] mkdir {path}: {e}") + + def write(self, data: Dict[str, Any]) -> None: + """Create a new journal entry for a transaction. + + Args: + data: Serialized transaction record (from TransactionRecord.to_journal()). + """ + tx_id = data["id"] + self._ensure_dir("/local/_system") + self._ensure_dir(_JOURNAL_ROOT) + self._ensure_dir(self._tx_dir(tx_id)) + payload = json.dumps(data, ensure_ascii=False, default=str).encode("utf-8") + self._agfs.write(self._journal_path(tx_id), payload) + logger.info(f"[Journal] Written: {self._journal_path(tx_id)}") + + def update(self, data: Dict[str, Any]) -> None: + """Overwrite an existing journal entry. + + Args: + data: Updated serialized transaction record. + """ + tx_id = data["id"] + payload = json.dumps(data, ensure_ascii=False, default=str).encode("utf-8") + self._agfs.write(self._journal_path(tx_id), payload) + + def read(self, tx_id: str) -> Dict[str, Any]: + """Read a journal entry. + + Args: + tx_id: Transaction ID. + + Returns: + Parsed journal data. + + Raises: + FileNotFoundError: If journal does not exist. + """ + content = self._agfs.cat(self._journal_path(tx_id)) + if isinstance(content, bytes): + content = content.decode("utf-8") + return json.loads(content) + + def delete(self, tx_id: str) -> None: + """Delete a transaction's journal directory. + + Args: + tx_id: Transaction ID. + """ + try: + self._agfs.rm(self._tx_dir(tx_id), recursive=True) + logger.debug(f"[Journal] Deleted journal for tx {tx_id}") + except Exception as e: + logger.warning(f"[Journal] Failed to delete journal for tx {tx_id}: {e}") + + def list_all(self) -> List[str]: + """List all transaction IDs that have journal entries. + + Returns: + List of transaction ID strings. + """ + try: + entries = self._agfs.ls(_JOURNAL_ROOT) + tx_ids = [] + if isinstance(entries, list): + for entry in entries: + name = entry.get("name", "") if isinstance(entry, dict) else str(entry) + if name and name not in (".", "..") and entry.get("isDir", True): + tx_ids.append(name) + return tx_ids + except Exception: + return [] diff --git a/openviking/storage/transaction/path_lock.py b/openviking/storage/transaction/path_lock.py index 1a2ad7b0d..5de99743a 100644 --- a/openviking/storage/transaction/path_lock.py +++ b/openviking/storage/transaction/path_lock.py @@ -1,14 +1,6 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -""" -Path lock implementation for transaction management. - -Provides path-based locking mechanism to prevent concurrent directory operations. -Lock protocol: viking://resources/.../.path.ovlock file exists = locked -""" - import asyncio -from typing import List, Optional +import time +from typing import Optional, Tuple from openviking.pyagfs import AGFSClient from openviking.storage.transaction.transaction_record import TransactionRecord @@ -19,312 +11,357 @@ # Lock file name LOCK_FILE_NAME = ".path.ovlock" +# Lock type constants +LOCK_TYPE_POINT = "P" +LOCK_TYPE_SUBTREE = "S" -class PathLock: - """Path lock manager for transaction-based directory locking. +# Default poll interval when waiting for a lock (seconds) +_POLL_INTERVAL = 0.2 - Implements path-based locking using lock files (.path.ovlock) to prevent - concurrent operations on the same directory tree. - """ - def __init__(self, agfs_client: AGFSClient): - """Initialize path lock manager. +def _make_fencing_token(tx_id: str, lock_type: str = LOCK_TYPE_POINT) -> str: + return f"{tx_id}:{time.time_ns()}:{lock_type}" - Args: - agfs_client: AGFS client for file system operations - """ - self._agfs = agfs_client - def _get_lock_path(self, path: str) -> str: - """Get lock file path for a directory. +def _parse_fencing_token(token: str) -> Tuple[str, int, str]: + if token.endswith(f":{LOCK_TYPE_POINT}") or token.endswith(f":{LOCK_TYPE_SUBTREE}"): + lock_type = token[-1] + rest = token[:-2] + idx = rest.rfind(":") + if idx >= 0: + tx_id_part = rest[:idx] + ts_part = rest[idx + 1 :] + try: + return tx_id_part, int(ts_part), lock_type + except ValueError: + pass + return rest, 0, lock_type - Args: - path: Directory path to lock + if ":" in token: + idx = token.rfind(":") + tx_id_part = token[:idx] + ts_part = token[idx + 1 :] + try: + return tx_id_part, int(ts_part), LOCK_TYPE_POINT + except ValueError: + pass + + return token, 0, LOCK_TYPE_POINT - Returns: - Lock file path (path/.path.ovlock) - """ - # Remove trailing slash if present + +class PathLock: + def __init__(self, agfs_client: AGFSClient, lock_expire: float = 300.0): + self._agfs = agfs_client + self._lock_expire = lock_expire + + def _get_lock_path(self, path: str) -> str: path = path.rstrip("/") return f"{path}/{LOCK_FILE_NAME}" def _get_parent_path(self, path: str) -> Optional[str]: - """Get parent directory path. - - Args: - path: Directory path - - Returns: - Parent directory path or None if at root - """ path = path.rstrip("/") if "/" not in path: return None parent = path.rsplit("/", 1)[0] return parent if parent else None - async def _is_locked_by_other(self, lock_path: str, transaction_id: str) -> bool: - """Check if path is locked by another transaction. - - Args: - lock_path: Lock file path - transaction_id: Current transaction ID - - Returns: - True if locked by another transaction, False otherwise - """ + def _read_token(self, lock_path: str) -> Optional[str]: try: - content = self._agfs.cat(lock_path) + content = self._agfs.read(lock_path) if isinstance(content, bytes): - lock_owner = content.decode("utf-8").strip() + token = content.decode("utf-8").strip() else: - lock_owner = str(content).strip() - return lock_owner != transaction_id + token = str(content).strip() + return token if token else None except Exception: - # Lock file doesn't exist or can't be read - not locked - return False + return None - async def _create_lock_file(self, lock_path: str, transaction_id: str) -> None: - """Create lock file with transaction ID. + async def _is_locked_by_other(self, lock_path: str, transaction_id: str) -> bool: + token = self._read_token(lock_path) + if token is None: + return False + lock_owner, _, _ = _parse_fencing_token(token) + return lock_owner != transaction_id - Args: - lock_path: Lock file path - transaction_id: Transaction ID to write to lock file - """ - self._agfs.write(lock_path, transaction_id.encode("utf-8")) + async def _create_lock_file( + self, lock_path: str, transaction_id: str, lock_type: str = LOCK_TYPE_POINT + ) -> None: + token = _make_fencing_token(transaction_id, lock_type) + self._agfs.write(lock_path, token.encode("utf-8")) async def _verify_lock_ownership(self, lock_path: str, transaction_id: str) -> bool: - """Verify lock file is owned by current transaction. - - Args: - lock_path: Lock file path - transaction_id: Current transaction ID + token = self._read_token(lock_path) + if token is None: + return False + lock_owner, _, _ = _parse_fencing_token(token) + return lock_owner == transaction_id - Returns: - True if lock is owned by current transaction, False otherwise - """ + async def _remove_lock_file(self, lock_path: str) -> bool: try: - content = self._agfs.cat(lock_path) - if isinstance(content, bytes): - lock_owner = content.decode("utf-8").strip() - else: - lock_owner = str(content).strip() - return lock_owner == transaction_id - except Exception: + self._agfs.rm(lock_path) + return True + except Exception as e: + if "not found" in str(e).lower(): + return True return False - async def _remove_lock_file(self, lock_path: str) -> None: - """Remove lock file. - - Args: - lock_path: Lock file path - """ + def is_lock_stale(self, lock_path: str, expire_seconds: float = 300.0) -> bool: + token = self._read_token(lock_path) + if token is None: + return True + _, ts, _ = _parse_fencing_token(token) + if ts == 0: + return True + age = (time.time_ns() - ts) / 1e9 + return age > expire_seconds + + async def _check_ancestors_for_subtree(self, path: str, exclude_tx_id: str) -> Optional[str]: + parent = self._get_parent_path(path) + while parent: + lock_path = self._get_lock_path(parent) + token = self._read_token(lock_path) + if token is not None: + owner_id, _, lock_type = _parse_fencing_token(token) + if owner_id != exclude_tx_id and lock_type == LOCK_TYPE_SUBTREE: + return lock_path + parent = self._get_parent_path(parent) + return None + + async def _scan_descendants_for_locks(self, path: str, exclude_tx_id: str) -> Optional[str]: try: - self._agfs.rm(lock_path) - except Exception: - # Lock file might not exist, ignore - pass + entries = self._agfs.ls(path) + if not isinstance(entries, list): + return None + for entry in entries: + if not isinstance(entry, dict): + continue + name = entry.get("name", "") + if not name or name in (".", ".."): + continue + if not entry.get("isDir", False): + continue + subdir = f"{path.rstrip('/')}/{name}" + subdir_lock = self._get_lock_path(subdir) + token = self._read_token(subdir_lock) + if token is not None: + owner_id, _, _ = _parse_fencing_token(token) + if owner_id != exclude_tx_id: + return subdir_lock + result = await self._scan_descendants_for_locks(subdir, exclude_tx_id) + if result: + return result + except Exception as e: + logger.warning(f"Failed to scan descendants of {path}: {e}") + return None - async def acquire_normal(self, path: str, transaction: TransactionRecord) -> bool: - """Acquire path lock for normal operations. - - Lock acquisition flow for normal operations: - 1. Check if target directory exists - 2. Check if target directory is locked by another transaction - 3. Check if parent directory is locked by another transaction - 4. Create .path.ovlock file with transaction ID - 5. Check again if parent directory is locked by another transaction - 6. Read lock file to confirm it contains current transaction ID - 7. Return success if all checks pass - - Args: - path: Directory path to lock - transaction: Transaction record - - Returns: - True if lock acquired successfully, False otherwise - """ + async def acquire_point( + self, path: str, transaction: TransactionRecord, timeout: float = 0.0 + ) -> bool: transaction_id = transaction.id lock_path = self._get_lock_path(path) - parent_path = self._get_parent_path(path) + deadline = asyncio.get_running_loop().time() + timeout - # Step 1: Check if target directory exists try: self._agfs.stat(path) except Exception: - logger.warning(f"Directory does not exist: {path}") - return False - - # Step 2: Check if target directory is locked by another transaction - if await self._is_locked_by_other(lock_path, transaction_id): - logger.warning(f"Path already locked by another transaction: {path}") - return False - - # Step 3: Check if parent directory is locked by another transaction - if parent_path: - parent_lock_path = self._get_lock_path(parent_path) - if await self._is_locked_by_other(parent_lock_path, transaction_id): - logger.warning(f"Parent path locked by another transaction: {parent_path}") - return False - - # Step 4: Create lock file - try: - await self._create_lock_file(lock_path, transaction_id) - except Exception as e: - logger.error(f"Failed to create lock file: {e}") + logger.warning(f"[POINT] Directory does not exist: {path}") return False - # Step 5: Check again if parent directory is locked - if parent_path: - parent_lock_path = self._get_lock_path(parent_path) - if await self._is_locked_by_other(parent_lock_path, transaction_id): - logger.warning(f"Parent path locked after lock creation: {parent_path}") - await self._remove_lock_file(lock_path) + while True: + if await self._is_locked_by_other(lock_path, transaction_id): + if self.is_lock_stale(lock_path, self._lock_expire): + logger.warning(f"[POINT] Removing stale lock: {lock_path}") + await self._remove_lock_file(lock_path) + continue + if asyncio.get_running_loop().time() >= deadline: + logger.warning(f"[POINT] Timeout waiting for lock on: {path}") + return False + await asyncio.sleep(_POLL_INTERVAL) + continue + + ancestor_conflict = await self._check_ancestors_for_subtree(path, transaction_id) + if ancestor_conflict: + if self.is_lock_stale(ancestor_conflict, self._lock_expire): + logger.warning( + f"[POINT] Removing stale ancestor SUBTREE lock: {ancestor_conflict}" + ) + await self._remove_lock_file(ancestor_conflict) + continue + if asyncio.get_running_loop().time() >= deadline: + logger.warning( + f"[POINT] Timeout waiting for ancestor SUBTREE lock: {ancestor_conflict}" + ) + return False + await asyncio.sleep(_POLL_INTERVAL) + continue + + try: + await self._create_lock_file(lock_path, transaction_id, LOCK_TYPE_POINT) + except Exception as e: + logger.error(f"[POINT] Failed to create lock file: {e}") return False - # Step 6: Verify lock ownership - if not await self._verify_lock_ownership(lock_path, transaction_id): - logger.error(f"Lock ownership verification failed: {path}") - return False - - # Step 7: Success - add lock to transaction - transaction.add_lock(lock_path) - logger.debug(f"Lock acquired: {lock_path}") - return True - - async def _collect_subdirectories(self, path: str) -> List[str]: - """Collect all subdirectory paths recursively. - - Args: - path: Root directory path - - Returns: - List of all subdirectory paths - """ - subdirs = [] - try: - entries = self._agfs.ls(path) - if isinstance(entries, list): - for entry in entries: - if isinstance(entry, dict) and entry.get("isDir"): - entry_path = entry.get("name", "") - if entry_path: - subdirs.append(entry_path) - # Recursively collect subdirectories - subdirs.extend(await self._collect_subdirectories(entry_path)) - except Exception as e: - logger.warning(f"Failed to list directory {path}: {e}") - - return subdirs + backed_off = False + conflict_after = await self._check_ancestors_for_subtree(path, transaction_id) + if conflict_after: + their_token = self._read_token(conflict_after) + if their_token: + their_tx_id, their_ts, _ = _parse_fencing_token(their_token) + my_token = self._read_token(lock_path) + _, my_ts, _ = ( + _parse_fencing_token(my_token) if my_token else ("", 0, LOCK_TYPE_POINT) + ) + if (my_ts, transaction_id) > (their_ts, their_tx_id): + logger.debug(f"[POINT] Backing off (livelock guard) on {path}") + await self._remove_lock_file(lock_path) + backed_off = True + if asyncio.get_running_loop().time() >= deadline: + if not backed_off: + await self._remove_lock_file(lock_path) + return False + await asyncio.sleep(_POLL_INTERVAL) + continue + + if not await self._verify_lock_ownership(lock_path, transaction_id): + logger.debug(f"[POINT] Lock ownership verification failed: {path}") + if asyncio.get_running_loop().time() >= deadline: + return False + await asyncio.sleep(_POLL_INTERVAL) + continue + + transaction.add_lock(lock_path) + logger.debug(f"[POINT] Lock acquired: {lock_path}") + return True - async def acquire_rm( - self, path: str, transaction: TransactionRecord, max_parallel: int = 8 + async def acquire_subtree( + self, path: str, transaction: TransactionRecord, timeout: float = 0.0 ) -> bool: - """Acquire path lock for rm operation using bottom-up parallel locking. - - Lock acquisition flow for rm operations (parallel bottom-up mode): - 1. Collect all subdirectory paths recursively - 2. Sort by depth (deepest first) - 3. Create lock files in batches with limited parallelism - 4. Lock the target directory last - 5. If any lock fails, release all acquired locks in reverse order - - Args: - path: Directory path to lock - transaction: Transaction record - max_parallel: Maximum number of parallel lock operations - - Returns: - True if all locks acquired successfully, False otherwise - """ transaction_id = transaction.id lock_path = self._get_lock_path(path) - acquired_locks = [] + deadline = asyncio.get_running_loop().time() + timeout - # Step 1: Collect all subdirectories - subdirs = await self._collect_subdirectories(path) + try: + self._agfs.stat(path) + except Exception: + logger.warning(f"[SUBTREE] Directory does not exist: {path}") + return False - # Step 2: Sort by depth (deepest first) - subdirs.sort(key=lambda p: p.count("/"), reverse=True) + while True: + if await self._is_locked_by_other(lock_path, transaction_id): + if self.is_lock_stale(lock_path, self._lock_expire): + logger.warning(f"[SUBTREE] Removing stale lock: {lock_path}") + await self._remove_lock_file(lock_path) + continue + if asyncio.get_running_loop().time() >= deadline: + logger.warning(f"[SUBTREE] Timeout waiting for lock on: {path}") + return False + await asyncio.sleep(_POLL_INTERVAL) + continue + + # Check ancestor paths for SUBTREE locks held by other transactions + ancestor_conflict = await self._check_ancestors_for_subtree(path, transaction_id) + if ancestor_conflict: + if self.is_lock_stale(ancestor_conflict, self._lock_expire): + logger.warning( + f"[SUBTREE] Removing stale ancestor SUBTREE lock: {ancestor_conflict}" + ) + await self._remove_lock_file(ancestor_conflict) + continue + if asyncio.get_running_loop().time() >= deadline: + logger.warning( + f"[SUBTREE] Timeout waiting for ancestor SUBTREE lock: {ancestor_conflict}" + ) + return False + await asyncio.sleep(_POLL_INTERVAL) + continue + + desc_conflict = await self._scan_descendants_for_locks(path, transaction_id) + if desc_conflict: + if self.is_lock_stale(desc_conflict, self._lock_expire): + logger.warning(f"[SUBTREE] Removing stale descendant lock: {desc_conflict}") + await self._remove_lock_file(desc_conflict) + continue + if asyncio.get_running_loop().time() >= deadline: + logger.warning( + f"[SUBTREE] Timeout waiting for descendant lock: {desc_conflict}" + ) + return False + await asyncio.sleep(_POLL_INTERVAL) + continue + + try: + await self._create_lock_file(lock_path, transaction_id, LOCK_TYPE_SUBTREE) + except Exception as e: + logger.error(f"[SUBTREE] Failed to create lock file: {e}") + return False - # Step 3: Create lock files in batches - try: - # Lock subdirectories in batches - for i in range(0, len(subdirs), max_parallel): - batch = subdirs[i : i + max_parallel] - tasks = [] - for subdir in batch: - subdir_lock_path = self._get_lock_path(subdir) - tasks.append(self._create_lock_file(subdir_lock_path, transaction_id)) - - # Execute batch in parallel - await asyncio.gather(*tasks) - acquired_locks.extend([self._get_lock_path(s) for s in batch]) - - # Step 4: Lock target directory - await self._create_lock_file(lock_path, transaction_id) - acquired_locks.append(lock_path) - - # Add all locks to transaction - for lock in acquired_locks: - transaction.add_lock(lock) - - logger.debug(f"RM locks acquired for {len(acquired_locks)} paths") + backed_off = False + conflict_after = await self._scan_descendants_for_locks(path, transaction_id) + if not conflict_after: + conflict_after = await self._check_ancestors_for_subtree(path, transaction_id) + if conflict_after: + their_token = self._read_token(conflict_after) + if their_token: + their_tx_id, their_ts, _ = _parse_fencing_token(their_token) + my_token = self._read_token(lock_path) + _, my_ts, _ = ( + _parse_fencing_token(my_token) if my_token else ("", 0, LOCK_TYPE_SUBTREE) + ) + if (my_ts, transaction_id) > (their_ts, their_tx_id): + logger.debug(f"[SUBTREE] Backing off (livelock guard) on {path}") + await self._remove_lock_file(lock_path) + backed_off = True + if asyncio.get_running_loop().time() >= deadline: + if not backed_off: + await self._remove_lock_file(lock_path) + return False + await asyncio.sleep(_POLL_INTERVAL) + continue + + if not await self._verify_lock_ownership(lock_path, transaction_id): + logger.debug(f"[SUBTREE] Lock ownership verification failed: {path}") + if asyncio.get_running_loop().time() >= deadline: + return False + await asyncio.sleep(_POLL_INTERVAL) + continue + + transaction.add_lock(lock_path) + logger.debug(f"[SUBTREE] Lock acquired: {lock_path}") return True - except Exception as e: - logger.error(f"Failed to acquire RM locks: {e}") - # Step 5: Release all acquired locks in reverse order - for lock in reversed(acquired_locks): - await self._remove_lock_file(lock) - return False - async def acquire_mv( self, src_path: str, dst_path: str, transaction: TransactionRecord, - max_parallel: int = 8, + timeout: float = 0.0, + src_is_dir: bool = True, ) -> bool: - """Acquire path lock for mv operation. - - Lock acquisition flow for mv operations: - 1. Lock source directory (using RM-style locking) - 2. Lock destination directory (using normal locking) - - Args: - src_path: Source directory path - dst_path: Destination directory path - transaction: Transaction record - max_parallel: Maximum number of parallel lock operations - - Returns: - True if all locks acquired successfully, False otherwise - """ - # Step 1: Lock source directory - if not await self.acquire_rm(src_path, transaction, max_parallel): - logger.warning(f"Failed to lock source path: {src_path}") - return False - - # Step 2: Lock destination directory - if not await self.acquire_normal(dst_path, transaction): - logger.warning(f"Failed to lock destination path: {dst_path}") - # Release source locks - await self.release(transaction) - return False + if src_is_dir: + if not await self.acquire_subtree(src_path, transaction, timeout=timeout): + logger.warning(f"[MV] Failed to acquire SUBTREE lock on source: {src_path}") + return False + if not await self.acquire_subtree(dst_path, transaction, timeout=timeout): + logger.warning(f"[MV] Failed to acquire SUBTREE lock on destination: {dst_path}") + await self.release(transaction) + return False + else: + src_parent = src_path.rsplit("/", 1)[0] if "/" in src_path else src_path + if not await self.acquire_point(src_parent, transaction, timeout=timeout): + logger.warning(f"[MV] Failed to acquire POINT lock on source parent: {src_parent}") + return False + if not await self.acquire_point(dst_path, transaction, timeout=timeout): + logger.warning(f"[MV] Failed to acquire POINT lock on destination: {dst_path}") + await self.release(transaction) + return False - logger.debug(f"MV locks acquired: {src_path} -> {dst_path}") + logger.debug(f"[MV] Locks acquired: {src_path} -> {dst_path}") return True async def release(self, transaction: TransactionRecord) -> None: - """Release all locks held by the transaction. - - Args: - transaction: Transaction record - """ - # Release locks in reverse order (LIFO) + lock_count = len(transaction.locks) for lock_path in reversed(transaction.locks): await self._remove_lock_file(lock_path) transaction.remove_lock(lock_path) - logger.debug(f"Released {len(transaction.locks)} locks for transaction {transaction.id}") + logger.debug(f"Released {lock_count} locks for transaction {transaction.id}") diff --git a/openviking/storage/transaction/transaction_manager.py b/openviking/storage/transaction/transaction_manager.py index da76cde7d..041b84239 100644 --- a/openviking/storage/transaction/transaction_manager.py +++ b/openviking/storage/transaction/transaction_manager.py @@ -9,7 +9,7 @@ import asyncio import threading import time -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from openviking.pyagfs import AGFSClient from openviking.storage.transaction.path_lock import PathLock @@ -34,6 +34,7 @@ class TransactionManager: - Allocating transaction IDs - Managing transaction lifecycle (start, commit, rollback) - Providing transaction lock mechanism interface, preventing deadlocks + - Persisting transaction state to journal for crash recovery """ def __init__( @@ -41,6 +42,9 @@ def __init__( agfs_client: AGFSClient, timeout: int = 3600, max_parallel_locks: int = 8, + lock_timeout: float = 0.0, + lock_expire: float = 300.0, + vector_store: Optional[Any] = None, ): """Initialize transaction manager. @@ -48,11 +52,21 @@ def __init__( agfs_client: AGFS client for file system operations timeout: Transaction timeout in seconds (default: 3600) max_parallel_locks: Maximum number of parallel lock operations (default: 8) + lock_timeout: Path lock acquisition timeout in seconds. + 0 (default) = fail immediately if locked. + > 0 = wait/retry up to this many seconds. + lock_expire: Stale lock expiry threshold in seconds (default: 300s). + vector_store: Optional vector store for VectorDB rollback operations. """ + from openviking.storage.transaction.journal import TransactionJournal + self._agfs = agfs_client self._timeout = timeout self._max_parallel_locks = max_parallel_locks - self._path_lock = PathLock(agfs_client) + self._lock_timeout = lock_timeout + self._vector_store = vector_store + self._path_lock = PathLock(agfs_client, lock_expire=lock_expire) + self._journal = TransactionJournal(agfs_client) # Active transactions: {transaction_id: TransactionRecord} self._transactions: Dict[str, TransactionRecord] = {} @@ -65,10 +79,15 @@ def __init__( f"TransactionManager initialized (timeout={timeout}s, max_parallel_locks={max_parallel_locks})" ) + @property + def journal(self): + return self._journal + async def start(self) -> None: """Start transaction manager. - Starts the background cleanup task for timed-out transactions. + Starts the background cleanup task and recovers any pending transactions + left from a previous process crash. """ if self._running: logger.debug("TransactionManager already running") @@ -76,9 +95,15 @@ async def start(self) -> None: self._running = True self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + + # Recover any transactions that were interrupted by a previous crash. + # Journal entries are written BEFORE lock acquisition, so every orphan + # lock has a corresponding journal entry that recovery can use to clean it up. + await self._recover_pending_transactions() + logger.info("TransactionManager started") - def stop(self) -> None: + async def stop(self) -> None: """Stop transaction manager. Stops the background cleanup task and releases all resources. @@ -92,11 +117,17 @@ def stop(self) -> None: # Cancel cleanup task if self._cleanup_task: self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass self._cleanup_task = None - # Release all active transactions + # Release all active transactions' locks for tx_id in list(self._transactions.keys()): - self._transactions.pop(tx_id, None) + tx = self._transactions.pop(tx_id, None) + if tx: + await self._path_lock.release(tx) logger.info("TransactionManager stopped") @@ -124,6 +155,215 @@ async def _cleanup_timed_out(self) -> None: logger.warning(f"Transaction timed out: {tx_id}") await self.rollback(tx_id) + async def _recover_pending_transactions(self) -> None: + """Recover pending transactions from journal after a crash. + + Reads all journal entries and rolls back any transactions that were + not cleanly committed or rolled back. + """ + try: + pending_ids = self._journal.list_all() + except Exception as e: + logger.warning(f"Failed to list journal entries for recovery: {e}") + return + + if not pending_ids: + return + + logger.info(f"Found {len(pending_ids)} pending transaction(s) to recover") + + for tx_id in pending_ids: + try: + await self._recover_one(tx_id) + except Exception as e: + logger.error(f"Failed to recover transaction {tx_id}: {e}") + + async def _recover_one(self, tx_id: str) -> None: + """Recover a single transaction from journal. + + Recovery strategy by status: + COMMITTED + post_actions → replay post_actions (enqueue etc.), then clean up + COMMITTED, no post_actions / RELEASED → just clean up + EXEC / FAIL / RELEASING → rollback completed+partial ops, then clean up + INIT / ACQUIRE → nothing executed yet, just clean up + """ + from openviking.storage.transaction.undo import execute_rollback + + try: + data = self._journal.read(tx_id) + except Exception as e: + logger.warning(f"Cannot read journal for tx {tx_id}: {e}") + return + + tx = TransactionRecord.from_journal(data) + logger.info(f"Recovering transaction {tx_id} (status={tx.status})") + + if tx.status == TransactionStatus.COMMIT: + # Transaction was committed — replay any unfinished post_actions + if tx.post_actions: + logger.info( + f"Replaying {len(tx.post_actions)} post_action(s) for committed tx {tx_id}" + ) + try: + await self._execute_post_actions(tx.post_actions) + except Exception as e: + logger.warning(f"Post-action replay failed for tx {tx_id}: {e}") + elif tx.status in (TransactionStatus.INIT, TransactionStatus.ACQUIRE): + # Transaction never executed any operations — nothing to rollback. + # However, locks may have been created before the journal was updated + # with the actual locks list. Use init_info.lock_paths to find and + # clean up orphan lock files owned by this transaction. + logger.info(f"Transaction {tx_id} never executed, cleaning up orphan locks") + if not tx.locks: + await self._cleanup_orphan_locks_from_init_info(tx_id, tx.init_info) + else: + # EXEC / FAIL / RELEASING: process crashed mid-operation + operation = tx.init_info.get("operation", "") + if operation == "session_memory": + # Redo: re-extract memories from archive and write + try: + await self._redo_session_memory(tx) + except Exception as e: + logger.warning(f"Redo session_memory failed for tx {tx_id}: {e}") + else: + # Default: rollback completed+partial ops + # Pass recover_all=True so partial (completed=False) ops are also reversed, + # e.g. a directory mv that started but never finished still leaves residue. + try: + await execute_rollback( + tx.undo_log, + self._agfs, + vector_store=self._vector_store, + recover_all=True, + ) + except Exception as e: + logger.warning(f"Rollback during recovery failed for tx {tx_id}: {e}") + + # Release any lock files still present + await self._path_lock.release(tx) + + # Clean up journal + try: + self._journal.delete(tx_id) + except Exception: + pass + + logger.info(f"Recovered transaction {tx_id}") + + async def _cleanup_orphan_locks_from_init_info( + self, tx_id: str, init_info: Dict[str, Any] + ) -> None: + """Clean up orphan lock files using lock path hints from init_info. + + When a crash occurs between lock creation and journal update, the + journal's ``locks`` list is empty but ``init_info.lock_paths`` records + the paths that were intended to be locked. This method checks those + paths and removes any lock files still owned by this transaction. + """ + from openviking.storage.transaction.path_lock import LOCK_FILE_NAME, _parse_fencing_token + + lock_paths = init_info.get("lock_paths", []) + lock_mode = init_info.get("lock_mode", "point") + mv_dst_path = init_info.get("mv_dst_path") + + # Collect all candidate paths to check + paths_to_check = list(lock_paths) + if lock_mode == "mv" and mv_dst_path: + paths_to_check.append(mv_dst_path) + + for path in paths_to_check: + lock_file = f"{path.rstrip('/')}/{LOCK_FILE_NAME}" + try: + token = self._path_lock._read_token(lock_file) + if token is None: + continue + owner_id, _, _ = _parse_fencing_token(token) + if owner_id == tx_id: + await self._path_lock._remove_lock_file(lock_file) + logger.info(f"Removed orphan lock for tx {tx_id}: {lock_file}") + except Exception as e: + logger.warning(f"Failed to check orphan lock {lock_file}: {e}") + + async def _redo_session_memory(self, tx: TransactionRecord) -> None: + """Redo a session_memory transaction from its archived messages. + + On crash during Phase 2 of session commit, we redo memory extraction + from the archive rather than rolling back. + """ + import json + + from openviking.message import Message + from openviking.server.identity import RequestContext, Role + from openviking_cli.session.user_id import UserIdentifier + + archive_uri = tx.init_info.get("archive_uri") + session_uri = tx.init_info.get("session_uri") + account_id = tx.init_info.get("account_id", "default") + user_id = tx.init_info.get("user_id", "default") + agent_id = tx.init_info.get("agent_id", "default") + role_str = tx.init_info.get("role", "root") + + if not archive_uri or not session_uri: + logger.warning("Cannot redo session_memory: missing archive_uri or session_uri") + return + + # 1. Read archived messages from AGFS + messages_path = f"{archive_uri}/messages.jsonl" + try: + agfs_path = messages_path.replace("viking://", "") + content = self._agfs.cat(agfs_path) + if isinstance(content, bytes): + content = content.decode("utf-8") + except Exception as e: + logger.warning(f"Cannot read archive for redo: {messages_path}: {e}") + return + + messages = [] + for line in content.strip().split("\n"): + if line.strip(): + try: + messages.append(Message.from_dict(json.loads(line))) + except Exception: + pass + + if not messages: + logger.warning(f"No messages found in archive for redo: {archive_uri}") + return + + # 2. Build request context for memory extraction + user = UserIdentifier(user_id=user_id, agent_id=agent_id) + ctx = RequestContext(user=user, role=Role(role_str), account_id=account_id) + + # 3. Re-extract memories + from openviking.session.compressor import SessionCompressor + + compressor = SessionCompressor() + session_id = session_uri.rstrip("/").rsplit("/", 1)[-1] + memories = await compressor.extract_long_term_memories( + messages=messages, + user=user, + session_id=session_id, + ctx=ctx, + ) + logger.info(f"Redo: extracted {len(memories)} memories from {archive_uri}") + + # 4. Enqueue semantic processing + await self._execute_post_actions( + [ + { + "type": "enqueue_semantic", + "params": { + "uri": session_uri, + "context_type": "memory", + "account_id": account_id, + "user_id": user_id, + "agent_id": agent_id, + "role": role_str, + }, + } + ] + ) + def create_transaction(self, init_info: Optional[Dict[str, Any]] = None) -> TransactionRecord: """Create a new transaction. @@ -163,13 +403,15 @@ async def begin(self, transaction_id: str) -> bool: logger.error(f"Transaction not found: {transaction_id}") return False - tx.update_status(TransactionStatus.AQUIRE) + tx.update_status(TransactionStatus.ACQUIRE) logger.debug(f"Transaction begun: {transaction_id}") return True async def commit(self, transaction_id: str) -> bool: """Commit a transaction. + Executes post-actions, releases all locks, and removes the journal entry. + Args: transaction_id: Transaction ID @@ -184,6 +426,16 @@ async def commit(self, transaction_id: str) -> bool: # Update status to COMMIT tx.update_status(TransactionStatus.COMMIT) + # Persist final committed state before releasing + try: + self._journal.update(tx.to_journal()) + except Exception: + pass + + # Execute post-actions (best-effort, errors are logged but don't fail commit) + if tx.post_actions: + await self._execute_post_actions(tx.post_actions) + # Release all locks tx.update_status(TransactionStatus.RELEASING) await self._path_lock.release(tx) @@ -194,18 +446,29 @@ async def commit(self, transaction_id: str) -> bool: # Remove from active transactions self._transactions.pop(transaction_id, None) + # Clean up journal entry (last step — lock is already released) + try: + self._journal.delete(transaction_id) + except Exception as e: + logger.warning(f"Failed to delete journal on commit for {transaction_id}: {e}") + logger.debug(f"Transaction committed: {transaction_id}") return True async def rollback(self, transaction_id: str) -> bool: """Rollback a transaction. + Executes undo log entries in reverse order, releases all locks, + and removes the journal entry. + Args: transaction_id: Transaction ID Returns: True if transaction rolled back successfully, False otherwise """ + from openviking.storage.transaction.undo import execute_rollback + tx = self.get_transaction(transaction_id) if not tx: logger.error(f"Transaction not found: {transaction_id}") @@ -214,6 +477,25 @@ async def rollback(self, transaction_id: str) -> bool: # Update status to FAIL tx.update_status(TransactionStatus.FAIL) + # Persist rollback state + try: + self._journal.update(tx.to_journal()) + except Exception: + pass + + # Execute undo log (best-effort) + if tx.undo_log: + try: + await execute_rollback( + tx.undo_log, + self._agfs, + vector_store=self._vector_store, + ) + except Exception as e: + logger.warning( + f"Undo log execution failed during rollback of {transaction_id}: {e}" + ) + # Release all locks tx.update_status(TransactionStatus.RELEASING) await self._path_lock.release(tx) @@ -224,11 +506,67 @@ async def rollback(self, transaction_id: str) -> bool: # Remove from active transactions self._transactions.pop(transaction_id, None) + # Clean up journal entry (last step — lock is already released) + try: + self._journal.delete(transaction_id) + except Exception as e: + logger.warning(f"Failed to delete journal on rollback for {transaction_id}: {e}") + logger.debug(f"Transaction rolled back: {transaction_id}") return True - async def acquire_lock_normal(self, transaction_id: str, path: str) -> bool: - """Acquire path lock for normal (non-rm/mv) operations. + async def _execute_post_actions(self, post_actions: List[Dict[str, Any]]) -> None: + """Execute post-commit actions. + + Post-actions are executed after a successful commit. Errors are logged + but do not affect the commit outcome. + + Args: + post_actions: List of post-action dicts with 'type' and 'params' keys + """ + for action in post_actions: + action_type = action.get("type", "") + params = action.get("params", {}) + try: + if action_type == "enqueue_semantic": + await self._post_enqueue_semantic(params) + else: + logger.warning(f"Unknown post-action type: {action_type}") + except Exception as e: + logger.warning(f"Post-action '{action_type}' failed: {e}") + + async def _post_enqueue_semantic(self, params: Dict[str, Any]) -> None: + """Execute enqueue_semantic post-action.""" + from openviking.storage.queuefs import get_queue_manager + from openviking.storage.queuefs.semantic_msg import SemanticMsg + + queue_manager = get_queue_manager() + if queue_manager is None: + logger.debug("No queue manager available, skipping enqueue_semantic post-action") + return + + uri = params.get("uri") + context_type = params.get("context_type", "resource") + account_id = params.get("account_id", "default") + user_id = params.get("user_id", "default") + agent_id = params.get("agent_id", "default") + role = params.get("role", "root") + if not uri: + return + + msg = SemanticMsg( + uri=uri, + context_type=context_type, + account_id=account_id, + user_id=user_id, + agent_id=agent_id, + role=role, + ) + semantic_queue = queue_manager.get_queue(queue_manager.SEMANTIC) + await semantic_queue.enqueue(msg) + + async def acquire_lock_point(self, transaction_id: str, path: str) -> bool: + """Acquire POINT lock for write/semantic-processing operations. Args: transaction_id: Transaction ID @@ -242,8 +580,8 @@ async def acquire_lock_normal(self, transaction_id: str, path: str) -> bool: logger.error(f"Transaction not found: {transaction_id}") return False - tx.update_status(TransactionStatus.AQUIRE) - success = await self._path_lock.acquire_normal(path, tx) + tx.update_status(TransactionStatus.ACQUIRE) + success = await self._path_lock.acquire_point(path, tx, timeout=self._lock_timeout) if success: tx.update_status(TransactionStatus.EXEC) @@ -252,15 +590,15 @@ async def acquire_lock_normal(self, transaction_id: str, path: str) -> bool: return success - async def acquire_lock_rm( - self, transaction_id: str, path: str, max_parallel: Optional[int] = None + async def acquire_lock_subtree( + self, transaction_id: str, path: str, timeout: Optional[float] = None ) -> bool: - """Acquire path lock for rm operation. + """Acquire SUBTREE lock for rm/mv-source operations. Args: transaction_id: Transaction ID - path: Directory path to lock - max_parallel: Maximum number of parallel lock operations (default: from config) + path: Directory path to lock (root of the subtree) + timeout: Maximum time to wait for the lock in seconds (default: from config) Returns: True if lock acquired successfully, False otherwise @@ -270,9 +608,9 @@ async def acquire_lock_rm( logger.error(f"Transaction not found: {transaction_id}") return False - tx.update_status(TransactionStatus.AQUIRE) - parallel = max_parallel or self._max_parallel_locks - success = await self._path_lock.acquire_rm(path, tx, parallel) + tx.update_status(TransactionStatus.ACQUIRE) + effective_timeout = timeout if timeout is not None else self._lock_timeout + success = await self._path_lock.acquire_subtree(path, tx, timeout=effective_timeout) if success: tx.update_status(TransactionStatus.EXEC) @@ -286,15 +624,17 @@ async def acquire_lock_mv( transaction_id: str, src_path: str, dst_path: str, - max_parallel: Optional[int] = None, + timeout: Optional[float] = None, + src_is_dir: bool = True, ) -> bool: """Acquire path lock for mv operation. Args: transaction_id: Transaction ID - src_path: Source directory path - dst_path: Destination directory path - max_parallel: Maximum number of parallel lock operations (default: from config) + src_path: Source path + dst_path: Destination parent directory path + timeout: Maximum time to wait for each lock in seconds (default: from config) + src_is_dir: Whether the source is a directory Returns: True if lock acquired successfully, False otherwise @@ -304,9 +644,11 @@ async def acquire_lock_mv( logger.error(f"Transaction not found: {transaction_id}") return False - tx.update_status(TransactionStatus.AQUIRE) - parallel = max_parallel or self._max_parallel_locks - success = await self._path_lock.acquire_mv(src_path, dst_path, tx, parallel) + tx.update_status(TransactionStatus.ACQUIRE) + effective_timeout = timeout if timeout is not None else self._lock_timeout + success = await self._path_lock.acquire_mv( + src_path, dst_path, tx, timeout=effective_timeout, src_is_dir=src_is_dir + ) if success: tx.update_status(TransactionStatus.EXEC) @@ -336,6 +678,9 @@ def init_transaction_manager( agfs: AGFSClient, tx_timeout: int = 3600, max_parallel_locks: int = 8, + lock_timeout: float = 0.0, + lock_expire: float = 300.0, + vector_store: Optional[Any] = None, ) -> TransactionManager: """Initialize transaction manager singleton. @@ -343,6 +688,11 @@ def init_transaction_manager( agfs: AGFS client instance tx_timeout: Transaction timeout in seconds (default: 3600) max_parallel_locks: Maximum number of parallel lock operations (default: 8) + lock_timeout: Path lock acquisition timeout in seconds. + 0 (default) = fail immediately if locked. + > 0 = wait/retry up to this many seconds. + lock_expire: Stale lock expiry threshold in seconds (default: 300s). + vector_store: Optional vector store for VectorDB rollback operations. Returns: TransactionManager instance @@ -359,16 +709,31 @@ def init_transaction_manager( agfs_client=agfs, timeout=tx_timeout, max_parallel_locks=max_parallel_locks, + lock_timeout=lock_timeout, + lock_expire=lock_expire, + vector_store=vector_store, ) logger.info("TransactionManager initialized as singleton") return _transaction_manager -def get_transaction_manager() -> Optional[TransactionManager]: - """Get transaction manager singleton. +def get_transaction_manager() -> TransactionManager: + """Get transaction manager singleton.""" + if _transaction_manager is None: + raise RuntimeError( + "TransactionManager not initialized. Call init_transaction_manager() first." + ) + return _transaction_manager - Returns: - TransactionManager instance or None if not initialized + +def reset_transaction_manager() -> None: + """Reset the transaction manager singleton (for testing). + + This function should ONLY be used in tests to clean up state between tests. + It clears the global singleton instance without performing cleanup - make sure + to call stop() first if the manager is still running. """ - return _transaction_manager + global _transaction_manager + with _lock: + _transaction_manager = None diff --git a/openviking/storage/transaction/transaction_record.py b/openviking/storage/transaction/transaction_record.py index fba6480b9..b9eb0656f 100644 --- a/openviking/storage/transaction/transaction_record.py +++ b/openviking/storage/transaction/transaction_record.py @@ -16,11 +16,11 @@ class TransactionStatus(str, Enum): """Transaction status enumeration. - Status machine: INIT -> AQUIRE -> EXEC -> COMMIT/FAIL -> RELEASING -> RELEASED + Status machine: INIT -> ACQUIRE -> EXEC -> COMMIT/FAIL -> RELEASING -> RELEASED """ INIT = "INIT" # Transaction initialized, waiting for lock acquisition - AQUIRE = "AQUIRE" # Acquiring lock resources + ACQUIRE = "ACQUIRE" # Acquiring lock resources EXEC = "EXEC" # Transaction operation in progress COMMIT = "COMMIT" # Transaction completed successfully FAIL = "FAIL" # Transaction failed @@ -41,6 +41,8 @@ class TransactionRecord: status: Current transaction status init_info: Transaction initialization information rollback_info: Information for rollback operations + undo_log: List of undo entries for rollback + post_actions: Actions to execute after successful commit created_at: Creation timestamp (Unix timestamp in seconds) updated_at: Last update timestamp (Unix timestamp in seconds) """ @@ -50,44 +52,30 @@ class TransactionRecord: status: TransactionStatus = field(default=TransactionStatus.INIT) init_info: Dict[str, Any] = field(default_factory=dict) rollback_info: Dict[str, Any] = field(default_factory=dict) + undo_log: List[Any] = field(default_factory=list) + post_actions: List[Dict[str, Any]] = field(default_factory=list) created_at: float = field(default_factory=time.time) updated_at: float = field(default_factory=time.time) def update_status(self, status: TransactionStatus) -> None: - """Update transaction status and timestamp. - - Args: - status: New transaction statusudi - """ + """Update transaction status and timestamp.""" self.status = status self.updated_at = time.time() def add_lock(self, lock_path: str) -> None: - """Add a lock to the transaction. - - Args: - lock_path: Path to be locked - """ + """Add a lock to the transaction.""" if lock_path not in self.locks: self.locks.append(lock_path) self.updated_at = time.time() def remove_lock(self, lock_path: str) -> None: - """Remove a lock from the transaction. - - Args: - lock_path: Path to be unlocked - """ + """Remove a lock from the transaction.""" if lock_path in self.locks: self.locks.remove(lock_path) self.updated_at = time.time() def to_dict(self) -> Dict[str, Any]: - """Convert transaction record to dictionary. - - Returns: - Dictionary representation of the transaction record - """ + """Convert transaction record to dictionary.""" return { "id": self.id, "locks": self.locks, @@ -98,16 +86,45 @@ def to_dict(self) -> Dict[str, Any]: "updated_at": self.updated_at, } + def to_journal(self) -> Dict[str, Any]: + """Serialize to journal format (includes undo_log and post_actions).""" + from openviking.storage.transaction.undo import UndoEntry + + return { + "id": self.id, + "locks": self.locks, + "status": str(self.status), + "init_info": self.init_info, + "undo_log": [e.to_dict() if isinstance(e, UndoEntry) else e for e in self.undo_log], + "post_actions": self.post_actions, + "created_at": self.created_at, + "updated_at": self.updated_at, + } + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "TransactionRecord": - """Create transaction record from dictionary. + def from_journal(cls, data: Dict[str, Any]) -> "TransactionRecord": + """Restore from journal format.""" + from openviking.storage.transaction.undo import UndoEntry + + status_str = data.get("status", "INIT") + status = TransactionStatus(status_str) if isinstance(status_str, str) else status_str + undo_log = [UndoEntry.from_dict(e) for e in data.get("undo_log", [])] - Args: - data: Dictionary representation of the transaction record + return cls( + id=data.get("id", str(uuid.uuid4())), + locks=data.get("locks", []), + status=status, + init_info=data.get("init_info", {}), + rollback_info={}, + undo_log=undo_log, + post_actions=data.get("post_actions", []), + created_at=data.get("created_at", time.time()), + updated_at=data.get("updated_at", time.time()), + ) - Returns: - TransactionRecord instance - """ + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TransactionRecord": + """Create transaction record from dictionary.""" status_str = data.get("status", "INIT") status = TransactionStatus(status_str) if isinstance(status_str, str) else status_str diff --git a/openviking/storage/transaction/undo.py b/openviking/storage/transaction/undo.py new file mode 100644 index 000000000..0b5b31130 --- /dev/null +++ b/openviking/storage/transaction/undo.py @@ -0,0 +1,178 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +""" +Undo log and rollback executor for transaction management. + +Records operations performed within a transaction so they can be reversed +on rollback. Each UndoEntry captures one atomic sub-operation. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from openviking_cli.utils.logger import get_logger + +logger = get_logger(__name__) + + +def _reconstruct_ctx(params: Dict[str, Any]) -> Optional[Any]: + """Reconstruct a RequestContext from serialized _ctx_* fields in undo params. + + Returns None if the required fields are missing. + """ + account_id = params.get("_ctx_account_id") + user_id = params.get("_ctx_user_id") + agent_id = params.get("_ctx_agent_id") + role_value = params.get("_ctx_role") + if account_id is None or user_id is None: + return None + try: + from openviking.server.identity import RequestContext, Role + from openviking_cli.session.user_id import UserIdentifier + + role = Role(role_value) if role_value in {r.value for r in Role} else Role.ROOT + user = UserIdentifier(account_id, user_id, agent_id or "default") + return RequestContext(user=user, role=role) + except Exception as e: + logger.warning(f"[Rollback] Failed to reconstruct ctx: {e}") + return None + + +@dataclass +class UndoEntry: + """A single undo log entry representing one reversible sub-operation. + + Attributes: + sequence: Monotonically increasing index within the transaction. + op_type: Operation type (fs_mv, fs_rm, fs_mkdir, fs_write_new, + vectordb_upsert, vectordb_delete, vectordb_update_uri). + params: Parameters needed to reverse the operation. + completed: Whether the forward operation completed successfully. + """ + + sequence: int + op_type: str + params: Dict[str, Any] = field(default_factory=dict) + completed: bool = False + + def to_dict(self) -> Dict[str, Any]: + return { + "sequence": self.sequence, + "op_type": self.op_type, + "params": self.params, + "completed": self.completed, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "UndoEntry": + return cls( + sequence=data.get("sequence", 0), + op_type=data.get("op_type", ""), + params=data.get("params", {}), + completed=data.get("completed", False), + ) + + +async def execute_rollback( + undo_log: List[UndoEntry], + agfs: Any, + vector_store: Optional[Any] = None, + ctx: Optional[Any] = None, + recover_all: bool = False, +) -> None: + """Execute rollback by reversing operations in reverse order. + + Best-effort: each step is wrapped in try-except so a single failure + does not prevent subsequent undo steps from running. + + Args: + undo_log: List of undo entries to process. + agfs: AGFS client for filesystem operations. + vector_store: Optional vector store client. + ctx: Optional request context. + recover_all: If True, also attempt to reverse entries that were not + marked completed (used during crash recovery to clean up partial + operations such as a directory mv that only half-finished). + """ + if recover_all: + entries = list(undo_log) + else: + entries = [e for e in undo_log if e.completed] + entries.sort(key=lambda e: e.sequence, reverse=True) + + for entry in entries: + try: + await _rollback_entry(entry, agfs, vector_store, ctx) + logger.info(f"[Rollback] Reversed {entry.op_type} seq={entry.sequence}") + except Exception as e: + logger.warning( + f"[Rollback] Failed to reverse {entry.op_type} seq={entry.sequence}: {e}" + ) + + +async def _rollback_entry( + entry: UndoEntry, + agfs: Any, + vector_store: Optional[Any], + ctx: Optional[Any], +) -> None: + """Dispatch rollback for a single undo entry.""" + op = entry.op_type + params = entry.params + + if op == "fs_mv": + agfs.mv(params["dst"], params["src"]) + + elif op == "fs_rm": + logger.debug("[Rollback] fs_rm is not reversible, skipping") + + elif op == "fs_mkdir": + try: + agfs.rm(params["uri"]) + except Exception: + pass + + elif op == "fs_write_new": + try: + agfs.rm(params["uri"], recursive=True) + except Exception: + pass + + elif op == "vectordb_upsert": + if vector_store: + record_id = params.get("record_id") + if record_id: + restored_ctx = _reconstruct_ctx(params) + if restored_ctx: + await vector_store.delete([record_id], ctx=restored_ctx) + else: + logger.warning("[Rollback] vectordb_upsert: cannot reconstruct ctx, skipping") + + elif op == "vectordb_delete": + if vector_store: + restored_ctx = _reconstruct_ctx(params) + if restored_ctx is None: + logger.warning("[Rollback] vectordb_delete: cannot reconstruct ctx, skipping") + else: + records_snapshot = params.get("records_snapshot", []) + for record in records_snapshot: + try: + await vector_store.upsert(record, ctx=restored_ctx) + except Exception as e: + logger.warning(f"[Rollback] Failed to restore vector record: {e}") + + elif op == "vectordb_update_uri": + if vector_store: + restored_ctx = _reconstruct_ctx(params) + if restored_ctx is None: + logger.warning("[Rollback] vectordb_update_uri: cannot reconstruct ctx, skipping") + else: + await vector_store.update_uri_mapping( + ctx=restored_ctx, + uri=params["new_uri"], + new_uri=params["old_uri"], + new_parent_uri=params.get("old_parent_uri", ""), + ) + + else: + logger.warning(f"[Rollback] Unknown op_type: {op}") diff --git a/openviking/storage/viking_fs.py b/openviking/storage/viking_fs.py index 7c45f5441..dc20acd17 100644 --- a/openviking/storage/viking_fs.py +++ b/openviking/storage/viking_fs.py @@ -22,7 +22,6 @@ from pathlib import PurePath from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from openviking.pyagfs.exceptions import AGFSHTTPError from openviking.server.identity import RequestContext, Role from openviking.telemetry import get_current_telemetry from openviking.utils.time_utils import format_simplified, get_current_timestamp, parse_iso_datetime @@ -289,15 +288,69 @@ async def rm( This method is idempotent: deleting a non-existent file succeeds after cleaning up any orphan index records. + + Wrapped in a transaction: deletes VectorDB records first, then FS files. + On rollback, VectorDB records are restored from snapshot. """ + from openviking.storage.transaction import TransactionContext, get_transaction_manager + self._ensure_access(uri, ctx) path = self._uri_to_path(uri, ctx=ctx) target_uri = self._path_to_uri(path, ctx=ctx) - uris_to_delete = await self._collect_uris(path, recursive, ctx=ctx) - uris_to_delete.append(target_uri) - result = self.agfs.rm(path, recursive=recursive) - await self._delete_from_vector_store(uris_to_delete, ctx=ctx) - return result + + tx_manager = get_transaction_manager() + + # Check existence and determine lock strategy + try: + stat = self.agfs.stat(path) + is_dir = stat.get("isDir", False) if isinstance(stat, dict) else False + except Exception: + # Path does not exist: clean up any orphan index records and return + uris_to_delete = await self._collect_uris(path, recursive, ctx=ctx) + uris_to_delete.append(target_uri) + await self._delete_from_vector_store(uris_to_delete, ctx=ctx) + logger.info(f"[VikingFS] rm target not found, cleaned orphan index: {uri}") + return {} + + if is_dir: + lock_paths = [path] + lock_mode = "subtree" + else: + parent = path.rsplit("/", 1)[0] if "/" in path else path + lock_paths = [parent] + lock_mode = "point" + + async with TransactionContext(tx_manager, "rm", lock_paths, lock_mode=lock_mode) as tx: + # Collect URIs inside the lock to avoid race conditions + uris_to_delete = await self._collect_uris(path, recursive, ctx=ctx) + uris_to_delete.append(target_uri) + + # Snapshot vector records for rollback + records_snapshot = await self._snapshot_vector_records(uris_to_delete, ctx=ctx) + + # Step 1: Delete from VectorDB first + real_ctx = self._ctx_or_default(ctx) + seq_vdb = tx.record_undo( + "vectordb_delete", + { + "uris": uris_to_delete, + "records_snapshot": records_snapshot, + "_ctx_account_id": real_ctx.account_id, + "_ctx_user_id": real_ctx.user.user_id, + "_ctx_agent_id": real_ctx.user.agent_id, + "_ctx_role": real_ctx.role.value, + }, + ) + await self._delete_from_vector_store(uris_to_delete, ctx=ctx) + tx.mark_completed(seq_vdb) + + # Step 2: Delete from FS + seq_fs = tx.record_undo("fs_rm", {"uri": path, "recursive": recursive}) + result = self.agfs.rm(path, recursive=recursive) + tx.mark_completed(seq_fs) + + await tx.commit() + return result async def mv( self, @@ -305,24 +358,91 @@ async def mv( new_uri: str, ctx: Optional[RequestContext] = None, ) -> Dict[str, Any]: - """Move file/directory + recursively update vector index.""" + """Move file/directory + recursively update vector index. + + Implemented as cp + rm to avoid lock files being carried by FS mv. + On rollback, the copy is deleted and the source remains intact. + """ + from openviking.pyagfs.helpers import cp as agfs_cp + from openviking.storage.transaction import TransactionContext, get_transaction_manager + self._ensure_access(old_uri, ctx) self._ensure_access(new_uri, ctx) old_path = self._uri_to_path(old_uri, ctx=ctx) new_path = self._uri_to_path(new_uri, ctx=ctx) target_uri = self._path_to_uri(old_path, ctx=ctx) - uris_to_move = await self._collect_uris(old_path, recursive=True, ctx=ctx) - uris_to_move.append(target_uri) + tx_manager = get_transaction_manager() + + # Verify source exists and determine type before locking try: - result = self.agfs.mv(old_path, new_path) + stat = self.agfs.stat(old_path) + is_dir = stat.get("isDir", False) if isinstance(stat, dict) else False + except Exception: + raise FileNotFoundError(f"mv source not found: {old_uri}") + + dst_parent = new_path.rsplit("/", 1)[0] if "/" in new_path else new_path + + async with TransactionContext( + tx_manager, + "mv", + [old_path], + lock_mode="mv", + mv_dst_path=dst_parent, + src_is_dir=is_dir, + ) as tx: + # Collect URIs inside the lock to avoid race conditions + uris_to_move = await self._collect_uris(old_path, recursive=True, ctx=ctx) + uris_to_move.append(target_uri) + + # Step 1: Copy source to destination + seq_cp = tx.record_undo("fs_write_new", {"uri": new_path}) + try: + agfs_cp(self.agfs, old_path, new_path, recursive=is_dir) + except Exception as e: + if "not found" in str(e).lower(): + await self._delete_from_vector_store(uris_to_move, ctx=ctx) + logger.info(f"[VikingFS] mv source not found, cleaned orphan index: {old_uri}") + raise + tx.mark_completed(seq_cp) + + # Step 2: Remove carried lock file from the copy (directory only) + if is_dir: + carried_lock = new_path.rstrip("/") + "/.path.ovlock" + try: + self.agfs.rm(carried_lock) + except Exception: + pass + + # Step 3: Update VectorDB URIs + old_uri_stripped = old_uri.rstrip("/") + old_parent_uri = ( + old_uri_stripped.rsplit("/", 1)[0] + "/" if "/" in old_uri_stripped else "" + ) + real_ctx = self._ctx_or_default(ctx) + seq_vdb = tx.record_undo( + "vectordb_update_uri", + { + "old_uri": old_uri, + "new_uri": new_uri, + "old_parent_uri": old_parent_uri, + "uris": uris_to_move, + "_ctx_account_id": real_ctx.account_id, + "_ctx_user_id": real_ctx.user.user_id, + "_ctx_agent_id": real_ctx.user.agent_id, + "_ctx_role": real_ctx.role.value, + }, + ) await self._update_vector_store_uris(uris_to_move, old_uri, new_uri, ctx=ctx) - return result - except AGFSHTTPError as e: - if e.status_code == 404: - await self._delete_from_vector_store(uris_to_move, ctx=ctx) - logger.info(f"[VikingFS] mv source not found, cleaned orphan index: {old_uri}") - raise + tx.mark_completed(seq_vdb) + + # Step 4: Remove source (lock file gets deleted along with it) + seq_rm = tx.record_undo("fs_rm", {"uri": old_path, "recursive": is_dir}) + self.agfs.rm(old_path, recursive=is_dir) + tx.mark_completed(seq_rm) + + await tx.commit() + return {} async def grep( self, @@ -1071,19 +1191,6 @@ def _handle_agfs_content(self, result: Union[bytes, Any, None]) -> str: return str(result) except Exception: return "" - """Handle AGFSClient content return types consistently.""" - if isinstance(result, bytes): - return result.decode("utf-8") - elif hasattr(result, "content"): - return result.content.decode("utf-8") - elif result is None: - return "" - else: - # Try to convert to string - try: - return str(result) - except Exception: - return "" def _infer_context_type(self, uri: str): """Infer context_type from URI. Returns None when ambiguous.""" @@ -1099,6 +1206,33 @@ def _infer_context_type(self, uri: str): # ========== Vector Sync Helper Methods ========== + async def _snapshot_vector_records( + self, uris: List[str], ctx: Optional[RequestContext] = None + ) -> List[Dict[str, Any]]: + """Snapshot vector records for the given URIs (for rollback). + + Queries VectorDB metadata (without embedding vectors) so that + records can be restored during rollback. + """ + vector_store = self._get_vector_store() + if not vector_store: + return [] + + real_ctx = self._ctx_or_default(ctx) + snapshots = [] + for uri in uris: + try: + records = await vector_store.get_context_by_uri( + uri=uri, + limit=10, + ctx=real_ctx, + ) + if records: + snapshots.extend(records) + except Exception as e: + logger.debug(f"[VikingFS] Failed to snapshot vector record for {uri}: {e}") + return snapshots + async def _collect_uris( self, path: str, recursive: bool, ctx: Optional[RequestContext] = None ) -> List[str]: @@ -1296,6 +1430,12 @@ async def read_file( """ self._ensure_access(uri, ctx) path = self._uri_to_path(uri, ctx=ctx) + # Verify the file exists before reading, because AGFS read returns + # empty bytes for non-existent files instead of raising an error. + try: + self.agfs.stat(path) + except Exception: + raise NotFoundError(uri, "file") try: content = self.agfs.read(path) except Exception: diff --git a/openviking/storage/viking_vector_index_backend.py b/openviking/storage/viking_vector_index_backend.py index 880af30b7..29026d49f 100644 --- a/openviking/storage/viking_vector_index_backend.py +++ b/openviking/storage/viking_vector_index_backend.py @@ -882,13 +882,21 @@ def _seed_uri_for_id(uri: str, level: int) -> str: async def increment_active_count(self, ctx: RequestContext, uris: List[str]) -> int: updated = 0 for uri in uris: - records = await self.get_context_by_uri(uri=uri, limit=1, ctx=ctx) + records = await self.get_context_by_uri(uri=uri, limit=100, ctx=ctx) if not records: continue - record = records[0] - current = int(record.get("active_count", 0) or 0) - record["active_count"] = current + 1 - if await self.upsert(record, ctx=ctx): + record_ids = [r["id"] for r in records if r.get("id")] + if not record_ids: + continue + # Re-fetch by ID to get full records including vectors + full_records = await self.get(record_ids, ctx=ctx) + uri_updated = False + for record in full_records: + current = int(record.get("active_count", 0) or 0) + record["active_count"] = current + 1 + if await self.upsert(record, ctx=ctx): + uri_updated = True + if uri_updated: updated += 1 return updated diff --git a/openviking/utils/agfs_utils.py b/openviking/utils/agfs_utils.py index 511cf2d04..8db073d95 100644 --- a/openviking/utils/agfs_utils.py +++ b/openviking/utils/agfs_utils.py @@ -99,6 +99,10 @@ def mount_agfs_backend(agfs: Any, agfs_config: Any) -> None: local_dir = plugin_config["config"]["local_dir"] os.makedirs(local_dir, exist_ok=True) logger.debug(f"[AGFSUtils] Ensured local directory exists: {local_dir}") + # Ensure queuefs db_path parent directory exists before mounting + if plugin_name == "queuefs" and "db_path" in plugin_config.get("config", {}): + db_path = plugin_config["config"]["db_path"] + os.makedirs(os.path.dirname(db_path), exist_ok=True) try: agfs.unmount(mount_path) diff --git a/openviking/utils/resource_processor.py b/openviking/utils/resource_processor.py index bd56a2eae..d6c44194f 100644 --- a/openviking/utils/resource_processor.py +++ b/openviking/utils/resource_processor.py @@ -7,6 +7,7 @@ as described in the OpenViking design document. """ +import asyncio import time from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -209,6 +210,49 @@ async def process_resource( return result + # ============ Phase 3.5: 首次添加立即落盘 ============ + root_uri = result.get("root_uri") + temp_uri = result.get("temp_uri") # temp_doc_uri + + if root_uri and temp_uri: + viking_fs = get_viking_fs() + target_exists = await viking_fs.exists(root_uri, ctx=ctx) + if not target_exists: + # 第一次添加:事务保护下将 temp 移到 final + from openviking.storage.transaction import ( + TransactionContext, + get_transaction_manager, + ) + + dst_path = viking_fs._uri_to_path(root_uri, ctx=ctx) + parent_path = dst_path.rsplit("/", 1)[0] if "/" in dst_path else dst_path + + # 确保父目录存在 + parent_uri = "/".join(root_uri.rsplit("/", 1)[:-1]) + if parent_uri: + await viking_fs.mkdir(parent_uri, exist_ok=True, ctx=ctx) + + async with TransactionContext( + get_transaction_manager(), + "finalize_from_temp", + [parent_path], + lock_mode="point", + ) as tx: + seq = tx.record_undo("fs_write_new", {"uri": dst_path}) + src_path = viking_fs._uri_to_path(temp_uri, ctx=ctx) + await asyncio.to_thread(viking_fs.agfs.mv, src_path, dst_path) + tx.mark_completed(seq) + await tx.commit() + + # 清理 temp 根目录 + try: + await viking_fs.delete_temp(parse_result.temp_dir_path, ctx=ctx) + except Exception: + pass + + # 更新 temp_uri → DAG 直接在 final 上跑 + result["temp_uri"] = root_uri + # ============ Phase 4: Optional Steps ============ build_index = kwargs.get("build_index", True) temp_uri_for_summarize = result.get("temp_uri") or parse_result.temp_dir_path diff --git a/openviking/utils/summarizer.py b/openviking/utils/summarizer.py index 22076f925..36b879e86 100644 --- a/openviking/utils/summarizer.py +++ b/openviking/utils/summarizer.py @@ -71,7 +71,7 @@ async def summarize( role=ctx.role.value, skip_vectorization=skip_vectorization, telemetry_id=telemetry.telemetry_id if telemetry.enabled else "", - target_uri=uri, + target_uri=uri if uri != temp_uri else None, ) await semantic_queue.enqueue(msg) enqueued_count += 1 diff --git a/openviking_cli/utils/config/embedding_config.py b/openviking_cli/utils/config/embedding_config.py index 81ad64d84..ba63891cd 100644 --- a/openviking_cli/utils/config/embedding_config.py +++ b/openviking_cli/utils/config/embedding_config.py @@ -67,7 +67,15 @@ def sync_provider_backend(cls, data: Any) -> Any: if backend is not None and provider is None: data["provider"] = backend - for key in ("input_type", "query_value", "document_value", "query_task", "document_task"): + for key in ( + "input_type", + "query_param", + "document_param", + "query_value", + "document_value", + "query_task", + "document_task", + ): value = data.get(key) if isinstance(value, str): data[key] = value.lower() @@ -176,7 +184,13 @@ def validate_config(self): ) return self - def _create_embedder(self, provider: str, embedder_type: str, config: EmbeddingModelConfig, context: Optional[str] = None): + def _create_embedder( + self, + provider: str, + embedder_type: str, + config: EmbeddingModelConfig, + context: Optional[str] = None, + ): """Factory method to create embedder instance based on provider and type. Args: @@ -211,7 +225,8 @@ def _create_embedder(self, provider: str, embedder_type: str, config: EmbeddingM OpenAIDenseEmbedder, lambda cfg: { "model_name": cfg.model, - "api_key": cfg.api_key or "no-key", # Placeholder for local OpenAI-compatible servers + "api_key": cfg.api_key + or "no-key", # Placeholder for local OpenAI-compatible servers "api_base": cfg.api_base, "dimension": cfg.dimension, "context": context, @@ -302,7 +317,8 @@ def _create_embedder(self, provider: str, embedder_type: str, config: EmbeddingM OpenAIDenseEmbedder, lambda cfg: { "model_name": cfg.model, - "api_key": cfg.api_key or "no-key", # Ollama ignores the key, but client requires non-empty + "api_key": cfg.api_key + or "no-key", # Ollama ignores the key, but client requires non-empty "api_base": cfg.api_base or "http://localhost:11434/v1", "dimension": cfg.dimension, "max_tokens": cfg.max_tokens, @@ -381,13 +397,18 @@ def _get_contextual_embedder(self, context: str): # OpenAI models are symmetric by default (no input_type sent). # Non-symmetric mode is activated implicitly when the user sets # query_param or document_param in the config. - non_symmetric = self.dense.query_param is not None or self.dense.document_param is not None + non_symmetric = ( + self.dense.query_param is not None or self.dense.document_param is not None + ) effective_context = context if non_symmetric else None return self._create_embedder(provider, "dense", self.dense, context=effective_context) if provider == "jina": - # Jina models are non-symmetric by default (task is always sent). - return self._create_embedder(provider, "dense", self.dense, context=context) + non_symmetric = ( + self.dense.query_param is not None or self.dense.document_param is not None + ) + effective_context = context if non_symmetric else None + return self._create_embedder(provider, "dense", self.dense, context=effective_context) return self.get_embedder() diff --git a/openviking_cli/utils/config/storage_config.py b/openviking_cli/utils/config/storage_config.py index 8daf6a792..b8b4bfea6 100644 --- a/openviking_cli/utils/config/storage_config.py +++ b/openviking_cli/utils/config/storage_config.py @@ -8,6 +8,7 @@ from openviking_cli.utils.logger import get_logger from .agfs_config import AGFSConfig +from .transaction_config import TransactionConfig from .vectordb_config import VectorDBBackendConfig logger = get_logger(__name__) @@ -25,6 +26,11 @@ class StorageConfig(BaseModel): agfs: AGFSConfig = Field(default_factory=lambda: AGFSConfig(), description="AGFS configuration") + transaction: TransactionConfig = Field( + default_factory=lambda: TransactionConfig(), + description="Transaction mechanism configuration", + ) + vectordb: VectorDBBackendConfig = Field( default_factory=lambda: VectorDBBackendConfig(), description="VectorDB backend configuration", diff --git a/openviking_cli/utils/config/transaction_config.py b/openviking_cli/utils/config/transaction_config.py new file mode 100644 index 000000000..fac8c2aa6 --- /dev/null +++ b/openviking_cli/utils/config/transaction_config.py @@ -0,0 +1,37 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +from pydantic import BaseModel, Field + + +class TransactionConfig(BaseModel): + """Configuration for the transaction mechanism. + + By default, lock acquisition does not wait (``lock_timeout=0``): if a + conflicting lock is held the operation fails immediately with + ``LockAcquisitionError``. Set ``lock_timeout`` to a positive value to + allow the caller to block and retry for up to that many seconds. + """ + + lock_timeout: float = Field( + default=0.0, + description=( + "Path lock acquisition timeout (seconds). " + "0 = fail immediately if locked (default). " + "> 0 = wait/retry up to this many seconds before raising LockAcquisitionError." + ), + ) + + lock_expire: float = Field( + default=300.0, + description=( + "Stale lock expiry threshold (seconds). " + "Locks held longer than this by a crashed process are force-released." + ), + ) + + max_parallel_locks: int = Field( + default=8, + description="Maximum parallel lock operations during recursive rm/mv.", + ) + + model_config = {"extra": "forbid"} diff --git a/tests/agfs/test_fs_binding.py b/tests/agfs/test_fs_binding.py index ed8d3d333..e55ff6fd8 100644 --- a/tests/agfs/test_fs_binding.py +++ b/tests/agfs/test_fs_binding.py @@ -13,6 +13,7 @@ import pytest +from openviking.storage.transaction import init_transaction_manager, reset_transaction_manager from openviking.storage.viking_fs import init_viking_fs from openviking_cli.utils.config.agfs_config import AGFSConfig @@ -32,16 +33,16 @@ async def viking_fs_binding_instance(): # Create AGFS client agfs_client = create_agfs_client(AGFS_CONF) - # Initialize VikingFS with client + # Initialize TransactionManager and VikingFS with client + init_transaction_manager(agfs=agfs_client) vfs = init_viking_fs(agfs=agfs_client) # make sure default/temp directory exists await vfs.mkdir("viking://temp/", exist_ok=True) - # Ensure test directory exists - await vfs.mkdir("viking://temp/", exist_ok=True) - yield vfs + reset_transaction_manager() + @pytest.mark.asyncio class TestVikingFSBindingLocal: diff --git a/tests/agfs/test_fs_binding_s3.py b/tests/agfs/test_fs_binding_s3.py index 692b869d7..aa7a753b6 100644 --- a/tests/agfs/test_fs_binding_s3.py +++ b/tests/agfs/test_fs_binding_s3.py @@ -13,6 +13,7 @@ import pytest +from openviking.storage.transaction import init_transaction_manager, reset_transaction_manager from openviking.storage.viking_fs import init_viking_fs from openviking_cli.utils.config.agfs_config import AGFSConfig @@ -57,11 +58,14 @@ async def viking_fs_binding_s3_instance(): # Create AGFS client agfs_client = create_agfs_client(AGFS_CONF) - # Initialize VikingFS with client + # Initialize TransactionManager and VikingFS with client + init_transaction_manager(agfs=agfs_client) vfs = init_viking_fs(agfs=agfs_client) yield vfs + reset_transaction_manager() + @pytest.mark.asyncio class TestVikingFSBindingS3: diff --git a/tests/agfs/test_fs_local.py b/tests/agfs/test_fs_local.py index 3a428ed68..9e59f6103 100644 --- a/tests/agfs/test_fs_local.py +++ b/tests/agfs/test_fs_local.py @@ -10,6 +10,7 @@ import pytest from openviking.agfs_manager import AGFSManager +from openviking.storage.transaction import init_transaction_manager, reset_transaction_manager from openviking.storage.viking_fs import init_viking_fs from openviking_cli.utils.config.agfs_config import AGFSConfig @@ -39,13 +40,15 @@ async def viking_fs_instance(): # Create AGFS client agfs_client = create_agfs_client(AGFS_CONF) - # Initialize VikingFS with client + # Initialize TransactionManager and VikingFS with client + init_transaction_manager(agfs=agfs_client) vfs = init_viking_fs(agfs=agfs_client) # make sure default/temp directory exists await vfs.mkdir("viking://temp/", exist_ok=True) yield vfs + reset_transaction_manager() # AGFSManager.stop is synchronous manager.stop() diff --git a/tests/agfs/test_fs_s3.py b/tests/agfs/test_fs_s3.py index 330c70893..67a54e40a 100644 --- a/tests/agfs/test_fs_s3.py +++ b/tests/agfs/test_fs_s3.py @@ -13,6 +13,7 @@ import pytest from openviking.agfs_manager import AGFSManager +from openviking.storage.transaction import init_transaction_manager, reset_transaction_manager from openviking.storage.viking_fs import VikingFS, init_viking_fs from openviking_cli.utils.config.agfs_config import AGFSConfig @@ -46,7 +47,8 @@ def load_agfs_config() -> AGFSConfig: AGFS_CONF = load_agfs_config() -AGFS_CONF.mode = "http-client" +if AGFS_CONF is not None: + AGFS_CONF.mode = "http-client" # 2. Skip tests if no S3 config found or backend is not S3 pytestmark = pytest.mark.skipif( @@ -81,11 +83,13 @@ async def viking_fs_instance(): # Create AGFS client agfs_client = create_agfs_client(AGFS_CONF) - # Initialize VikingFS with client + # Initialize TransactionManager and VikingFS with client + init_transaction_manager(agfs=agfs_client) vfs = init_viking_fs(agfs=agfs_client) yield vfs + reset_transaction_manager() # AGFSManager.stop is synchronous manager.stop() diff --git a/tests/client/test_resource_management.py b/tests/client/test_resource_management.py index 98302fd7e..f42ce1324 100644 --- a/tests/client/test_resource_management.py +++ b/tests/client/test_resource_management.py @@ -49,7 +49,7 @@ async def test_add_resource_with_to(self, client: AsyncOpenViking, sample_markdo """Test adding resource to specified target""" result = await client.add_resource( path=str(sample_markdown_file), - to="viking://resources/custom/", + to="viking://resources/custom/sample", reason="Test resource", ) diff --git a/tests/integration/test_add_resource_index.py b/tests/integration/test_add_resource_index.py index 32421e695..3da7cc256 100644 --- a/tests/integration/test_add_resource_index.py +++ b/tests/integration/test_add_resource_index.py @@ -1,10 +1,8 @@ -import pytest -import asyncio -import os import json -import shutil -from pathlib import Path -from unittest.mock import MagicMock, AsyncMock, patch +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest from openviking.async_client import AsyncOpenViking from openviking_cli.utils.config.open_viking_config import OpenVikingConfigSingleton @@ -88,6 +86,20 @@ async def test_add_resource_indexing_logic(test_config, tmp_path): mock_agfs = MockLocalAGFS(root_path=tmp_path / "mock_agfs_root") + # Create mock parse result for Phase 1 (media processor) + mock_parse_result = MagicMock() + mock_parse_result.source_path = str(resource_file) + mock_parse_result.meta = {} + mock_parse_result.temp_dir_path = "/tmp/fake_temp_dir" + mock_parse_result.warnings = [] + mock_parse_result.source_format = "markdown" + + # Create mock context tree for Phase 2/3 (tree builder) + mock_context_tree = MagicMock() + mock_context_tree.root = MagicMock() + mock_context_tree.root.uri = "viking://resources/test_doc" + mock_context_tree.root.temp_uri = None + # Patch the Summarizer and IndexBuilder to verify calls with ( patch( @@ -96,6 +108,16 @@ async def test_add_resource_indexing_logic(test_config, tmp_path): patch("openviking.utils.agfs_utils.create_agfs_client", return_value=mock_agfs), patch("openviking.agfs_manager.AGFSManager.start"), patch("openviking.agfs_manager.AGFSManager.stop"), + patch( + "openviking.utils.media_processor.UnifiedResourceProcessor.process", + new_callable=AsyncMock, + return_value=mock_parse_result, + ), + patch( + "openviking.parse.tree_builder.TreeBuilder.finalize_from_temp", + new_callable=AsyncMock, + return_value=mock_context_tree, + ), ): mock_summarize.return_value = {"status": "success"} diff --git a/tests/integration/test_full_workflow.py b/tests/integration/test_full_workflow.py index 3f86b5599..823cefd74 100644 --- a/tests/integration/test_full_workflow.py +++ b/tests/integration/test_full_workflow.py @@ -67,11 +67,17 @@ async def test_add_search_read_workflow( # 3. Read searched resource if search_result.resources: - res = await client.tree(search_result.resources[0].uri) - for data in res: - if not data["isDir"]: - content = await client.read(data["uri"]) - assert len(content) > 0 + uri = search_result.resources[0].uri + info = await client.stat(uri) + if info.get("isDir"): + res = await client.tree(uri) + for data in res: + if not data["isDir"]: + content = await client.read(data["uri"]) + assert len(content) > 0 + else: + content = await client.read(uri) + assert len(content) > 0 class TestSessionWorkflow: diff --git a/tests/retrieve/test_hierarchical_retriever_rerank.py b/tests/retrieve/test_hierarchical_retriever_rerank.py index ffaea6a8f..f72682b32 100644 --- a/tests/retrieve/test_hierarchical_retriever_rerank.py +++ b/tests/retrieve/test_hierarchical_retriever_rerank.py @@ -180,8 +180,8 @@ def test_merge_starting_points_prefers_rerank_scores_in_thinking_mode(monkeypatc "hello", ["viking://resources"], [ - {"uri": "viking://resources/root-a", "abstract": "root A", "_score": 0.2}, - {"uri": "viking://resources/root-b", "abstract": "root B", "_score": 0.8}, + {"uri": "viking://resources/root-a", "abstract": "root A", "_score": 0.2, "level": 1}, + {"uri": "viking://resources/root-b", "abstract": "root B", "_score": 0.8, "level": 1}, ], mode=RetrieverMode.THINKING, ) diff --git a/tests/server/conftest.py b/tests/server/conftest.py index 627798b4a..78dbb63e7 100644 --- a/tests/server/conftest.py +++ b/tests/server/conftest.py @@ -20,8 +20,10 @@ from openviking.server.config import ServerConfig from openviking.server.identity import RequestContext, Role from openviking.service.core import OpenVikingService +from openviking.storage.transaction import reset_transaction_manager from openviking_cli.session.user_id import UserIdentifier from openviking_cli.utils.config.embedding_config import EmbeddingConfig +from openviking_cli.utils.config.vlm_config import VLMConfig # --------------------------------------------------------------------------- # Paths @@ -67,6 +69,20 @@ def get_dimension(self) -> int: return FakeEmbedder +def _install_fake_vlm(monkeypatch): + """Use a fake VLM so server tests never hit external LLM APIs.""" + + async def _fake_get_completion(self, prompt, thinking=False, max_retries=0): + return "# Test Summary\n\nFake summary for testing.\n\n## Details\nTest content." + + async def _fake_get_vision_completion(self, prompt, images, thinking=False): + return "Fake image description for testing." + + monkeypatch.setattr(VLMConfig, "is_available", lambda self: True) + monkeypatch.setattr(VLMConfig, "get_completion_async", _fake_get_completion) + monkeypatch.setattr(VLMConfig, "get_vision_completion_async", _fake_get_vision_completion) + + # --------------------------------------------------------------------------- # Core fixtures: service + app + async client (HTTP API tests, in-process) # --------------------------------------------------------------------------- @@ -94,7 +110,9 @@ def sample_markdown_file(temp_dir: Path) -> Path: @pytest_asyncio.fixture(scope="function") async def service(temp_dir: Path, monkeypatch): """Create and initialize an OpenVikingService in embedded mode.""" + reset_transaction_manager() fake_embedder_cls = _install_fake_embedder(monkeypatch) + _install_fake_vlm(monkeypatch) svc = OpenVikingService( path=str(temp_dir / "data"), user=UserIdentifier.the_default_user("test_user") ) @@ -102,6 +120,7 @@ async def service(temp_dir: Path, monkeypatch): svc.viking_fs.query_embedder = fake_embedder_cls() yield svc await svc.close() + reset_transaction_manager() @pytest_asyncio.fixture(scope="function") @@ -146,7 +165,9 @@ async def client_with_resource(client, service, sample_markdown_file): async def running_server(temp_dir: Path, monkeypatch): """Start a real uvicorn server in a background thread.""" await AsyncOpenViking.reset() + reset_transaction_manager() fake_embedder_cls = _install_fake_embedder(monkeypatch) + _install_fake_vlm(monkeypatch) svc = OpenVikingService( path=str(temp_dir / "sdk_data"), user=UserIdentifier.the_default_user("sdk_test_user") diff --git a/tests/server/test_api_filesystem.py b/tests/server/test_api_filesystem.py index 79058d375..3a0da611c 100644 --- a/tests/server/test_api_filesystem.py +++ b/tests/server/test_api_filesystem.py @@ -66,14 +66,6 @@ async def test_tree(client: httpx.AsyncClient): assert body["status"] == "ok" -async def test_stat_after_add_resource(client_with_resource): - client, uri = client_with_resource - resp = await client.get("/api/v1/fs/stat", params={"uri": uri}) - assert resp.status_code == 200 - body = resp.json() - assert body["status"] == "ok" - - async def test_stat_not_found(client: httpx.AsyncClient): resp = await client.get( "/api/v1/fs/stat", @@ -84,18 +76,28 @@ async def test_stat_not_found(client: httpx.AsyncClient): assert body["status"] == "error" -async def test_rm_resource(client_with_resource): +async def test_resource_ops(client_with_resource): + """Test stat, ls_recursive, mv, rm on a single shared resource.""" + import uuid + client, uri = client_with_resource - resp = await client.request("DELETE", "/api/v1/fs", params={"uri": uri, "recursive": True}) + + # stat + resp = await client.get("/api/v1/fs/stat", params={"uri": uri}) assert resp.status_code == 200 assert resp.json()["status"] == "ok" + # ls recursive + resp = await client.get( + "/api/v1/fs/ls", + params={"uri": "viking://", "recursive": True}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "ok" + assert isinstance(body["result"], list) -async def test_mv_resource(client_with_resource): - import uuid - - client, uri = client_with_resource - # Use a unique name to avoid conflicts with leftover data + # mv unique = uuid.uuid4().hex[:8] new_uri = uri.rstrip("/") + f"_mv_{unique}/" resp = await client.post( @@ -105,14 +107,7 @@ async def test_mv_resource(client_with_resource): assert resp.status_code == 200 assert resp.json()["status"] == "ok" - -async def test_ls_recursive(client_with_resource): - client, _ = client_with_resource - resp = await client.get( - "/api/v1/fs/ls", - params={"uri": "viking://", "recursive": True}, - ) + # rm (on the moved uri) + resp = await client.request("DELETE", "/api/v1/fs", params={"uri": new_uri, "recursive": True}) assert resp.status_code == 200 - body = resp.json() - assert body["status"] == "ok" - assert isinstance(body["result"], list) + assert resp.json()["status"] == "ok" diff --git a/tests/server/test_api_resources.py b/tests/server/test_api_resources.py index 48e214e8f..6e5911273 100644 --- a/tests/server/test_api_resources.py +++ b/tests/server/test_api_resources.py @@ -120,7 +120,7 @@ async def test_add_resource_with_to(client: httpx.AsyncClient, sample_markdown_f "/api/v1/resources", json={ "path": str(sample_markdown_file), - "to": "viking://resources/custom/", + "to": "viking://resources/custom/sample", "reason": "test resource", }, ) diff --git a/tests/session/test_memory_dedup_actions.py b/tests/session/test_memory_dedup_actions.py index e7bb1a80f..0f8f94e6c 100644 --- a/tests/session/test_memory_dedup_actions.py +++ b/tests/session/test_memory_dedup_actions.py @@ -179,7 +179,6 @@ async def test_find_similar_memories_uses_path_must_filter_and__score(self): assert len(similar) == 1 assert similar[0].uri == existing.uri call = vikingdb.search_similar_memories.await_args.kwargs - assert call["account_id"] == "acc1" assert call["owner_space"] == _make_user().user_space_name() assert call["category_uri_prefix"] == ( f"viking://user/{_make_user().user_space_name()}/memories/preferences/" diff --git a/tests/session/test_session_commit.py b/tests/session/test_session_commit.py index 60a42d027..efa57fc70 100644 --- a/tests/session/test_session_commit.py +++ b/tests/session/test_session_commit.py @@ -6,9 +6,6 @@ from openviking import AsyncOpenViking from openviking.message import TextPart from openviking.session import Session -from tests.utils.mock_context import make_test_ctx - -ctx = make_test_ctx() class TestCommit: @@ -98,12 +95,14 @@ async def test_active_count_incremented_after_commit(self, client_with_resource_ """ client, uri = client_with_resource_sync vikingdb = client._client.service.vikingdb_manager + # Use the client's own context to match the account_id used when adding the resource + client_ctx = client._client._ctx # Look up the record by URI records_before = await vikingdb.get_context_by_uri( uri=uri, limit=1, - ctx=ctx, + ctx=client_ctx, ) assert records_before, f"Resource not found for URI: {uri}" count_before = records_before[0].get("active_count") or 0 @@ -121,7 +120,7 @@ async def test_active_count_incremented_after_commit(self, client_with_resource_ records_after = await vikingdb.get_context_by_uri( uri=uri, limit=1, - ctx=ctx, + ctx=client_ctx, ) assert records_after, f"Record disappeared after commit for URI: {uri}" count_after = records_after[0].get("active_count") or 0 diff --git a/tests/session/test_session_compressor_vikingdb.py b/tests/session/test_session_compressor_vikingdb.py index 71e225333..7cb38d020 100644 --- a/tests/session/test_session_compressor_vikingdb.py +++ b/tests/session/test_session_compressor_vikingdb.py @@ -15,8 +15,12 @@ async def test_delete_existing_memory_uses_vikingdb_manager(): compressor = SessionCompressor.__new__(SessionCompressor) compressor.vikingdb = AsyncMock() + compressor._pending_semantic_changes = {} viking_fs = AsyncMock() - memory = SimpleNamespace(uri="viking://user/user1/memories/events/e1") + memory = SimpleNamespace( + uri="viking://user/user1/memories/events/e1", + parent_uri="viking://user/user1/memories/events", + ) ctx = RequestContext(user=UserIdentifier("acc1", "user1", "agent1"), role=Role.USER) ok = await SessionCompressor._delete_existing_memory(compressor, memory, viking_fs, ctx) diff --git a/tests/storage/test_collection_schemas.py b/tests/storage/test_collection_schemas.py index c25dbbda6..c5a17501f 100644 --- a/tests/storage/test_collection_schemas.py +++ b/tests/storage/test_collection_schemas.py @@ -25,7 +25,7 @@ def __init__(self, embedder: _DummyEmbedder): self.storage = SimpleNamespace(vectordb=SimpleNamespace(name="context")) self.embedding = SimpleNamespace( dimension=2, - get_embedder=lambda: embedder, + get_document_embedder=lambda: embedder, ) diff --git a/tests/storage/test_semantic_dag_skip_files.py b/tests/storage/test_semantic_dag_skip_files.py index 6b52fa4e9..6fdf30eef 100644 --- a/tests/storage/test_semantic_dag_skip_files.py +++ b/tests/storage/test_semantic_dag_skip_files.py @@ -1,6 +1,8 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: Apache-2.0 +from unittest.mock import AsyncMock, MagicMock + import pytest from openviking.server.identity import RequestContext, Role @@ -8,6 +10,24 @@ from openviking_cli.session.user_id import UserIdentifier +def _mock_transaction_layer(monkeypatch): + """Patch transaction layer to no-op for DAG tests.""" + mock_tx = MagicMock() + mock_tx.commit = AsyncMock() + monkeypatch.setattr( + "openviking.storage.transaction.context_manager.TransactionContext.__aenter__", + AsyncMock(return_value=mock_tx), + ) + monkeypatch.setattr( + "openviking.storage.transaction.context_manager.TransactionContext.__aexit__", + AsyncMock(return_value=False), + ) + monkeypatch.setattr( + "openviking.storage.transaction.get_transaction_manager", + lambda: MagicMock(), + ) + + class _FakeVikingFS: def __init__(self, tree): self._tree = tree @@ -19,6 +39,9 @@ async def ls(self, uri, ctx=None): async def write_file(self, path, content, ctx=None): self.writes.append((path, content)) + def _uri_to_path(self, uri, ctx=None): + return uri.replace("viking://", "/local/acc1/") + class _FakeProcessor: def __init__(self): @@ -57,6 +80,7 @@ async def register(self, **_kwargs): @pytest.mark.asyncio async def test_messages_jsonl_excluded_from_summary(monkeypatch): """messages.jsonl should be skipped by _list_dir and never summarized.""" + _mock_transaction_layer(monkeypatch) root_uri = "viking://session/test-session" tree = { root_uri: [ @@ -91,6 +115,7 @@ async def test_messages_jsonl_excluded_from_summary(monkeypatch): @pytest.mark.asyncio async def test_messages_jsonl_excluded_in_subdirectory(monkeypatch): """messages.jsonl in a subdirectory should also be skipped.""" + _mock_transaction_layer(monkeypatch) root_uri = "viking://session/test-session" tree = { root_uri: [ diff --git a/tests/storage/test_semantic_dag_stats.py b/tests/storage/test_semantic_dag_stats.py index 8cb8c1f8f..23dde0410 100644 --- a/tests/storage/test_semantic_dag_stats.py +++ b/tests/storage/test_semantic_dag_stats.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +from unittest.mock import AsyncMock, MagicMock import pytest @@ -21,6 +22,9 @@ async def ls(self, uri, ctx=None): async def write_file(self, path, content, ctx=None): self.writes.append((path, content)) + def _uri_to_path(self, uri, ctx=None): + return uri.replace("viking://", "/local/acc1/") + class _FakeProcessor: def __init__(self): @@ -75,6 +79,22 @@ async def test_semantic_dag_stats_collects_nodes(monkeypatch): lambda: _DummyTracker(), ) + # Mock transaction layer: TransactionContext as no-op passthrough + mock_tx = MagicMock() + mock_tx.commit = AsyncMock() + monkeypatch.setattr( + "openviking.storage.transaction.context_manager.TransactionContext.__aenter__", + AsyncMock(return_value=mock_tx), + ) + monkeypatch.setattr( + "openviking.storage.transaction.context_manager.TransactionContext.__aexit__", + AsyncMock(return_value=False), + ) + monkeypatch.setattr( + "openviking.storage.transaction.get_transaction_manager", + lambda: MagicMock(), + ) + processor = _FakeProcessor() ctx = RequestContext(user=UserIdentifier("acc1", "user1", "agent1"), role=Role.USER) executor = SemanticDagExecutor( diff --git a/tests/test_session_task_tracking.py b/tests/test_session_task_tracking.py index 8a61fe4d6..1306d5003 100644 --- a/tests/test_session_task_tracking.py +++ b/tests/test_session_task_tracking.py @@ -181,7 +181,7 @@ async def test_task_failed_when_memory_extraction_raises(api_client): async def failing_extract(_context, _user, _session_id): raise RuntimeError("memory_extraction_failed: synthetic extractor error") - service.sessions._session_compressor.extractor.extract_strict = failing_extract + service.sessions._session_compressor.extractor.extract = failing_extract resp = await client.post(f"/api/v1/sessions/{session_id}/commit", params={"wait": False}) task_id = resp.json()["result"]["task_id"] diff --git a/tests/test_telemetry_runtime.py b/tests/test_telemetry_runtime.py index eecdbf04b..585670127 100644 --- a/tests/test_telemetry_runtime.py +++ b/tests/test_telemetry_runtime.py @@ -12,6 +12,7 @@ from openviking.server.identity import RequestContext, Role from openviking.service.resource_service import ResourceService from openviking.storage.collection_schemas import TextEmbeddingHandler +from openviking.storage.queuefs.semantic_dag import DagStats from openviking.storage.queuefs.semantic_msg import SemanticMsg from openviking.storage.queuefs.semantic_processor import SemanticProcessor from openviking.telemetry import ( @@ -134,15 +135,25 @@ class FakeVikingFS: async def ls(self, uri, ctx=None): return [] - async def fake_process_single_directory(**kwargs): - assert get_current_telemetry() is telemetry - get_current_telemetry().record_token_usage("llm", 11, 7) + class _FakeDagExecutor: + def __init__(self, **kwargs): + pass + + async def run(self, root_uri): + assert get_current_telemetry() is telemetry + get_current_telemetry().record_token_usage("llm", 11, 7) + + def get_stats(self): + return DagStats() monkeypatch.setattr( "openviking.storage.queuefs.semantic_processor.get_viking_fs", lambda: FakeVikingFS(), ) - monkeypatch.setattr(processor, "_process_single_directory", fake_process_single_directory) + monkeypatch.setattr( + "openviking.storage.queuefs.semantic_processor.SemanticDagExecutor", + lambda **kwargs: _FakeDagExecutor(**kwargs), + ) try: await processor.on_dequeue( @@ -179,13 +190,13 @@ def __init__(self): self.storage = SimpleNamespace(vectordb=SimpleNamespace(name="context")) self.embedding = SimpleNamespace( dimension=2, - get_embedder=lambda: _TelemetryAwareEmbedder(), + get_document_embedder=lambda: _TelemetryAwareEmbedder(), ) class _DummyVikingDB: is_closing = False - async def upsert(self, _data): + async def upsert(self, _data, *, ctx=None): return "rec-1" monkeypatch.setattr( diff --git a/tests/transaction/__init__.py b/tests/transaction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/transaction/conftest.py b/tests/transaction/conftest.py new file mode 100644 index 000000000..05fac402b --- /dev/null +++ b/tests/transaction/conftest.py @@ -0,0 +1,140 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Shared fixtures for transaction tests using real AGFS and VectorDB backends.""" + +import os +import shutil +import uuid + +import pytest + +from openviking.agfs_manager import AGFSManager +from openviking.server.identity import RequestContext, Role +from openviking.storage.collection_schemas import CollectionSchemas +from openviking.storage.transaction.journal import TransactionJournal +from openviking.storage.transaction.path_lock import LOCK_FILE_NAME, _make_fencing_token +from openviking.storage.transaction.transaction_manager import TransactionManager +from openviking.storage.viking_vector_index_backend import VikingVectorIndexBackend +from openviking.utils.agfs_utils import create_agfs_client +from openviking_cli.session.user_id import UserIdentifier +from openviking_cli.utils.config.agfs_config import AGFSConfig +from openviking_cli.utils.config.vectordb_config import VectorDBBackendConfig + +AGFS_CONF = AGFSConfig( + path="/tmp/ov-tx-test", backend="local", port=1834, url="http://localhost:1834", timeout=10 +) + +VECTOR_DIM = 4 +COLLECTION_NAME = "tx_test_ctx" + +# Clean slate before session starts +if os.path.exists(AGFS_CONF.path): + shutil.rmtree(AGFS_CONF.path) + + +@pytest.fixture(scope="session") +def agfs_manager(): + manager = AGFSManager(config=AGFS_CONF) + manager.start() + yield manager + manager.stop() + + +@pytest.fixture(scope="session") +def agfs_client(agfs_manager): + return create_agfs_client(AGFS_CONF) + + +def _mkdir_ok(agfs_client, path): + """Create directory, ignoring already-exists errors.""" + try: + agfs_client.mkdir(path) + except Exception: + pass # already exists + + +@pytest.fixture +def test_dir(agfs_client): + """每个测试独享隔离目录,自动清理。""" + path = f"/local/tx-tests/{uuid.uuid4().hex}" + _mkdir_ok(agfs_client, "/local") + _mkdir_ok(agfs_client, "/local/tx-tests") + _mkdir_ok(agfs_client, path) + yield path + try: + agfs_client.rm(path, recursive=True) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# VectorDB fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="session") +def vector_store(tmp_path_factory): + """Session-scoped real local VectorDB backend.""" + db_path = str(tmp_path_factory.mktemp("vectordb")) + config = VectorDBBackendConfig( + backend="local", + name=COLLECTION_NAME, + path=db_path, + dimension=VECTOR_DIM, + ) + store = VikingVectorIndexBackend(config=config) + + import asyncio + + schema = CollectionSchemas.context_collection(COLLECTION_NAME, VECTOR_DIM) + asyncio.get_event_loop().run_until_complete(store.create_collection(COLLECTION_NAME, schema)) + + yield store + + asyncio.get_event_loop().run_until_complete(store.close()) + + +@pytest.fixture(scope="session") +def request_ctx(): + """Session-scoped RequestContext for VectorDB operations.""" + user = UserIdentifier("default", "test_user", "default") + return RequestContext(user=user, role=Role.ROOT) + + +# --------------------------------------------------------------------------- +# Transaction fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def tx_manager(agfs_client, vector_store): + """Function-scoped TransactionManager with real backends.""" + return TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + + +@pytest.fixture +def journal(agfs_client): + """Function-scoped TransactionJournal with real AGFS backend.""" + return TransactionJournal(agfs_client) + + +# --------------------------------------------------------------------------- +# Utility helpers +# --------------------------------------------------------------------------- + + +def file_exists(agfs_client, path) -> bool: + """Check if a file/dir exists in AGFS.""" + try: + agfs_client.stat(path) + return True + except Exception: + return False + + +def make_lock_file(agfs_client, dir_path, tx_id, lock_type="P") -> str: + """Create a real lock file in AGFS and return its path.""" + lock_path = f"{dir_path.rstrip('/')}/{LOCK_FILE_NAME}" + token = _make_fencing_token(tx_id, lock_type) + agfs_client.write(lock_path, token.encode("utf-8")) + return lock_path diff --git a/tests/transaction/test_concurrent_lock.py b/tests/transaction/test_concurrent_lock.py new file mode 100644 index 000000000..e98279e49 --- /dev/null +++ b/tests/transaction/test_concurrent_lock.py @@ -0,0 +1,103 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for concurrent lock acquisition using real AGFS backend.""" + +import asyncio +import uuid + +from openviking.storage.transaction.path_lock import PathLock +from openviking.storage.transaction.transaction_record import TransactionRecord + + +class TestConcurrentLock: + async def test_point_mutual_exclusion_same_path(self, agfs_client, test_dir): + """两个任务竞争同一路径的 POINT 锁,均最终成功(串行执行)。""" + lock = PathLock(agfs_client) + + results = {} + + async def holder(tx_id): + tx = TransactionRecord(id=tx_id) + ok = await lock.acquire_point(test_dir, tx, timeout=5.0) + if ok: + await asyncio.sleep(0.3) + await lock.release(tx) + results[tx_id] = ok + + await asyncio.gather( + holder("tx-conc-1"), + holder("tx-conc-2"), + ) + + # Both should eventually succeed (one waits for the other) + assert results["tx-conc-1"] is True + assert results["tx-conc-2"] is True + + async def test_subtree_blocks_concurrent_point_child(self, agfs_client, test_dir): + """SUBTREE on parent 持锁期间,子目录的 POINT 被阻塞,释放后成功。""" + child = f"{test_dir}/child-{uuid.uuid4().hex}" + agfs_client.mkdir(child) + + lock = PathLock(agfs_client) + parent_acquired = asyncio.Event() + parent_released = asyncio.Event() + + child_result = {} + + async def parent_holder(): + tx = TransactionRecord(id="tx-sub-parent") + ok = await lock.acquire_subtree(test_dir, tx, timeout=5.0) + assert ok is True + parent_acquired.set() + await asyncio.sleep(0.5) + await lock.release(tx) + parent_released.set() + + async def child_worker(): + await parent_acquired.wait() + tx = TransactionRecord(id="tx-sub-child") + ok = await lock.acquire_point(child, tx, timeout=5.0) + child_result["ok"] = ok + child_result["after_release"] = parent_released.is_set() + if ok: + await lock.release(tx) + + await asyncio.gather(parent_holder(), child_worker()) + + assert child_result["ok"] is True + # Child should succeed only after parent released + assert child_result["after_release"] is True + + async def test_point_child_blocks_concurrent_subtree_parent(self, agfs_client, test_dir): + """POINT on child 持锁期间,父目录的 SUBTREE 被阻塞,释放后成功。""" + child = f"{test_dir}/child-{uuid.uuid4().hex}" + agfs_client.mkdir(child) + + lock = PathLock(agfs_client) + child_acquired = asyncio.Event() + child_released = asyncio.Event() + + parent_result = {} + + async def child_holder(): + tx = TransactionRecord(id="tx-rev-child") + ok = await lock.acquire_point(child, tx, timeout=5.0) + assert ok is True + child_acquired.set() + await asyncio.sleep(0.5) + await lock.release(tx) + child_released.set() + + async def parent_worker(): + await child_acquired.wait() + tx = TransactionRecord(id="tx-rev-parent") + ok = await lock.acquire_subtree(test_dir, tx, timeout=5.0) + parent_result["ok"] = ok + parent_result["after_release"] = child_released.is_set() + if ok: + await lock.release(tx) + + await asyncio.gather(child_holder(), parent_worker()) + + assert parent_result["ok"] is True + assert parent_result["after_release"] is True diff --git a/tests/transaction/test_context_manager.py b/tests/transaction/test_context_manager.py new file mode 100644 index 000000000..bf077bf91 --- /dev/null +++ b/tests/transaction/test_context_manager.py @@ -0,0 +1,226 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for TransactionContext.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from openviking.storage.errors import LockAcquisitionError +from openviking.storage.transaction.context_manager import TransactionContext +from openviking.storage.transaction.transaction_record import TransactionRecord, TransactionStatus + + +def _make_tx_manager(lock_succeeds=True): + """Create a mock TransactionManager with async methods.""" + tx_manager = MagicMock() + record = TransactionRecord(id="tx-test", status=TransactionStatus.INIT) + + tx_manager.create_transaction.return_value = record + tx_manager.acquire_lock_point = AsyncMock(return_value=lock_succeeds) + tx_manager.acquire_lock_subtree = AsyncMock(return_value=lock_succeeds) + tx_manager.acquire_lock_mv = AsyncMock(return_value=lock_succeeds) + tx_manager.commit = AsyncMock(return_value=True) + tx_manager.rollback = AsyncMock(return_value=True) + + journal = MagicMock() + tx_manager.journal = journal + + return tx_manager, record + + +class TestTransactionContextNormal: + async def test_commit_success(self): + tx_manager, record = _make_tx_manager() + + async with TransactionContext(tx_manager, "test_op", ["/path"]) as tx: + seq = tx.record_undo("fs_write_new", {"uri": "/path/file"}) + tx.mark_completed(seq) + await tx.commit() + + tx_manager.commit.assert_called_once_with("tx-test") + tx_manager.rollback.assert_not_called() + + async def test_rollback_on_exception(self): + tx_manager, record = _make_tx_manager() + + with pytest.raises(ValueError): + async with TransactionContext(tx_manager, "test_op", ["/path"]) as tx: + seq = tx.record_undo("fs_write_new", {"uri": "/path/file"}) + tx.mark_completed(seq) + raise ValueError("something went wrong") + + tx_manager.rollback.assert_called_once_with("tx-test") + tx_manager.commit.assert_not_called() + + async def test_rollback_on_no_commit(self): + tx_manager, record = _make_tx_manager() + + async with TransactionContext(tx_manager, "test_op", ["/path"]) as tx: + tx.record_undo("fs_write_new", {"uri": "/path/file"}) + # Forgot to call tx.commit() + + tx_manager.rollback.assert_called_once_with("tx-test") + + async def test_lock_failure_raises(self): + tx_manager, record = _make_tx_manager(lock_succeeds=False) + + with pytest.raises(LockAcquisitionError): + async with TransactionContext(tx_manager, "test_op", ["/path"]) as _tx: + pass + + +class TestTransactionContextLockModes: + async def test_subtree_lock_mode(self): + tx_manager, record = _make_tx_manager() + + async with TransactionContext(tx_manager, "rm_op", ["/path"], lock_mode="subtree") as tx: + await tx.commit() + + tx_manager.acquire_lock_subtree.assert_called_once() + + async def test_mv_lock_mode(self): + tx_manager, record = _make_tx_manager() + + async with TransactionContext( + tx_manager, "mv_op", ["/src"], lock_mode="mv", mv_dst_path="/dst" + ) as tx: + await tx.commit() + + tx_manager.acquire_lock_mv.assert_called_once_with( + "tx-test", "/src", "/dst", src_is_dir=True + ) + + async def test_point_lock_mode(self): + tx_manager, record = _make_tx_manager() + + async with TransactionContext(tx_manager, "write_op", ["/path"], lock_mode="point") as tx: + await tx.commit() + + tx_manager.acquire_lock_point.assert_called_once() + + +class TestTransactionContextUndoLog: + async def test_undo_entries_tracked(self): + tx_manager, record = _make_tx_manager() + + async with TransactionContext(tx_manager, "test", ["/path"]) as tx: + s0 = tx.record_undo("fs_mkdir", {"uri": "/a"}) + s1 = tx.record_undo("fs_write_new", {"uri": "/a/f.txt"}) + tx.mark_completed(s0) + tx.mark_completed(s1) + await tx.commit() + + assert len(record.undo_log) == 2 + assert record.undo_log[0].completed is True + assert record.undo_log[1].completed is True + + +class TestTransactionContextPostActions: + async def test_post_actions_added(self): + tx_manager, record = _make_tx_manager() + + async with TransactionContext(tx_manager, "test", ["/path"]) as tx: + tx.add_post_action("enqueue_semantic", {"uri": "viking://test"}) + await tx.commit() + + assert len(record.post_actions) == 1 + assert record.post_actions[0]["type"] == "enqueue_semantic" + + +class TestTransactionContextEdgeCases: + async def test_commit_failure_raises_transaction_error(self): + """When TransactionManager.commit() returns False, TransactionError is raised.""" + from openviking.storage.errors import TransactionError + + tx_manager, record = _make_tx_manager() + tx_manager.commit = AsyncMock(return_value=False) + + with pytest.raises(TransactionError, match="Failed to commit"): + async with TransactionContext(tx_manager, "test", ["/path"]) as tx: + await tx.commit() + + async def test_mv_mode_missing_dst_raises(self): + """mv lock mode without mv_dst_path raises TransactionError.""" + from openviking.storage.errors import TransactionError + + tx_manager, record = _make_tx_manager() + + with pytest.raises(TransactionError, match="mv lock mode requires"): + async with TransactionContext( + tx_manager, "mv_op", ["/src"], lock_mode="mv", mv_dst_path=None + ) as _tx: + pass + + async def test_mark_completed_nonexistent_sequence_is_noop(self): + """mark_completed with a sequence not in undo_log doesn't crash.""" + tx_manager, record = _make_tx_manager() + + async with TransactionContext(tx_manager, "test", ["/path"]) as tx: + seq = tx.record_undo("fs_mkdir", {"uri": "/a"}) + tx.mark_completed(999) # Nonexistent sequence + # Original entry should remain unmarked + assert record.undo_log[0].completed is False + tx.mark_completed(seq) + assert record.undo_log[0].completed is True + await tx.commit() + + async def test_journal_update_failure_does_not_break_transaction(self): + """Journal update failures during record_undo/mark_completed are silently ignored.""" + tx_manager, record = _make_tx_manager() + tx_manager.journal.update.side_effect = Exception("disk full") + + # Should not raise despite journal failures + async with TransactionContext(tx_manager, "test", ["/path"]) as tx: + seq = tx.record_undo("fs_mkdir", {"uri": "/a"}) + tx.mark_completed(seq) + await tx.commit() + + assert len(record.undo_log) == 1 + assert record.undo_log[0].completed is True + + async def test_record_property_before_enter_raises(self): + """Accessing tx.record before __aenter__ raises TransactionError.""" + from openviking.storage.errors import TransactionError + + tx_manager, _ = _make_tx_manager() + ctx = TransactionContext(tx_manager, "test", ["/path"]) + + with pytest.raises(TransactionError, match="Transaction not started"): + _ = ctx.record + + async def test_multiple_undo_entries_sequence_increments(self): + tx_manager, record = _make_tx_manager() + + async with TransactionContext(tx_manager, "test", ["/path"]) as tx: + s0 = tx.record_undo("fs_mkdir", {"uri": "/a"}) + s1 = tx.record_undo("fs_write_new", {"uri": "/a/f"}) + s2 = tx.record_undo("fs_mv", {"src": "/a", "dst": "/b"}) + assert s0 == 0 + assert s1 == 1 + assert s2 == 2 + await tx.commit() + + async def test_multiple_lock_paths_point_mode(self): + """Multiple lock_paths in point mode: each path gets acquire_lock_point called.""" + tx_manager, record = _make_tx_manager() + + async with TransactionContext( + tx_manager, "multi", ["/path1", "/path2"], lock_mode="point" + ) as tx: + await tx.commit() + + assert tx_manager.acquire_lock_point.call_count == 2 + + async def test_subtree_multiple_paths_stops_on_first_failure(self): + """If acquiring subtree lock on first path fails, second path is not attempted.""" + tx_manager, record = _make_tx_manager(lock_succeeds=False) + + with pytest.raises(LockAcquisitionError): + async with TransactionContext( + tx_manager, "rm", ["/path1", "/path2"], lock_mode="subtree" + ) as _tx: + pass + + # Only called once (failed on first path) + assert tx_manager.acquire_lock_subtree.call_count == 1 diff --git a/tests/transaction/test_crash_recovery.py b/tests/transaction/test_crash_recovery.py new file mode 100644 index 000000000..21569edda --- /dev/null +++ b/tests/transaction/test_crash_recovery.py @@ -0,0 +1,561 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Integration test: crash recovery from journal using real AGFS and VectorDB backends.""" + +import uuid +from unittest.mock import AsyncMock, patch + +from openviking.storage.transaction.journal import TransactionJournal +from openviking.storage.transaction.transaction_manager import TransactionManager +from openviking.storage.transaction.transaction_record import ( + TransactionRecord, + TransactionStatus, +) +from openviking.storage.transaction.undo import UndoEntry + +from .conftest import VECTOR_DIM, _mkdir_ok, file_exists, make_lock_file + + +def _write_journal(journal, record): + """Write a TransactionRecord to real journal storage.""" + journal.write(record.to_journal()) + + +class TestCrashRecovery: + """ + Core technique: simulate crash recovery. + + 1. Create real FS state via agfs_client + 2. Build TransactionRecord, write to real journal + 3. Create fresh TransactionManager (simulates process restart) + 4. Call manager._recover_pending_transactions() + 5. Verify final state via agfs_client.stat()/cat() and vector_store.get() + """ + + async def test_recover_commit_no_rollback(self, agfs_client, vector_store, test_dir): + """COMMIT status → committed files NOT rolled back, journal cleaned up.""" + # Create a file that was part of a committed transaction + committed_file = f"{test_dir}/committed.txt" + agfs_client.write(committed_file, b"committed data") + + journal = TransactionJournal(agfs_client) + tx_id = f"tx-commit-{uuid.uuid4().hex[:8]}" + record = TransactionRecord( + id=tx_id, + status=TransactionStatus.COMMIT, + locks=[], + undo_log=[ + UndoEntry( + sequence=0, + op_type="fs_write_new", + params={"uri": committed_file}, + completed=True, + ) + ], + post_actions=[], + ) + _write_journal(journal, record) + + # New manager (simulates restart) + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + # File should still exist (no rollback for committed tx) + assert file_exists(agfs_client, committed_file) + # Journal should be cleaned up + assert tx_id not in journal.list_all() + + async def test_recover_commit_replays_post_actions(self, agfs_client, vector_store, test_dir): + """COMMIT + post_actions → replay post_actions.""" + journal = TransactionJournal(agfs_client) + tx_id = f"tx-post-{uuid.uuid4().hex[:8]}" + record = TransactionRecord( + id=tx_id, + status=TransactionStatus.COMMIT, + locks=[], + undo_log=[], + post_actions=[ + { + "type": "enqueue_semantic", + "params": { + "uri": "viking://test-post", + "context_type": "resource", + "account_id": "acc", + }, + } + ], + ) + _write_journal(journal, record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + + with patch.object(manager, "_execute_post_actions", new_callable=AsyncMock) as mock_post: + await manager._recover_pending_transactions() + + mock_post.assert_called_once() + assert tx_id not in journal.list_all() + + async def test_recover_exec_rollback_fs_mv(self, agfs_client, vector_store, test_dir): + """EXEC status with fs_mv → recovery rolls back → file moved back.""" + src = f"{test_dir}/exec-mv-src" + dst = f"{test_dir}/exec-mv-dst" + _mkdir_ok(agfs_client, src) + agfs_client.write(f"{src}/data.txt", b"mv-data") + + # Simulate: forward mv happened, then crash + agfs_client.mv(src, dst) + assert not file_exists(agfs_client, src) + + journal = TransactionJournal(agfs_client) + tx_id = f"tx-exec-mv-{uuid.uuid4().hex[:8]}" + record = TransactionRecord( + id=tx_id, + status=TransactionStatus.EXEC, + locks=[], + undo_log=[ + UndoEntry( + sequence=0, + op_type="fs_mv", + params={"src": src, "dst": dst}, + completed=True, + ) + ], + post_actions=[], + ) + _write_journal(journal, record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + assert file_exists(agfs_client, src) + assert not file_exists(agfs_client, dst) + assert tx_id not in journal.list_all() + + async def test_recover_exec_rollback_fs_mkdir(self, agfs_client, vector_store, test_dir): + """EXEC with fs_mkdir → recovery → directory removed.""" + new_dir = f"{test_dir}/exec-mkdir" + _mkdir_ok(agfs_client, new_dir) + + journal = TransactionJournal(agfs_client) + tx_id = f"tx-exec-mkdir-{uuid.uuid4().hex[:8]}" + record = TransactionRecord( + id=tx_id, + status=TransactionStatus.EXEC, + locks=[], + undo_log=[ + UndoEntry( + sequence=0, + op_type="fs_mkdir", + params={"uri": new_dir}, + completed=True, + ) + ], + post_actions=[], + ) + _write_journal(journal, record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + assert not file_exists(agfs_client, new_dir) + assert tx_id not in journal.list_all() + + async def test_recover_exec_rollback_fs_write_new(self, agfs_client, vector_store, test_dir): + """EXEC with fs_write_new → recovery → file removed.""" + file_path = f"{test_dir}/exec-write.txt" + agfs_client.write(file_path, b"to-be-rolled-back") + + journal = TransactionJournal(agfs_client) + tx_id = f"tx-exec-write-{uuid.uuid4().hex[:8]}" + record = TransactionRecord( + id=tx_id, + status=TransactionStatus.EXEC, + locks=[], + undo_log=[ + UndoEntry( + sequence=0, + op_type="fs_write_new", + params={"uri": file_path}, + completed=True, + ) + ], + post_actions=[], + ) + _write_journal(journal, record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + assert not file_exists(agfs_client, file_path) + assert tx_id not in journal.list_all() + + async def test_recover_exec_rollback_vectordb_upsert( + self, agfs_client, vector_store, request_ctx, test_dir + ): + """EXEC with vectordb_upsert → recovery → record deleted from VectorDB.""" + record_id = str(uuid.uuid4()) + record = { + "id": record_id, + "uri": f"viking://resources/crash-upsert-{record_id}.md", + "parent_uri": "viking://resources/", + "account_id": "default", + "context_type": "resource", + "level": 2, + "vector": [0.5] * VECTOR_DIM, + "name": "crash-upsert", + "description": "test", + "abstract": "test", + } + await vector_store.upsert(record, ctx=request_ctx) + assert len(await vector_store.get([record_id], ctx=request_ctx)) == 1 + + journal = TransactionJournal(agfs_client) + tx_id = f"tx-exec-vdb-{uuid.uuid4().hex[:8]}" + tx_record = TransactionRecord( + id=tx_id, + status=TransactionStatus.EXEC, + locks=[], + undo_log=[ + UndoEntry( + sequence=0, + op_type="vectordb_upsert", + params={ + "record_id": record_id, + "_ctx_account_id": "default", + "_ctx_user_id": "test_user", + "_ctx_role": "root", + }, + completed=True, + ) + ], + post_actions=[], + ) + _write_journal(journal, tx_record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + results = await vector_store.get([record_id], ctx=request_ctx) + assert len(results) == 0 + assert tx_id not in journal.list_all() + + async def test_recover_fail_triggers_rollback(self, agfs_client, vector_store, test_dir): + """FAIL status → also triggers rollback.""" + new_dir = f"{test_dir}/fail-dir" + _mkdir_ok(agfs_client, new_dir) + + journal = TransactionJournal(agfs_client) + tx_id = f"tx-fail-{uuid.uuid4().hex[:8]}" + record = TransactionRecord( + id=tx_id, + status=TransactionStatus.FAIL, + locks=[], + undo_log=[ + UndoEntry( + sequence=0, + op_type="fs_mkdir", + params={"uri": new_dir}, + completed=True, + ) + ], + post_actions=[], + ) + _write_journal(journal, record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + assert not file_exists(agfs_client, new_dir) + assert tx_id not in journal.list_all() + + async def test_recover_releasing_triggers_rollback(self, agfs_client, vector_store, test_dir): + """RELEASING status → rollback + lock cleanup.""" + new_dir = f"{test_dir}/releasing-dir" + _mkdir_ok(agfs_client, new_dir) + + lock_path = make_lock_file(agfs_client, test_dir, "tx-releasing-placeholder", "S") + + journal = TransactionJournal(agfs_client) + tx_id = f"tx-releasing-{uuid.uuid4().hex[:8]}" + # Rewrite lock with correct tx_id + lock_path = make_lock_file(agfs_client, test_dir, tx_id, "S") + + record = TransactionRecord( + id=tx_id, + status=TransactionStatus.RELEASING, + locks=[lock_path], + undo_log=[ + UndoEntry( + sequence=0, + op_type="fs_mkdir", + params={"uri": new_dir}, + completed=True, + ) + ], + post_actions=[], + ) + _write_journal(journal, record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + assert not file_exists(agfs_client, new_dir) + assert not file_exists(agfs_client, lock_path) + assert tx_id not in journal.list_all() + + async def test_recover_exec_includes_incomplete(self, agfs_client, vector_store, test_dir): + """EXEC recovery uses recover_all=True → also reverses incomplete entries.""" + new_dir = f"{test_dir}/exec-incomplete" + _mkdir_ok(agfs_client, new_dir) + + journal = TransactionJournal(agfs_client) + tx_id = f"tx-exec-inc-{uuid.uuid4().hex[:8]}" + record = TransactionRecord( + id=tx_id, + status=TransactionStatus.EXEC, + locks=[], + undo_log=[ + UndoEntry( + sequence=0, + op_type="fs_mkdir", + params={"uri": new_dir}, + completed=False, # incomplete, but recover_all=True reverses it + ) + ], + post_actions=[], + ) + _write_journal(journal, record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + assert not file_exists(agfs_client, new_dir) + assert tx_id not in journal.list_all() + + async def test_recover_init_cleans_locks(self, agfs_client, vector_store, test_dir): + """INIT status → no rollback, just lock cleanup + journal delete.""" + lock_dir = f"{test_dir}/init-lock-dir" + _mkdir_ok(agfs_client, lock_dir) + + tx_id = f"tx-init-{uuid.uuid4().hex[:8]}" + lock_path = make_lock_file(agfs_client, lock_dir, tx_id, "P") + + journal = TransactionJournal(agfs_client) + record = TransactionRecord( + id=tx_id, + status=TransactionStatus.INIT, + locks=[lock_path], + undo_log=[], + post_actions=[], + ) + _write_journal(journal, record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + assert not file_exists(agfs_client, lock_path) + assert tx_id not in journal.list_all() + + async def test_recover_acquire_cleans_locks(self, agfs_client, vector_store, test_dir): + """ACQUIRE status → same as INIT, clean up only.""" + lock_dir = f"{test_dir}/acquire-lock-dir" + _mkdir_ok(agfs_client, lock_dir) + + tx_id = f"tx-acq-{uuid.uuid4().hex[:8]}" + lock_path = make_lock_file(agfs_client, lock_dir, tx_id, "P") + + journal = TransactionJournal(agfs_client) + record = TransactionRecord( + id=tx_id, + status=TransactionStatus.ACQUIRE, + locks=[lock_path], + undo_log=[], + post_actions=[], + ) + _write_journal(journal, record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + assert not file_exists(agfs_client, lock_path) + assert tx_id not in journal.list_all() + + async def test_recover_init_orphan_lock_via_init_info( + self, agfs_client, vector_store, test_dir + ): + """INIT with empty locks but init_info.lock_paths → clean orphan lock owned by tx.""" + orphan_dir = f"{test_dir}/orphan-dir" + _mkdir_ok(agfs_client, orphan_dir) + + tx_id = f"tx-orphan-{uuid.uuid4().hex[:8]}" + lock_path = make_lock_file(agfs_client, orphan_dir, tx_id, "S") + + journal = TransactionJournal(agfs_client) + record = TransactionRecord( + id=tx_id, + status=TransactionStatus.INIT, + locks=[], # Empty — crash happened before journal recorded locks + init_info={ + "operation": "rm", + "lock_paths": [orphan_dir], + "lock_mode": "subtree", + }, + undo_log=[], + post_actions=[], + ) + _write_journal(journal, record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + assert not file_exists(agfs_client, lock_path) + assert tx_id not in journal.list_all() + + async def test_recover_init_orphan_lock_other_owner(self, agfs_client, vector_store, test_dir): + """INIT with orphan lock owned by different tx → not removed.""" + orphan_dir = f"{test_dir}/orphan-other" + _mkdir_ok(agfs_client, orphan_dir) + + other_tx_id = f"tx-OTHER-{uuid.uuid4().hex[:8]}" + lock_path = make_lock_file(agfs_client, orphan_dir, other_tx_id, "S") + + tx_id = f"tx-innocent-{uuid.uuid4().hex[:8]}" + journal = TransactionJournal(agfs_client) + record = TransactionRecord( + id=tx_id, + status=TransactionStatus.INIT, + locks=[], + init_info={ + "operation": "rm", + "lock_paths": [orphan_dir], + "lock_mode": "subtree", + }, + undo_log=[], + post_actions=[], + ) + _write_journal(journal, record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + # Lock file should still exist — owned by different tx + assert file_exists(agfs_client, lock_path) + assert tx_id not in journal.list_all() + + async def test_recover_mv_orphan_both_paths(self, agfs_client, vector_store, test_dir): + """INIT mv operation → check both lock_paths and mv_dst_path for orphan locks.""" + src_dir = f"{test_dir}/mv-orphan-src" + dst_dir = f"{test_dir}/mv-orphan-dst" + _mkdir_ok(agfs_client, src_dir) + _mkdir_ok(agfs_client, dst_dir) + + tx_id = f"tx-mv-orphan-{uuid.uuid4().hex[:8]}" + src_lock = make_lock_file(agfs_client, src_dir, tx_id, "S") + dst_lock = make_lock_file(agfs_client, dst_dir, tx_id, "P") + + journal = TransactionJournal(agfs_client) + record = TransactionRecord( + id=tx_id, + status=TransactionStatus.INIT, + locks=[], + init_info={ + "operation": "mv", + "lock_paths": [src_dir], + "lock_mode": "mv", + "mv_dst_path": dst_dir, + }, + undo_log=[], + post_actions=[], + ) + _write_journal(journal, record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + # Both orphan locks should be cleaned up + assert not file_exists(agfs_client, src_lock) + assert not file_exists(agfs_client, dst_lock) + assert tx_id not in journal.list_all() + + async def test_recover_multiple_transactions(self, agfs_client, vector_store, test_dir): + """Multiple journal entries are all recovered.""" + dir_a = f"{test_dir}/multi-tx-a" + _mkdir_ok(agfs_client, dir_a) + + journal = TransactionJournal(agfs_client) + + # tx-a: EXEC with mkdir → should rollback + tx_a = f"tx-multi-a-{uuid.uuid4().hex[:8]}" + record_a = TransactionRecord( + id=tx_a, + status=TransactionStatus.EXEC, + locks=[], + undo_log=[ + UndoEntry( + sequence=0, + op_type="fs_mkdir", + params={"uri": dir_a}, + completed=True, + ) + ], + post_actions=[], + ) + _write_journal(journal, record_a) + + # tx-b: COMMIT → no rollback, just cleanup + tx_b = f"tx-multi-b-{uuid.uuid4().hex[:8]}" + record_b = TransactionRecord( + id=tx_b, + status=TransactionStatus.COMMIT, + locks=[], + undo_log=[], + post_actions=[], + ) + _write_journal(journal, record_b) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + assert not file_exists(agfs_client, dir_a) # rolled back + assert tx_a not in journal.list_all() + assert tx_b not in journal.list_all() + + async def test_recover_corrupted_journal_skips(self, agfs_client, vector_store, test_dir): + """Corrupted journal entry → skipped, others still processed.""" + journal = TransactionJournal(agfs_client) + + # Write a corrupted journal entry (invalid JSON) + bad_tx_id = f"tx-bad-{uuid.uuid4().hex[:8]}" + _mkdir_ok(agfs_client, "/local/_system") + _mkdir_ok(agfs_client, "/local/_system/transactions") + bad_dir = f"/local/_system/transactions/{bad_tx_id}" + _mkdir_ok(agfs_client, bad_dir) + agfs_client.write(f"{bad_dir}/journal.json", b"NOT VALID JSON {{{{") + + # Write a good journal entry + good_dir = f"{test_dir}/good-recovery" + _mkdir_ok(agfs_client, good_dir) + + good_tx_id = f"tx-good-{uuid.uuid4().hex[:8]}" + record = TransactionRecord( + id=good_tx_id, + status=TransactionStatus.EXEC, + locks=[], + undo_log=[ + UndoEntry( + sequence=0, + op_type="fs_mkdir", + params={"uri": good_dir}, + completed=True, + ) + ], + post_actions=[], + ) + _write_journal(journal, record) + + manager = TransactionManager(agfs_client=agfs_client, vector_store=vector_store) + await manager._recover_pending_transactions() + + # Good tx should still be recovered + assert not file_exists(agfs_client, good_dir) + assert good_tx_id not in journal.list_all() diff --git a/tests/transaction/test_e2e.py b/tests/transaction/test_e2e.py new file mode 100644 index 000000000..d7b850c40 --- /dev/null +++ b/tests/transaction/test_e2e.py @@ -0,0 +1,437 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""End-to-end transaction tests using real AGFS backend. + +These tests exercise the full stack: TransactionContext → TransactionManager → +PathLock → Journal → AGFS, verifying the complete acquire → operate → commit/rollback +→ release → journal cleanup lifecycle. +""" + +import uuid + +import pytest + +from openviking.storage.transaction.context_manager import TransactionContext +from openviking.storage.transaction.journal import TransactionJournal +from openviking.storage.transaction.path_lock import LOCK_FILE_NAME +from openviking.storage.transaction.transaction_manager import TransactionManager + + +@pytest.fixture +def tx_manager(agfs_client): + """Create a real TransactionManager backed by the test AGFS.""" + manager = TransactionManager( + agfs_client=agfs_client, + timeout=3600, + max_parallel_locks=8, + lock_timeout=1.0, + lock_expire=1.0, + ) + return manager + + +class TestE2ECommit: + async def test_full_commit_lifecycle(self, agfs_client, tx_manager, test_dir): + """Full lifecycle: context enter → record undo → commit → locks released → journal cleaned.""" + async with TransactionContext( + tx_manager, "test_write", [test_dir], lock_mode="point" + ) as tx: + # Lock should be acquired + lock_path = f"{test_dir}/{LOCK_FILE_NAME}" + token = agfs_client.cat(lock_path) + assert token is not None + + # Record some operations + seq = tx.record_undo("fs_write_new", {"uri": f"{test_dir}/file.txt"}) + agfs_client.write(f"{test_dir}/file.txt", b"hello") + tx.mark_completed(seq) + + # Add post action + tx.add_post_action( + "enqueue_semantic", + {"uri": "viking://test", "context_type": "resource", "account_id": "default"}, + ) + + await tx.commit() + + # After commit: lock should be released + try: + agfs_client.cat(lock_path) + raise AssertionError("Lock file should be gone after commit") + except Exception: + pass # Expected + + # Transaction should be removed from manager + assert tx_manager.get_transaction(tx.record.id) is None + + async def test_commit_file_persists(self, agfs_client, tx_manager, test_dir): + """Files written inside a committed transaction persist.""" + file_path = f"{test_dir}/committed-file.txt" + + async with TransactionContext(tx_manager, "write_op", [test_dir], lock_mode="point") as tx: + seq = tx.record_undo("fs_write_new", {"uri": file_path}) + agfs_client.write(file_path, b"committed data") + tx.mark_completed(seq) + await tx.commit() + + content = agfs_client.cat(file_path) + assert content == b"committed data" + + +class TestE2ERollback: + async def test_explicit_exception_triggers_rollback(self, agfs_client, tx_manager, test_dir): + """Exception inside context → auto-rollback → undo operations reversed.""" + new_dir = f"{test_dir}/to-be-rolled-back-{uuid.uuid4().hex}" + + with pytest.raises(RuntimeError): + async with TransactionContext( + tx_manager, "failing_op", [test_dir], lock_mode="point" + ) as tx: + seq = tx.record_undo("fs_mkdir", {"uri": new_dir}) + agfs_client.mkdir(new_dir) + tx.mark_completed(seq) + + raise RuntimeError("simulated failure") + + # Directory should be removed by rollback + try: + agfs_client.stat(new_dir) + raise AssertionError("Directory should be removed by rollback") + except Exception: + pass + + # Lock should be released + lock_path = f"{test_dir}/{LOCK_FILE_NAME}" + try: + agfs_client.cat(lock_path) + raise AssertionError("Lock should be released after rollback") + except Exception: + pass + + async def test_no_commit_triggers_rollback(self, agfs_client, tx_manager, test_dir): + """Exiting context without calling commit() triggers auto-rollback.""" + new_dir = f"{test_dir}/forgot-commit-{uuid.uuid4().hex}" + + async with TransactionContext(tx_manager, "no_commit", [test_dir], lock_mode="point") as tx: + seq = tx.record_undo("fs_mkdir", {"uri": new_dir}) + agfs_client.mkdir(new_dir) + tx.mark_completed(seq) + # Intentionally not calling tx.commit() + + # Directory should be removed by rollback + try: + agfs_client.stat(new_dir) + raise AssertionError("Directory should be removed by rollback") + except Exception: + pass + + +class TestE2EMvLock: + async def test_mv_lock_acquires_both_paths(self, agfs_client, tx_manager, test_dir): + """mv lock mode acquires SUBTREE on both source and destination.""" + src = f"{test_dir}/mv-src-{uuid.uuid4().hex}" + dst = f"{test_dir}/mv-dst-{uuid.uuid4().hex}" + agfs_client.mkdir(src) + agfs_client.mkdir(dst) + + async with TransactionContext( + tx_manager, "mv_op", [src], lock_mode="mv", mv_dst_path=dst + ) as tx: + # Both lock files should exist + src_token = agfs_client.cat(f"{src}/{LOCK_FILE_NAME}") + dst_token = agfs_client.cat(f"{dst}/{LOCK_FILE_NAME}") + src_token_str = src_token.decode("utf-8") if isinstance(src_token, bytes) else src_token + dst_token_str = dst_token.decode("utf-8") if isinstance(dst_token, bytes) else dst_token + + assert ":S" in src_token_str # SUBTREE on source + assert ":S" in dst_token_str # SUBTREE on destination + + await tx.commit() + + # Both locks released + for path in [f"{src}/{LOCK_FILE_NAME}", f"{dst}/{LOCK_FILE_NAME}"]: + try: + agfs_client.cat(path) + raise AssertionError(f"Lock {path} should be gone") + except Exception: + pass + + +class TestE2ESubtreeRollback: + async def test_subtree_lock_with_rollback(self, agfs_client, tx_manager, test_dir): + """Subtree lock + rollback: undo is executed and lock released.""" + target = f"{test_dir}/sub-rb-{uuid.uuid4().hex}" + agfs_client.mkdir(target) + + child = f"{target}/child-{uuid.uuid4().hex}" + + with pytest.raises(ValueError): + async with TransactionContext(tx_manager, "rm_op", [target], lock_mode="subtree") as tx: + seq = tx.record_undo("fs_mkdir", {"uri": child}) + agfs_client.mkdir(child) + tx.mark_completed(seq) + + raise ValueError("abort rm") + + # Child dir should be removed by rollback + try: + agfs_client.stat(child) + raise AssertionError("Child should be cleaned up") + except Exception: + pass + + # Lock released + try: + agfs_client.cat(f"{target}/{LOCK_FILE_NAME}") + raise AssertionError("Lock should be released") + except Exception: + pass + + +class TestE2EJournalCleanup: + async def test_journal_cleaned_after_commit(self, agfs_client, tx_manager, test_dir): + """After successful commit, the journal entry for the transaction is deleted.""" + journal = TransactionJournal(agfs_client) + + async with TransactionContext( + tx_manager, "journal_test", [test_dir], lock_mode="point" + ) as tx: + tx_id = tx.record.id + await tx.commit() + + # Journal should be cleaned up + all_ids = journal.list_all() + assert tx_id not in all_ids + + async def test_journal_cleaned_after_rollback(self, agfs_client, tx_manager, test_dir): + """After rollback, the journal entry is also cleaned up.""" + journal = TransactionJournal(agfs_client) + + with pytest.raises(RuntimeError): + async with TransactionContext( + tx_manager, "journal_rb", [test_dir], lock_mode="point" + ) as tx: + tx_id = tx.record.id + raise RuntimeError("force rollback") + + all_ids = journal.list_all() + assert tx_id not in all_ids + + +class TestE2EMvRollback: + async def test_mv_rollback_moves_file_back(self, agfs_client, tx_manager, test_dir): + """mv commit 前失败 → 文件被移回原位。""" + src = f"{test_dir}/mv-rb-src-{uuid.uuid4().hex}" + dst_parent = f"{test_dir}/mv-rb-dst-{uuid.uuid4().hex}" + agfs_client.mkdir(src) + agfs_client.mkdir(dst_parent) + + # Write a file inside src + agfs_client.write(f"{src}/data.txt", b"important") + + dst = f"{dst_parent}/moved" + + with pytest.raises(RuntimeError): + async with TransactionContext( + tx_manager, "mv_op", [src], lock_mode="mv", mv_dst_path=dst_parent + ) as tx: + seq = tx.record_undo("fs_mv", {"src": src, "dst": dst}) + agfs_client.mv(src, dst) + tx.mark_completed(seq) + + raise RuntimeError("abort after mv") + + # src should be restored (mv reversed: dst → src) + content = agfs_client.cat(f"{src}/data.txt") + assert content == b"important" + + # dst should no longer exist + try: + agfs_client.stat(dst) + raise AssertionError("dst should not exist after rollback") + except Exception: + pass + + async def test_mv_commit_persists(self, agfs_client, tx_manager, test_dir): + """mv commit 成功 → 文件在新位置,旧位置不存在。""" + src = f"{test_dir}/mv-ok-src-{uuid.uuid4().hex}" + dst_parent = f"{test_dir}/mv-ok-dst-{uuid.uuid4().hex}" + agfs_client.mkdir(src) + agfs_client.mkdir(dst_parent) + agfs_client.write(f"{src}/data.txt", b"moved-data") + + dst = f"{dst_parent}/moved" + + async with TransactionContext( + tx_manager, "mv_op", [src], lock_mode="mv", mv_dst_path=dst_parent + ) as tx: + seq = tx.record_undo("fs_mv", {"src": src, "dst": dst}) + agfs_client.mv(src, dst) + tx.mark_completed(seq) + await tx.commit() + + # File at new location + content = agfs_client.cat(f"{dst}/data.txt") + assert content == b"moved-data" + + # Old location gone + try: + agfs_client.stat(src) + raise AssertionError("src should not exist after committed mv") + except Exception: + pass + + +class TestE2EMultiStepRollback: + async def test_multi_step_rollback_reverses_all(self, agfs_client, tx_manager, test_dir): + """多步操作(mkdir + write + mkdir),中间失败 → 全部反序回滚。 + + 执行顺序:seq0 mkdir /a → seq1 write /a/f.txt → seq2 mkdir /a/sub + 在 seq2 完成后抛异常。 + 回滚顺序:seq2 rm /a/sub → seq1 rm /a/f.txt → seq0 rm /a + """ + dir_a = f"{test_dir}/multi-a-{uuid.uuid4().hex}" + file_f = f"{dir_a}/f.txt" + dir_sub = f"{dir_a}/sub" + + with pytest.raises(RuntimeError): + async with TransactionContext( + tx_manager, "multi_step", [test_dir], lock_mode="point" + ) as tx: + s0 = tx.record_undo("fs_mkdir", {"uri": dir_a}) + agfs_client.mkdir(dir_a) + tx.mark_completed(s0) + + s1 = tx.record_undo("fs_write_new", {"uri": file_f}) + agfs_client.write(file_f, b"content") + tx.mark_completed(s1) + + s2 = tx.record_undo("fs_mkdir", {"uri": dir_sub}) + agfs_client.mkdir(dir_sub) + tx.mark_completed(s2) + + raise RuntimeError("abort after all steps") + + # Everything should be cleaned up in reverse order + for path in [dir_sub, file_f, dir_a]: + try: + agfs_client.stat(path) + raise AssertionError(f"{path} should not exist after rollback") + except Exception: + pass + + async def test_partial_step_rollback(self, agfs_client, tx_manager, test_dir): + """两步操作,第二步执行到一半崩溃(未 mark_completed)→ 只回滚第一步。 + + seq0 mkdir (completed=True) → seq1 write (completed=False,异常在 mark 前抛出) + 回滚只处理 seq0。 + """ + dir_a = f"{test_dir}/partial-{uuid.uuid4().hex}" + file_f = f"{dir_a}/f.txt" + + with pytest.raises(RuntimeError): + async with TransactionContext( + tx_manager, "partial", [test_dir], lock_mode="point" + ) as tx: + s0 = tx.record_undo("fs_mkdir", {"uri": dir_a}) + agfs_client.mkdir(dir_a) + tx.mark_completed(s0) + + _s1 = tx.record_undo("fs_write_new", {"uri": file_f}) + agfs_client.write(file_f, b"half-done") + # NOT calling tx.mark_completed(s1) — simulates crash mid-operation + raise RuntimeError("crash before marking s1 completed") + + # dir_a (seq0, completed) should be rolled back + try: + agfs_client.stat(dir_a) + raise AssertionError("dir_a should be rolled back") + except Exception: + pass + + # file_f was written but undo entry not marked completed → not rolled back by normal mode + # However, file_f is inside dir_a which was removed, so it's gone too + + async def test_rollback_order_matters_nested_dirs(self, agfs_client, tx_manager, test_dir): + """嵌套目录回滚顺序:必须先删子目录再删父目录。 + + seq0 mkdir /parent → seq1 mkdir /parent/child + 回滚必须 seq1 (rm child) → seq0 (rm parent),否则 parent 非空删除失败。 + """ + parent = f"{test_dir}/nested-parent-{uuid.uuid4().hex}" + child = f"{parent}/child" + + with pytest.raises(RuntimeError): + async with TransactionContext( + tx_manager, "nested", [test_dir], lock_mode="point" + ) as tx: + s0 = tx.record_undo("fs_mkdir", {"uri": parent}) + agfs_client.mkdir(parent) + tx.mark_completed(s0) + + s1 = tx.record_undo("fs_mkdir", {"uri": child}) + agfs_client.mkdir(child) + tx.mark_completed(s1) + + raise RuntimeError("abort nested") + + # Both gone (child first, then parent) + for path in [child, parent]: + try: + agfs_client.stat(path) + raise AssertionError(f"{path} should not exist") + except Exception: + pass + + async def test_rollback_failure_best_effort_continues(self, agfs_client, tx_manager, test_dir): + """回滚中某步失败,后续步骤仍然执行(best-effort)。 + + seq0 mkdir /a → seq1 mkdir /b + 手动删除 /b(模拟回滚 seq1 时目标已不存在),seq0 的回滚仍应执行。 + """ + dir_a = f"{test_dir}/be-a-{uuid.uuid4().hex}" + dir_b = f"{test_dir}/be-b-{uuid.uuid4().hex}" + + with pytest.raises(RuntimeError): + async with TransactionContext( + tx_manager, "best_effort", [test_dir], lock_mode="point" + ) as tx: + s0 = tx.record_undo("fs_mkdir", {"uri": dir_a}) + agfs_client.mkdir(dir_a) + tx.mark_completed(s0) + + s1 = tx.record_undo("fs_mkdir", {"uri": dir_b}) + agfs_client.mkdir(dir_b) + tx.mark_completed(s1) + + # Manually remove dir_b before rollback — simulates external interference + agfs_client.rm(dir_b) + + raise RuntimeError("abort") + + # dir_b removal during rollback "fails" (already gone), but dir_a should still be rolled back + try: + agfs_client.stat(dir_a) + raise AssertionError("dir_a should be rolled back despite dir_b failure") + except Exception: + pass + + +class TestE2ESequentialTransactions: + async def test_sequential_transactions_on_same_path(self, agfs_client, tx_manager, test_dir): + """Two sequential transactions on the same path both succeed.""" + for i in range(3): + async with TransactionContext( + tx_manager, f"seq_{i}", [test_dir], lock_mode="point" + ) as tx: + seq = tx.record_undo("fs_write_new", {"uri": f"{test_dir}/f{i}.txt"}) + agfs_client.write(f"{test_dir}/f{i}.txt", f"data-{i}".encode()) + tx.mark_completed(seq) + await tx.commit() + + # All files should exist + for i in range(3): + content = agfs_client.cat(f"{test_dir}/f{i}.txt") + assert content == f"data-{i}".encode() + + assert tx_manager.get_transaction_count() == 0 diff --git a/tests/transaction/test_journal.py b/tests/transaction/test_journal.py new file mode 100644 index 000000000..57f1e483c --- /dev/null +++ b/tests/transaction/test_journal.py @@ -0,0 +1,215 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for transaction journal.""" + +import json +import uuid +from unittest.mock import MagicMock + +from openviking.storage.transaction.journal import TransactionJournal + + +class TestTransactionJournal: + def _make_journal(self) -> tuple: + agfs = MagicMock() + journal = TransactionJournal(agfs) + return journal, agfs + + def test_write_calls_agfs_write_with_correct_data(self): + journal, agfs = self._make_journal() + data = {"id": "tx-1", "status": "INIT", "locks": []} + + journal.write(data) + + # Should call agfs.write with the journal path and serialized data + agfs.write.assert_called_once() + path, payload = agfs.write.call_args[0] + assert "tx-1" in path + assert path.endswith("journal.json") + parsed = json.loads(payload.decode("utf-8")) + assert parsed["id"] == "tx-1" + assert parsed["status"] == "INIT" + + def test_write_ensures_directories_exist(self): + journal, agfs = self._make_journal() + data = {"id": "tx-1", "status": "INIT", "locks": []} + + journal.write(data) + + # Should call mkdir at least once (for parent dirs) + assert agfs.mkdir.called + + def test_update_overwrites(self): + journal, agfs = self._make_journal() + data = {"id": "tx-2", "status": "EXEC", "locks": []} + + journal.update(data) + + agfs.write.assert_called_once() + path, payload = agfs.write.call_args[0] + assert json.loads(payload.decode("utf-8"))["status"] == "EXEC" + + def test_read_parses_json(self): + journal, agfs = self._make_journal() + agfs.cat.return_value = json.dumps({"id": "tx-3", "status": "EXEC"}).encode("utf-8") + + result = journal.read("tx-3") + assert result["id"] == "tx-3" + assert result["status"] == "EXEC" + + def test_read_handles_string_response(self): + """Some AGFS backends may return str instead of bytes.""" + journal, agfs = self._make_journal() + agfs.cat.return_value = json.dumps({"id": "tx-str", "status": "INIT"}) + + result = journal.read("tx-str") + assert result["id"] == "tx-str" + + def test_delete_removes_directory(self): + journal, agfs = self._make_journal() + journal.delete("tx-4") + agfs.rm.assert_called_once() + path = agfs.rm.call_args[0][0] + assert "tx-4" in path + + def test_list_all_returns_tx_ids(self): + journal, agfs = self._make_journal() + agfs.ls.return_value = [ + {"name": "tx-a", "isDir": True}, + {"name": "tx-b", "isDir": True}, + {"name": ".", "isDir": True}, + ] + + result = journal.list_all() + assert "tx-a" in result + assert "tx-b" in result + assert "." not in result + + def test_list_all_filters_dotdot(self): + journal, agfs = self._make_journal() + agfs.ls.return_value = [ + {"name": "..", "isDir": True}, + {"name": "tx-real", "isDir": True}, + ] + + result = journal.list_all() + assert ".." not in result + assert "tx-real" in result + + def test_list_all_empty_on_error(self): + journal, agfs = self._make_journal() + agfs.ls.side_effect = Exception("not found") + + result = journal.list_all() + assert result == [] + + def test_delete_tolerates_missing(self): + journal, agfs = self._make_journal() + agfs.rm.side_effect = Exception("not found") + # Should not raise + journal.delete("tx-missing") + + def test_write_with_post_actions(self): + journal, agfs = self._make_journal() + data = { + "id": "tx-5", + "status": "COMMIT", + "locks": [], + "post_actions": [ + {"type": "enqueue_semantic", "params": {"uri": "viking://test"}}, + ], + } + journal.write(data) + path, payload = agfs.write.call_args[0] + parsed = json.loads(payload.decode("utf-8")) + assert len(parsed["post_actions"]) == 1 + assert parsed["post_actions"][0]["type"] == "enqueue_semantic" + + def test_write_with_undo_log(self): + journal, agfs = self._make_journal() + data = { + "id": "tx-6", + "status": "EXEC", + "locks": [], + "undo_log": [ + { + "sequence": 0, + "op_type": "fs_mv", + "params": {"src": "/a", "dst": "/b"}, + "completed": True, + }, + ], + } + journal.write(data) + _, payload = agfs.write.call_args[0] + parsed = json.loads(payload.decode("utf-8")) + assert len(parsed["undo_log"]) == 1 + assert parsed["undo_log"][0]["op_type"] == "fs_mv" + + +class TestTransactionJournalIntegration: + """Integration tests using real AGFS backend to verify persistence behavior.""" + + def test_write_read_roundtrip(self, agfs_client): + journal = TransactionJournal(agfs_client) + tx_id = f"tx-int-{uuid.uuid4().hex}" + data = {"id": tx_id, "status": "INIT", "locks": [], "undo_log": []} + + journal.write(data) + result = journal.read(tx_id) + + assert result["id"] == tx_id + assert result["status"] == "INIT" + + journal.delete(tx_id) + + def test_update_overwrites(self, agfs_client): + journal = TransactionJournal(agfs_client) + tx_id = f"tx-int-{uuid.uuid4().hex}" + + journal.write({"id": tx_id, "status": "INIT", "locks": []}) + journal.update({"id": tx_id, "status": "EXEC", "locks": []}) + + result = journal.read(tx_id) + assert result["status"] == "EXEC" + + journal.delete(tx_id) + + def test_delete_removes_journal(self, agfs_client): + journal = TransactionJournal(agfs_client) + tx_id = f"tx-int-{uuid.uuid4().hex}" + + journal.write({"id": tx_id, "status": "INIT", "locks": []}) + journal.delete(tx_id) + + try: + journal.read(tx_id) + raise AssertionError("Should have raised after deletion") + except Exception: + pass # Expected + + def test_list_all_returns_written_ids(self, agfs_client): + journal = TransactionJournal(agfs_client) + tx_id_a = f"tx-int-{uuid.uuid4().hex}" + tx_id_b = f"tx-int-{uuid.uuid4().hex}" + + journal.write({"id": tx_id_a, "status": "INIT", "locks": []}) + journal.write({"id": tx_id_b, "status": "INIT", "locks": []}) + + result = journal.list_all() + assert tx_id_a in result + assert tx_id_b in result + + journal.delete(tx_id_a) + journal.delete(tx_id_b) + + def test_list_all_empty_when_none(self, agfs_client): + """After cleanup, list_all should not include previously deleted entries.""" + journal = TransactionJournal(agfs_client) + tx_id = f"tx-int-{uuid.uuid4().hex}" + + journal.write({"id": tx_id, "status": "INIT", "locks": []}) + journal.delete(tx_id) + + result = journal.list_all() + assert tx_id not in result diff --git a/tests/transaction/test_path_lock.py b/tests/transaction/test_path_lock.py new file mode 100644 index 000000000..2f3b6afc0 --- /dev/null +++ b/tests/transaction/test_path_lock.py @@ -0,0 +1,334 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for path lock with fencing tokens.""" + +import time +from unittest.mock import MagicMock + +from openviking.storage.transaction.path_lock import ( + LOCK_FILE_NAME, + LOCK_TYPE_POINT, + LOCK_TYPE_SUBTREE, + PathLock, + _make_fencing_token, + _parse_fencing_token, +) +from openviking.storage.transaction.transaction_record import TransactionRecord + + +class TestFencingToken: + def test_make_parse_roundtrip(self): + token = _make_fencing_token("tx-123") + tx_id, ts, lock_type = _parse_fencing_token(token) + assert tx_id == "tx-123" + assert ts > 0 + assert lock_type == LOCK_TYPE_POINT + + def test_make_parse_subtree_roundtrip(self): + token = _make_fencing_token("tx-456", LOCK_TYPE_SUBTREE) + tx_id, ts, lock_type = _parse_fencing_token(token) + assert tx_id == "tx-456" + assert ts > 0 + assert lock_type == LOCK_TYPE_SUBTREE + + def test_parse_legacy_format_two_part(self): + """Legacy two-part token "{tx_id}:{ts}" defaults to POINT.""" + tx_id, ts, lock_type = _parse_fencing_token("tx-old:1234567890") + assert tx_id == "tx-old" + assert ts == 1234567890 + assert lock_type == LOCK_TYPE_POINT + + def test_parse_legacy_format_plain(self): + """Plain tx_id (no colon) defaults to ts=0, lock_type=POINT.""" + tx_id, ts, lock_type = _parse_fencing_token("tx-bare") + assert tx_id == "tx-bare" + assert ts == 0 + assert lock_type == LOCK_TYPE_POINT + + def test_tokens_are_unique(self): + t1 = _make_fencing_token("tx-1") + time.sleep(0.001) + t2 = _make_fencing_token("tx-1") + assert t1 != t2 + + +class TestPathLockStale: + def test_is_lock_stale_no_file(self): + agfs = MagicMock() + agfs.cat.side_effect = Exception("not found") + lock = PathLock(agfs) + assert lock.is_lock_stale("/test/.path.ovlock") is True + + def test_is_lock_stale_legacy_token(self): + agfs = MagicMock() + agfs.cat.return_value = b"tx-old-format" + lock = PathLock(agfs) + assert lock.is_lock_stale("/test/.path.ovlock") is True + + def test_is_lock_stale_recent_token(self): + agfs = MagicMock() + token = _make_fencing_token("tx-1") + agfs.cat.return_value = token.encode("utf-8") + lock = PathLock(agfs) + assert lock.is_lock_stale("/test/.path.ovlock", expire_seconds=300.0) is False + + +class TestPathLockBehavior: + """Behavioral tests using real AGFS backend.""" + + async def test_acquire_point_creates_lock_file(self, agfs_client, test_dir): + lock = PathLock(agfs_client) + tx = TransactionRecord(id="tx-point-1") + + ok = await lock.acquire_point(test_dir, tx, timeout=3.0) + assert ok is True + + lock_path = f"{test_dir}/{LOCK_FILE_NAME}" + content = agfs_client.cat(lock_path) + token = content.decode("utf-8") if isinstance(content, bytes) else content + assert ":P" in token + assert "tx-point-1" in token + + await lock.release(tx) + + async def test_acquire_subtree_creates_lock_file(self, agfs_client, test_dir): + lock = PathLock(agfs_client) + tx = TransactionRecord(id="tx-subtree-1") + + ok = await lock.acquire_subtree(test_dir, tx, timeout=3.0) + assert ok is True + + lock_path = f"{test_dir}/{LOCK_FILE_NAME}" + content = agfs_client.cat(lock_path) + token = content.decode("utf-8") if isinstance(content, bytes) else content + assert ":S" in token + assert "tx-subtree-1" in token + + await lock.release(tx) + + async def test_acquire_point_dir_not_found(self, agfs_client): + lock = PathLock(agfs_client) + tx = TransactionRecord(id="tx-no-dir") + + ok = await lock.acquire_point("/local/nonexistent-path-xyz", tx, timeout=0.5) + assert ok is False + assert len(tx.locks) == 0 + + async def test_release_removes_lock_file(self, agfs_client, test_dir): + lock = PathLock(agfs_client) + tx = TransactionRecord(id="tx-release-1") + + await lock.acquire_point(test_dir, tx, timeout=3.0) + lock_path = f"{test_dir}/{LOCK_FILE_NAME}" + + await lock.release(tx) + + # Lock file should be gone + try: + agfs_client.cat(lock_path) + raise AssertionError("Lock file should have been removed") + except Exception: + pass # Expected: file not found + + async def test_sequential_acquire_works(self, agfs_client, test_dir): + lock = PathLock(agfs_client) + + tx1 = TransactionRecord(id="tx-seq-1") + ok1 = await lock.acquire_point(test_dir, tx1, timeout=3.0) + assert ok1 is True + + await lock.release(tx1) + + tx2 = TransactionRecord(id="tx-seq-2") + ok2 = await lock.acquire_point(test_dir, tx2, timeout=3.0) + assert ok2 is True + + await lock.release(tx2) + + async def test_point_blocked_by_ancestor_subtree(self, agfs_client, test_dir): + """POINT on child blocked while ancestor holds SUBTREE lock.""" + import uuid as _uuid + + child = f"{test_dir}/child-{_uuid.uuid4().hex}" + agfs_client.mkdir(child) + + lock = PathLock(agfs_client) + tx_parent = TransactionRecord(id="tx-parent-subtree") + ok = await lock.acquire_subtree(test_dir, tx_parent, timeout=3.0) + assert ok is True + + tx_child = TransactionRecord(id="tx-child-point") + blocked = await lock.acquire_point(child, tx_child, timeout=0.5) + assert blocked is False + + await lock.release(tx_parent) + + async def test_subtree_blocked_by_descendant_point(self, agfs_client, test_dir): + """SUBTREE on parent blocked while descendant holds POINT lock.""" + import uuid as _uuid + + child = f"{test_dir}/child-{_uuid.uuid4().hex}" + agfs_client.mkdir(child) + + lock = PathLock(agfs_client) + tx_child = TransactionRecord(id="tx-desc-point") + ok = await lock.acquire_point(child, tx_child, timeout=3.0) + assert ok is True + + tx_parent = TransactionRecord(id="tx-parent-sub") + blocked = await lock.acquire_subtree(test_dir, tx_parent, timeout=0.5) + assert blocked is False + + await lock.release(tx_child) + + async def test_acquire_mv_creates_subtree_locks(self, agfs_client, test_dir): + """acquire_mv puts SUBTREE on both src and dst.""" + import uuid as _uuid + + src = f"{test_dir}/src-{_uuid.uuid4().hex}" + dst = f"{test_dir}/dst-{_uuid.uuid4().hex}" + agfs_client.mkdir(src) + agfs_client.mkdir(dst) + + lock = PathLock(agfs_client) + tx = TransactionRecord(id="tx-mv-1") + ok = await lock.acquire_mv(src, dst, tx, timeout=3.0) + assert ok is True + + src_token_bytes = agfs_client.cat(f"{src}/{LOCK_FILE_NAME}") + src_token = ( + src_token_bytes.decode("utf-8") + if isinstance(src_token_bytes, bytes) + else src_token_bytes + ) + assert ":S" in src_token + + dst_token_bytes = agfs_client.cat(f"{dst}/{LOCK_FILE_NAME}") + dst_token = ( + dst_token_bytes.decode("utf-8") + if isinstance(dst_token_bytes, bytes) + else dst_token_bytes + ) + assert ":S" in dst_token + + await lock.release(tx) + + async def test_point_does_not_block_sibling_point(self, agfs_client, test_dir): + """POINT locks on different directories do not conflict.""" + import uuid as _uuid + + dir_a = f"{test_dir}/sibling-a-{_uuid.uuid4().hex}" + dir_b = f"{test_dir}/sibling-b-{_uuid.uuid4().hex}" + agfs_client.mkdir(dir_a) + agfs_client.mkdir(dir_b) + + lock = PathLock(agfs_client) + tx_a = TransactionRecord(id="tx-sib-a") + tx_b = TransactionRecord(id="tx-sib-b") + + ok_a = await lock.acquire_point(dir_a, tx_a, timeout=3.0) + ok_b = await lock.acquire_point(dir_b, tx_b, timeout=3.0) + + assert ok_a is True + assert ok_b is True + + await lock.release(tx_a) + await lock.release(tx_b) + + async def test_stale_lock_auto_removed_on_acquire(self, agfs_client, test_dir): + """A stale lock (expired fencing token) is auto-removed, allowing a new acquire.""" + import uuid as _uuid + + target = f"{test_dir}/stale-{_uuid.uuid4().hex}" + agfs_client.mkdir(target) + + lock_path = f"{target}/{LOCK_FILE_NAME}" + + # Write a lock file with a very old timestamp (simulate crashed process) + old_ts = time.time_ns() - int(600 * 1e9) # 600 seconds ago + stale_token = f"tx-dead:{old_ts}:{LOCK_TYPE_POINT}" + agfs_client.write(lock_path, stale_token.encode("utf-8")) + + # New transaction should succeed by auto-removing the stale lock + lock = PathLock(agfs_client, lock_expire=300.0) + tx = TransactionRecord(id="tx-new-owner") + ok = await lock.acquire_point(target, tx, timeout=2.0) + assert ok is True + + # Verify new lock is owned by our transaction + content = agfs_client.cat(lock_path) + token = content.decode("utf-8") if isinstance(content, bytes) else content + assert "tx-new-owner" in token + + await lock.release(tx) + + async def test_stale_subtree_ancestor_auto_removed(self, agfs_client, test_dir): + """A stale SUBTREE lock on ancestor is auto-removed when child acquires POINT.""" + import uuid as _uuid + + child = f"{test_dir}/child-stale-{_uuid.uuid4().hex}" + agfs_client.mkdir(child) + + # Write stale SUBTREE lock on parent + parent_lock = f"{test_dir}/{LOCK_FILE_NAME}" + old_ts = time.time_ns() - int(600 * 1e9) + stale_token = f"tx-dead-parent:{old_ts}:{LOCK_TYPE_SUBTREE}" + agfs_client.write(parent_lock, stale_token.encode("utf-8")) + + lock = PathLock(agfs_client, lock_expire=300.0) + tx = TransactionRecord(id="tx-child-new") + ok = await lock.acquire_point(child, tx, timeout=2.0) + assert ok is True + + await lock.release(tx) + # Clean up stale parent lock if still present + try: + agfs_client.rm(parent_lock) + except Exception: + pass + + async def test_point_same_path_no_wait_fails_immediately(self, agfs_client, test_dir): + """With timeout=0, a conflicting lock fails immediately.""" + import uuid as _uuid + + target = f"{test_dir}/nowait-{_uuid.uuid4().hex}" + agfs_client.mkdir(target) + + lock = PathLock(agfs_client) + tx1 = TransactionRecord(id="tx-hold") + ok1 = await lock.acquire_point(target, tx1, timeout=3.0) + assert ok1 is True + + # Second acquire with timeout=0 should fail immediately + tx2 = TransactionRecord(id="tx-blocked") + t0 = time.monotonic() + ok2 = await lock.acquire_point(target, tx2, timeout=0.0) + elapsed = time.monotonic() - t0 + + assert ok2 is False + assert elapsed < 1.0 # Should not wait + + await lock.release(tx1) + + async def test_subtree_same_path_mutual_exclusion(self, agfs_client, test_dir): + """Two SUBTREE locks on the same path: second one blocked until first releases.""" + import uuid as _uuid + + target = f"{test_dir}/sub-excl-{_uuid.uuid4().hex}" + agfs_client.mkdir(target) + + lock = PathLock(agfs_client) + tx1 = TransactionRecord(id="tx-sub1") + ok1 = await lock.acquire_subtree(target, tx1, timeout=3.0) + assert ok1 is True + + tx2 = TransactionRecord(id="tx-sub2") + ok2 = await lock.acquire_subtree(target, tx2, timeout=0.5) + assert ok2 is False + + await lock.release(tx1) + + # Now tx2 should succeed + ok2_retry = await lock.acquire_subtree(target, tx2, timeout=3.0) + assert ok2_retry is True + await lock.release(tx2) diff --git a/tests/transaction/test_post_actions.py b/tests/transaction/test_post_actions.py new file mode 100644 index 000000000..2ae3c12be --- /dev/null +++ b/tests/transaction/test_post_actions.py @@ -0,0 +1,112 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for post_actions execution and replay.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +from openviking.storage.transaction.transaction_manager import TransactionManager + + +class TestPostActions: + def _make_manager(self): + agfs = MagicMock() + manager = TransactionManager(agfs_client=agfs, timeout=3600) + manager._journal = MagicMock() + return manager, agfs + + async def test_execute_enqueue_semantic(self): + manager, _ = self._make_manager() + + mock_queue = AsyncMock() + mock_queue_manager = MagicMock() + mock_queue_manager.get_queue.return_value = mock_queue + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await manager._execute_post_actions( + [ + { + "type": "enqueue_semantic", + "params": { + "uri": "viking://resources/test", + "context_type": "resource", + "account_id": "acc-1", + }, + } + ] + ) + + mock_queue.enqueue.assert_called_once() + msg = mock_queue.enqueue.call_args[0][0] + assert msg.uri == "viking://resources/test" + assert msg.context_type == "resource" + assert msg.account_id == "acc-1" + + async def test_execute_unknown_action_logged(self): + manager, _ = self._make_manager() + # Should not raise, just log + await manager._execute_post_actions( + [ + {"type": "unknown_action", "params": {}}, + ] + ) + + async def test_execute_multiple_actions(self): + manager, _ = self._make_manager() + + mock_queue = AsyncMock() + mock_queue_manager = MagicMock() + mock_queue_manager.get_queue.return_value = mock_queue + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + await manager._execute_post_actions( + [ + { + "type": "enqueue_semantic", + "params": { + "uri": "viking://a", + "context_type": "resource", + "account_id": "acc-1", + }, + }, + { + "type": "enqueue_semantic", + "params": { + "uri": "viking://b", + "context_type": "memory", + "account_id": "acc-2", + }, + }, + ] + ) + + assert mock_queue.enqueue.call_count == 2 + + async def test_post_action_failure_does_not_crash(self): + manager, _ = self._make_manager() + + mock_queue_manager = MagicMock() + mock_queue_manager.get_queue.side_effect = Exception("queue not available") + + with patch( + "openviking.storage.queuefs.get_queue_manager", + return_value=mock_queue_manager, + ): + # Should not raise + await manager._execute_post_actions( + [ + { + "type": "enqueue_semantic", + "params": { + "uri": "viking://test", + "context_type": "resource", + "account_id": "", + }, + }, + ] + ) diff --git a/tests/transaction/test_rm_rollback.py b/tests/transaction/test_rm_rollback.py new file mode 100644 index 000000000..604b5f50c --- /dev/null +++ b/tests/transaction/test_rm_rollback.py @@ -0,0 +1,294 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Integration tests: multi-step rollback covering FS + VectorDB coordination.""" + +import uuid + +from openviking.storage.transaction.undo import UndoEntry, execute_rollback + +from .conftest import VECTOR_DIM, _mkdir_ok, file_exists + + +class TestRmRollback: + async def test_fs_rm_not_reversible(self, agfs_client, test_dir): + """fs_rm is intentionally irreversible: even completed=True is a no-op.""" + path = f"{test_dir}/rm-target" + _mkdir_ok(agfs_client, path) + + undo_log = [ + UndoEntry(sequence=0, op_type="fs_rm", params={"uri": path}, completed=True), + ] + await execute_rollback(undo_log, agfs_client) + + # Directory still exists — fs_rm rollback does nothing + assert file_exists(agfs_client, path) + + +class TestMvRollback: + async def test_mv_reversed_on_rollback(self, agfs_client, test_dir): + """Real mv → rollback → content back at original location.""" + src = f"{test_dir}/mv-src" + dst = f"{test_dir}/mv-dst" + _mkdir_ok(agfs_client, src) + agfs_client.write(f"{src}/payload.txt", b"important data") + + # Forward mv + agfs_client.mv(src, dst) + assert not file_exists(agfs_client, src) + content = agfs_client.cat(f"{dst}/payload.txt") + assert content == b"important data" + + undo_log = [ + UndoEntry( + sequence=0, + op_type="fs_mv", + params={"src": src, "dst": dst}, + completed=True, + ), + ] + await execute_rollback(undo_log, agfs_client) + + assert file_exists(agfs_client, src) + restored = agfs_client.cat(f"{src}/payload.txt") + assert restored == b"important data" + + +class TestRecoverAll: + async def test_recover_all_reverses_incomplete(self, agfs_client, test_dir): + """recover_all=True also reverses entries with completed=False.""" + new_dir = f"{test_dir}/recover-all-dir" + _mkdir_ok(agfs_client, new_dir) + + undo_log = [ + UndoEntry(sequence=0, op_type="fs_mkdir", params={"uri": new_dir}, completed=False), + ] + await execute_rollback(undo_log, agfs_client, recover_all=True) + + assert not file_exists(agfs_client, new_dir) + + async def test_recover_all_false_skips_incomplete(self, agfs_client, test_dir): + """recover_all=False skips entries with completed=False.""" + new_dir = f"{test_dir}/skip-incomplete" + _mkdir_ok(agfs_client, new_dir) + + undo_log = [ + UndoEntry(sequence=0, op_type="fs_mkdir", params={"uri": new_dir}, completed=False), + ] + await execute_rollback(undo_log, agfs_client, recover_all=False) + + assert file_exists(agfs_client, new_dir) + + +class TestMultiStepRollback: + async def test_reverse_order_nested_dirs(self, agfs_client, test_dir): + """parent + child → rollback reverses in reverse sequence order.""" + parent = f"{test_dir}/multi-parent" + child = f"{test_dir}/multi-parent/child" + _mkdir_ok(agfs_client, parent) + _mkdir_ok(agfs_client, child) + + undo_log = [ + UndoEntry(sequence=0, op_type="fs_mkdir", params={"uri": parent}, completed=True), + UndoEntry(sequence=1, op_type="fs_mkdir", params={"uri": child}, completed=True), + ] + await execute_rollback(undo_log, agfs_client) + + assert not file_exists(agfs_client, child) + assert not file_exists(agfs_client, parent) + + async def test_write_new_rollback(self, agfs_client, test_dir): + """New file → rollback → file deleted.""" + file_path = f"{test_dir}/new-file.txt" + agfs_client.write(file_path, b"new content") + assert file_exists(agfs_client, file_path) + + undo_log = [ + UndoEntry( + sequence=0, op_type="fs_write_new", params={"uri": file_path}, completed=True + ), + ] + await execute_rollback(undo_log, agfs_client) + + assert not file_exists(agfs_client, file_path) + + async def test_best_effort_continues(self, agfs_client, test_dir): + """If one step fails, subsequent steps still execute.""" + real_dir = f"{test_dir}/best-effort-real" + _mkdir_ok(agfs_client, real_dir) + + undo_log = [ + # seq=0: mkdir rollback on real dir → should succeed + UndoEntry(sequence=0, op_type="fs_mkdir", params={"uri": real_dir}, completed=True), + # seq=1: mkdir rollback on nonexistent dir → fails silently + UndoEntry( + sequence=1, + op_type="fs_mkdir", + params={"uri": f"{test_dir}/no-such-dir-{uuid.uuid4().hex}"}, + completed=True, + ), + ] + await execute_rollback(undo_log, agfs_client) + + # seq=0 still executed despite seq=1 failure (reversed order: 1 runs first, then 0) + assert not file_exists(agfs_client, real_dir) + + async def test_unknown_op_type_no_crash(self, agfs_client, test_dir): + """Unknown op_type is logged but doesn't raise.""" + undo_log = [ + UndoEntry( + sequence=0, + op_type="some_future_op", + params={"foo": "bar"}, + completed=True, + ), + ] + # Should not raise + await execute_rollback(undo_log, agfs_client) + + +class TestVectorDBRollback: + async def test_vectordb_delete_rollback_restores(self, agfs_client, vector_store, request_ctx): + """upsert → delete → rollback(vectordb_delete) → record restored.""" + record_id = str(uuid.uuid4()) + record = { + "id": record_id, + "uri": f"viking://resources/del-restore-{record_id}.md", + "parent_uri": "viking://resources/", + "account_id": "default", + "context_type": "resource", + "level": 2, + "vector": [0.3] * VECTOR_DIM, + "name": "del-restore", + "description": "test", + "abstract": "test", + } + await vector_store.upsert(record, ctx=request_ctx) + + # Snapshot before delete + snapshot = await vector_store.get([record_id], ctx=request_ctx) + assert len(snapshot) == 1 + + # Forward: delete + await vector_store.delete([record_id], ctx=request_ctx) + assert len(await vector_store.get([record_id], ctx=request_ctx)) == 0 + + undo_log = [ + UndoEntry( + sequence=0, + op_type="vectordb_delete", + params={ + "uris": [record["uri"]], + "records_snapshot": snapshot, + "_ctx_account_id": "default", + "_ctx_user_id": "test_user", + "_ctx_role": "root", + }, + completed=True, + ), + ] + await execute_rollback(undo_log, agfs_client, vector_store=vector_store) + + results = await vector_store.get([record_id], ctx=request_ctx) + assert len(results) == 1 + + async def test_vectordb_delete_multi_record(self, agfs_client, vector_store, request_ctx): + """3 records in snapshot → rollback → all restored.""" + records = [] + for i in range(3): + rid = str(uuid.uuid4()) + rec = { + "id": rid, + "uri": f"viking://resources/multi-{rid}.md", + "parent_uri": "viking://resources/", + "account_id": "default", + "context_type": "resource", + "level": 2, + "vector": [0.1 * (i + 1)] * VECTOR_DIM, + "name": f"multi-{i}", + "description": "test", + "abstract": "test", + } + await vector_store.upsert(rec, ctx=request_ctx) + records.append(rec) + + ids = [r["id"] for r in records] + snapshot = await vector_store.get(ids, ctx=request_ctx) + assert len(snapshot) == 3 + + # Delete all + await vector_store.delete(ids, ctx=request_ctx) + assert len(await vector_store.get(ids, ctx=request_ctx)) == 0 + + undo_log = [ + UndoEntry( + sequence=0, + op_type="vectordb_delete", + params={ + "uris": [r["uri"] for r in records], + "records_snapshot": snapshot, + "_ctx_account_id": "default", + "_ctx_user_id": "test_user", + "_ctx_role": "root", + }, + completed=True, + ), + ] + await execute_rollback(undo_log, agfs_client, vector_store=vector_store) + + results = await vector_store.get(ids, ctx=request_ctx) + assert len(results) == 3 + + async def test_vectordb_delete_empty_snapshot(self, agfs_client, vector_store, request_ctx): + """Empty snapshot → no-op, no error.""" + undo_log = [ + UndoEntry( + sequence=0, + op_type="vectordb_delete", + params={ + "uris": [], + "records_snapshot": [], + "_ctx_account_id": "default", + "_ctx_user_id": "test_user", + "_ctx_role": "root", + }, + completed=True, + ), + ] + # Should not raise + await execute_rollback(undo_log, agfs_client, vector_store=vector_store) + + async def test_vectordb_upsert_rollback_deletes(self, agfs_client, vector_store, request_ctx): + """upsert → rollback(vectordb_upsert) → record deleted.""" + record_id = str(uuid.uuid4()) + record = { + "id": record_id, + "uri": f"viking://resources/upsert-del-{record_id}.md", + "parent_uri": "viking://resources/", + "account_id": "default", + "context_type": "resource", + "level": 2, + "vector": [0.4] * VECTOR_DIM, + "name": "upsert-del", + "description": "test", + "abstract": "test", + } + await vector_store.upsert(record, ctx=request_ctx) + assert len(await vector_store.get([record_id], ctx=request_ctx)) == 1 + + undo_log = [ + UndoEntry( + sequence=0, + op_type="vectordb_upsert", + params={ + "record_id": record_id, + "_ctx_account_id": "default", + "_ctx_user_id": "test_user", + "_ctx_role": "root", + }, + completed=True, + ), + ] + await execute_rollback(undo_log, agfs_client, vector_store=vector_store) + + results = await vector_store.get([record_id], ctx=request_ctx) + assert len(results) == 0 diff --git a/tests/transaction/test_transaction_manager.py b/tests/transaction/test_transaction_manager.py new file mode 100644 index 000000000..ef0f0b3eb --- /dev/null +++ b/tests/transaction/test_transaction_manager.py @@ -0,0 +1,323 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for TransactionManager: CRUD, lifecycle, commit/rollback flows, timeout cleanup.""" + +import time +from unittest.mock import AsyncMock, MagicMock, patch + +from openviking.storage.transaction.transaction_manager import TransactionManager +from openviking.storage.transaction.transaction_record import TransactionRecord, TransactionStatus + + +def _make_manager(**kwargs): + """Create a TransactionManager with mocked AGFS and journal.""" + agfs = MagicMock() + defaults = {"agfs_client": agfs, "timeout": 3600, "lock_timeout": 0.0, "lock_expire": 300.0} + defaults.update(kwargs) + manager = TransactionManager(**defaults) + manager._journal = MagicMock() + manager._journal.list_all.return_value = [] + return manager, agfs + + +class TestCreateAndGet: + def test_create_transaction_returns_record(self): + manager, _ = _make_manager() + tx = manager.create_transaction(init_info={"operation": "rm"}) + assert isinstance(tx, TransactionRecord) + assert tx.status == TransactionStatus.INIT + assert tx.init_info == {"operation": "rm"} + + def test_create_assigns_unique_ids(self): + manager, _ = _make_manager() + tx1 = manager.create_transaction() + tx2 = manager.create_transaction() + assert tx1.id != tx2.id + + def test_get_transaction_found(self): + manager, _ = _make_manager() + tx = manager.create_transaction() + assert manager.get_transaction(tx.id) is tx + + def test_get_transaction_not_found(self): + manager, _ = _make_manager() + assert manager.get_transaction("nonexistent") is None + + def test_get_transaction_count(self): + manager, _ = _make_manager() + assert manager.get_transaction_count() == 0 + manager.create_transaction() + assert manager.get_transaction_count() == 1 + manager.create_transaction() + assert manager.get_transaction_count() == 2 + + def test_get_active_transactions(self): + manager, _ = _make_manager() + tx = manager.create_transaction() + active = manager.get_active_transactions() + assert tx.id in active + # Returned copy, not the internal dict + active.pop(tx.id) + assert manager.get_transaction(tx.id) is tx + + +class TestBegin: + async def test_begin_updates_status(self): + manager, _ = _make_manager() + tx = manager.create_transaction() + ok = await manager.begin(tx.id) + assert ok is True + assert tx.status == TransactionStatus.ACQUIRE + + async def test_begin_unknown_tx(self): + manager, _ = _make_manager() + ok = await manager.begin("unknown-tx") + assert ok is False + + +class TestCommitFlow: + async def test_commit_full_lifecycle(self): + manager, _ = _make_manager() + tx = manager.create_transaction() + + # Simulate lock acquisition + tx.update_status(TransactionStatus.EXEC) + tx.add_lock("/test/.path.ovlock") + + ok = await manager.commit(tx.id) + assert ok is True + assert tx.status == TransactionStatus.RELEASED + # Removed from active transactions + assert manager.get_transaction(tx.id) is None + # Journal cleaned up + manager._journal.delete.assert_called_once_with(tx.id) + + async def test_commit_persists_journal_before_release(self): + manager, _ = _make_manager() + tx = manager.create_transaction() + tx.update_status(TransactionStatus.EXEC) + + call_order = [] + original_update = manager._journal.update + + def track_update(data): + call_order.append(("journal_update", data.get("status"))) + return original_update(data) + + manager._journal.update = track_update + manager._journal.delete = MagicMock( + side_effect=lambda _: call_order.append(("journal_delete",)) + ) + + await manager.commit(tx.id) + # Journal update (COMMIT) happens before delete + assert call_order[0] == ("journal_update", "COMMIT") + + async def test_commit_executes_post_actions(self): + manager, _ = _make_manager() + tx = manager.create_transaction() + tx.update_status(TransactionStatus.EXEC) + tx.post_actions.append({"type": "enqueue_semantic", "params": {"uri": "viking://x"}}) + + with patch.object(manager, "_execute_post_actions", new_callable=AsyncMock) as mock_post: + await manager.commit(tx.id) + mock_post.assert_called_once() + + async def test_commit_unknown_tx(self): + manager, _ = _make_manager() + ok = await manager.commit("nonexistent") + assert ok is False + + async def test_commit_releases_locks(self): + manager, _ = _make_manager() + tx = manager.create_transaction() + tx.update_status(TransactionStatus.EXEC) + tx.add_lock("/a/.path.ovlock") + tx.add_lock("/b/.path.ovlock") + + with patch.object(manager._path_lock, "release", new_callable=AsyncMock) as mock_release: + await manager.commit(tx.id) + mock_release.assert_called_once() + + +class TestRollbackFlow: + async def test_rollback_executes_undo_log(self): + manager, agfs = _make_manager() + tx = manager.create_transaction() + tx.update_status(TransactionStatus.EXEC) + + from openviking.storage.transaction.undo import UndoEntry + + tx.undo_log.append( + UndoEntry( + sequence=0, op_type="fs_mv", params={"src": "/a", "dst": "/b"}, completed=True + ) + ) + + ok = await manager.rollback(tx.id) + assert ok is True + assert tx.status == TransactionStatus.RELEASED + agfs.mv.assert_called_once_with("/b", "/a") + + async def test_rollback_removes_from_active(self): + manager, _ = _make_manager() + tx = manager.create_transaction() + tx.update_status(TransactionStatus.EXEC) + + await manager.rollback(tx.id) + assert manager.get_transaction(tx.id) is None + + async def test_rollback_cleans_journal(self): + manager, _ = _make_manager() + tx = manager.create_transaction() + tx.update_status(TransactionStatus.EXEC) + + await manager.rollback(tx.id) + manager._journal.delete.assert_called_once_with(tx.id) + + async def test_rollback_unknown_tx(self): + manager, _ = _make_manager() + ok = await manager.rollback("nonexistent") + assert ok is False + + async def test_rollback_undo_failure_does_not_prevent_cleanup(self): + """Undo failure is best-effort; lock release and journal cleanup still happen.""" + manager, agfs = _make_manager() + tx = manager.create_transaction() + tx.update_status(TransactionStatus.EXEC) + + from openviking.storage.transaction.undo import UndoEntry + + tx.undo_log.append( + UndoEntry( + sequence=0, op_type="fs_mv", params={"src": "/a", "dst": "/b"}, completed=True + ) + ) + agfs.mv.side_effect = Exception("disk error") + + ok = await manager.rollback(tx.id) + assert ok is True + manager._journal.delete.assert_called_once() + + +class TestLockAcquisitionWrappers: + async def test_acquire_lock_point_success_transitions_to_exec(self): + manager, _ = _make_manager() + tx = manager.create_transaction() + + with patch.object( + manager._path_lock, "acquire_point", new_callable=AsyncMock, return_value=True + ): + ok = await manager.acquire_lock_point(tx.id, "/test") + assert ok is True + assert tx.status == TransactionStatus.EXEC + + async def test_acquire_lock_point_failure_transitions_to_fail(self): + manager, _ = _make_manager() + tx = manager.create_transaction() + + with patch.object( + manager._path_lock, "acquire_point", new_callable=AsyncMock, return_value=False + ): + ok = await manager.acquire_lock_point(tx.id, "/test") + assert ok is False + assert tx.status == TransactionStatus.FAIL + + async def test_acquire_lock_subtree_success(self): + manager, _ = _make_manager() + tx = manager.create_transaction() + + with patch.object( + manager._path_lock, "acquire_subtree", new_callable=AsyncMock, return_value=True + ): + ok = await manager.acquire_lock_subtree(tx.id, "/test") + assert ok is True + assert tx.status == TransactionStatus.EXEC + + async def test_acquire_lock_subtree_uses_config_timeout(self): + manager, _ = _make_manager(lock_timeout=5.0) + tx = manager.create_transaction() + + with patch.object( + manager._path_lock, "acquire_subtree", new_callable=AsyncMock, return_value=True + ) as mock_acquire: + await manager.acquire_lock_subtree(tx.id, "/test") + mock_acquire.assert_called_once_with("/test", tx, timeout=5.0) + + async def test_acquire_lock_subtree_override_timeout(self): + manager, _ = _make_manager(lock_timeout=5.0) + tx = manager.create_transaction() + + with patch.object( + manager._path_lock, "acquire_subtree", new_callable=AsyncMock, return_value=True + ) as mock_acquire: + await manager.acquire_lock_subtree(tx.id, "/test", timeout=10.0) + mock_acquire.assert_called_once_with("/test", tx, timeout=10.0) + + async def test_acquire_lock_mv_success(self): + manager, _ = _make_manager() + tx = manager.create_transaction() + + with patch.object( + manager._path_lock, "acquire_mv", new_callable=AsyncMock, return_value=True + ): + ok = await manager.acquire_lock_mv(tx.id, "/src", "/dst") + assert ok is True + assert tx.status == TransactionStatus.EXEC + + async def test_acquire_lock_unknown_tx(self): + manager, _ = _make_manager() + ok = await manager.acquire_lock_point("nonexistent", "/test") + assert ok is False + + +class TestLifecycle: + async def test_start_sets_running(self): + manager, _ = _make_manager() + await manager.start() + assert manager._running is True + manager.stop() + + async def test_start_idempotent(self): + manager, _ = _make_manager() + await manager.start() + await manager.start() # Should not error + assert manager._running is True + await manager.stop() + + async def test_stop_clears_state(self): + manager, _ = _make_manager() + await manager.start() + manager.create_transaction() + await manager.stop() + assert manager._running is False + assert manager.get_transaction_count() == 0 + + async def test_stop_idempotent(self): + manager, _ = _make_manager() + await manager.stop() + await manager.stop() # Should not error + + +class TestTimeoutCleanup: + async def test_cleanup_timed_out_rolls_back(self): + manager, _ = _make_manager(timeout=1) + tx = manager.create_transaction() + tx.update_status(TransactionStatus.EXEC) + # Simulate old updated_at + tx.updated_at = time.time() - 10 + + with patch.object( + manager, "rollback", new_callable=AsyncMock, return_value=True + ) as mock_rb: + await manager._cleanup_timed_out() + mock_rb.assert_called_once_with(tx.id) + + async def test_cleanup_skips_fresh_transactions(self): + manager, _ = _make_manager(timeout=3600) + tx = manager.create_transaction() + tx.update_status(TransactionStatus.EXEC) + + with patch.object(manager, "rollback", new_callable=AsyncMock) as mock_rb: + await manager._cleanup_timed_out() + mock_rb.assert_not_called() diff --git a/tests/transaction/test_undo.py b/tests/transaction/test_undo.py new file mode 100644 index 000000000..aff57887f --- /dev/null +++ b/tests/transaction/test_undo.py @@ -0,0 +1,249 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for undo log and rollback executor.""" + +import uuid + +from openviking.storage.transaction.undo import UndoEntry, execute_rollback + +from .conftest import VECTOR_DIM, _mkdir_ok, file_exists + + +class TestUndoEntry: + def test_to_dict(self): + entry = UndoEntry(sequence=0, op_type="fs_mv", params={"src": "/a", "dst": "/b"}) + d = entry.to_dict() + assert d["sequence"] == 0 + assert d["op_type"] == "fs_mv" + assert d["params"] == {"src": "/a", "dst": "/b"} + assert d["completed"] is False + + def test_from_dict(self): + data = {"sequence": 1, "op_type": "fs_rm", "params": {"uri": "/x"}, "completed": True} + entry = UndoEntry.from_dict(data) + assert entry.sequence == 1 + assert entry.op_type == "fs_rm" + assert entry.completed is True + + def test_roundtrip(self): + entry = UndoEntry( + sequence=5, op_type="vectordb_upsert", params={"record_id": "r1"}, completed=True + ) + restored = UndoEntry.from_dict(entry.to_dict()) + assert restored.sequence == entry.sequence + assert restored.op_type == entry.op_type + assert restored.params == entry.params + assert restored.completed == entry.completed + + +class TestExecuteRollback: + """Integration tests for execute_rollback using real AGFS and VectorDB backends.""" + + async def test_rollback_fs_mv(self, agfs_client, test_dir): + src = f"{test_dir}/src" + dst = f"{test_dir}/dst" + _mkdir_ok(agfs_client, src) + agfs_client.write(f"{src}/data.txt", b"hello") + + # Forward: mv src → dst + agfs_client.mv(src, dst) + assert not file_exists(agfs_client, src) + assert file_exists(agfs_client, dst) + + undo_log = [ + UndoEntry( + sequence=0, + op_type="fs_mv", + params={"src": src, "dst": dst}, + completed=True, + ), + ] + await execute_rollback(undo_log, agfs_client) + + # src restored, dst gone + assert file_exists(agfs_client, src) + assert not file_exists(agfs_client, dst) + + async def test_rollback_fs_rm_skipped(self, agfs_client, test_dir): + path = f"{test_dir}/will-not-delete" + _mkdir_ok(agfs_client, path) + + undo_log = [ + UndoEntry(sequence=0, op_type="fs_rm", params={"uri": path}, completed=True), + ] + await execute_rollback(undo_log, agfs_client) + + # fs_rm rollback is a no-op; directory still exists + assert file_exists(agfs_client, path) + + async def test_rollback_fs_mkdir(self, agfs_client, test_dir): + new_dir = f"{test_dir}/created" + _mkdir_ok(agfs_client, new_dir) + assert file_exists(agfs_client, new_dir) + + undo_log = [ + UndoEntry(sequence=0, op_type="fs_mkdir", params={"uri": new_dir}, completed=True), + ] + await execute_rollback(undo_log, agfs_client) + + assert not file_exists(agfs_client, new_dir) + + async def test_rollback_fs_write_new(self, agfs_client, test_dir): + file_path = f"{test_dir}/new-file.txt" + agfs_client.write(file_path, b"content") + assert file_exists(agfs_client, file_path) + + undo_log = [ + UndoEntry( + sequence=0, op_type="fs_write_new", params={"uri": file_path}, completed=True + ), + ] + await execute_rollback(undo_log, agfs_client) + + assert not file_exists(agfs_client, file_path) + + async def test_rollback_reverse_order(self, agfs_client, test_dir): + """mkdir parent + child → rollback → both removed in reverse order.""" + parent = f"{test_dir}/parent" + child = f"{test_dir}/parent/child" + _mkdir_ok(agfs_client, parent) + _mkdir_ok(agfs_client, child) + + undo_log = [ + UndoEntry(sequence=0, op_type="fs_mkdir", params={"uri": parent}, completed=True), + UndoEntry(sequence=1, op_type="fs_mkdir", params={"uri": child}, completed=True), + ] + await execute_rollback(undo_log, agfs_client) + + # child removed first (seq=1), then parent (seq=0) + assert not file_exists(agfs_client, child) + assert not file_exists(agfs_client, parent) + + async def test_rollback_skips_incomplete(self, agfs_client, test_dir): + new_dir = f"{test_dir}/incomplete" + _mkdir_ok(agfs_client, new_dir) + + undo_log = [ + UndoEntry(sequence=0, op_type="fs_mkdir", params={"uri": new_dir}, completed=False), + ] + await execute_rollback(undo_log, agfs_client) + + # completed=False → not rolled back + assert file_exists(agfs_client, new_dir) + + async def test_rollback_best_effort(self, agfs_client, test_dir): + """A failing rollback entry should not prevent others from running.""" + real_dir = f"{test_dir}/real-dir" + _mkdir_ok(agfs_client, real_dir) + + src = f"{test_dir}/be-src" + dst = f"{test_dir}/be-dst" + _mkdir_ok(agfs_client, dst) + + undo_log = [ + # seq=0: fs_mv rollback will succeed + UndoEntry(sequence=0, op_type="fs_mv", params={"src": src, "dst": dst}, completed=True), + # seq=1: fs_mkdir rollback will fail (rm on non-empty or non-existent path) + UndoEntry( + sequence=1, + op_type="fs_mkdir", + params={"uri": f"{test_dir}/nonexistent-dir-xyz"}, + completed=True, + ), + ] + # Should not raise + await execute_rollback(undo_log, agfs_client) + + # seq=0 mv rollback should have executed (dst → src) + assert file_exists(agfs_client, src) + + async def test_rollback_vectordb_upsert(self, agfs_client, vector_store, request_ctx): + """Real upsert → rollback → record deleted.""" + record_id = str(uuid.uuid4()) + record = { + "id": record_id, + "uri": f"viking://resources/test-upsert-{record_id}.md", + "parent_uri": "viking://resources/", + "account_id": "default", + "context_type": "resource", + "level": 2, + "vector": [0.1] * VECTOR_DIM, + "name": "test", + "description": "test record", + "abstract": "test", + } + await vector_store.upsert(record, ctx=request_ctx) + + # Confirm it exists + results = await vector_store.get([record_id], ctx=request_ctx) + assert len(results) == 1 + + undo_log = [ + UndoEntry( + sequence=0, + op_type="vectordb_upsert", + params={ + "record_id": record_id, + "_ctx_account_id": "default", + "_ctx_user_id": "test_user", + "_ctx_role": "root", + }, + completed=True, + ), + ] + await execute_rollback(undo_log, agfs_client, vector_store=vector_store) + + results = await vector_store.get([record_id], ctx=request_ctx) + assert len(results) == 0 + + async def test_rollback_vectordb_update_uri(self, agfs_client, vector_store, request_ctx): + """Real upsert → update_uri_mapping → rollback → URI restored.""" + record_id = str(uuid.uuid4()) + old_uri = f"viking://resources/old-{record_id}.md" + new_uri = f"viking://resources/new-{record_id}.md" + record = { + "id": record_id, + "uri": old_uri, + "parent_uri": "viking://resources/", + "account_id": "default", + "context_type": "resource", + "level": 2, + "vector": [0.2] * VECTOR_DIM, + "name": "test", + "description": "test", + "abstract": "test", + } + await vector_store.upsert(record, ctx=request_ctx) + + # Forward: update URI mapping + await vector_store.update_uri_mapping( + ctx=request_ctx, + uri=old_uri, + new_uri=new_uri, + new_parent_uri="viking://resources/", + ) + + # Verify forward operation + result = await vector_store.fetch_by_uri(new_uri, ctx=request_ctx) + assert result is not None + + undo_log = [ + UndoEntry( + sequence=0, + op_type="vectordb_update_uri", + params={ + "old_uri": old_uri, + "new_uri": new_uri, + "old_parent_uri": "viking://resources/", + "_ctx_account_id": "default", + "_ctx_user_id": "test_user", + "_ctx_role": "root", + }, + completed=True, + ), + ] + await execute_rollback(undo_log, agfs_client, vector_store=vector_store) + + # URI should be restored to old_uri + result = await vector_store.fetch_by_uri(old_uri, ctx=request_ctx) + assert result is not None diff --git a/tests/unit/test_openai_embedder.py b/tests/unit/test_openai_embedder.py index 8e9b72a82..a4efd892c 100644 --- a/tests/unit/test_openai_embedder.py +++ b/tests/unit/test_openai_embedder.py @@ -4,8 +4,6 @@ from unittest.mock import MagicMock, patch -import pytest - from openviking.models.embedder import OpenAIDenseEmbedder @@ -294,10 +292,11 @@ def test_telemetry_called_when_module_available(self, mock_openai_class): ) mock_telemetry = MagicMock() - mock_telemetry_module = MagicMock() - mock_telemetry_module.get_current_telemetry.return_value = mock_telemetry - with patch("importlib.import_module", return_value=mock_telemetry_module): + with patch( + "openviking.models.embedder.openai_embedders.get_current_telemetry", + return_value=mock_telemetry, + ): result = embedder.embed("Hello world") assert result.dense_vector is not None diff --git a/third_party/agfs/agfs-server/pkg/plugins/queuefs/backend.go b/third_party/agfs/agfs-server/pkg/plugins/queuefs/backend.go index f2ccde995..c20fdc662 100644 --- a/third_party/agfs/agfs-server/pkg/plugins/queuefs/backend.go +++ b/third_party/agfs/agfs-server/pkg/plugins/queuefs/backend.go @@ -24,9 +24,18 @@ type QueueBackend interface { // Enqueue adds a message to a queue Enqueue(queueName string, msg QueueMessage) error - // Dequeue removes and returns the first message from a queue + // Dequeue marks the first pending message as 'processing' and returns it. + // Call Ack after successful processing to permanently delete the message. Dequeue(queueName string) (QueueMessage, bool, error) + // Ack permanently deletes a message that has been successfully processed. + Ack(queueName string, messageID string) error + + // RecoverStale resets messages stuck in 'processing' state back to 'pending'. + // staleSec: minimum age in seconds; pass 0 to reset all processing messages. + // Returns the number of messages recovered. + RecoverStale(staleSec int64) (int, error) + // Peek returns the first message without removing it Peek(queueName string) (QueueMessage, bool, error) @@ -124,6 +133,16 @@ func (b *MemoryBackend) Dequeue(queueName string) (QueueMessage, bool, error) { return msg, true, nil } +// Ack is a no-op for the memory backend (messages are already removed on Dequeue). +func (b *MemoryBackend) Ack(queueName string, messageID string) error { + return nil +} + +// RecoverStale is a no-op for the memory backend (no persistence across restarts). +func (b *MemoryBackend) RecoverStale(staleSec int64) (int, error) { + return 0, nil +} + func (b *MemoryBackend) Peek(queueName string) (QueueMessage, bool, error) { queue, exists := b.queues[queueName] if !exists { @@ -345,6 +364,16 @@ func (b *TiDBBackend) Enqueue(queueName string, msg QueueMessage) error { return nil } +// Ack is not yet implemented for TiDB backend (messages are already soft-deleted on Dequeue). +func (b *TiDBBackend) Ack(queueName string, messageID string) error { + return nil +} + +// RecoverStale is not yet implemented for TiDB backend. +func (b *TiDBBackend) RecoverStale(staleSec int64) (int, error) { + return 0, nil +} + func (b *TiDBBackend) Dequeue(queueName string) (QueueMessage, bool, error) { // Get table name from cache (lazy loading) tableName, err := b.getTableName(queueName, false) diff --git a/third_party/agfs/agfs-server/pkg/plugins/queuefs/db_backend.go b/third_party/agfs/agfs-server/pkg/plugins/queuefs/db_backend.go index 03b7342f9..9639531c0 100644 --- a/third_party/agfs/agfs-server/pkg/plugins/queuefs/db_backend.go +++ b/third_party/agfs/agfs-server/pkg/plugins/queuefs/db_backend.go @@ -63,16 +63,22 @@ func (b *SQLiteDBBackend) GetInitSQL() []string { last_updated INTEGER DEFAULT (strftime('%s', 'now')) )`, // Queue messages table + // status: 'pending' (waiting) | 'processing' (dequeued, not yet acked) + // processing_started_at: Unix timestamp when dequeued; NULL if pending `CREATE TABLE IF NOT EXISTS queue_messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, queue_name TEXT NOT NULL, message_id TEXT NOT NULL, data TEXT NOT NULL, timestamp INTEGER NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + processing_started_at INTEGER, created_at INTEGER DEFAULT (strftime('%s', 'now')) )`, `CREATE INDEX IF NOT EXISTS idx_queue_name ON queue_messages(queue_name)`, `CREATE INDEX IF NOT EXISTS idx_queue_order ON queue_messages(queue_name, id)`, + `CREATE INDEX IF NOT EXISTS idx_queue_status ON queue_messages(queue_name, status, id)`, + `CREATE INDEX IF NOT EXISTS idx_queue_message_id ON queue_messages(queue_name, message_id)`, } } diff --git a/third_party/agfs/agfs-server/pkg/plugins/queuefs/queuefs.go b/third_party/agfs/agfs-server/pkg/plugins/queuefs/queuefs.go index d8d481b0d..052a8f19d 100644 --- a/third_party/agfs/agfs-server/pkg/plugins/queuefs/queuefs.go +++ b/third_party/agfs/agfs-server/pkg/plugins/queuefs/queuefs.go @@ -137,7 +137,9 @@ func (q *QueueFSPlugin) Initialize(cfg map[string]interface{}) error { switch backendType { case "memory": backend = NewMemoryBackend() - case "tidb", "mysql", "sqlite", "sqlite3": + case "sqlite", "sqlite3": + backend = NewSQLiteQueueBackend() + case "tidb", "mysql": backend = NewTiDBBackend() default: return fmt.Errorf("unsupported backend: %s", backendType) @@ -384,6 +386,7 @@ var queueOperations = map[string]bool{ "peek": true, "size": true, "clear": true, + "ack": true, // write message_id to confirm processing complete (at-least-once delivery) } // parseQueuePath parses a path like "/queue_name/operation" or "/dir/queue_name/operation" @@ -529,7 +532,7 @@ func (qfs *queueFS) Read(path string, offset int64, size int64) ([]byte, error) data, err = qfs.peek(queueName) case "size": data, err = qfs.size(queueName) - case "enqueue", "clear": + case "enqueue", "clear", "ack": // Write-only files return []byte(""), fmt.Errorf("permission denied: %s is write-only", path) default: @@ -573,6 +576,12 @@ func (qfs *queueFS) Write(path string, data []byte, offset int64, flags filesyst return 0, err } return 0, nil + case "ack": + msgID := strings.TrimSpace(string(data)) + if err := qfs.ackMessage(queueName, msgID); err != nil { + return 0, err + } + return int64(len(data)), nil default: return 0, fmt.Errorf("cannot write to: %s", path) } @@ -844,7 +853,7 @@ func (qfs *queueFS) Stat(p string) (*filesystem.FileInfo, error) { } mode := uint32(0644) - if operation == "enqueue" || operation == "clear" { + if operation == "enqueue" || operation == "clear" || operation == "ack" { mode = 0222 } else { mode = 0444 @@ -992,6 +1001,13 @@ func (qfs *queueFS) clear(queueName string) error { return qfs.plugin.backend.Clear(queueName) } +func (qfs *queueFS) ackMessage(queueName string, msgID string) error { + qfs.plugin.mu.Lock() + defer qfs.plugin.mu.Unlock() + + return qfs.plugin.backend.Ack(queueName, msgID) +} + // Ensure QueueFSPlugin implements ServicePlugin var _ plugin.ServicePlugin = (*QueueFSPlugin)(nil) var _ filesystem.FileSystem = (*queueFS)(nil) diff --git a/third_party/agfs/agfs-server/pkg/plugins/queuefs/sqlite_backend.go b/third_party/agfs/agfs-server/pkg/plugins/queuefs/sqlite_backend.go new file mode 100644 index 000000000..2a0c4dbed --- /dev/null +++ b/third_party/agfs/agfs-server/pkg/plugins/queuefs/sqlite_backend.go @@ -0,0 +1,321 @@ +package queuefs + +import ( + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +// SQLiteQueueBackend implements QueueBackend using SQLite with a single-table schema. +// +// Schema: +// - queue_metadata: tracks all queues (including empty ones created via mkdir) +// - queue_messages: stores all messages, filtered by queue_name column +// - status: 'pending' (waiting to be processed) | 'processing' (dequeued, awaiting ack) +// - processing_started_at: Unix timestamp when dequeued; NULL while pending +// +// Delivery semantics: at-least-once +// - Dequeue marks message as 'processing' (does NOT delete) +// - Ack deletes the message after successful processing +// - On startup, RecoverStale resets all 'processing' messages back to 'pending' +// so that messages from a previous crashed run are automatically retried +type SQLiteQueueBackend struct { + db *sql.DB +} + +func NewSQLiteQueueBackend() *SQLiteQueueBackend { + return &SQLiteQueueBackend{} +} + +func (b *SQLiteQueueBackend) Initialize(config map[string]interface{}) error { + dbBackend := NewSQLiteDBBackend() + + db, err := dbBackend.Open(config) + if err != nil { + return fmt.Errorf("failed to open SQLite database: %w", err) + } + b.db = db + + for _, sqlStmt := range dbBackend.GetInitSQL() { + if _, err := db.Exec(sqlStmt); err != nil { + db.Close() + return fmt.Errorf("failed to initialize schema: %w", err) + } + } + + // Migrate existing databases: add new columns if they don't exist yet. + b.runMigrations() + + // Reset any messages left in 'processing' state by a previous crashed process. + // staleSec=0 resets ALL processing messages — safe at startup because no workers + // are running yet. + if n, err := b.RecoverStale(0); err != nil { + log.Warnf("[queuefs] Failed to recover stale messages on startup: %v", err) + } else if n > 0 { + log.Infof("[queuefs] Recovered %d in-flight message(s) from previous run", n) + } + + log.Info("[queuefs] SQLite backend initialized") + return nil +} + +// runMigrations applies schema changes needed to upgrade an existing database. +// Each ALTER TABLE is executed and "duplicate column name" errors are silently ignored. +func (b *SQLiteQueueBackend) runMigrations() { + migrations := []string{ + `ALTER TABLE queue_messages ADD COLUMN status TEXT NOT NULL DEFAULT 'pending'`, + `ALTER TABLE queue_messages ADD COLUMN processing_started_at INTEGER`, + `CREATE INDEX IF NOT EXISTS idx_queue_status ON queue_messages(queue_name, status, id)`, + `CREATE INDEX IF NOT EXISTS idx_queue_message_id ON queue_messages(queue_name, message_id)`, + } + for _, stmt := range migrations { + if _, err := b.db.Exec(stmt); err != nil { + // "duplicate column name" means the column already exists — that's fine. + if !strings.Contains(err.Error(), "duplicate column name") && + !strings.Contains(err.Error(), "already exists") { + log.Warnf("[queuefs] Migration warning: %v", err) + } + } + } +} + +func (b *SQLiteQueueBackend) Close() error { + if b.db != nil { + return b.db.Close() + } + return nil +} + +func (b *SQLiteQueueBackend) GetType() string { + return "sqlite" +} + +func (b *SQLiteQueueBackend) Enqueue(queueName string, msg QueueMessage) error { + msgData, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + _, err = b.db.Exec( + "INSERT INTO queue_messages (queue_name, message_id, data, timestamp, status) VALUES (?, ?, ?, ?, 'pending')", + queueName, msg.ID, string(msgData), msg.Timestamp.Unix(), + ) + if err != nil { + return fmt.Errorf("failed to enqueue message: %w", err) + } + return nil +} + +// Dequeue marks the first pending message as 'processing' and returns it. +// The message remains in the database until Ack is called. +// If the process crashes before Ack, RecoverStale on the next startup will +// reset the message back to 'pending' so it is retried. +func (b *SQLiteQueueBackend) Dequeue(queueName string) (QueueMessage, bool, error) { + tx, err := b.db.Begin() + if err != nil { + return QueueMessage{}, false, fmt.Errorf("failed to start transaction: %w", err) + } + defer tx.Rollback() + + var id int64 + var data string + err = tx.QueryRow( + "SELECT id, data FROM queue_messages WHERE queue_name = ? AND status = 'pending' ORDER BY id LIMIT 1", + queueName, + ).Scan(&id, &data) + + if err == sql.ErrNoRows { + return QueueMessage{}, false, nil + } else if err != nil { + return QueueMessage{}, false, fmt.Errorf("failed to query message: %w", err) + } + + // Mark as processing instead of deleting. + _, err = tx.Exec( + "UPDATE queue_messages SET status = 'processing', processing_started_at = ? WHERE id = ?", + time.Now().Unix(), id, + ) + if err != nil { + return QueueMessage{}, false, fmt.Errorf("failed to mark message as processing: %w", err) + } + + if err := tx.Commit(); err != nil { + return QueueMessage{}, false, fmt.Errorf("failed to commit transaction: %w", err) + } + + var msg QueueMessage + if err := json.Unmarshal([]byte(data), &msg); err != nil { + return QueueMessage{}, false, fmt.Errorf("failed to unmarshal message: %w", err) + } + + return msg, true, nil +} + +// Ack deletes a message that has been successfully processed. +// Should be called after the consumer has finished processing the message. +func (b *SQLiteQueueBackend) Ack(queueName string, messageID string) error { + result, err := b.db.Exec( + "DELETE FROM queue_messages WHERE queue_name = ? AND message_id = ? AND status = 'processing'", + queueName, messageID, + ) + if err != nil { + return fmt.Errorf("failed to ack message: %w", err) + } + rows, _ := result.RowsAffected() + if rows == 0 { + log.Warnf("[queuefs] Ack found no matching processing message: queue=%s msg=%s", queueName, messageID) + } + return nil +} + +// RecoverStale resets messages stuck in 'processing' state back to 'pending'. +// staleSec is the minimum age (in seconds) of a processing message before it +// is considered stale. Pass 0 to reset ALL processing messages immediately +// (appropriate at startup before any workers have started). +// Returns the number of messages recovered. +func (b *SQLiteQueueBackend) RecoverStale(staleSec int64) (int, error) { + cutoff := time.Now().Unix() - staleSec + result, err := b.db.Exec( + "UPDATE queue_messages SET status = 'pending', processing_started_at = NULL WHERE status = 'processing' AND processing_started_at <= ?", + cutoff, + ) + if err != nil { + return 0, fmt.Errorf("failed to recover stale messages: %w", err) + } + n, _ := result.RowsAffected() + return int(n), nil +} + +func (b *SQLiteQueueBackend) Peek(queueName string) (QueueMessage, bool, error) { + var data string + err := b.db.QueryRow( + "SELECT data FROM queue_messages WHERE queue_name = ? AND status = 'pending' ORDER BY id LIMIT 1", + queueName, + ).Scan(&data) + + if err == sql.ErrNoRows { + return QueueMessage{}, false, nil + } else if err != nil { + return QueueMessage{}, false, fmt.Errorf("failed to peek message: %w", err) + } + + var msg QueueMessage + if err := json.Unmarshal([]byte(data), &msg); err != nil { + return QueueMessage{}, false, fmt.Errorf("failed to unmarshal message: %w", err) + } + + return msg, true, nil +} + +// Size returns the number of pending (not yet dequeued) messages. +func (b *SQLiteQueueBackend) Size(queueName string) (int, error) { + var count int + err := b.db.QueryRow( + "SELECT COUNT(*) FROM queue_messages WHERE queue_name = ? AND status = 'pending'", + queueName, + ).Scan(&count) + if err != nil { + return 0, fmt.Errorf("failed to get queue size: %w", err) + } + return count, nil +} + +func (b *SQLiteQueueBackend) Clear(queueName string) error { + _, err := b.db.Exec("DELETE FROM queue_messages WHERE queue_name = ?", queueName) + if err != nil { + return fmt.Errorf("failed to clear queue: %w", err) + } + return nil +} + +func (b *SQLiteQueueBackend) ListQueues(prefix string) ([]string, error) { + var query string + var args []interface{} + + if prefix == "" { + query = "SELECT queue_name FROM queue_metadata" + } else { + query = "SELECT queue_name FROM queue_metadata WHERE queue_name = ? OR queue_name LIKE ?" + args = []interface{}{prefix, prefix + "/%"} + } + + rows, err := b.db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("failed to list queues: %w", err) + } + defer rows.Close() + + var queues []string + for rows.Next() { + var qName string + if err := rows.Scan(&qName); err != nil { + return nil, fmt.Errorf("failed to scan queue name: %w", err) + } + queues = append(queues, qName) + } + return queues, nil +} + +func (b *SQLiteQueueBackend) GetLastEnqueueTime(queueName string) (time.Time, error) { + var timestamp sql.NullInt64 + err := b.db.QueryRow( + "SELECT MAX(timestamp) FROM queue_messages WHERE queue_name = ? AND status = 'pending'", + queueName, + ).Scan(×tamp) + + if err != nil || !timestamp.Valid { + return time.Time{}, nil + } + return time.Unix(timestamp.Int64, 0), nil +} + +func (b *SQLiteQueueBackend) RemoveQueue(queueName string) error { + if queueName == "" { + if _, err := b.db.Exec("DELETE FROM queue_messages"); err != nil { + return err + } + _, err := b.db.Exec("DELETE FROM queue_metadata") + return err + } + + if _, err := b.db.Exec( + "DELETE FROM queue_messages WHERE queue_name = ? OR queue_name LIKE ?", + queueName, queueName+"/%", + ); err != nil { + return fmt.Errorf("failed to remove queue messages: %w", err) + } + + _, err := b.db.Exec( + "DELETE FROM queue_metadata WHERE queue_name = ? OR queue_name LIKE ?", + queueName, queueName+"/%", + ) + return err +} + +func (b *SQLiteQueueBackend) CreateQueue(queueName string) error { + _, err := b.db.Exec( + "INSERT OR IGNORE INTO queue_metadata (queue_name) VALUES (?)", + queueName, + ) + if err != nil { + return fmt.Errorf("failed to create queue: %w", err) + } + log.Infof("[queuefs] Created queue '%s' (SQLite)", queueName) + return nil +} + +func (b *SQLiteQueueBackend) QueueExists(queueName string) (bool, error) { + var count int + err := b.db.QueryRow( + "SELECT COUNT(*) FROM queue_metadata WHERE queue_name = ?", + queueName, + ).Scan(&count) + if err != nil { + return false, fmt.Errorf("failed to check queue existence: %w", err) + } + return count > 0, nil +}