Skip to content

Implement CoT no-op for reasoning models #8375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dspy/clients/base_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def _process_lm_response(self, response, prompt, messages, **kwargs):
output["logprobs"] = c.logprobs if hasattr(c, "logprobs") else c["logprobs"]
if hasattr(c, "message") and getattr(c.message, "tool_calls", None):
output["tool_calls"] = c.message.tool_calls
if hasattr(c, "message") and hasattr(c.message, "reasoning_content"):
output["reasoning_content"] = c.message.reasoning_content
outputs.append(output)

if all(len(output) == 1 for output in outputs):
Expand Down
3 changes: 3 additions & 0 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
callbacks: Optional[List[BaseCallback]] = None,
num_retries: int = 3,
provider: Optional[Provider] = None,
reasoning_model: Optional[bool] = None,
finetuning_model: Optional[str] = None,
launch_kwargs: Optional[dict[str, Any]] = None,
train_kwargs: Optional[dict[str, Any]] = None,
Expand All @@ -51,6 +52,7 @@ def __init__(
model_type: The type of the model, either ``"chat"`` or ``"text"``.
temperature: The sampling temperature to use when generating responses.
max_tokens: The maximum number of tokens to generate per response.
reasoning_model: Whether the model is a reasoning model.
cache: Whether to cache the model responses for reuse to improve performance
and reduce costs.
cache_in_memory (deprecated): To enable additional caching with LRU in memory.
Expand All @@ -71,6 +73,7 @@ def __init__(
self.callbacks = callbacks or []
self.history = []
self.num_retries = num_retries
self.reasoning_model = reasoning_model if reasoning_model is not None else litellm.supports_reasoning(model)
self.finetuning_model = finetuning_model
self.launch_kwargs = launch_kwargs or {}
self.train_kwargs = train_kwargs or {}
Expand Down
73 changes: 66 additions & 7 deletions dspy/predict/chain_of_thought.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import cached_property
from typing import Any, Optional, Type, Union

from pydantic.fields import FieldInfo
Expand Down Expand Up @@ -26,16 +27,74 @@ def __init__(
**config: The configuration for the module.
"""
super().__init__()
signature = ensure_signature(signature)
prefix = "Reasoning: Let's think step by step in order to"
desc = "${reasoning}"
rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type
rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type)
self.predict = dspy.Predict(extended_signature, **config)
self._signature = ensure_signature(signature)
self._config = config
self._rationale_field = rationale_field
self._rationale_field_type = rationale_field_type

@cached_property
def predict(self):
"""Returns the appropriate predict instance based on the LM's reasoning model capability."""
lm = dspy.settings.lm
if lm and getattr(lm, "reasoning_model", False):
return dspy.Predict(self._signature, **self._config)
else:
prefix = "Reasoning: Let's think step by step in order to"
desc = "${reasoning}"
rationale_field_type = (
self._rationale_field.annotation if self._rationale_field else self._rationale_field_type
)
rationale_field = (
self._rationale_field if self._rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
)
extended_signature = self._signature.prepend(
name="reasoning", field=rationale_field, type_=rationale_field_type
)
return dspy.Predict(extended_signature, **self._config)

def forward(self, **kwargs):
return self.predict(**kwargs)

async def aforward(self, **kwargs):
return await self.predict.acall(**kwargs)

def load_state(self, state):
"""Override to ensure predict parameter is created before loading state."""
# If predict state exists but predict hasn't been accessed yet, access it first
if "predict" in state and "predict" not in self.__dict__:
_ = self.predict # This creates the predict instance

# Now call the base load_state which will load into all named_parameters
return super().load_state(state)

def __setstate__(self, state):
"""Custom deserialization for cloudpickle to preserve predict instance."""
# Restore the state normally
self.__dict__.update(state)

# If predict was cached and serialized, we don't need to do anything special
# since cloudpickle should have preserved it correctly

def __getstate__(self):
"""Custom serialization for cloudpickle to ensure predict instance is preserved."""
state = self.__dict__.copy()
# Force evaluation of cached property if not already done
if "predict" not in state:
# Access the predict property to cache it before serialization
_ = self.predict
state = self.__dict__.copy()
return state

def named_parameters(self):
"""Override to ensure the predict property is cached and included in named parameters."""
# Force evaluation of the cached_property if not already done
# This ensures it gets stored in __dict__ and picked up by the base implementation
if "predict" not in self.__dict__:
try:
_ = self.predict # This triggers the cached_property
except Exception:
# If accessing predict fails for any reason, continue without it
pass

# Now call the base implementation which will include the cached predict
return super().named_parameters()
5 changes: 4 additions & 1 deletion dspy/streaming/streaming_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ def find_predictor_for_stream_listeners(program: "Module", stream_listeners: Lis
if listener.predict:
predict_id_to_listener[id(listener.predict)].append(listener)
continue
if listener.signature_field_name not in field_name_to_named_predictor:
if (
listener.signature_field_name not in field_name_to_named_predictor
or field_name_to_named_predictor[listener.signature_field_name] is None
):
raise ValueError(
f"Signature field {listener.signature_field_name} is not a field of any predictor in the program, "
"cannot automatically determine which predictor to use for streaming. Please verify your field name or "
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"optuna>=3.4.0",
"pydantic>=2.0",
"magicattr>=0.1.6",
"litellm>=1.64.0",
"litellm>=1.72.4",
"diskcache>=5.6.0",
"json-repair>=0.30.0",
"tenacity>=8.2.3",
Expand All @@ -59,8 +59,8 @@ dev = [
"pillow>=10.1.0",
"datamodel_code_generator>=0.26.3",
"build>=1.0.3",
"litellm>=1.64.0; sys_platform == 'win32'",
"litellm[proxy]>=1.64.0; sys_platform != 'win32'",
"litellm>=1.72.4; sys_platform == 'win32'",
"litellm[proxy]>=1.72.4; sys_platform != 'win32'",
]
test_extras = [
"mcp; python_version >= '3.10'",
Expand Down
12 changes: 12 additions & 0 deletions tests/predict/test_chain_of_thought.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,15 @@ async def test_async_chain_of_thought():
program = ChainOfThought("question -> answer")
result = await program.acall(question="What is 1+1?")
assert result.answer == "2"


def test_cot_skips_with_reasoning_model():
lm = DummyLM([{"answer": "2"}])
lm.reasoning_model = True
dspy.settings.configure(lm=lm)
signature = dspy.Signature("question -> answer")
predict = ChainOfThought(signature)
assert list(predict.predict.signature.output_fields.keys()) == [
"answer",
]
assert predict(question="What is 1+1?").answer == "2"
Loading