Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ mcp-rag/
└── .github/workflows/ CI (lint + test)
```

## Development

Built as part of a local AI development infrastructure, extracted and open-sourced as a standalone tool. Development uses a structured review process — each commit addresses specific findings from code review passes (SQL injection safety, transaction correctness, logging hygiene). 61 tests with CI running lint ([ruff](https://github.com/astral-sh/ruff)) and pytest on every push.

See [commit history](https://github.com/JMRussas/mcp-rag/commits/main) for the review-driven development trail.

## Limitations

- All embeddings loaded into memory at startup — practical up to ~50k chunks (~150 MB)
Expand Down
10 changes: 9 additions & 1 deletion config.example.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@
"default_top_k": 8,
"max_top_k": 20,
"embed_dimensions": 768,
"min_score": 0.0
"min_score": 0.0,
"hybrid": false,
"retrieval_depth": 20,
"rrf_k": 60
},
"reranker": {
"enabled": false,
"model": "cross-encoder/ms-marco-MiniLM-L6-v2",
"backend": "onnx"
},
"sources": {
"repos_dir": "data/repos",
Expand Down
4 changes: 4 additions & 0 deletions pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def cmd_chunk(config: dict):
log.warning(f"No path or local_dir for {repo['name']}, skipping.")
continue

# Allow diving into a subdirectory (e.g. "Modules/FortniteGame")
if repo.get("source_subdir"):
source_dir = source_dir / repo["source_subdir"]

try:
chunker = get_chunker(chunker_type)
except ValueError as e:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mcp>=1.0.0
httpx>=0.27.0
numpy>=1.26.0
sentence-transformers>=4.0.0
pytest>=8.0.0
72 changes: 72 additions & 0 deletions reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# mcp-rag - Cross-Encoder Reranker
#
# Reranks search candidates using a cross-encoder model for improved
# relevance. Lazy-loads the model on first use to avoid startup cost
# when reranking is disabled.
#
# Uses sentence-transformers CrossEncoder with ONNX backend by default
# (avoids pulling in full PyTorch for inference).
#
# Depends on: sentence-transformers[onnx] (optional, only when enabled)
# Used by: server.py (when config reranker.enabled = true)

import logging

log = logging.getLogger("mcp-rag-server")

_reranker = None


def get_reranker(config: dict):
"""Load the cross-encoder model on first use.

Subsequent calls return the cached instance. The model is downloaded
from HuggingFace on first load if not cached locally.

Args:
config: Full config dict (uses 'reranker' section).

Returns:
A CrossEncoder instance.
"""
global _reranker
if _reranker is None:
from sentence_transformers import CrossEncoder

reranker_config = config.get("reranker", {})
model_name = reranker_config.get("model", "cross-encoder/ms-marco-MiniLM-L6-v2")
backend = reranker_config.get("backend", "onnx")

log.info(f"Loading reranker: {model_name} (backend={backend})")
try:
_reranker = CrossEncoder(model_name, backend=backend)
except Exception as e:
# Fall back to default backend if ONNX fails
log.warning(f"ONNX backend failed ({e}), falling back to default backend")
_reranker = CrossEncoder(model_name)
log.info("Reranker loaded")

return _reranker


def rerank(query: str, candidates: list, config: dict) -> list:
"""Rerank candidate chunks by cross-encoder relevance.

Args:
query: The user's search query.
candidates: List of sqlite3.Row or dict objects (must have 'id' and 'text' keys).
config: Full config dict.

Returns:
Candidates reordered by cross-encoder score (best first).
"""
if len(candidates) <= 1:
return candidates

reranker = get_reranker(config)
pairs = [(query, c["text"]) for c in candidates]
scores = reranker.predict(pairs)

scored = list(zip(candidates, scores))
scored.sort(key=lambda x: x[1], reverse=True)
return [c for c, s in scored]
147 changes: 125 additions & 22 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
# Tool names, descriptions, and server identity are all config-driven,
# so one codebase serves any project.
#
# Supports hybrid search (BM25 + vector via RRF) and optional
# cross-encoder reranking, both config-gated.
#
# Depends on: config.json, data/*.db, numpy, httpx, mcp
# Used by: MCP clients (registered via `claude mcp add`)

Expand Down Expand Up @@ -46,6 +49,26 @@ def _sanitize_fts(query: str) -> str:
return result.strip()


def _rrf_fuse(ranked_lists: list[list[str]], k: int = 60) -> list[tuple[str, float]]:
"""Reciprocal Rank Fusion across multiple ranked ID lists.

Combines rankings from different retrieval methods (e.g. vector + BM25)
without requiring score normalization. Higher RRF score = better.

Args:
ranked_lists: List of ranked ID lists (best-first).
k: RRF constant (default 60, standard value from the paper).

Returns:
List of (chunk_id, rrf_score) tuples, sorted by score descending.
"""
scores: dict[str, float] = {}
for ranked in ranked_lists:
for rank, doc_id in enumerate(ranked):
scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank + 1)
return sorted(scores.items(), key=lambda x: x[1], reverse=True)


def format_results(rows: list, scores: dict[str, float] | None = None) -> str:
"""Format search results as readable text."""
if not rows:
Expand Down Expand Up @@ -113,6 +136,11 @@ def create_server(config_path: Path | None = None) -> FastMCP:
default_top_k = config["search"]["default_top_k"]
max_top_k = config["search"]["max_top_k"]
min_score = config["search"].get("min_score", 0.0)
hybrid_enabled = config["search"].get("hybrid", False)
retrieval_depth = config["search"].get("retrieval_depth", 20)
rrf_k = config["search"].get("rrf_k", 60)
reranker_config = config.get("reranker", {})
reranker_enabled = reranker_config.get("enabled", False)
db_path = SCRIPT_DIR / config["database"]["path"]

mcp_server = FastMCP(server_name)
Expand Down Expand Up @@ -242,52 +270,74 @@ async def search(
return "Error: Database not loaded. Run 'python pipeline.py rebuild' first."

top_k = min(max(1, top_k), max_top_k)
depth = retrieval_depth if hybrid_enabled else top_k

query_vec = await get_query_embedding(query)
if query_vec is None:
return "Error: Failed to generate query embedding. Is Ollama running?"

# Cosine similarity (embeddings are pre-normalized)
# --- Vector retrieval ---
similarities = embeddings @ query_vec

# Apply filters using pre-loaded metadata (vectorized where possible)
filter_mask = None
if source_filter or module_filter:
mask = np.ones(len(chunk_ids), dtype=bool)
filter_mask = np.ones(len(chunk_ids), dtype=bool)
if source_filter:
mask &= chunk_sources == source_filter
filter_mask &= chunk_sources == source_filter
if module_filter:
module_filter_lower = module_filter.lower()
mask &= np.array([module_filter_lower in m for m in chunk_modules])
similarities[~mask] = -1

# Get top-k indices
top_indices = np.argsort(similarities)[::-1][:top_k]

# Collect qualifying IDs in ranked order
ranked_ids = []
result_scores = {}
for idx in top_indices:
if similarities[idx] < min_score:
continue
chunk_id = chunk_ids[idx]
ranked_ids.append(chunk_id)
result_scores[chunk_id] = float(similarities[idx])
filter_mask &= np.array([module_filter_lower in m for m in chunk_modules])
similarities[~filter_mask] = -1

top_indices = np.argsort(similarities)[::-1][:depth]
vector_ranked = [
chunk_ids[idx] for idx in top_indices
if similarities[idx] >= min_score
]

# --- BM25 retrieval (hybrid mode) ---
if hybrid_enabled:
bm25_ranked = _bm25_retrieve(
conn, query, depth * 2, source_filter, module_filter,
chunk_ids, chunk_sources, chunk_modules, filter_mask,
)
# Fuse with RRF
fused = _rrf_fuse([vector_ranked, bm25_ranked], k=rrf_k)
ranked_ids = [doc_id for doc_id, _ in fused[:depth]]
result_scores = {doc_id: score for doc_id, score in fused[:depth]}
else:
ranked_ids = vector_ranked[:top_k]
result_scores = {chunk_ids[idx]: float(similarities[idx]) for idx in top_indices
if chunk_ids[idx] in ranked_ids}

if not ranked_ids:
return format_results([])

# Single batch query instead of N individual queries
# Fetch full rows
placeholders = ",".join("?" for _ in ranked_ids)
rows = conn.execute(
f"SELECT * FROM chunks WHERE id IN ({placeholders})",
ranked_ids,
).fetchall()

# Re-order to match similarity ranking
row_map = {row["id"]: row for row in rows}
results = [row_map[cid] for cid in ranked_ids if cid in row_map]

return format_results(results, result_scores)
# --- Reranking (optional) ---
if reranker_enabled and len(results) > 1:
try:
from reranker import rerank
results = rerank(query, results, config)
# Reassign scores based on reranked order
result_scores = {r["id"]: 1.0 / (i + 1) for i, r in enumerate(results)}
except Exception as e:
log.warning(f"Reranker failed, using original ranking: {e}")

# Apply final top_k
results = results[:top_k]
final_scores = {r["id"]: result_scores.get(r["id"], 0.0) for r in results}

return format_results(results, final_scores)

@mcp_server.tool(name=lookup_name, description=lookup_desc)
async def lookup(
Expand Down Expand Up @@ -343,6 +393,59 @@ async def lookup(

return format_results(rows)

def _bm25_retrieve(
db_conn: sqlite3.Connection,
query: str,
limit: int,
source_filter: str,
module_filter: str,
all_chunk_ids: list[str],
all_chunk_sources: np.ndarray,
all_chunk_modules: list[str],
precomputed_mask: np.ndarray | None,
) -> list[str]:
"""Retrieve chunk IDs ranked by BM25 relevance via FTS5.

Filters are applied post-query in Python (consistent with vector path).
Returns a ranked list of chunk IDs (best-first).
"""
safe_query = _sanitize_fts(query)
if not safe_query:
return []

try:
# FTS5 bm25() returns negative scores (lower = better match)
bm25_rows = db_conn.execute(
"""SELECT chunks.id, bm25(chunks_fts) as bm25_score
FROM chunks_fts
JOIN chunks ON chunks.rowid = chunks_fts.rowid
WHERE chunks_fts MATCH ?
ORDER BY bm25(chunks_fts)
LIMIT ?""",
(f'"{safe_query}"', limit),
).fetchall()
except sqlite3.OperationalError as e:
log.warning(f"BM25 query failed: {e}")
return []

if not bm25_rows:
return []

# Build ID-to-index lookup for filter checking
if source_filter or module_filter:
id_to_idx = {cid: i for i, cid in enumerate(all_chunk_ids)}
filtered = []
for row in bm25_rows:
idx = id_to_idx.get(row[0])
if idx is None:
continue
if precomputed_mask is not None and not precomputed_mask[idx]:
continue
filtered.append(row[0])
return filtered
else:
return [row[0] for row in bm25_rows]

return mcp_server


Expand Down
Loading
Loading