Skip to content

Commit b6cc67b

Browse files
committed
refactor(gepa): unify ReAct module key handling and use constant
- Replace all magic string 'react_module' with REACT_MODULE_PREFIX constant - Unify path normalization pattern across gepa.py and gepa_utils.py - Rename 'prefix' to 'normalized_path' for clarity - Simplify module lookup by using consistent normalization - Remove awkward OR clause in ReAct module matching logic This makes the codebase more maintainable with a single source of truth for the module prefix and consistent naming throughout.
1 parent ec6bb7b commit b6cc67b

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

dspy/teleprompt/gepa/gepa.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from dspy.clients.lm import LM
1313
from dspy.predict.react import ReAct
1414
from dspy.primitives import Example, Module, Prediction
15-
from dspy.teleprompt.gepa.gepa_utils import DspyAdapter, DSPyTrace, PredictorFeedbackFn, ScoreWithFeedback
15+
from dspy.teleprompt.gepa.gepa_utils import DspyAdapter, DSPyTrace, PredictorFeedbackFn, REACT_MODULE_PREFIX, ScoreWithFeedback
1616
from dspy.teleprompt.teleprompt import Teleprompter
1717
from dspy.utils.annotation import experimental
1818

@@ -539,12 +539,14 @@ def feedback_fn(
539539
# Only process ReAct modules
540540
if not isinstance(module, ReAct):
541541
continue
542-
prefix = module_path.removeprefix("self.") if module_path != "self" else ""
542+
normalized_path = module_path.removeprefix("self.") if module_path != "self" else ""
543543

544544
# Get first predictor name as module identifier
545545
for pred_name, _ in module.named_predictors():
546-
comp_name = pred_name if not prefix else f"{prefix}.{pred_name}"
547-
module_key = f"react_module:{comp_name.split('.')[0]}" if prefix else "react_module"
546+
comp_name = pred_name if not normalized_path else f"{normalized_path}.{pred_name}"
547+
# Use full normalized path to avoid collapsing nested modules
548+
# e.g., "multi_agent.coordinator" not "multi_agent"
549+
module_key = f"{REACT_MODULE_PREFIX}:{normalized_path}" if normalized_path else REACT_MODULE_PREFIX
548550

549551
# Build JSON config with tool args for reflection
550552
config = {
@@ -563,15 +565,15 @@ def feedback_fn(
563565

564566
# Replace predictor keys with module key and extract key to prevent duplicates
565567
base_program.pop(comp_name, None)
566-
extract_key = f"{prefix}.extract.predict" if prefix else "extract.predict"
568+
extract_key = f"{normalized_path}.extract.predict" if normalized_path else "extract.predict"
567569
base_program.pop(extract_key, None)
568570
base_program[module_key] = json.dumps(config, indent=2)
569571
break
570572

571573
# Log base_program keys for debugging
572574
logger.info(f"Initialized base_program with {len(base_program)} components:")
573575
for key in sorted(base_program.keys()):
574-
if key.startswith("react_module"):
576+
if key.startswith(REACT_MODULE_PREFIX):
575577
logger.info(f" {key}: <ReAct module JSON config>")
576578
else:
577579
logger.info(f" {key}: <instruction>")

dspy/teleprompt/gepa/gepa_utils.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def propose_component_texts(
162162

163163
# Otherwise, route to appropriate proposers
164164
# Separate react_module components from regular instruction components
165-
react_module_components = [c for c in components_to_update if c.startswith("react_module")]
166-
instruction_components = [c for c in components_to_update if not c.startswith("react_module")]
165+
react_module_components = [c for c in components_to_update if c.startswith(REACT_MODULE_PREFIX)]
166+
instruction_components = [c for c in components_to_update if not c.startswith(REACT_MODULE_PREFIX)]
167167

168168
results: dict[str, str] = {}
169169

@@ -234,8 +234,8 @@ def build_program(self, candidate: dict[str, str]):
234234
continue
235235

236236
# Build module key
237-
prefix = module_path.removeprefix("self.") if module_path != "self" else ""
238-
module_key = "react_module" if prefix == "" else f"react_module:{prefix}"
237+
normalized_path = module_path.removeprefix("self.") if module_path != "self" else ""
238+
module_key = REACT_MODULE_PREFIX if normalized_path == "" else f"{REACT_MODULE_PREFIX}:{normalized_path}"
239239

240240
# Check if this module was optimized
241241
if module_key not in candidate:
@@ -342,16 +342,19 @@ def make_reflective_dataset(
342342
logger.info(f"Processing component: {pred_name}")
343343

344344
# Handle ReAct module components - use extract predictor for final outputs
345-
if pred_name.startswith("react_module"):
345+
if pred_name.startswith(REACT_MODULE_PREFIX):
346346
# Extract the target path from the key
347-
target_path = pred_name.replace("react_module:", "") if ":" in pred_name else ""
347+
target_path = pred_name.removeprefix(f"{REACT_MODULE_PREFIX}:") if ":" in pred_name else ""
348348

349349
# Find the ReAct module by traversing program structure (same as regular predictors)
350350
react_module = None
351351
for module_path, m in program.named_sub_modules():
352-
clean_path = module_path.removeprefix("self.")
353-
# For top-level ReAct (target_path=""), match "self" or empty string
354-
if isinstance(m, ReAct) and (clean_path == target_path or (target_path == "" and clean_path == "self")):
352+
if not isinstance(m, ReAct):
353+
continue
354+
355+
# Normalize path (same pattern as build_program)
356+
normalized_path = module_path.removeprefix("self.") if module_path != "self" else ""
357+
if normalized_path == target_path:
355358
react_module = m
356359
break
357360

@@ -392,7 +395,7 @@ def make_reflective_dataset(
392395
continue
393396

394397
# For ReAct modules, use LAST extract invocation (has trajectory + final outputs)
395-
if pred_name.startswith("react_module"):
398+
if pred_name.startswith(REACT_MODULE_PREFIX):
396399
selected = trace_instances[-1]
397400
logger.debug(f" Using LAST extract call ({len(trace_instances)} total) with trajectory + final outputs")
398401
if "trajectory" in selected[1]:
@@ -485,7 +488,7 @@ def make_reflective_dataset(
485488
items.append(d)
486489

487490
# Log exact reflective example that reflection LM will see
488-
if pred_name.startswith("react_module") and len(items) == 1:
491+
if pred_name.startswith(REACT_MODULE_PREFIX) and len(items) == 1:
489492
logger.info(f" First reflective example for {pred_name}:")
490493
logger.info(f" Inputs: {list(d['Inputs'].keys())}")
491494
if "trajectory" in d["Inputs"]:

0 commit comments

Comments
 (0)