Skip to content

Commit 41a0313

Browse files
automatically turn on reasoning for COT on reasoning model
1 parent ec2fbe4 commit 41a0313

File tree

5 files changed

+18
-11
lines changed

5 files changed

+18
-11
lines changed

dspy/adapters/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,13 @@ def _call_preprocess(
7373

7474
return signature_for_native_function_calling
7575

76-
# Handle custom types that use native response
76+
# Handle custom types that use native LM features, e.g., reasoning, citations, etc.
7777
for name, field in signature.output_fields.items():
7878
if (
7979
isinstance(field.annotation, type)
8080
and issubclass(field.annotation, Type)
8181
and field.annotation in self.native_response_types
82-
and field.annotation.is_natively_supported(lm, lm_kwargs)
82+
and field.annotation.adapt_to_native_lm_feature(lm, lm_kwargs)
8383
):
8484
signature = signature.delete(name)
8585

@@ -134,7 +134,7 @@ def _call_postprocess(
134134
isinstance(field.annotation, type)
135135
and issubclass(field.annotation, Type)
136136
and field.annotation in self.native_response_types
137-
and field.annotation.is_natively_supported(lm, lm_kwargs)
137+
and field.annotation.adapt_to_native_lm_feature(lm, lm_kwargs)
138138
):
139139
value[name] = field.annotation.parse_lm_response(output)
140140

dspy/adapters/types/base_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def serialize_model(self):
7171
return formatted
7272

7373
@classmethod
74-
def is_natively_supported(cls, lm, lm_kwargs) -> bool:
75-
"""Whether the custom type is natively supported by the LM."""
74+
def adapt_to_native_lm_feature(cls, lm, lm_kwargs) -> bool:
75+
"""Check whether the custom type is natively supported by the LM and adapt to the native feature if possible."""
7676
return False
7777

7878
@classmethod

dspy/adapters/types/citation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,7 @@ def __getitem__(self, index):
166166
return self.citations[index]
167167

168168
@classmethod
169-
def is_natively_supported(cls, lm, lm_kwargs) -> bool:
170-
"""Whether the Citations type is natively supported by the LM."""
169+
def adapt_to_native_lm_feature(cls, lm, lm_kwargs) -> bool:
171170
return lm.model.startswith("anthropic/")
172171

173172
@classmethod

dspy/adapters/types/reasoning.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ def validate_input(cls, data: Any):
4040
raise ValueError(f"Received invalid value for `dspy.Reasoning`: {data}")
4141

4242
@classmethod
43-
def is_natively_supported(cls, lm, lm_kwargs) -> bool:
44-
"""Whether the Reasoning type is natively supported by the LM."""
43+
def adapt_to_native_lm_feature(cls, lm, lm_kwargs) -> bool:
4544
if not litellm.supports_reasoning(lm.model):
4645
return False
4746

@@ -53,7 +52,11 @@ def is_natively_supported(cls, lm, lm_kwargs) -> bool:
5352
else:
5453
reasoning_effort = None
5554

56-
return reasoning_effort is not None
55+
if reasoning_effort is None:
56+
# Turn on the native reasoning
57+
lm_kwargs["reasoning_effort"] = "low"
58+
59+
return True
5760

5861
@classmethod
5962
def parse_lm_response(cls, response: str | dict[str, Any]) -> Optional["Reasoning"]:

tests/predict/test_chain_of_thought.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ async def test_async_chain_of_thought():
3131
def test_chain_of_thought_with_native_reasoning():
3232
"""Test ChainOfThought with native reasoning support where LM returns reasoning natively."""
3333

34-
lm = dspy.LM(model="anthropic/claude-3-7-sonnet-20250219", reasoning_effort="low", cache=False)
34+
lm = dspy.LM(model="anthropic/claude-3-7-sonnet-20250219", cache=False)
3535
dspy.settings.configure(lm=lm)
3636

3737
with mock.patch("litellm.completion") as mock_completion:
@@ -53,6 +53,11 @@ def test_chain_of_thought_with_native_reasoning():
5353
assert isinstance(result.reasoning, dspy.Reasoning)
5454
assert result.reasoning.content == "Step-by-step thinking about the capital of France"
5555

56+
# Check that the reasoning_effort is automatically set to "low" when the LM supports native reasoning and not
57+
# provided in the LM kwargs
58+
args, kwargs = mock_completion.call_args
59+
assert kwargs["reasoning_effort"] == "low"
60+
5661

5762
def test_chain_of_thought_with_manual_reasoning():
5863
"""Test ChainOfThought with manual reasoning where LM doesn't support native reasoning."""

0 commit comments

Comments
 (0)