Skip to content

Commit 1206f38

Browse files
committed
test(gepa): add ReAct module detection tests for nested structures
- Add 3 comprehensive detection tests: single ReAct, mixed workflow (2 ReAct + ChainOfThought), orchestrator with 2 workers - Tests validate full path preservation (bug fix validation) - Uses monkey patching to capture base_program from gepa.optimize - Helper functions for DRY: setup spy, create optimizer, assert detection - Validates all ReAct components: react, extract, tools, tool metadata
1 parent b6cc67b commit 1206f38

File tree

1 file changed

+260
-0
lines changed

1 file changed

+260
-0
lines changed

tests/teleprompt/test_gepa_react_optimization.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import dspy
1616
from dspy import Example
17+
from dspy.utils.dummies import DummyLM
1718

1819
# Load fixture
1920
with open("tests/teleprompt/gepa_dummy_lm_react_opt.json") as f:
@@ -172,3 +173,262 @@ def metric(example, prediction, trace=None, pred_name=None, pred_trace=None):
172173
"toolB argument description should be optimized"
173174
assert optimized.tools["toolC"].arg_desc != baseline_toolC_arg_desc, \
174175
"toolC argument description should be optimized"
176+
177+
178+
def setup_spy_for_base_program(monkeypatch):
179+
"""Setup spy to capture base_program from gepa.optimize."""
180+
captured_base_program = {}
181+
182+
from gepa import optimize as original_optimize
183+
184+
def spy_optimize(seed_candidate, **kwargs):
185+
captured_base_program.update(seed_candidate)
186+
return original_optimize(seed_candidate=seed_candidate, **kwargs)
187+
188+
import gepa
189+
monkeypatch.setattr(gepa, "optimize", spy_optimize)
190+
191+
return captured_base_program
192+
193+
194+
def create_gepa_optimizer_for_detection():
195+
"""Create GEPA optimizer with standard test configuration."""
196+
task_lm = DummyLM([{"answer": "test"}] * 10)
197+
reflection_lm = DummyLM([{"improved_instruction": "optimized"}] * 10)
198+
dspy.settings.configure(lm=task_lm)
199+
200+
def simple_metric(example, pred, trace=None, pred_name=None, pred_trace=None):
201+
return dspy.Prediction(score=0.5, feedback="ok")
202+
203+
optimizer = dspy.GEPA(
204+
metric=simple_metric,
205+
reflection_lm=reflection_lm,
206+
max_metric_calls=2,
207+
optimize_react_components=True,
208+
)
209+
210+
trainset = [Example(question="test", answer="test").with_inputs("question")]
211+
212+
return optimizer, trainset
213+
214+
215+
def assert_react_module_detected(captured_base_program, module_path, expected_tools):
216+
"""Assert that a ReAct module was detected with all components."""
217+
from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX
218+
219+
module_key = REACT_MODULE_PREFIX if module_path == "" else f"{REACT_MODULE_PREFIX}:{module_path}"
220+
221+
assert module_key in captured_base_program, f"Expected '{module_key}' to be detected"
222+
223+
config = json.loads(captured_base_program[module_key])
224+
225+
assert "react" in config, f"{module_key} should have react instruction"
226+
assert "extract" in config, f"{module_key} should have extract instruction"
227+
assert "tools" in config, f"{module_key} should have tools"
228+
229+
for tool_name, expected_desc in expected_tools.items():
230+
assert tool_name in config["tools"], f"{module_key} should have '{tool_name}' tool"
231+
tool = config["tools"][tool_name]
232+
assert "desc" in tool, f"{tool_name} should have desc"
233+
assert tool["desc"] == expected_desc, f"{tool_name} desc should match"
234+
assert "arg_desc" in tool, f"{tool_name} should have arg_desc"
235+
236+
return config
237+
238+
239+
def assert_regular_module_detected(captured_base_program, module_key):
240+
"""Assert that a non-ReAct module was detected."""
241+
assert module_key in captured_base_program, f"Expected '{module_key}' to be detected"
242+
instruction = captured_base_program[module_key]
243+
assert isinstance(instruction, str), f"{module_key} should be string instruction, not JSON"
244+
return instruction
245+
246+
247+
def test_single_react_module_detection(monkeypatch):
248+
"""Test GEPA detects a single top-level ReAct module."""
249+
from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX
250+
251+
captured_base_program = setup_spy_for_base_program(monkeypatch)
252+
253+
def search_tool(query: str) -> str:
254+
"""Search for information."""
255+
return f"Results for: {query}"
256+
257+
def calculate_tool(expr: str) -> str:
258+
"""Calculate math expression."""
259+
return "42"
260+
261+
program = dspy.ReAct(
262+
"question -> answer",
263+
tools=[
264+
dspy.Tool(search_tool, name="search", desc="Search the web"),
265+
dspy.Tool(calculate_tool, name="calc", desc="Calculate math"),
266+
],
267+
max_iters=3
268+
)
269+
270+
optimizer, trainset = create_gepa_optimizer_for_detection()
271+
272+
try:
273+
optimizer.compile(program, trainset=trainset, valset=trainset)
274+
except:
275+
pass
276+
277+
module_key = REACT_MODULE_PREFIX
278+
assert module_key in captured_base_program, f"Expected '{module_key}' to be detected"
279+
280+
assert_react_module_detected(
281+
captured_base_program,
282+
"",
283+
{"search": "Search the web", "calc": "Calculate math"}
284+
)
285+
286+
287+
def test_multi_react_workflow_detection(monkeypatch):
288+
"""Test GEPA detects multiple ReAct modules (tests bug fix for path truncation)."""
289+
from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX
290+
291+
captured_base_program = setup_spy_for_base_program(monkeypatch)
292+
293+
class ResearchWorkflow(dspy.Module):
294+
def __init__(self):
295+
super().__init__()
296+
297+
def search_papers(query: str) -> str:
298+
return f"Papers: {query}"
299+
300+
def analyze_data(data: str) -> str:
301+
return f"Analysis: {data}"
302+
303+
self.coordinator = dspy.ReAct(
304+
"task -> plan",
305+
tools=[dspy.Tool(search_papers, name="search", desc="Search tool")],
306+
max_iters=2
307+
)
308+
309+
self.researcher = dspy.ReAct(
310+
"plan -> findings",
311+
tools=[dspy.Tool(analyze_data, name="analyze", desc="Analysis tool")],
312+
max_iters=2
313+
)
314+
315+
self.summarizer = dspy.ChainOfThought("findings -> summary")
316+
317+
def forward(self, question):
318+
plan = self.coordinator(task=question)
319+
findings = self.researcher(plan=plan.plan)
320+
summary = self.summarizer(findings=findings.findings)
321+
return dspy.Prediction(answer=summary.summary)
322+
323+
class MixedWorkflowSystem(dspy.Module):
324+
def __init__(self):
325+
super().__init__()
326+
self.workflow = ResearchWorkflow()
327+
328+
def forward(self, question):
329+
return self.workflow(question=question)
330+
331+
program = MixedWorkflowSystem()
332+
333+
optimizer, trainset = create_gepa_optimizer_for_detection()
334+
335+
try:
336+
optimizer.compile(program, trainset=trainset, valset=trainset)
337+
except:
338+
pass
339+
340+
assert f"{REACT_MODULE_PREFIX}:workflow.coordinator" in captured_base_program
341+
assert f"{REACT_MODULE_PREFIX}:workflow.researcher" in captured_base_program
342+
343+
react_modules = [k for k in captured_base_program.keys() if k.startswith(REACT_MODULE_PREFIX)]
344+
assert len(react_modules) == 2, f"Expected 2 ReAct modules, got {len(react_modules)}"
345+
346+
assert_react_module_detected(captured_base_program, "workflow.coordinator", {"search": "Search tool"})
347+
assert_react_module_detected(captured_base_program, "workflow.researcher", {"analyze": "Analysis tool"})
348+
assert_regular_module_detected(captured_base_program, "workflow.summarizer.predict")
349+
350+
351+
def test_nested_react_orchestrator_worker_detection(monkeypatch):
352+
"""Test GEPA detects orchestrator with 2 worker ReAct modules as tools."""
353+
from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX
354+
355+
captured_base_program = setup_spy_for_base_program(monkeypatch)
356+
357+
class OrchestratorWorkerSystem(dspy.Module):
358+
def __init__(self):
359+
super().__init__()
360+
361+
def search_web(query: str) -> str:
362+
return f"Search results: {query}"
363+
364+
def analyze_data(data: str) -> str:
365+
return f"Analysis: {data}"
366+
367+
def research_topic(topic: str) -> str:
368+
return f"Research: {topic}"
369+
370+
self.analyst = dspy.ReAct(
371+
"data -> analysis",
372+
tools=[dspy.Tool(analyze_data, name="analyze", desc="Analyze data")],
373+
max_iters=2
374+
)
375+
376+
self.researcher = dspy.ReAct(
377+
"topic -> findings",
378+
tools=[dspy.Tool(research_topic, name="research", desc="Research topic")],
379+
max_iters=2
380+
)
381+
382+
def use_analyst(data: str) -> str:
383+
result = self.analyst(data=data)
384+
return str(result.analysis) if hasattr(result, 'analysis') else str(result)
385+
386+
def use_researcher(topic: str) -> str:
387+
result = self.researcher(topic=topic)
388+
return str(result.findings) if hasattr(result, 'findings') else str(result)
389+
390+
self.orchestrator = dspy.ReAct(
391+
"question -> answer",
392+
tools=[
393+
dspy.Tool(search_web, name="search", desc="Search tool"),
394+
dspy.Tool(use_analyst, name="analyst", desc="Use analyst"),
395+
dspy.Tool(use_researcher, name="researcher", desc="Use researcher"),
396+
],
397+
max_iters=3
398+
)
399+
400+
def forward(self, question):
401+
result = self.orchestrator(question=question)
402+
return dspy.Prediction(answer=result.answer)
403+
404+
class MultiAgentSystem(dspy.Module):
405+
def __init__(self):
406+
super().__init__()
407+
self.multi_agent = OrchestratorWorkerSystem()
408+
409+
def forward(self, question):
410+
return self.multi_agent(question=question)
411+
412+
program = MultiAgentSystem()
413+
414+
optimizer, trainset = create_gepa_optimizer_for_detection()
415+
416+
try:
417+
optimizer.compile(program, trainset=trainset, valset=trainset)
418+
except:
419+
pass
420+
421+
assert f"{REACT_MODULE_PREFIX}:multi_agent.orchestrator" in captured_base_program
422+
assert f"{REACT_MODULE_PREFIX}:multi_agent.analyst" in captured_base_program
423+
assert f"{REACT_MODULE_PREFIX}:multi_agent.researcher" in captured_base_program
424+
425+
react_modules = [k for k in captured_base_program.keys() if k.startswith(REACT_MODULE_PREFIX)]
426+
assert len(react_modules) == 3, f"Expected 3 ReAct modules, got {len(react_modules)}"
427+
428+
assert_react_module_detected(
429+
captured_base_program,
430+
"multi_agent.orchestrator",
431+
{"search": "Search tool", "analyst": "Use analyst", "researcher": "Use researcher"}
432+
)
433+
assert_react_module_detected(captured_base_program, "multi_agent.analyst", {"analyze": "Analyze data"})
434+
assert_react_module_detected(captured_base_program, "multi_agent.researcher", {"research": "Research topic"})

0 commit comments

Comments
 (0)