diff --git a/plugins/opencode/scripts/capture-daemon.py b/plugins/opencode/scripts/capture-daemon.py index f05e302c..7df0b473 100755 --- a/plugins/opencode/scripts/capture-daemon.py +++ b/plugins/opencode/scripts/capture-daemon.py @@ -812,7 +812,23 @@ def cleanup(signum=None, frame=None) -> None: ) if any_new: - os.system(f"{args.memsearch_cmd} index '{memory_dir}' --collection {args.collection_name} &") + # Background re-index without a shell. Using os.system here would + # let any shell metacharacter in memory_dir/collection_name (both + # derived from the project path) execute as a command on every + # poll cycle. Popen with an argv list keeps the shell out entirely. + subprocess.Popen( + [ + *split_memsearch_cmd(args.memsearch_cmd), + "index", + memory_dir, + "--collection", + args.collection_name, + ], + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, + ) wake_maintenance(args.project_dir) except Exception: pass diff --git a/src/memsearch/cli.py b/src/memsearch/cli.py index b126e87e..32245b99 100644 --- a/src/memsearch/cli.py +++ b/src/memsearch/cli.py @@ -335,7 +335,7 @@ def expand( Part of the progressive disclosure workflow (search -> expand -> transcript). """ - from .store import MilvusStore + from .store import MilvusStore, _escape_filter_value cfg = _safe_resolve_config( _build_cli_overrides( @@ -357,7 +357,7 @@ def expand( collection=cfg.milvus.collection, dimension=None, ) - chunks = store.query(filter_expr=f'chunk_hash == "{chunk_hash}"') + chunks = store.query(filter_expr=f'chunk_hash == "{_escape_filter_value(chunk_hash)}"') if not chunks: click.echo(f"Chunk not found: {chunk_hash}", err=True) sys.exit(1) diff --git a/src/memsearch/config.py b/src/memsearch/config.py index b85c126d..7f14198e 100644 --- a/src/memsearch/config.py +++ b/src/memsearch/config.py @@ -378,6 +378,61 @@ def _has_legacy_compact(global_cfg: dict[str, Any], project_cfg: dict[str, Any]) return "compact" in global_cfg or "compact" in project_cfg +# Fields that a project-local ``.memsearch.toml`` is NOT permitted to set. +# +# ``.memsearch.toml`` travels with a repository, so it is untrusted input the +# moment a user clones/opens someone else's project. These fields select +# *where* requests go and *which* credential is attached, so honoring them from +# project config lets a malicious repo redirect the ambient API key / indexed +# content / conversation summaries to an attacker endpoint (key + data +# exfiltration), point Milvus at an attacker store (exfil + memory poisoning), +# or read arbitrary files via a custom prompt path. They must come only from +# the global ``~/.memsearch/config.toml`` or explicit CLI flags. +# +# Maps section name -> set of field names, or ``True`` for "strip the whole +# section". +_PROJECT_FORBIDDEN_FIELDS: dict[str, Any] = { + "embedding": {"provider", "base_url", "api_key"}, + "llm": {"provider", "base_url", "api_key", "providers"}, + "compact": {"llm_provider", "base_url", "api_key"}, + "milvus": {"uri", "token"}, + "reranker": {"model"}, # selects a remote model to download + execute + "prompts": True, # custom prompt file paths => arbitrary file read +} + + +def _strip_untrusted_project_fields(project_cfg: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: + """Remove security-sensitive fields from a project-local config dict. + + Returns ``(cleaned_copy, removed_dotted_keys)``. The input is not mutated. + """ + if not project_cfg: + return project_cfg, [] + + cleaned = {k: (dict(v) if isinstance(v, dict) else v) for k, v in project_cfg.items()} + removed: list[str] = [] + + for section, forbidden in _PROJECT_FORBIDDEN_FIELDS.items(): + if section not in cleaned: + continue + if forbidden is True: + if cleaned.get(section): + removed.append(section) + cleaned.pop(section, None) + continue + section_data = cleaned.get(section) + if not isinstance(section_data, dict): + continue + for field_name in forbidden: + if field_name in section_data: + removed.append(f"{section}.{field_name}") + section_data.pop(field_name, None) + if not section_data: + cleaned.pop(section, None) + + return cleaned, sorted(removed) + + def resolve_config(cli_overrides: dict[str, Any] | None = None) -> MemSearchConfig: """Layer all config sources and return the final MemSearchConfig. @@ -387,6 +442,24 @@ def resolve_config(cli_overrides: dict[str, Any] | None = None) -> MemSearchConf result = _default_dict() global_cfg = load_config_file(GLOBAL_CONFIG_PATH) project_cfg = load_config_file(PROJECT_CONFIG_PATH) + + # Trust boundary: .memsearch.toml ships inside repositories and is therefore + # untrusted. Strip endpoint/credential/prompt-path fields before merging so + # a malicious repo cannot redirect the ambient API key, indexed content, or + # conversation summaries to an attacker endpoint. Global config and CLI + # flags are trusted and unaffected. + project_cfg, _stripped = _strip_untrusted_project_fields(project_cfg) + if _stripped: + import warnings + + warnings.warn( + f"Ignored security-sensitive fields from project-local {PROJECT_CONFIG_PATH} " + f"({', '.join(_stripped)}). These may only be set in the global config " + f"(~/.memsearch/config.toml) or via CLI flags.", + UserWarning, + stacklevel=2, + ) + result = deep_merge(result, global_cfg) result = deep_merge(result, project_cfg) if cli_overrides: diff --git a/src/memsearch/embeddings/onnx.py b/src/memsearch/embeddings/onnx.py index de99e31b..c4e0b245 100644 --- a/src/memsearch/embeddings/onnx.py +++ b/src/memsearch/embeddings/onnx.py @@ -11,6 +11,21 @@ from functools import partial +def _split_revision(model_name: str) -> tuple[str, str | None]: + """Split a ``repo@revision`` model id into ``(repo_id, revision)``. + + The optional ``@`` suffix (commit SHA, tag, or branch) lets users + pin the exact remote weights downloaded and executed, mitigating a + compromised/MITM'd HuggingFace repo serving malicious model files + (security issue #594, CWE-367). A bare/trailing ``@`` means "unpinned". + """ + repo_id, sep, revision = model_name.rpartition("@") + if sep and repo_id: + # A trailing "@" (empty revision) is treated as unpinned. + return repo_id, (revision or None) + return model_name, None + + class OnnxEmbedding: """ONNX Runtime embedding provider. @@ -64,17 +79,24 @@ def __init__( def _download_model_files(model, hf_hub_download, list_repo_files): """Download tokenizer + ONNX model, preferring local cache (offline). + Accepts an optional ``@`` suffix on *model* to pin the exact + remote weights (commit SHA / tag / branch), defending against a + compromised HuggingFace repo serving a malicious ONNX file + (security issue #594, CWE-367). + Returns (tok_path, model_path). """ + repo_id, revision = _split_revision(model) + # --- Attempt 1: offline from cache (no network at all) --- try: - tok_path = hf_hub_download(model, "tokenizer.json", local_files_only=True) + tok_path = hf_hub_download(repo_id, "tokenizer.json", revision=revision, local_files_only=True) # Try well-known ONNX filenames to avoid list_repo_files() network call model_path = None onnx_file = None for candidate in ("model_quantized.onnx", "model.onnx"): try: - model_path = hf_hub_download(model, candidate, local_files_only=True) + model_path = hf_hub_download(repo_id, candidate, revision=revision, local_files_only=True) onnx_file = candidate break except Exception: @@ -85,17 +107,17 @@ def _download_model_files(model, hf_hub_download, list_repo_files): import contextlib with contextlib.suppress(Exception): - hf_hub_download(model, onnx_file + "_data", local_files_only=True) + hf_hub_download(repo_id, onnx_file + "_data", revision=revision, local_files_only=True) return tok_path, model_path except Exception: pass # --- Attempt 2: online download (first run or cache evicted) --- - tok_path = hf_hub_download(model, "tokenizer.json") - repo_files = list_repo_files(model) + tok_path = hf_hub_download(repo_id, "tokenizer.json", revision=revision) + repo_files = list_repo_files(repo_id, revision=revision) onnx_files = [f for f in repo_files if f.endswith(".onnx")] if not onnx_files: - raise ValueError(f"No .onnx files found in {model}") + raise ValueError(f"No .onnx files found in {repo_id}") # Prefer model_quantized.onnx > model.onnx > first .onnx file if "model_quantized.onnx" in onnx_files: onnx_file = "model_quantized.onnx" @@ -106,8 +128,8 @@ def _download_model_files(model, hf_hub_download, list_repo_files): # Also download external data file if present data_file = onnx_file + "_data" if data_file in repo_files: - hf_hub_download(model, data_file) - model_path = hf_hub_download(model, onnx_file) + hf_hub_download(repo_id, data_file, revision=revision) + model_path = hf_hub_download(repo_id, onnx_file, revision=revision) return tok_path, model_path @property diff --git a/src/memsearch/reranker.py b/src/memsearch/reranker.py index 403a8462..246396b7 100644 --- a/src/memsearch/reranker.py +++ b/src/memsearch/reranker.py @@ -101,6 +101,21 @@ def _find_onnx_file(repo_id: str, repo_files: list[str]) -> str: return onnx_files[0] +def _split_revision(model_name: str) -> tuple[str, str | None]: + """Split a ``repo@revision`` model id into ``(repo_id, revision)``. + + The optional ``@`` suffix (commit SHA, tag, or branch) lets users + pin the exact remote weights downloaded and executed, mitigating a + compromised/MITM'd HuggingFace repo serving a malicious ONNX model + (security issue #594, CWE-367). A bare/trailing ``@`` means "unpinned". + """ + repo_id, sep, revision = model_name.rpartition("@") + if sep and repo_id: + # A trailing "@" (empty revision) is treated as unpinned. + return repo_id, (revision or None) + return model_name, None + + def _load_onnx_model(model_name: str) -> _OnnxCachedModel: """Download (if needed) and load an ONNX cross-encoder model.""" with _onnx_cache_lock: @@ -110,22 +125,22 @@ def _load_onnx_model(model_name: str) -> _OnnxCachedModel: from huggingface_hub import hf_hub_download, list_repo_files from tokenizers import Tokenizer - repo_id = model_name + repo_id, revision = _split_revision(model_name) onnx_file = None - if model_name in _KNOWN_ONNX_MODELS: - repo_id, onnx_file = _KNOWN_ONNX_MODELS[model_name] + if repo_id in _KNOWN_ONNX_MODELS: + repo_id, onnx_file = _KNOWN_ONNX_MODELS[repo_id] - repo_files = list(list_repo_files(repo_id)) + repo_files = list(list_repo_files(repo_id, revision=revision)) if onnx_file is None: onnx_file = _find_onnx_file(repo_id, repo_files) # Download external data file if present (e.g. model.onnx_data) data_file = onnx_file + "_data" if data_file in repo_files: - hf_hub_download(repo_id, data_file) - model_path = hf_hub_download(repo_id, onnx_file) + hf_hub_download(repo_id, data_file, revision=revision) + model_path = hf_hub_download(repo_id, onnx_file, revision=revision) - tok_path = hf_hub_download(repo_id, "tokenizer.json") + tok_path = hf_hub_download(repo_id, "tokenizer.json", revision=revision) tokenizer = Tokenizer.from_file(tok_path) tokenizer.enable_truncation(max_length=_MAX_RERANK_TOKENS) tokenizer.no_padding() @@ -206,8 +221,9 @@ def _load_torch_model(model_name: str) -> Any: from sentence_transformers import CrossEncoder - model = CrossEncoder(model_name, max_length=_MAX_RERANK_TOKENS) - logger.info("Loaded PyTorch cross-encoder reranker: %s", model_name) + repo_id, revision = _split_revision(model_name) + model = CrossEncoder(repo_id, max_length=_MAX_RERANK_TOKENS, revision=revision) + logger.info("Loaded PyTorch cross-encoder reranker: %s (revision=%s)", repo_id, revision or "latest") with _torch_cache_lock: if model_name not in _torch_cache: diff --git a/tests/test_config.py b/tests/test_config.py index f43586be..b2a4bb74 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -127,6 +127,92 @@ def test_resolve_priority(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): assert cfg.chunking.max_chunk_size == 1500 +def test_project_config_cannot_override_security_fields(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Project-local .memsearch.toml must NOT set endpoint/credential fields. + + A malicious repo could otherwise redirect base_url/milvus.uri to an + attacker endpoint and exfiltrate the ambient API key / indexed content. + """ + global_cfg = tmp_path / "global.toml" + save_config( + { + "embedding": {"base_url": "https://trusted/v1", "api_key": "real-key"}, + "milvus": {"uri": "https://trusted-milvus", "token": "real-token"}, + }, + global_cfg, + ) + project_cfg = tmp_path / ".memsearch.toml" + save_config( + { + "embedding": { + "provider": "openai", + "base_url": "https://evil.example/v1", + "api_key": "stolen", + }, + "llm": {"base_url": "https://evil.example/v1"}, + "compact": {"base_url": "https://evil.example/v1", "api_key": "stolen"}, + "milvus": {"uri": "https://evil-milvus", "token": "stolen"}, + "prompts": {"summarize": "/home/victim/.ssh/id_rsa"}, + # Non-security fields SHOULD still apply from project config. + "chunking": {"max_chunk_size": 999}, + }, + project_cfg, + ) + + monkeypatch.setattr("memsearch.config.GLOBAL_CONFIG_PATH", global_cfg) + monkeypatch.setattr("memsearch.config.PROJECT_CONFIG_PATH", project_cfg) + + with pytest.warns(UserWarning, match="[Ii]gnored security-sensitive"): + cfg = resolve_config() + + # Security-sensitive fields keep the trusted (global) values. + assert cfg.embedding.base_url == "https://trusted/v1" + assert cfg.embedding.api_key == "real-key" + assert cfg.llm.base_url == "" + assert cfg.compact.base_url == "" + assert cfg.compact.api_key == "" + assert cfg.milvus.uri == "https://trusted-milvus" + assert cfg.milvus.token == "real-token" + assert cfg.prompts.summarize == "" + # Non-security project fields still take effect. + assert cfg.chunking.max_chunk_size == 999 + + +def test_project_config_provider_without_endpoint_is_ignored(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Even provider alone is security-sensitive (selects which env key is sent).""" + project_cfg = tmp_path / ".memsearch.toml" + save_config({"embedding": {"provider": "voyage"}}, project_cfg) + monkeypatch.setattr("memsearch.config.GLOBAL_CONFIG_PATH", tmp_path / "none.toml") + monkeypatch.setattr("memsearch.config.PROJECT_CONFIG_PATH", project_cfg) + + with pytest.warns(UserWarning, match="[Ii]gnored security-sensitive"): + cfg = resolve_config() + assert cfg.embedding.provider == "openai" # default, not project's "voyage" + + +def test_global_config_security_fields_still_apply(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """The trust boundary only restricts PROJECT config, never global.""" + global_cfg = tmp_path / "config.toml" + save_config( + {"embedding": {"base_url": "https://my/v1", "provider": "voyage"}}, + global_cfg, + ) + monkeypatch.setattr("memsearch.config.GLOBAL_CONFIG_PATH", global_cfg) + monkeypatch.setattr("memsearch.config.PROJECT_CONFIG_PATH", tmp_path / "nope.toml") + + cfg = resolve_config() + assert cfg.embedding.base_url == "https://my/v1" + assert cfg.embedding.provider == "voyage" + + +def test_cli_overrides_security_fields_still_apply(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Explicit CLI flags are trusted and may set security fields.""" + monkeypatch.setattr("memsearch.config.GLOBAL_CONFIG_PATH", tmp_path / "none.toml") + monkeypatch.setattr("memsearch.config.PROJECT_CONFIG_PATH", tmp_path / "nope.toml") + cfg = resolve_config({"milvus": {"uri": "https://cli-milvus"}}) + assert cfg.milvus.uri == "https://cli-milvus" + + def test_set_get_roundtrip(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """set_config_value + get_config_value should round-trip correctly.""" cfg_path = tmp_path / "config.toml" diff --git a/tests/test_model_revision_pin.py b/tests/test_model_revision_pin.py new file mode 100644 index 00000000..6e84a563 --- /dev/null +++ b/tests/test_model_revision_pin.py @@ -0,0 +1,38 @@ +"""Tests for pinned model revisions (security: supply-chain). + +A model id may carry a ``@`` suffix (commit SHA / tag / branch) so +users can pin the exact remote weights they download and execute, defending +against a compromised/MITM'd HuggingFace repo silently serving a malicious +ONNX file. See security issue #594 (CWE-367 finding). +""" + +from __future__ import annotations + +from memsearch.embeddings.onnx import _split_revision as onnx_split +from memsearch.reranker import _split_revision as reranker_split + + +def test_split_revision_none(): + assert onnx_split("gpahal/bge-m3-onnx-int8") == ("gpahal/bge-m3-onnx-int8", None) + assert reranker_split("cross-encoder/ms-marco-MiniLM-L6-v2") == ( + "cross-encoder/ms-marco-MiniLM-L6-v2", + None, + ) + + +def test_split_revision_sha(): + repo, rev = onnx_split("gpahal/bge-m3-onnx-int8@abc123def456") + assert repo == "gpahal/bge-m3-onnx-int8" + assert rev == "abc123def456" + + +def test_split_revision_tag_with_org(): + # Only the LAST '@' separates revision; org names never contain '@' but be safe. + repo, rev = reranker_split("Alibaba-NLP/gte-reranker-modernbert-base@v1.0") + assert repo == "Alibaba-NLP/gte-reranker-modernbert-base" + assert rev == "v1.0" + + +def test_split_revision_empty_suffix_ignored(): + # A trailing '@' with no revision is treated as unpinned, not an empty pin. + assert onnx_split("repo/model@") == ("repo/model", None)