Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 98 additions & 12 deletions openviking/session/memory_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,92 @@ def _format_message_with_parts(self, msg) -> str:

return "\n".join(lines) if lines else ""

@staticmethod
def _normalize_extraction_payload(data) -> dict:
"""Normalize common LLM response shapes into ``{"memories": [...]}``.

Some smaller or OpenAI-compatible local models do not reliably return the
exact outer schema requested by the prompt. In practice we may see:
- a bare list of memory objects
- a single memory object
- ``{"memories": {...}}`` with one object instead of a list
- wrapper keys such as ``items`` / ``results`` / ``data``

This method keeps the accepted shapes intentionally narrow: a payload is
only treated as a single memory when it has a ``category`` plus at least
one memory content field.
"""

def _is_memory_item(value) -> bool:
if not isinstance(value, dict):
return False
if not value.get("category"):
return False
return any(
bool(value.get(field))
for field in ("abstract", "overview", "content", "tool_name", "skill_name")
)

def _coerce_sequence(value, source: str) -> list:
if isinstance(value, list):
dict_items = [item for item in value if isinstance(item, dict)]
dropped = len(value) - len(dict_items)
if dropped > 0:
logger.warning(
"Memory extraction ignored %d non-dict item(s) from %s",
dropped,
source,
)
return dict_items
if _is_memory_item(value):
logger.debug(
"Memory extraction normalized single memory object from %s",
source,
)
return [value]
return []

if isinstance(data, list):
logger.debug("Memory extraction normalized bare list response")
return {"memories": _coerce_sequence(data, "root list")}

if not isinstance(data, dict):
if data is not None:
logger.warning(
"Memory extraction received unexpected normalized payload type %s",
type(data).__name__,
)
return {}

if _is_memory_item(data):
logger.debug("Memory extraction normalized bare memory object response")
return {"memories": [data]}

memories = data.get("memories")
if memories is not None:
if not isinstance(memories, list):
logger.debug("Memory extraction normalized non-list memories field")
return {"memories": _coerce_sequence(memories, "memories field")}

for key in ("items", "results", "data"):
wrapped = data.get(key)
if wrapped is None:
continue
if isinstance(wrapped, dict) and "memories" in wrapped:
logger.debug("Memory extraction normalized %s wrapper", key)
return MemoryExtractor._normalize_extraction_payload(wrapped)
coerced = _coerce_sequence(wrapped, f"{key} wrapper")
if coerced:
logger.debug("Memory extraction normalized %s wrapper sequence", key)
return {"memories": coerced}

if data:
logger.debug(
"Memory extraction payload dict did not match known schemas; keys=%s",
sorted(str(key) for key in data.keys())[:8],
)
return {}

async def extract(
self,
context: dict,
Expand Down Expand Up @@ -301,6 +387,7 @@ async def extract(

try:
from openviking_cli.utils.llm import parse_json_from_response
from openviking.session.memory.utils import parse_json_with_stability

request_summary = {
"user": user._user_id,
Expand All @@ -315,18 +402,17 @@ async def extract(
response = await vlm.get_completion_async(prompt)
logger.debug("Memory extraction LLM raw response: %s", response)
with telemetry.measure("memory.extract.stage.normalize_candidates"):
data = parse_json_from_response(response) or {}
if isinstance(data, list):
logger.warning(
"Memory extraction received list instead of dict; wrapping as memories"
)
data = {"memories": data}
elif not isinstance(data, dict):
logger.warning(
"Memory extraction received unexpected type %s; skipping",
type(data).__name__,
)
data = {}
data = parse_json_from_response(response)
if data is None:
stable_data, stable_error = parse_json_with_stability(response)
if stable_data is not None:
data = stable_data
else:
logger.warning(
"Memory extraction stable parse failed: %s",
stable_error,
)
data = self._normalize_extraction_payload(data)
logger.debug("Memory extraction LLM parsed payload: %s", data)

candidates = []
Expand Down
179 changes: 163 additions & 16 deletions tests/session/test_memory_extractor_response_types.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,106 @@
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
# SPDX-License-Identifier: AGPL-3.0
"""
Tests that memory extraction handles non-dict LLM responses gracefully.
Tests that memory extraction normalizes common non-canonical LLM response shapes.

Covers issue #605: Ollama models may return a JSON list instead of the
expected {"memories": [...]} dict, causing AttributeError on .get().
Covers:
- issue #605: Ollama models may return a bare JSON list instead of the
expected {"memories": [...]} dict
- issue #1410: smaller local/OpenAI-compatible models may return a single
memory object or wrap one under ``memories`` as an object
"""

import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
from unittest.mock import MagicMock

try:
from openviking.session.memory_extractor import MemoryExtractor
except Exception: # pragma: no cover - fallback for minimal local test env
logger_stub = SimpleNamespace(
debug=lambda *a, **k: None,
info=lambda *a, **k: None,
warning=lambda *a, **k: None,
error=lambda *a, **k: None,
)

modules = {
"openviking": ModuleType("openviking"),
"openviking.core": ModuleType("openviking.core"),
"openviking.core.context": ModuleType("openviking.core.context"),
"openviking.prompts": ModuleType("openviking.prompts"),
"openviking.server": ModuleType("openviking.server"),
"openviking.server.identity": ModuleType("openviking.server.identity"),
"openviking.storage": ModuleType("openviking.storage"),
"openviking.storage.viking_fs": ModuleType("openviking.storage.viking_fs"),
"openviking.telemetry": ModuleType("openviking.telemetry"),
"openviking_cli": ModuleType("openviking_cli"),
"openviking_cli.exceptions": ModuleType("openviking_cli.exceptions"),
"openviking_cli.session": ModuleType("openviking_cli.session"),
"openviking_cli.session.user_id": ModuleType("openviking_cli.session.user_id"),
"openviking_cli.utils": ModuleType("openviking_cli.utils"),
"openviking_cli.utils.config": ModuleType("openviking_cli.utils.config"),
}

modules["openviking.core.context"].Context = object
modules["openviking.core.context"].ContextType = SimpleNamespace(
MEMORY=SimpleNamespace(value="memory")
)
modules["openviking.core.context"].Vectorize = object
modules["openviking.prompts"].render_prompt = lambda *a, **k: ""
modules["openviking.server.identity"].RequestContext = object
modules["openviking.storage.viking_fs"].get_viking_fs = lambda: None
modules["openviking.telemetry"].get_current_telemetry = lambda: SimpleNamespace(
measure=lambda *_a, **_k: SimpleNamespace(
__enter__=lambda self: None,
__exit__=lambda self, exc_type, exc, tb: False,
)
)

class _NotFoundError(Exception):
pass

modules["openviking_cli.exceptions"].NotFoundError = _NotFoundError
modules["openviking_cli.session.user_id"].UserIdentifier = object
modules["openviking_cli.utils"].get_logger = lambda _name: logger_stub
modules["openviking_cli.utils.config"].get_openviking_config = lambda: SimpleNamespace(
language_fallback="en",
vlm=None,
)

for name, module in modules.items():
sys.modules.setdefault(name, module)

module_path = (
Path(__file__).resolve().parents[2] / "openviking" / "session" / "memory_extractor.py"
)
spec = importlib.util.spec_from_file_location(
"openviking.session.memory_extractor", module_path
)
memory_extractor = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(memory_extractor)
MemoryExtractor = memory_extractor.MemoryExtractor

def _normalize_parsed_data(data):
"""
Replicate the type-checking logic added in memory_extractor.py:extract().
return MemoryExtractor._normalize_extraction_payload(data)

After ``parse_json_from_response(response) or {}``, the code now does:
- list -> wrap as ``{"memories": data}``
- dict -> use as-is
- other -> fall back to ``{}``
"""
if isinstance(data, list):
return {"memories": data}
if not isinstance(data, dict):
return {}
return data

def _memory_extractor_module():
return sys.modules[MemoryExtractor.__module__]


def _make_memory(category="patterns", content="user prefers dark mode"):
return {"category": category, "content": content, "event": "", "emoji": ""}
return {
"category": category,
"abstract": "sample abstract",
"overview": "sample overview",
"content": content,
"event": "",
"emoji": "",
}


class TestExtractResponseTypes:
Expand All @@ -49,6 +124,46 @@ def test_list_response_wrapped_as_memories(self):
assert len(data["memories"]) == 2
assert data["memories"][1]["content"] == "likes Python"

def test_single_memory_object_wrapped_as_memories(self):
"""A bare memory object should be treated as one extracted memory."""
payload = _make_memory(category="preferences", content="likes pour-over coffee")
data = _normalize_parsed_data(payload)

assert isinstance(data, dict)
assert len(data["memories"]) == 1
assert data["memories"][0]["category"] == "preferences"

def test_memories_object_wrapped_into_single_item_list(self):
"""Some small models emit {"memories": {...}} instead of a list."""
payload = {"memories": _make_memory(category="entities", content="dog named Wangcai")}
data = _normalize_parsed_data(payload)

assert isinstance(data, dict)
assert len(data["memories"]) == 1
assert data["memories"][0]["category"] == "entities"

def test_items_wrapper_is_accepted(self):
"""Alternative wrapper keys like ``items`` should be normalized."""
payload = {"items": [_make_memory(category="events", content="scheduled vet visit")]}
data = _normalize_parsed_data(payload)

assert isinstance(data, dict)
assert len(data["memories"]) == 1
assert data["memories"][0]["category"] == "events"

def test_nested_data_memories_wrapper_is_accepted(self):
"""Nested ``data -> memories`` wrappers should be unwrapped."""
payload = {
"data": {
"memories": _make_memory(category="preferences", content="prefers oat milk")
}
}
data = _normalize_parsed_data(payload)

assert isinstance(data, dict)
assert len(data["memories"]) == 1
assert data["memories"][0]["content"] == "prefers oat milk"

def test_string_response_yields_empty(self):
"""If parse returns a bare string, treat as empty."""
data = _normalize_parsed_data("some unexpected string")
Expand All @@ -75,3 +190,35 @@ def test_empty_list_wraps_to_empty_memories(self):

assert data == {"memories": []}
assert data.get("memories", []) == []

def test_logs_when_non_dict_items_are_dropped(self, monkeypatch):
module = _memory_extractor_module()
warning = MagicMock()
monkeypatch.setattr(module.logger, "warning", warning)

data = _normalize_parsed_data([_make_memory(), "bad-item", 42])

assert len(data["memories"]) == 1
warning.assert_called_once()
assert "ignored" in warning.call_args.args[0]

def test_logs_when_single_memory_object_is_normalized(self, monkeypatch):
module = _memory_extractor_module()
debug = MagicMock()
monkeypatch.setattr(module.logger, "debug", debug)

data = _normalize_parsed_data(_make_memory(category="preferences"))

assert len(data["memories"]) == 1
debug.assert_called()

def test_logs_when_payload_type_is_unexpected(self, monkeypatch):
module = _memory_extractor_module()
warning = MagicMock()
monkeypatch.setattr(module.logger, "warning", warning)

data = _normalize_parsed_data(42)

assert data == {}
warning.assert_called_once()
assert "unexpected normalized payload type" in warning.call_args.args[0]