Skip to content

Commit 142ad29

Browse files
committed
feat: add grok search for general knowledge
1 parent 9024245 commit 142ad29

File tree

10 files changed

+368
-49
lines changed

10 files changed

+368
-49
lines changed

.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ ANTHROPIC_API_KEY=""
2424
GEMINI_API_KEY=""
2525
DEEPSEEK_API_KEY=""
2626
GROQ_API_KEY=""
27+
XAI_API_KEY=""
2728

2829
# Version Configuration
2930
STARKNET_FOUNDRY_VERSION="0.47.0"

python/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ dependencies = [
5252
"toml>=0.10.2",
5353
"tqdm>=4.66.0",
5454
"typer>=0.19.2",
55+
"xai_sdk>=1.3.1",
5556
]
5657

5758
[project.optional-dependencies]

python/src/cairo_coder/core/rag_pipeline.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from cairo_coder.dspy.document_retriever import DocumentRetrieverProgram
2828
from cairo_coder.dspy.generation_program import GenerationProgram, McpGenerationProgram
29+
from cairo_coder.dspy.grok_search import GrokSearchProgram
2930
from cairo_coder.dspy.query_processor import QueryProcessorProgram
3031
from cairo_coder.dspy.retrieval_judge import RetrievalJudge
3132

@@ -73,6 +74,8 @@ def __init__(self, config: RagPipelineConfig):
7374
self.generation_program = config.generation_program
7475
self.mcp_generation_program = config.mcp_generation_program
7576
self.retrieval_judge = RetrievalJudge()
77+
self.grok_search = GrokSearchProgram()
78+
self._grok_citations: list[str] = []
7679

7780
# Pipeline state
7881
self._current_processed_query: ProcessedQuery | None = None
@@ -96,6 +99,22 @@ async def _aprocess_query_and_retrieve_docs(
9699
processed_query=processed_query, sources=retrieval_sources
97100
)
98101

102+
# Optional Grok web/X augmentation: activate when STARKNET_BLOG is among sources.
103+
try:
104+
if DocumentSource.STARKNET_BLOG in retrieval_sources:
105+
grok_docs = await self.grok_search.aforward(processed_query)
106+
self._grok_citations = list(self.grok_search.last_citations)
107+
if grok_docs:
108+
documents.extend(grok_docs)
109+
grok_summary_doc = next((d for d in grok_docs if d.metadata.get("name") == "grok-answer"), None)
110+
else:
111+
self._grok_citations = []
112+
grok_summary_doc = None
113+
except Exception as e:
114+
logger.warning("Grok augmentation failed; continuing without it", error=str(e), exc_info=True)
115+
grok_summary_doc = None
116+
self._grok_citations = []
117+
99118
try:
100119
with dspy.context(
101120
lm=dspy.LM("gemini/gemini-flash-lite-latest", max_tokens=10000, temperature=0.5),
@@ -110,6 +129,16 @@ async def _aprocess_query_and_retrieve_docs(
110129
)
111130
# documents already contains all retrieved docs, no action needed
112131

132+
# Ensure Grok summary is present and first in order (for generation context)
133+
try:
134+
if grok_summary_doc is not None:
135+
if grok_summary_doc in documents:
136+
documents = [grok_summary_doc] + [d for d in documents if d is not grok_summary_doc]
137+
else:
138+
documents = [grok_summary_doc] + documents
139+
except Exception:
140+
pass
141+
113142
self._current_documents = documents
114143

115144
return processed_query, documents
@@ -290,14 +319,34 @@ def _format_sources(self, documents: list[Document]) -> list[dict[str, Any]]:
290319
List of dicts: [{"title": str, "url": str}, ...]
291320
"""
292321
sources: list[dict[str, str]] = []
322+
323+
# Helper to extract domain title
324+
def title_from_url(url: str) -> str:
325+
try:
326+
import urllib.parse as _up
327+
328+
host = _up.urlparse(url).netloc
329+
return host or url
330+
except Exception:
331+
return url
332+
333+
# 1) Vector store and other docs (skip Grok summary virtual doc)
293334
for doc in documents:
335+
if doc.metadata.get("name") == "grok-answer" or doc.metadata.get("is_virtual"):
336+
continue
294337
if doc.source_link is None:
295338
logger.warning(f"Document {doc.title} has no source link")
296-
to_append = ({"metadata": {"title": doc.title, "url": ""}})
339+
to_append = {"metadata": {"title": doc.title, "url": ""}}
297340
else:
298-
to_append = ({"metadata": {"title": doc.title, "url": doc.source_link}})
341+
to_append = {"metadata": {"title": doc.title, "url": doc.source_link}}
299342
sources.append(to_append)
300343

344+
# 2) Append Grok citations (raw URLs)
345+
for url in self._grok_citations:
346+
if not url:
347+
continue
348+
sources.append({"metadata": {"title": title_from_url(url), "url": url}})
349+
301350
return sources
302351

303352
def _prepare_context(self, documents: list[Document]) -> str:
@@ -325,11 +374,12 @@ def _prepare_context(self, documents: list[Document]) -> str:
325374
for i, doc in enumerate(documents, 1):
326375
source_name = doc.metadata.get("source_display", "Unknown Source")
327376
title = doc.metadata.get("title", f"Document {i}")
328-
url = doc.metadata.get("url", "#")
377+
url = doc.metadata.get("url")
329378

330379
context_parts.append(f"## {i}. {title}")
331380
context_parts.append(f"Source: {source_name}")
332-
context_parts.append(f"URL: {url}")
381+
if url:
382+
context_parts.append(f"URL: {url}")
333383
context_parts.append("")
334384
context_parts.append(doc.page_content)
335385
context_parts.append("")

python/src/cairo_coder/dspy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
create_generation_program,
1616
create_mcp_generation_program,
1717
)
18+
from .grok_search import GrokSearchProgram
1819
from .query_processor import QueryProcessorProgram, create_query_processor
1920
from .retrieval_judge import RetrievalJudge
2021
from .suggestion_program import SuggestionGeneration
@@ -29,4 +30,5 @@
2930
"create_mcp_generation_program",
3031
"RetrievalJudge",
3132
"SuggestionGeneration",
33+
"GrokSearchProgram",
3234
]
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
Grok Web/X Search module for Cairo Coder.
3+
4+
Uses the xAI SDK (agentic server‑side tools: web_search, x_search) to fetch
5+
fresh context and a synthesized answer. The output is provided as a single
6+
virtual Document for the generator and a list of citation URLs that the
7+
pipeline emits via SOURCES.
8+
9+
Behavior:
10+
- Activated upstream when DocumentSource.STARKNET_BLOG is in the requested sources.
11+
- Returns one primary virtual Document containing the Grok-composed answer
12+
plus an inline source list inside the content.
13+
- Does not create per-citation documents; citations are emitted via SOURCES.
14+
15+
Environment:
16+
- Set XAI_API_KEY with a valid xAI API key.
17+
"""
18+
19+
from __future__ import annotations
20+
21+
import hashlib
22+
import os
23+
from urllib.parse import urlparse
24+
25+
import dspy
26+
import structlog
27+
from xai_sdk import AsyncClient as XaiClient
28+
from xai_sdk.chat import Response, user
29+
from xai_sdk.tools import web_search, x_search
30+
31+
from cairo_coder.core.types import Document, DocumentSource, ProcessedQuery
32+
33+
logger = structlog.get_logger(__name__)
34+
35+
36+
DEFAULT_GROK_MODEL = "grok-4-fast"
37+
38+
39+
def _sha1(text: str) -> str:
40+
return hashlib.sha1(text.encode("utf-8")).hexdigest()
41+
42+
43+
def _mk_unique_id(prefix: str, content: str, idx: int = 0) -> str:
44+
return f"{prefix}-{_sha1(content)[:10]}-{idx}"
45+
46+
47+
48+
class GrokSearchProgram(dspy.Module):
49+
"""
50+
DSPy module that queries xAI's Grok Responses API with web and X search tools.
51+
52+
aforward returns a list[Document] suitable for inclusion in the RAG pipeline.
53+
"""
54+
55+
def __init__(
56+
self,
57+
) -> None:
58+
super().__init__()
59+
api_key = os.getenv("XAI_API_KEY")
60+
if not api_key:
61+
raise RuntimeError("XAI_API_KEY must be set for GrokSearchProgram")
62+
self.client = XaiClient(api_key=api_key)
63+
self.last_citations: list[str] = []
64+
65+
@staticmethod
66+
def _domain_from_url(url: str) -> str:
67+
try:
68+
return urlparse(url).netloc or url
69+
except Exception:
70+
return url
71+
72+
async def aforward(self, processed_query: ProcessedQuery) -> list[Document]:
73+
formatted_query = f"""Answer the following query: {processed_query.original}. \
74+
For more context, here are some semantic terms associated with the question: \
75+
{', '.join(processed_query.search_queries)}.
76+
"""
77+
chat = self.client.chat.create(
78+
model=DEFAULT_GROK_MODEL,
79+
tools=[web_search(), x_search()],
80+
)
81+
logger.info(f"Formatted query: {formatted_query}")
82+
chat.append(user(formatted_query))
83+
response: Response = await chat.sample()
84+
answer: str = response.content
85+
citations_urls: list[str] = response.citations
86+
self.last_citations = list(citations_urls or [])
87+
logger.info(f"Answer: {answer}")
88+
logger.info(f"Citations URLs: {citations_urls}")
89+
90+
# Assemble a compact source list at the end
91+
cite_lines = []
92+
for i, url in enumerate(citations_urls):
93+
title = self._domain_from_url(url) or f"Source {i}"
94+
cite_lines.append(f"[{i}] {title}: {url}")
95+
if cite_lines:
96+
answer_with_sources = f"{answer}\n\nSources:\n" + "\n".join(cite_lines)
97+
else:
98+
answer_with_sources = answer
99+
100+
documents: list[Document] = []
101+
unique_id = _mk_unique_id("grok-answer", answer)
102+
documents.append(
103+
Document(
104+
page_content=answer_with_sources,
105+
metadata={
106+
"name": "grok-answer",
107+
"title": "Grok Web/X Summary",
108+
"uniqueId": unique_id,
109+
"contentHash": _sha1(answer_with_sources),
110+
"chunkNumber": 0,
111+
# Treat as Starknet blog related to gate activation
112+
"source": DocumentSource.STARKNET_BLOG,
113+
"source_display": "Grok Web/X",
114+
"sourceLink": "",
115+
"url": "",
116+
"is_virtual": True,
117+
},
118+
)
119+
)
120+
121+
return documents

python/tests/conftest.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from cairo_coder.core.agent_factory import AgentFactory
1616
from cairo_coder.core.config import VectorStoreConfig
17+
from cairo_coder.core.rag_pipeline import RagPipeline, RagPipelineConfig
1718
from cairo_coder.core.types import (
1819
Document,
1920
DocumentSource,
@@ -29,7 +30,7 @@
2930
from cairo_coder.server.app import CairoCoderServer, get_agent_factory
3031

3132

32-
@pytest.fixture(scope="session")
33+
@pytest.fixture(scope="function")
3334
def mock_returned_documents(sample_documents):
3435
"""DSPy Examples derived from sample_documents for DRY content."""
3536
return [dspy.Example(content=doc.page_content, metadata=doc.metadata) for doc in sample_documents]
@@ -251,7 +252,7 @@ def sample_processed_query():
251252
)
252253

253254

254-
@pytest.fixture(scope="session")
255+
@pytest.fixture(scope="function")
255256
def sample_documents():
256257
"""
257258
Create a collection of sample documents for testing.
@@ -334,6 +335,9 @@ def clean_config_env_vars(monkeypatch):
334335
original_values[var] = os.environ.get(var)
335336
monkeypatch.delenv(var, raising=False)
336337

338+
# Ensure xAI SDK clients can initialize in tests (no real network calls occur).
339+
monkeypatch.setenv("XAI_API_KEY", "test")
340+
337341
yield
338342

339343
# Restore original values after test
@@ -450,3 +454,41 @@ async def async_filter_docs(query: str, documents: list[Document]) -> list[Docum
450454
judge.get_lm_usage = Mock(return_value={})
451455

452456
return judge
457+
458+
@pytest.fixture
459+
def pipeline_config(
460+
mock_vector_store_config,
461+
mock_query_processor,
462+
mock_document_retriever,
463+
mock_generation_program,
464+
mock_mcp_generation_program,
465+
):
466+
"""Create a pipeline configuration."""
467+
return RagPipelineConfig(
468+
name="test_pipeline",
469+
vector_store_config=mock_vector_store_config,
470+
query_processor=mock_query_processor,
471+
document_retriever=mock_document_retriever,
472+
generation_program=mock_generation_program,
473+
mcp_generation_program=mock_mcp_generation_program,
474+
sources=list(DocumentSource),
475+
max_source_count=10,
476+
similarity_threshold=0.4,
477+
)
478+
479+
480+
481+
@pytest.fixture(scope="function")
482+
def pipeline(pipeline_config):
483+
"""Create a RagPipeline instance."""
484+
with patch("cairo_coder.core.rag_pipeline.RetrievalJudge") as mock_judge_class:
485+
mock_judge = Mock()
486+
mock_judge.get_lm_usage.return_value = {}
487+
mock_judge.aforward = AsyncMock(side_effect=lambda query, documents: documents)
488+
mock_judge_class.return_value = mock_judge
489+
return RagPipeline(pipeline_config)
490+
491+
@pytest.fixture(scope="function")
492+
def rag_pipeline(pipeline_config):
493+
"""Alias fixture for pipeline to maintain backward compatibility."""
494+
return RagPipeline(pipeline_config)

python/tests/unit/test_document_retriever.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def retriever(
2929
similarity_threshold=0.4,
3030
)
3131

32-
@pytest.fixture(scope="session")
32+
@pytest.fixture(scope="function")
3333
def mock_dspy_examples(self, sample_documents: list[Document]) -> list[dspy.Example]:
3434
"""Create mock DSPy Example objects from sample documents."""
3535
examples = []

0 commit comments

Comments
 (0)