Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions spoon_ai/rag/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class RagConfig:
# Use DeepSeek as LLM for QA generation, and other models for embeddings.
embeddings_provider: Optional[str] = None
embeddings_model: str = "text-embedding-3-small" # Generic model name for all embedding providers
# Reranking
rerank_provider: Optional[str] = None
rerank_model: Optional[str] = None
# Storage paths
rag_dir: str = ".rag_store"

Expand Down Expand Up @@ -62,6 +65,11 @@ def get_default_config() -> RagConfig:
embeddings_model = os.getenv("RAG_EMBEDDINGS_MODEL", "").strip()


rerank_provider = os.getenv("RAG_RERANK_PROVIDER")
if rerank_provider:
rerank_provider = rerank_provider.strip().lower()
rerank_model = os.getenv("RAG_RERANK_MODEL")

return RagConfig(
backend=backend,
collection=collection,
Expand All @@ -71,5 +79,7 @@ def get_default_config() -> RagConfig:
min_similarity=min_similarity,
embeddings_provider=embeddings_provider,
embeddings_model=embeddings_model,
rerank_provider=rerank_provider,
rerank_model=rerank_model,
rag_dir=rag_dir,
)
149 changes: 149 additions & 0 deletions spoon_ai/rag/reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List, Optional, Iterable, Dict
import os
import requests
import json
import logging

logger = logging.getLogger(__name__)

class RerankClient(ABC):
@abstractmethod
def rerank(self, query: str, docs: List[str]) -> List[float]:
"""
Rerank a list of documents based on a query.
Returns a list of scores (higher is better) corresponding to each doc.
"""
pass

class NoOpRerankClient(RerankClient):
"""Pass-through reranker (does nothing, returns 0.0 scores or keeps original order implicitly)."""
def rerank(self, query: str, docs: List[str]) -> List[float]:
return [0.0] * len(docs)

class LLMRerankClient(RerankClient):
"""
Uses an LLM (via OpenAI-compatible API) to score documents.
This is dependency-free (uses requests) and flexible.
"""
def __init__(
self,
api_key: str,
base_url: str,
model: str,
timeout: int = 10
):
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.model = model
self.timeout = timeout

def _score_single(self, query: str, doc: str) -> float:
# Prompt engineering for scoring
prompt = (
f"Query: {query}\n"
f"Document: {doc[:4000]}...\n\n" # Truncate to avoid context unexpected limits
"Rate the relevance of the document to the query on a continuous scale from 0.0 (irrelevant) to 10.0 (exact match).\n"
"Output ONLY the number, nothing else."
)

try:
url = f"{self.base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
if "openrouter" in self.base_url:
# Optional headers for OpenRouter
headers["HTTP-Referer"] = "https://spoon.ai"
headers["X-Title"] = "Spoon AI Reranker"



payload = {
"model": self.model,
"messages": [
{"role": "system", "content": "You are a helpful relevance ranking assistant. Output only a float score."},
{"role": "user", "content": prompt}
],
"temperature": 0.0,
"max_tokens": 10
}

resp = requests.post(url, json=payload, headers=headers, timeout=self.timeout)
resp.raise_for_status()
data = resp.json()
content = data["choices"][0]["message"]["content"].strip()

# Simple parsing
try:
score = float(content)
return score
except ValueError:
# Fallback heuristic parsing if model explains itself
import re
match = re.search(r"(\d+(\.\d+)?)", content)
if match:
return float(match.group(1))
return 0.0

except Exception as e:
logger.warning(f"Rerank failed for doc: {e}")
return 0.0

def rerank(self, query: str, docs: List[str]) -> List[float]:
# Optimization: In a real prod scenario, we might batch this or use a specific Rerank API endpoint.
# For now, we do sequential or simple parallel calls.
# Given "chunk + BM25" context, we assume we rerank only Top N (e.g. 5-10).
# We process sequentially here for safety and simplicity, or we could use ThreadPoolExecutor.
# Let's use simple sequential for now to avoid complexity issues with threads/requests.

scores = []
for doc in docs:
scores.append(self._score_single(query, doc))
return scores

def get_rerank_client(
provider: Optional[str],
*,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: Optional[str] = None
) -> RerankClient:
"""
Factory to get a rerank client.
"""
url = base_url
if not provider or provider == "none":
return NoOpRerankClient()

if provider in ("openai", "openrouter", "openai_compatible", "deepseek"):
# Resolve config with similar logic to EmbeddingClient
# Ideally we'd reuse a centralized config manager, but to keep RAG standalone-ish:

# Defaults
if provider == "openai":
key = api_key or os.getenv("OPENAI_API_KEY") or os.getenv("RERANK_API_KEY")
url = url or "https://api.openai.com/v1"
elif provider == "openrouter":
key = api_key or os.getenv("OPENROUTER_API_KEY") or os.getenv("RERANK_API_KEY") or os.getenv("OPENAI_API_KEY")
url = url or "https://openrouter.ai/api/v1"
else:
# Generice fallback
key = api_key or os.getenv("RERANK_API_KEY") or os.getenv("OPENAI_API_KEY")

final_model = model or os.getenv("RERANK_MODEL") or "gpt-4o-mini" # default to a cheap fast model

if not key:
logger.warning("No API key found for reranker, disabling.")
return NoOpRerankClient()

if not url:
logger.warning("No Base URL found for reranker, disabling.")
return NoOpRerankClient()

return LLMRerankClient(api_key=key, base_url=url, model=final_model)

return NoOpRerankClient()
40 changes: 36 additions & 4 deletions spoon_ai/rag/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .config import RagConfig
from .embeddings import EmbeddingClient
from .vectorstores import VectorStore
from .reranker import get_rerank_client


@dataclass
Expand All @@ -32,6 +33,10 @@ def __init__(
self.bm25 = None
self.bm25_data = None
self._load_bm25()
self.reranker = get_rerank_client(
config.rerank_provider,
model=config.rerank_model
)

def _load_bm25(self):
try:
Expand All @@ -41,9 +46,14 @@ def _load_bm25(self):
with open(bm2_file, "rb") as f:
self.bm25_data = pickle.load(f)

# Simple whitespace tokenization
tokenized_corpus = [doc.lower().split() for doc in self.bm25_data["texts"]]
# Better tokenization for code
import re
def tokenizer(text):
return re.findall(r"\w+", text.lower())

tokenized_corpus = [tokenizer(doc) for doc in self.bm25_data["texts"]]
self.bm25 = BM25Okapi(tokenized_corpus)
self._tokenizer = tokenizer
except ImportError:
pass # BM25 optional
except Exception as e:
Expand Down Expand Up @@ -119,8 +129,13 @@ def retrieve(
bm25_chunks: List[RetrievedChunk] = []
if self.bm25 and self.bm25_data:
try:
tokenized_query = query.lower().split()
if hasattr(self, '_tokenizer'):
tokenized_query = self._tokenizer(query)
else:
tokenized_query = query.lower().split()

scores = self.bm25.get_scores(tokenized_query)
# print(f"DEBUG: BM25 Max Score: {max(scores) if len(scores)>0 else 0}", flush=True)
top_n_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:max(k * 3, 20)]

for idx in top_n_indices:
Expand Down Expand Up @@ -151,7 +166,24 @@ def retrieve(
seen.add(key)
deduped.append(c)

return deduped[:k]
# 4. Reranking (if enabled)
# We rerank the top (k * 2) from the fused/deduped list to optimize precision
candidates = deduped[:max(k * 2, 10)]

# If we have a reranker, use it
# But we only rerank if we have enough candidates or if enforced?
# Actually usually we rerank everything we selected to be a candidate.
if hasattr(self, 'reranker') and self.config.rerank_provider:
candidate_texts = [c.text for c in candidates]
if candidate_texts:
new_scores = self.reranker.rerank(query, candidate_texts)
for c, s in zip(candidates, new_scores):
c.score = s

# Sort by new score
candidates.sort(key=lambda x: x.score, reverse=True)

return candidates[:k]

def build_context(self, chunks: List[RetrievedChunk]) -> str:
lines: List[str] = []
Expand Down