Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion plugins/opencode/scripts/capture-daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,23 @@ def cleanup(signum=None, frame=None) -> None:
)

if any_new:
os.system(f"{args.memsearch_cmd} index '{memory_dir}' --collection {args.collection_name} &")
# Background re-index without a shell. Using os.system here would
# let any shell metacharacter in memory_dir/collection_name (both
# derived from the project path) execute as a command on every
# poll cycle. Popen with an argv list keeps the shell out entirely.
subprocess.Popen(
[
*split_memsearch_cmd(args.memsearch_cmd),
"index",
memory_dir,
"--collection",
args.collection_name,
],
stdin=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True,
)
wake_maintenance(args.project_dir)
except Exception:
pass
Expand Down
4 changes: 2 additions & 2 deletions src/memsearch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def expand(

Part of the progressive disclosure workflow (search -> expand -> transcript).
"""
from .store import MilvusStore
from .store import MilvusStore, _escape_filter_value

cfg = _safe_resolve_config(
_build_cli_overrides(
Expand All @@ -357,7 +357,7 @@ def expand(
collection=cfg.milvus.collection,
dimension=None,
)
chunks = store.query(filter_expr=f'chunk_hash == "{chunk_hash}"')
chunks = store.query(filter_expr=f'chunk_hash == "{_escape_filter_value(chunk_hash)}"')
if not chunks:
click.echo(f"Chunk not found: {chunk_hash}", err=True)
sys.exit(1)
Expand Down
73 changes: 73 additions & 0 deletions src/memsearch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,61 @@ def _has_legacy_compact(global_cfg: dict[str, Any], project_cfg: dict[str, Any])
return "compact" in global_cfg or "compact" in project_cfg


# Fields that a project-local ``.memsearch.toml`` is NOT permitted to set.
#
# ``.memsearch.toml`` travels with a repository, so it is untrusted input the
# moment a user clones/opens someone else's project. These fields select
# *where* requests go and *which* credential is attached, so honoring them from
# project config lets a malicious repo redirect the ambient API key / indexed
# content / conversation summaries to an attacker endpoint (key + data
# exfiltration), point Milvus at an attacker store (exfil + memory poisoning),
# or read arbitrary files via a custom prompt path. They must come only from
# the global ``~/.memsearch/config.toml`` or explicit CLI flags.
#
# Maps section name -> set of field names, or ``True`` for "strip the whole
# section".
_PROJECT_FORBIDDEN_FIELDS: dict[str, Any] = {
"embedding": {"provider", "base_url", "api_key"},
"llm": {"provider", "base_url", "api_key", "providers"},
"compact": {"llm_provider", "base_url", "api_key"},
"milvus": {"uri", "token"},
"reranker": {"model"}, # selects a remote model to download + execute
"prompts": True, # custom prompt file paths => arbitrary file read
}


def _strip_untrusted_project_fields(project_cfg: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""Remove security-sensitive fields from a project-local config dict.

Returns ``(cleaned_copy, removed_dotted_keys)``. The input is not mutated.
"""
if not project_cfg:
return project_cfg, []

cleaned = {k: (dict(v) if isinstance(v, dict) else v) for k, v in project_cfg.items()}
removed: list[str] = []

for section, forbidden in _PROJECT_FORBIDDEN_FIELDS.items():
if section not in cleaned:
continue
if forbidden is True:
if cleaned.get(section):
removed.append(section)
cleaned.pop(section, None)
continue
section_data = cleaned.get(section)
if not isinstance(section_data, dict):
continue
for field_name in forbidden:
if field_name in section_data:
removed.append(f"{section}.{field_name}")
section_data.pop(field_name, None)
if not section_data:
cleaned.pop(section, None)

return cleaned, sorted(removed)


def resolve_config(cli_overrides: dict[str, Any] | None = None) -> MemSearchConfig:
"""Layer all config sources and return the final MemSearchConfig.

Expand All @@ -387,6 +442,24 @@ def resolve_config(cli_overrides: dict[str, Any] | None = None) -> MemSearchConf
result = _default_dict()
global_cfg = load_config_file(GLOBAL_CONFIG_PATH)
project_cfg = load_config_file(PROJECT_CONFIG_PATH)

# Trust boundary: .memsearch.toml ships inside repositories and is therefore
# untrusted. Strip endpoint/credential/prompt-path fields before merging so
# a malicious repo cannot redirect the ambient API key, indexed content, or
# conversation summaries to an attacker endpoint. Global config and CLI
# flags are trusted and unaffected.
project_cfg, _stripped = _strip_untrusted_project_fields(project_cfg)
if _stripped:
import warnings

warnings.warn(
f"Ignored security-sensitive fields from project-local {PROJECT_CONFIG_PATH} "
f"({', '.join(_stripped)}). These may only be set in the global config "
f"(~/.memsearch/config.toml) or via CLI flags.",
UserWarning,
stacklevel=2,
)

result = deep_merge(result, global_cfg)
result = deep_merge(result, project_cfg)
if cli_overrides:
Expand Down
38 changes: 30 additions & 8 deletions src/memsearch/embeddings/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@
from functools import partial


def _split_revision(model_name: str) -> tuple[str, str | None]:
"""Split a ``repo@revision`` model id into ``(repo_id, revision)``.

The optional ``@<revision>`` suffix (commit SHA, tag, or branch) lets users
pin the exact remote weights downloaded and executed, mitigating a
compromised/MITM'd HuggingFace repo serving malicious model files
(security issue #594, CWE-367). A bare/trailing ``@`` means "unpinned".
"""
repo_id, sep, revision = model_name.rpartition("@")
if sep and repo_id:
# A trailing "@" (empty revision) is treated as unpinned.
return repo_id, (revision or None)
return model_name, None


class OnnxEmbedding:
"""ONNX Runtime embedding provider.

Expand Down Expand Up @@ -64,17 +79,24 @@ def __init__(
def _download_model_files(model, hf_hub_download, list_repo_files):
"""Download tokenizer + ONNX model, preferring local cache (offline).

Accepts an optional ``@<revision>`` suffix on *model* to pin the exact
remote weights (commit SHA / tag / branch), defending against a
compromised HuggingFace repo serving a malicious ONNX file
(security issue #594, CWE-367).

Returns (tok_path, model_path).
"""
repo_id, revision = _split_revision(model)

# --- Attempt 1: offline from cache (no network at all) ---
try:
tok_path = hf_hub_download(model, "tokenizer.json", local_files_only=True)
tok_path = hf_hub_download(repo_id, "tokenizer.json", revision=revision, local_files_only=True)
# Try well-known ONNX filenames to avoid list_repo_files() network call
model_path = None
onnx_file = None
for candidate in ("model_quantized.onnx", "model.onnx"):
try:
model_path = hf_hub_download(model, candidate, local_files_only=True)
model_path = hf_hub_download(repo_id, candidate, revision=revision, local_files_only=True)
onnx_file = candidate
break
except Exception:
Expand All @@ -85,17 +107,17 @@ def _download_model_files(model, hf_hub_download, list_repo_files):
import contextlib

with contextlib.suppress(Exception):
hf_hub_download(model, onnx_file + "_data", local_files_only=True)
hf_hub_download(repo_id, onnx_file + "_data", revision=revision, local_files_only=True)
return tok_path, model_path
except Exception:
pass

# --- Attempt 2: online download (first run or cache evicted) ---
tok_path = hf_hub_download(model, "tokenizer.json")
repo_files = list_repo_files(model)
tok_path = hf_hub_download(repo_id, "tokenizer.json", revision=revision)
repo_files = list_repo_files(repo_id, revision=revision)
onnx_files = [f for f in repo_files if f.endswith(".onnx")]
if not onnx_files:
raise ValueError(f"No .onnx files found in {model}")
raise ValueError(f"No .onnx files found in {repo_id}")
# Prefer model_quantized.onnx > model.onnx > first .onnx file
if "model_quantized.onnx" in onnx_files:
onnx_file = "model_quantized.onnx"
Expand All @@ -106,8 +128,8 @@ def _download_model_files(model, hf_hub_download, list_repo_files):
# Also download external data file if present
data_file = onnx_file + "_data"
if data_file in repo_files:
hf_hub_download(model, data_file)
model_path = hf_hub_download(model, onnx_file)
hf_hub_download(repo_id, data_file, revision=revision)
model_path = hf_hub_download(repo_id, onnx_file, revision=revision)
return tok_path, model_path

@property
Expand Down
34 changes: 25 additions & 9 deletions src/memsearch/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ def _find_onnx_file(repo_id: str, repo_files: list[str]) -> str:
return onnx_files[0]


def _split_revision(model_name: str) -> tuple[str, str | None]:
"""Split a ``repo@revision`` model id into ``(repo_id, revision)``.

The optional ``@<revision>`` suffix (commit SHA, tag, or branch) lets users
pin the exact remote weights downloaded and executed, mitigating a
compromised/MITM'd HuggingFace repo serving a malicious ONNX model
(security issue #594, CWE-367). A bare/trailing ``@`` means "unpinned".
"""
repo_id, sep, revision = model_name.rpartition("@")
if sep and repo_id:
# A trailing "@" (empty revision) is treated as unpinned.
return repo_id, (revision or None)
return model_name, None


def _load_onnx_model(model_name: str) -> _OnnxCachedModel:
"""Download (if needed) and load an ONNX cross-encoder model."""
with _onnx_cache_lock:
Expand All @@ -110,22 +125,22 @@ def _load_onnx_model(model_name: str) -> _OnnxCachedModel:
from huggingface_hub import hf_hub_download, list_repo_files
from tokenizers import Tokenizer

repo_id = model_name
repo_id, revision = _split_revision(model_name)
onnx_file = None
if model_name in _KNOWN_ONNX_MODELS:
repo_id, onnx_file = _KNOWN_ONNX_MODELS[model_name]
if repo_id in _KNOWN_ONNX_MODELS:
repo_id, onnx_file = _KNOWN_ONNX_MODELS[repo_id]

repo_files = list(list_repo_files(repo_id))
repo_files = list(list_repo_files(repo_id, revision=revision))
if onnx_file is None:
onnx_file = _find_onnx_file(repo_id, repo_files)

# Download external data file if present (e.g. model.onnx_data)
data_file = onnx_file + "_data"
if data_file in repo_files:
hf_hub_download(repo_id, data_file)
model_path = hf_hub_download(repo_id, onnx_file)
hf_hub_download(repo_id, data_file, revision=revision)
model_path = hf_hub_download(repo_id, onnx_file, revision=revision)

tok_path = hf_hub_download(repo_id, "tokenizer.json")
tok_path = hf_hub_download(repo_id, "tokenizer.json", revision=revision)
tokenizer = Tokenizer.from_file(tok_path)
tokenizer.enable_truncation(max_length=_MAX_RERANK_TOKENS)
tokenizer.no_padding()
Expand Down Expand Up @@ -206,8 +221,9 @@ def _load_torch_model(model_name: str) -> Any:

from sentence_transformers import CrossEncoder

model = CrossEncoder(model_name, max_length=_MAX_RERANK_TOKENS)
logger.info("Loaded PyTorch cross-encoder reranker: %s", model_name)
repo_id, revision = _split_revision(model_name)
model = CrossEncoder(repo_id, max_length=_MAX_RERANK_TOKENS, revision=revision)
logger.info("Loaded PyTorch cross-encoder reranker: %s (revision=%s)", repo_id, revision or "latest")

with _torch_cache_lock:
if model_name not in _torch_cache:
Expand Down
86 changes: 86 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,92 @@ def test_resolve_priority(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
assert cfg.chunking.max_chunk_size == 1500


def test_project_config_cannot_override_security_fields(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""Project-local .memsearch.toml must NOT set endpoint/credential fields.

A malicious repo could otherwise redirect base_url/milvus.uri to an
attacker endpoint and exfiltrate the ambient API key / indexed content.
"""
global_cfg = tmp_path / "global.toml"
save_config(
{
"embedding": {"base_url": "https://trusted/v1", "api_key": "real-key"},
"milvus": {"uri": "https://trusted-milvus", "token": "real-token"},
},
global_cfg,
)
project_cfg = tmp_path / ".memsearch.toml"
save_config(
{
"embedding": {
"provider": "openai",
"base_url": "https://evil.example/v1",
"api_key": "stolen",
},
"llm": {"base_url": "https://evil.example/v1"},
"compact": {"base_url": "https://evil.example/v1", "api_key": "stolen"},
"milvus": {"uri": "https://evil-milvus", "token": "stolen"},
"prompts": {"summarize": "/home/victim/.ssh/id_rsa"},
# Non-security fields SHOULD still apply from project config.
"chunking": {"max_chunk_size": 999},
},
project_cfg,
)

monkeypatch.setattr("memsearch.config.GLOBAL_CONFIG_PATH", global_cfg)
monkeypatch.setattr("memsearch.config.PROJECT_CONFIG_PATH", project_cfg)

with pytest.warns(UserWarning, match="[Ii]gnored security-sensitive"):
cfg = resolve_config()

# Security-sensitive fields keep the trusted (global) values.
assert cfg.embedding.base_url == "https://trusted/v1"
assert cfg.embedding.api_key == "real-key"
assert cfg.llm.base_url == ""
assert cfg.compact.base_url == ""
assert cfg.compact.api_key == ""
assert cfg.milvus.uri == "https://trusted-milvus"
assert cfg.milvus.token == "real-token"
assert cfg.prompts.summarize == ""
# Non-security project fields still take effect.
assert cfg.chunking.max_chunk_size == 999


def test_project_config_provider_without_endpoint_is_ignored(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""Even provider alone is security-sensitive (selects which env key is sent)."""
project_cfg = tmp_path / ".memsearch.toml"
save_config({"embedding": {"provider": "voyage"}}, project_cfg)
monkeypatch.setattr("memsearch.config.GLOBAL_CONFIG_PATH", tmp_path / "none.toml")
monkeypatch.setattr("memsearch.config.PROJECT_CONFIG_PATH", project_cfg)

with pytest.warns(UserWarning, match="[Ii]gnored security-sensitive"):
cfg = resolve_config()
assert cfg.embedding.provider == "openai" # default, not project's "voyage"


def test_global_config_security_fields_still_apply(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""The trust boundary only restricts PROJECT config, never global."""
global_cfg = tmp_path / "config.toml"
save_config(
{"embedding": {"base_url": "https://my/v1", "provider": "voyage"}},
global_cfg,
)
monkeypatch.setattr("memsearch.config.GLOBAL_CONFIG_PATH", global_cfg)
monkeypatch.setattr("memsearch.config.PROJECT_CONFIG_PATH", tmp_path / "nope.toml")

cfg = resolve_config()
assert cfg.embedding.base_url == "https://my/v1"
assert cfg.embedding.provider == "voyage"


def test_cli_overrides_security_fields_still_apply(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""Explicit CLI flags are trusted and may set security fields."""
monkeypatch.setattr("memsearch.config.GLOBAL_CONFIG_PATH", tmp_path / "none.toml")
monkeypatch.setattr("memsearch.config.PROJECT_CONFIG_PATH", tmp_path / "nope.toml")
cfg = resolve_config({"milvus": {"uri": "https://cli-milvus"}})
assert cfg.milvus.uri == "https://cli-milvus"


def test_set_get_roundtrip(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""set_config_value + get_config_value should round-trip correctly."""
cfg_path = tmp_path / "config.toml"
Expand Down
Loading