|
14 | 14 |
|
15 | 15 | import dspy |
16 | 16 | from dspy import Example |
| 17 | +from dspy.utils.dummies import DummyLM |
17 | 18 |
|
18 | 19 | # Load fixture |
19 | 20 | with open("tests/teleprompt/gepa_dummy_lm_react_opt.json") as f: |
@@ -172,3 +173,262 @@ def metric(example, prediction, trace=None, pred_name=None, pred_trace=None): |
172 | 173 | "toolB argument description should be optimized" |
173 | 174 | assert optimized.tools["toolC"].arg_desc != baseline_toolC_arg_desc, \ |
174 | 175 | "toolC argument description should be optimized" |
| 176 | + |
| 177 | + |
| 178 | +def setup_spy_for_base_program(monkeypatch): |
| 179 | + """Setup spy to capture base_program from gepa.optimize.""" |
| 180 | + captured_base_program = {} |
| 181 | + |
| 182 | + from gepa import optimize as original_optimize |
| 183 | + |
| 184 | + def spy_optimize(seed_candidate, **kwargs): |
| 185 | + captured_base_program.update(seed_candidate) |
| 186 | + return original_optimize(seed_candidate=seed_candidate, **kwargs) |
| 187 | + |
| 188 | + import gepa |
| 189 | + monkeypatch.setattr(gepa, "optimize", spy_optimize) |
| 190 | + |
| 191 | + return captured_base_program |
| 192 | + |
| 193 | + |
| 194 | +def create_gepa_optimizer_for_detection(): |
| 195 | + """Create GEPA optimizer with standard test configuration.""" |
| 196 | + task_lm = DummyLM([{"answer": "test"}] * 10) |
| 197 | + reflection_lm = DummyLM([{"improved_instruction": "optimized"}] * 10) |
| 198 | + dspy.settings.configure(lm=task_lm) |
| 199 | + |
| 200 | + def simple_metric(example, pred, trace=None, pred_name=None, pred_trace=None): |
| 201 | + return dspy.Prediction(score=0.5, feedback="ok") |
| 202 | + |
| 203 | + optimizer = dspy.GEPA( |
| 204 | + metric=simple_metric, |
| 205 | + reflection_lm=reflection_lm, |
| 206 | + max_metric_calls=2, |
| 207 | + optimize_react_components=True, |
| 208 | + ) |
| 209 | + |
| 210 | + trainset = [Example(question="test", answer="test").with_inputs("question")] |
| 211 | + |
| 212 | + return optimizer, trainset |
| 213 | + |
| 214 | + |
| 215 | +def assert_react_module_detected(captured_base_program, module_path, expected_tools): |
| 216 | + """Assert that a ReAct module was detected with all components.""" |
| 217 | + from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX |
| 218 | + |
| 219 | + module_key = REACT_MODULE_PREFIX if module_path == "" else f"{REACT_MODULE_PREFIX}:{module_path}" |
| 220 | + |
| 221 | + assert module_key in captured_base_program, f"Expected '{module_key}' to be detected" |
| 222 | + |
| 223 | + config = json.loads(captured_base_program[module_key]) |
| 224 | + |
| 225 | + assert "react" in config, f"{module_key} should have react instruction" |
| 226 | + assert "extract" in config, f"{module_key} should have extract instruction" |
| 227 | + assert "tools" in config, f"{module_key} should have tools" |
| 228 | + |
| 229 | + for tool_name, expected_desc in expected_tools.items(): |
| 230 | + assert tool_name in config["tools"], f"{module_key} should have '{tool_name}' tool" |
| 231 | + tool = config["tools"][tool_name] |
| 232 | + assert "desc" in tool, f"{tool_name} should have desc" |
| 233 | + assert tool["desc"] == expected_desc, f"{tool_name} desc should match" |
| 234 | + assert "arg_desc" in tool, f"{tool_name} should have arg_desc" |
| 235 | + |
| 236 | + return config |
| 237 | + |
| 238 | + |
| 239 | +def assert_regular_module_detected(captured_base_program, module_key): |
| 240 | + """Assert that a non-ReAct module was detected.""" |
| 241 | + assert module_key in captured_base_program, f"Expected '{module_key}' to be detected" |
| 242 | + instruction = captured_base_program[module_key] |
| 243 | + assert isinstance(instruction, str), f"{module_key} should be string instruction, not JSON" |
| 244 | + return instruction |
| 245 | + |
| 246 | + |
| 247 | +def test_single_react_module_detection(monkeypatch): |
| 248 | + """Test GEPA detects a single top-level ReAct module.""" |
| 249 | + from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX |
| 250 | + |
| 251 | + captured_base_program = setup_spy_for_base_program(monkeypatch) |
| 252 | + |
| 253 | + def search_tool(query: str) -> str: |
| 254 | + """Search for information.""" |
| 255 | + return f"Results for: {query}" |
| 256 | + |
| 257 | + def calculate_tool(expr: str) -> str: |
| 258 | + """Calculate math expression.""" |
| 259 | + return "42" |
| 260 | + |
| 261 | + program = dspy.ReAct( |
| 262 | + "question -> answer", |
| 263 | + tools=[ |
| 264 | + dspy.Tool(search_tool, name="search", desc="Search the web"), |
| 265 | + dspy.Tool(calculate_tool, name="calc", desc="Calculate math"), |
| 266 | + ], |
| 267 | + max_iters=3 |
| 268 | + ) |
| 269 | + |
| 270 | + optimizer, trainset = create_gepa_optimizer_for_detection() |
| 271 | + |
| 272 | + try: |
| 273 | + optimizer.compile(program, trainset=trainset, valset=trainset) |
| 274 | + except: |
| 275 | + pass |
| 276 | + |
| 277 | + module_key = REACT_MODULE_PREFIX |
| 278 | + assert module_key in captured_base_program, f"Expected '{module_key}' to be detected" |
| 279 | + |
| 280 | + assert_react_module_detected( |
| 281 | + captured_base_program, |
| 282 | + "", |
| 283 | + {"search": "Search the web", "calc": "Calculate math"} |
| 284 | + ) |
| 285 | + |
| 286 | + |
| 287 | +def test_multi_react_workflow_detection(monkeypatch): |
| 288 | + """Test GEPA detects multiple ReAct modules (tests bug fix for path truncation).""" |
| 289 | + from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX |
| 290 | + |
| 291 | + captured_base_program = setup_spy_for_base_program(monkeypatch) |
| 292 | + |
| 293 | + class ResearchWorkflow(dspy.Module): |
| 294 | + def __init__(self): |
| 295 | + super().__init__() |
| 296 | + |
| 297 | + def search_papers(query: str) -> str: |
| 298 | + return f"Papers: {query}" |
| 299 | + |
| 300 | + def analyze_data(data: str) -> str: |
| 301 | + return f"Analysis: {data}" |
| 302 | + |
| 303 | + self.coordinator = dspy.ReAct( |
| 304 | + "task -> plan", |
| 305 | + tools=[dspy.Tool(search_papers, name="search", desc="Search tool")], |
| 306 | + max_iters=2 |
| 307 | + ) |
| 308 | + |
| 309 | + self.researcher = dspy.ReAct( |
| 310 | + "plan -> findings", |
| 311 | + tools=[dspy.Tool(analyze_data, name="analyze", desc="Analysis tool")], |
| 312 | + max_iters=2 |
| 313 | + ) |
| 314 | + |
| 315 | + self.summarizer = dspy.ChainOfThought("findings -> summary") |
| 316 | + |
| 317 | + def forward(self, question): |
| 318 | + plan = self.coordinator(task=question) |
| 319 | + findings = self.researcher(plan=plan.plan) |
| 320 | + summary = self.summarizer(findings=findings.findings) |
| 321 | + return dspy.Prediction(answer=summary.summary) |
| 322 | + |
| 323 | + class MixedWorkflowSystem(dspy.Module): |
| 324 | + def __init__(self): |
| 325 | + super().__init__() |
| 326 | + self.workflow = ResearchWorkflow() |
| 327 | + |
| 328 | + def forward(self, question): |
| 329 | + return self.workflow(question=question) |
| 330 | + |
| 331 | + program = MixedWorkflowSystem() |
| 332 | + |
| 333 | + optimizer, trainset = create_gepa_optimizer_for_detection() |
| 334 | + |
| 335 | + try: |
| 336 | + optimizer.compile(program, trainset=trainset, valset=trainset) |
| 337 | + except: |
| 338 | + pass |
| 339 | + |
| 340 | + assert f"{REACT_MODULE_PREFIX}:workflow.coordinator" in captured_base_program |
| 341 | + assert f"{REACT_MODULE_PREFIX}:workflow.researcher" in captured_base_program |
| 342 | + |
| 343 | + react_modules = [k for k in captured_base_program.keys() if k.startswith(REACT_MODULE_PREFIX)] |
| 344 | + assert len(react_modules) == 2, f"Expected 2 ReAct modules, got {len(react_modules)}" |
| 345 | + |
| 346 | + assert_react_module_detected(captured_base_program, "workflow.coordinator", {"search": "Search tool"}) |
| 347 | + assert_react_module_detected(captured_base_program, "workflow.researcher", {"analyze": "Analysis tool"}) |
| 348 | + assert_regular_module_detected(captured_base_program, "workflow.summarizer.predict") |
| 349 | + |
| 350 | + |
| 351 | +def test_nested_react_orchestrator_worker_detection(monkeypatch): |
| 352 | + """Test GEPA detects orchestrator with 2 worker ReAct modules as tools.""" |
| 353 | + from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX |
| 354 | + |
| 355 | + captured_base_program = setup_spy_for_base_program(monkeypatch) |
| 356 | + |
| 357 | + class OrchestratorWorkerSystem(dspy.Module): |
| 358 | + def __init__(self): |
| 359 | + super().__init__() |
| 360 | + |
| 361 | + def search_web(query: str) -> str: |
| 362 | + return f"Search results: {query}" |
| 363 | + |
| 364 | + def analyze_data(data: str) -> str: |
| 365 | + return f"Analysis: {data}" |
| 366 | + |
| 367 | + def research_topic(topic: str) -> str: |
| 368 | + return f"Research: {topic}" |
| 369 | + |
| 370 | + self.analyst = dspy.ReAct( |
| 371 | + "data -> analysis", |
| 372 | + tools=[dspy.Tool(analyze_data, name="analyze", desc="Analyze data")], |
| 373 | + max_iters=2 |
| 374 | + ) |
| 375 | + |
| 376 | + self.researcher = dspy.ReAct( |
| 377 | + "topic -> findings", |
| 378 | + tools=[dspy.Tool(research_topic, name="research", desc="Research topic")], |
| 379 | + max_iters=2 |
| 380 | + ) |
| 381 | + |
| 382 | + def use_analyst(data: str) -> str: |
| 383 | + result = self.analyst(data=data) |
| 384 | + return str(result.analysis) if hasattr(result, 'analysis') else str(result) |
| 385 | + |
| 386 | + def use_researcher(topic: str) -> str: |
| 387 | + result = self.researcher(topic=topic) |
| 388 | + return str(result.findings) if hasattr(result, 'findings') else str(result) |
| 389 | + |
| 390 | + self.orchestrator = dspy.ReAct( |
| 391 | + "question -> answer", |
| 392 | + tools=[ |
| 393 | + dspy.Tool(search_web, name="search", desc="Search tool"), |
| 394 | + dspy.Tool(use_analyst, name="analyst", desc="Use analyst"), |
| 395 | + dspy.Tool(use_researcher, name="researcher", desc="Use researcher"), |
| 396 | + ], |
| 397 | + max_iters=3 |
| 398 | + ) |
| 399 | + |
| 400 | + def forward(self, question): |
| 401 | + result = self.orchestrator(question=question) |
| 402 | + return dspy.Prediction(answer=result.answer) |
| 403 | + |
| 404 | + class MultiAgentSystem(dspy.Module): |
| 405 | + def __init__(self): |
| 406 | + super().__init__() |
| 407 | + self.multi_agent = OrchestratorWorkerSystem() |
| 408 | + |
| 409 | + def forward(self, question): |
| 410 | + return self.multi_agent(question=question) |
| 411 | + |
| 412 | + program = MultiAgentSystem() |
| 413 | + |
| 414 | + optimizer, trainset = create_gepa_optimizer_for_detection() |
| 415 | + |
| 416 | + try: |
| 417 | + optimizer.compile(program, trainset=trainset, valset=trainset) |
| 418 | + except: |
| 419 | + pass |
| 420 | + |
| 421 | + assert f"{REACT_MODULE_PREFIX}:multi_agent.orchestrator" in captured_base_program |
| 422 | + assert f"{REACT_MODULE_PREFIX}:multi_agent.analyst" in captured_base_program |
| 423 | + assert f"{REACT_MODULE_PREFIX}:multi_agent.researcher" in captured_base_program |
| 424 | + |
| 425 | + react_modules = [k for k in captured_base_program.keys() if k.startswith(REACT_MODULE_PREFIX)] |
| 426 | + assert len(react_modules) == 3, f"Expected 3 ReAct modules, got {len(react_modules)}" |
| 427 | + |
| 428 | + assert_react_module_detected( |
| 429 | + captured_base_program, |
| 430 | + "multi_agent.orchestrator", |
| 431 | + {"search": "Search tool", "analyst": "Use analyst", "researcher": "Use researcher"} |
| 432 | + ) |
| 433 | + assert_react_module_detected(captured_base_program, "multi_agent.analyst", {"analyze": "Analyze data"}) |
| 434 | + assert_react_module_detected(captured_base_program, "multi_agent.researcher", {"research": "Research topic"}) |
0 commit comments