Skip to content

Make dspy.settings and dspy.context safe in async setup #8203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions dspy/dsp/utils/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import contextvars
import copy
import threading
from contextlib import contextmanager
@@ -34,13 +36,7 @@
# Global lock for settings configuration
global_lock = threading.Lock()


class ThreadLocalOverrides(threading.local):
def __init__(self):
self.overrides = dotdict()


thread_local_overrides = ThreadLocalOverrides()
thread_local_overrides = contextvars.ContextVar("context_overrides", default=dotdict())


class Settings:
@@ -71,7 +67,7 @@ def lock(self):
return global_lock

def __getattr__(self, name):
overrides = getattr(thread_local_overrides, "overrides", dotdict())
overrides = thread_local_overrides.get()
if name in overrides:
return overrides[name]
elif name in main_thread_config:
@@ -92,7 +88,7 @@ def __setitem__(self, key, value):
self.__setattr__(key, value)

def __contains__(self, key):
overrides = getattr(thread_local_overrides, "overrides", dotdict())
overrides = thread_local_overrides.get()
return key in overrides or key in main_thread_config

def get(self, key, default=None):
@@ -102,7 +98,7 @@ def get(self, key, default=None):
return default

def copy(self):
overrides = getattr(thread_local_overrides, "overrides", dotdict())
overrides = thread_local_overrides.get()
return dotdict({**main_thread_config, **overrides})

@property
@@ -113,6 +109,19 @@ def configure(self, **kwargs):
global main_thread_config, config_owner_thread_id
current_thread_id = threading.get_ident()

# Check if we're actually running in an async task
try:
if asyncio.current_task() is not None:
raise RuntimeError(
"dspy.settings.configure(...) cannot be called from an async task. Use `dspy.context(...)` instead."
)
except RuntimeError as e:
# We're not in an async context, which is what we want
if e.args[0].startswith(
"dspy.settings.configure(...) cannot be called from an async task. Use `dspy.context(...)` instead."
):
raise e

with self.lock:
# First configuration: establish ownership. If ownership established, only that thread can configure.
if config_owner_thread_id in [None, current_thread_id]:
@@ -132,17 +141,17 @@ def context(self, **kwargs):
If threads are spawned inside this block using ParallelExecutor, they will inherit these overrides.
"""

original_overrides = getattr(thread_local_overrides, "overrides", dotdict()).copy()
original_overrides = thread_local_overrides.get().copy()
new_overrides = dotdict({**main_thread_config, **original_overrides, **kwargs})
thread_local_overrides.overrides = new_overrides
token = thread_local_overrides.set(new_overrides)

try:
yield
finally:
thread_local_overrides.overrides = original_overrides
thread_local_overrides.reset(token)

def __repr__(self):
overrides = getattr(thread_local_overrides, "overrides", dotdict())
overrides = thread_local_overrides.get()
combined_config = {**main_thread_config, **overrides}
return repr(combined_config)

7 changes: 0 additions & 7 deletions dspy/streaming/streamify.py
Original file line number Diff line number Diff line change
@@ -222,16 +222,10 @@ def apply_sync_streaming(async_generator: AsyncGenerator) -> Generator:

# To propagate prediction request ID context to the child thread
context = contextvars.copy_context()
from dspy.dsp.utils.settings import thread_local_overrides

parent_overrides = thread_local_overrides.overrides.copy()

def producer():
"""Runs in a background thread to fetch items asynchronously."""

original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()

async def runner():
try:
async for item in async_generator:
@@ -241,7 +235,6 @@ async def runner():
queue.put(stop_sentinel)

context.run(asyncio.run, runner())
thread_local_overrides.overrides = original_overrides

# Start the producer in a background thread
thread = threading.Thread(target=producer, daemon=True)
8 changes: 4 additions & 4 deletions dspy/utils/asyncify.py
Original file line number Diff line number Diff line change
@@ -46,17 +46,17 @@ async def async_program(*args, **kwargs) -> Any:
# Capture the current overrides at call-time.
from dspy.dsp.utils.settings import thread_local_overrides

parent_overrides = thread_local_overrides.overrides.copy()
parent_overrides = thread_local_overrides.get().copy()

def wrapped_program(*a, **kw):
from dspy.dsp.utils.settings import thread_local_overrides

original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()
original_overrides = thread_local_overrides.get()
token = thread_local_overrides.set({**original_overrides, **parent_overrides.copy()})
try:
return program(*a, **kw)
finally:
thread_local_overrides.overrides = original_overrides
thread_local_overrides.reset(token)

# Create a fresh asyncified callable each time, ensuring the latest context is used.
call_async = asyncer.asyncify(wrapped_program, abandon_on_cancel=True, limiter=get_limiter())
8 changes: 4 additions & 4 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
@@ -86,16 +86,16 @@ def worker(parent_overrides, submission_id, index, item):
# Apply parent's thread-local overrides
from dspy.dsp.utils.settings import thread_local_overrides

original = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()
original = thread_local_overrides.get()
token = thread_local_overrides.set({**original, **parent_overrides.copy()})
if parent_overrides.get("usage_tracker"):
# Usage tracker needs to be deep copied across threads so that each thread tracks its own usage
thread_local_overrides.overrides["usage_tracker"] = copy.deepcopy(parent_overrides["usage_tracker"])

try:
return index, function(item)
finally:
thread_local_overrides.overrides = original
thread_local_overrides.reset(token)

# Handle Ctrl-C in the main thread
@contextlib.contextmanager
@@ -121,7 +121,7 @@ def handler(sig, frame):
with interrupt_manager():
from dspy.dsp.utils.settings import thread_local_overrides

parent_overrides = thread_local_overrides.overrides.copy()
parent_overrides = thread_local_overrides.get().copy()

futures_map = {}
futures_set = set()
5 changes: 2 additions & 3 deletions tests/adapters/test_two_step_adapter.py
Original file line number Diff line number Diff line change
@@ -94,9 +94,8 @@ class TestSignature(dspy.Signature):
mock_extraction_lm.kwargs = {"temperature": 1.0}
mock_extraction_lm.model = "openai/gpt-4o"

dspy.configure(lm=mock_main_lm, adapter=dspy.TwoStepAdapter(extraction_model=mock_extraction_lm))

result = await program.acall(question="What is 5 + 7?")
with dspy.context(lm=mock_main_lm, adapter=dspy.TwoStepAdapter(extraction_model=mock_extraction_lm)):
result = await program.acall(question="What is 5 + 7?")

assert result.answer == 12

9 changes: 4 additions & 5 deletions tests/callback/test_callback.py
Original file line number Diff line number Diff line change
@@ -189,13 +189,12 @@ def test_callback_complex_module():
@pytest.mark.asyncio
async def test_callback_async_module():
callback = MyCallback()
dspy.settings.configure(
with dspy.context(
lm=DummyLM({"How are you?": {"answer": "test output", "reasoning": "No more responses"}}),
callbacks=[callback],
)

cot = dspy.ChainOfThought("question -> answer", n=3)
result = await cot.acall(question="How are you?")
):
cot = dspy.ChainOfThought("question -> answer", n=3)
result = await cot.acall(question="How are you?")
assert result["answer"] == "test output"
assert result["reasoning"] == "No more responses"

8 changes: 4 additions & 4 deletions tests/predict/test_chain_of_thought.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@ def test_initialization_with_string_signature():
@pytest.mark.asyncio
async def test_async_chain_of_thought():
lm = DummyLM([{"reasoning": "find the number after 1", "answer": "2"}])
dspy.settings.configure(lm=lm)
program = ChainOfThought("question -> answer")
result = await program.acall(question="What is 1+1?")
assert result.answer == "2"
with dspy.context(lm=lm):
program = ChainOfThought("question -> answer")
result = await program.acall(question="What is 1+1?")
assert result.answer == "2"
74 changes: 71 additions & 3 deletions tests/predict/test_predict.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,10 @@
import pydantic
import pytest
import ujson
import os
import time
import asyncio
import types
from litellm import ModelResponse

import dspy
@@ -506,6 +510,70 @@ def test_lm_usage():
assert result.get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10


def test_lm_usage_with_parallel():
program = Predict("question -> answer")

def program_wrapper(question):
# Sleep to make it possible to cause a race condition
time.sleep(0.5)
return program(question=question)

dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True)
with patch(
"dspy.clients.lm.litellm_completion",
return_value=ModelResponse(
choices=[{"message": {"content": "[[ ## answer ## ]]\nParis"}}],
usage={"total_tokens": 10},
),
):
parallelizer = dspy.Parallel()
input_pairs = [
(program_wrapper, {"question": "What is the capital of France?"}),
(program_wrapper, {"question": "What is the capital of France?"}),
]
results = parallelizer(input_pairs)
assert results[0].answer == "Paris"
assert results[1].answer == "Paris"
assert results[0].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
assert results[1].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10


@pytest.mark.asyncio
async def test_lm_usage_with_async():
program = Predict("question -> answer")

original_aforward = program.aforward

async def patched_aforward(self, **kwargs):
await asyncio.sleep(1)
return await original_aforward(**kwargs)

program.aforward = types.MethodType(patched_aforward, program)

with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True):
with patch(
"litellm.acompletion",
return_value=ModelResponse(
choices=[{"message": {"content": "[[ ## answer ## ]]\nParis"}}],
usage={"total_tokens": 10},
),
):
tasks = []
async with asyncio.TaskGroup() as tg:
tasks.append(tg.create_task(program.acall(question="What is the capital of France?")))
tasks.append(tg.create_task(program.acall(question="What is the capital of France?")))
tasks.append(tg.create_task(program.acall(question="What is the capital of France?")))
tasks.append(tg.create_task(program.acall(question="What is the capital of France?")))

results = await asyncio.gather(*tasks)
assert results[0].answer == "Paris"
assert results[1].answer == "Paris"
assert results[0].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
assert results[1].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
assert results[2].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
assert results[3].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10


def test_positional_arguments():
program = Predict("question -> answer")
with pytest.raises(ValueError) as e:
@@ -569,9 +637,9 @@ class ConstrainedSignature(dspy.Signature):
@pytest.mark.asyncio
async def test_async_predict():
program = Predict("question -> answer")
dspy.settings.configure(lm=DummyLM([{"answer": "Paris"}]))
result = await program.acall(question="What is the capital of France?")
assert result.answer == "Paris"
with dspy.context(lm=DummyLM([{"answer": "Paris"}])):
result = await program.acall(question="What is the capital of France?")
assert result.answer == "Paris"


def test_predicted_outputs_piped_from_predict_to_lm_call():
24 changes: 11 additions & 13 deletions tests/predict/test_react.py
Original file line number Diff line number Diff line change
@@ -254,16 +254,15 @@ class InvitationSignature(dspy.Signature):
},
]
)
dspy.settings.configure(lm=lm)

outputs = await react.acall(
participant_name="Alice",
event_info=CalendarEvent(
name="Science Fair",
date="Friday",
participants={"Alice": "female", "Bob": "male"},
),
)
with dspy.context(lm=lm):
outputs = await react.acall(
participant_name="Alice",
event_info=CalendarEvent(
name="Science Fair",
date="Friday",
participants={"Alice": "female", "Bob": "male"},
),
)
assert outputs.invitation_letter == "It's my honor to invite Alice to the Science Fair event on Friday."

expected_trajectory = {
@@ -309,9 +308,8 @@ async def foo(a, b):
{"reasoning": "I added the numbers successfully", "c": 3},
]
)
dspy.settings.configure(lm=lm)

outputs = await react.acall(a=1, b=2, max_iters=2)
with dspy.context(lm=lm):
outputs = await react.acall(a=1, b=2, max_iters=2)
traj = outputs.trajectory

# Exact-match checks (thoughts + tool calls)
97 changes: 48 additions & 49 deletions tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
@@ -142,8 +142,6 @@ def __call__(self, x: str, **kwargs):
judgement = self.predict2(question=x, answer=answer, **kwargs)
return judgement

# Turn off the cache to ensure the stream is produced.
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False))
my_program = MyProgram()
program = dspy.streamify(
my_program,
@@ -153,11 +151,13 @@ def __call__(self, x: str, **kwargs):
],
include_final_prediction_in_output_stream=False,
)
output = program(x="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)
# Turn off the cache to ensure the stream is produced.
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
output = program(x="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)

assert all_chunks[0].predict_name == "predict1"
assert all_chunks[0].signature_field_name == "answer"
@@ -205,8 +205,6 @@ def __call__(self, x: str, **kwargs):
judgement = self.predict2(question=x, answer=answer, **kwargs)
return judgement

# Turn off the cache to ensure the stream is produced.
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter())
my_program = MyProgram()
program = dspy.streamify(
my_program,
@@ -216,11 +214,13 @@ def __call__(self, x: str, **kwargs):
],
include_final_prediction_in_output_stream=False,
)
output = program(x="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)
# Turn off the cache to ensure the stream is produced.
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()):
output = program(x="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)

assert all_chunks[0].predict_name == "predict1"
assert all_chunks[0].signature_field_name == "answer"
@@ -241,8 +241,6 @@ def __call__(self, x: str, **kwargs):
judgement = self.predict2(question=x, answer=answer, **kwargs)
return judgement

# Turn off the cache to ensure the stream is produced.
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False))
my_program = MyProgram()
program = dspy.streamify(
my_program,
@@ -253,11 +251,13 @@ def __call__(self, x: str, **kwargs):
include_final_prediction_in_output_stream=False,
async_streaming=False,
)
output = program(x="why did a chicken cross the kitchen?")
all_chunks = []
for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)
# Turn off the cache to ensure the stream is produced.
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
output = program(x="why did a chicken cross the kitchen?")
all_chunks = []
for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)

assert all_chunks[0].predict_name == "predict1"
assert all_chunks[0].signature_field_name == "answer"
@@ -350,8 +350,6 @@ async def gpt_4o_mini_stream_2():
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]"))])

dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False))

stream_generators = [gpt_4o_mini_stream_1, gpt_4o_mini_stream_2]

async def completion_side_effect(*args, **kwargs):
@@ -365,11 +363,12 @@ async def completion_side_effect(*args, **kwargs):
dspy.streaming.StreamListener(signature_field_name="judgement"),
],
)
output = program(question="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
output = program(question="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)

assert all_chunks[0].predict_name == "predict1"
assert all_chunks[0].signature_field_name == "answer"
@@ -396,6 +395,7 @@ async def completion_side_effect(*args, **kwargs):
async def test_stream_listener_returns_correct_chunk_json_adapter():
class MyProgram(dspy.Module):
def __init__(self):
super().__init__()
self.predict1 = dspy.Predict("question->answer")
self.predict2 = dspy.Predict("question,answer->judgement")

@@ -404,8 +404,6 @@ def forward(self, question, **kwargs):
judgement = self.predict2(question=question, answer=answer, **kwargs)
return judgement

dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter())

async def gpt_4o_mini_stream_1(*args, **kwargs):
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"'))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="answer"))])
@@ -461,11 +459,12 @@ async def gpt_4o_mini_stream_2(*args, **kwargs):
dspy.streaming.StreamListener(signature_field_name="judgement"),
],
)
output = program(question="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()):
output = program(question="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)

assert all_chunks[0].predict_name == "predict1"
assert all_chunks[0].signature_field_name == "answer"
@@ -488,6 +487,7 @@ async def gpt_4o_mini_stream_2(*args, **kwargs):
async def test_stream_listener_returns_correct_chunk_chat_adapter_untokenized_stream():
class MyProgram(dspy.Module):
def __init__(self):
super().__init__()
self.predict1 = dspy.Predict("question->answer")
self.predict2 = dspy.Predict("question,answer->judgement")

@@ -496,8 +496,6 @@ def forward(self, question, **kwargs):
judgement = self.predict2(question=question, answer=answer, **kwargs)
return judgement

dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.ChatAdapter())

async def gemini_stream_1(*args, **kwargs):
yield ModelResponseStream(model="gemini", choices=[StreamingChoices(delta=Delta(content="[[ ##"))])
yield ModelResponseStream(model="gemini", choices=[StreamingChoices(delta=Delta(content=" answer ## ]]"))])
@@ -539,11 +537,12 @@ async def gemini_stream_2(*args, **kwargs):
dspy.streaming.StreamListener(signature_field_name="judgement"),
],
)
output = program(question="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)
with dspy.context(lm=dspy.LM("gemini/gemini-2.5-flash", cache=False), adapter=dspy.ChatAdapter()):
output = program(question="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)

assert all_chunks[0].predict_name == "predict1"
assert all_chunks[0].signature_field_name == "answer"
@@ -561,6 +560,7 @@ async def gemini_stream_2(*args, **kwargs):
async def test_stream_listener_returns_correct_chunk_json_adapter_untokenized_stream():
class MyProgram(dspy.Module):
def __init__(self):
super().__init__()
self.predict1 = dspy.Predict("question->answer")
self.predict2 = dspy.Predict("question,answer->judgement")

@@ -569,8 +569,6 @@ def forward(self, question, **kwargs):
judgement = self.predict2(question=question, answer=answer, **kwargs)
return judgement

dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter())

async def gemini_stream_1(*args, **kwargs):
yield ModelResponseStream(model="gemini", choices=[StreamingChoices(delta=Delta(content="{\n"))])
yield ModelResponseStream(
@@ -606,11 +604,12 @@ async def gemini_stream_2(*args, **kwargs):
dspy.streaming.StreamListener(signature_field_name="judgement"),
],
)
output = program(question="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)
with dspy.context(lm=dspy.LM("gemini/gemini-2.5-flash", cache=False), adapter=dspy.JSONAdapter()):
output = program(question="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)

assert all_chunks[0].predict_name == "predict1"
assert all_chunks[0].signature_field_name == "answer"
19 changes: 9 additions & 10 deletions tests/utils/test_asyncify.py
Original file line number Diff line number Diff line change
@@ -14,10 +14,10 @@ async def test_async_limiter():
assert limiter.total_tokens == 8, "Default async capacity should be 8"
assert get_limiter() == limiter, "AsyncLimiter should be a singleton"

dspy.settings.configure(async_max_workers=16)
assert get_limiter() == limiter, "AsyncLimiter should be a singleton"
assert get_limiter().total_tokens == 16, "Async capacity should be 16"
assert get_limiter() == get_limiter(), "AsyncLimiter should be a singleton"
with dspy.context(async_max_workers=16):
assert get_limiter() == limiter, "AsyncLimiter should be a singleton"
assert get_limiter().total_tokens == 16, "Async capacity should be 16"
assert get_limiter() == get_limiter(), "AsyncLimiter should be a singleton"


@pytest.mark.anyio
@@ -32,12 +32,11 @@ async def run_n_tasks(n: int, wait: float):
await asyncio.gather(*[ask_the_question(wait) for _ in range(n)])

async def verify_asyncify(capacity: int, number_of_tasks: int, wait: float = 0.5):
dspy.settings.configure(async_max_workers=capacity)

start = time()
await run_n_tasks(number_of_tasks, wait)
end = time()
total_time = end - start
with dspy.context(async_max_workers=capacity):
start = time()
await run_n_tasks(number_of_tasks, wait)
end = time()
total_time = end - start

# If asyncify is working correctly, the total time should be less than the total number of loops
# `(number_of_tasks / capacity)` times wait time, plus the computational overhead. The lower bound should
162 changes: 162 additions & 0 deletions tests/utils/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import dspy
from concurrent.futures import ThreadPoolExecutor
from litellm import ModelResponse, Choices, Message
from unittest import mock
import pytest
import asyncio
import time


def test_basic_dspy_settings():
dspy.configure(lm=dspy.LM("openai/gpt-4o"), adapter=dspy.JSONAdapter(), callbacks=[lambda x: x])
assert dspy.settings.lm.model == "openai/gpt-4o"
assert isinstance(dspy.settings.adapter, dspy.JSONAdapter)
assert len(dspy.settings.callbacks) == 1


def test_forbid_configure_call_in_child_thread():
dspy.configure(lm=dspy.LM("openai/gpt-4o"), adapter=dspy.JSONAdapter(), callbacks=[lambda x: x])

def worker():
with pytest.raises(RuntimeError, match="Cannot call dspy.configure in a child thread"):
dspy.configure(lm=dspy.LM("openai/gpt-4o-mini"), callbacks=[])

with ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(worker)


@pytest.mark.asyncio
async def test_forbid_configure_call_in_async_function():
with pytest.raises(
RuntimeError,
match=r"dspy.settings.configure\(\.\.\.\) cannot be called from*",
):
dspy.configure(lm=dspy.LM("openai/gpt-4o"), adapter=dspy.JSONAdapter(), callbacks=[lambda x: x])

with dspy.context(lm=dspy.LM("openai/gpt-4o-mini"), callbacks=[]):
# context is allowed
pass


def test_dspy_context():
dspy.configure(lm=dspy.LM("openai/gpt-4o"), adapter=dspy.JSONAdapter(), callbacks=[lambda x: x])
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini"), callbacks=[]):
assert dspy.settings.lm.model == "openai/gpt-4o-mini"
assert len(dspy.settings.callbacks) == 0

assert dspy.settings.lm.model == "openai/gpt-4o"
assert len(dspy.settings.callbacks) == 1


def test_dspy_context_parallel():
dspy.configure(lm=dspy.LM("openai/gpt-4o"), adapter=dspy.JSONAdapter(), callbacks=[lambda x: x])

def worker(i):
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini"), trace=[i], callbacks=[]):
assert dspy.settings.lm.model == "openai/gpt-4o-mini"
assert dspy.settings.trace == [i]
assert len(dspy.settings.callbacks) == 0

with ThreadPoolExecutor(max_workers=5) as executor:
executor.map(worker, range(3))

assert dspy.settings.lm.model == "openai/gpt-4o"
assert len(dspy.settings.callbacks) == 1


def test_dspy_context_with_dspy_parallel():
dspy.configure(lm=dspy.LM("openai/gpt-4o", cache=False), adapter=dspy.ChatAdapter())

class MyModule(dspy.Module):
def __init__(self):
self.predict = dspy.Predict("question -> answer")

def forward(self, question: str) -> str:
lm = dspy.LM("openai/gpt-4o-mini", cache=False) if "France" in question else dspy.settings.lm
with dspy.context(lm=lm):
time.sleep(1)
assert dspy.settings.lm.model == lm.model
return self.predict(question=question)

with mock.patch("litellm.completion") as mock_completion:
mock_completion.return_value = ModelResponse(
choices=[Choices(message=Message(content="[[ ## answer ## ]]\nParis"))],
model="openai/gpt-4o-mini",
)

module = MyModule()
parallelizer = dspy.Parallel()
input_pairs = [
(module, {"question": "What is the capital of France?"}),
(module, {"question": "What is the capital of Germany?"}),
]
parallelizer(input_pairs)

# Verify mock was called correctly
assert mock_completion.call_count == 2
for call_args in mock_completion.call_args_list:
if "France" in call_args.kwargs["messages"][-1]["content"]:
# France question uses gpt-4o-mini
assert call_args.kwargs["model"] == "openai/gpt-4o-mini"
else:
# Germany question uses gpt-4o
assert call_args.kwargs["model"] == "openai/gpt-4o"

# The main thread is not affected by the context
assert dspy.settings.lm.model == "openai/gpt-4o"


@pytest.mark.asyncio
async def test_dspy_context_with_async_task_group():

class MyModule(dspy.Module):
def __init__(self):
self.predict = dspy.Predict("question -> answer")

async def aforward(self, question: str) -> str:
lm = (
dspy.LM("openai/gpt-4o-mini", cache=False)
if "France" in question
else dspy.LM("openai/gpt-4o", cache=False)
)
with dspy.context(lm=lm, trace=[]):
await asyncio.sleep(1)
assert dspy.settings.lm.model == lm.model
result = await self.predict.acall(question=question)
assert len(dspy.settings.trace) == 1
return result

module = MyModule()

with dspy.context(lm=dspy.LM("openai/gpt-4.1", cache=False), adapter=dspy.ChatAdapter()):
with mock.patch("litellm.acompletion") as mock_completion:
mock_completion.return_value = ModelResponse(
choices=[Choices(message=Message(content="[[ ## answer ## ]]\nParis"))],
model="openai/gpt-4o-mini",
)
tasks = []
async with asyncio.TaskGroup() as tg:
tasks.append(tg.create_task(module.acall(question="What is the capital of France?")))
tasks.append(tg.create_task(module.acall(question="What is the capital of France?")))
tasks.append(tg.create_task(module.acall(question="What is the capital of Germany?")))
tasks.append(tg.create_task(module.acall(question="What is the capital of Germany?")))

results = await asyncio.gather(*tasks)

assert results[0].answer == "Paris"
assert results[1].answer == "Paris"
assert results[2].answer == "Paris"
assert results[3].answer == "Paris"

# Verify mock was called correctly
assert mock_completion.call_count == 4
# France question uses gpt-4o-mini
assert mock_completion.call_args_list[0].kwargs["model"] == "openai/gpt-4o-mini"
assert mock_completion.call_args_list[1].kwargs["model"] == "openai/gpt-4o-mini"
# Germany question uses gpt-4o
assert mock_completion.call_args_list[2].kwargs["model"] == "openai/gpt-4o"
assert mock_completion.call_args_list[3].kwargs["model"] == "openai/gpt-4o"

# The main thread is not affected by the context
assert dspy.settings.lm.model == "openai/gpt-4.1"
assert dspy.settings.trace == []