Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 36 additions & 28 deletions dspy/adapters/baml_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def _render_type_str(
if origin in (types.UnionType, Union):
non_none_args = [arg for arg in args if arg is not type(None)]
# Render the non-None part of the union
type_render = " or ".join([_render_type_str(arg, depth + 1, indent) for arg in non_none_args])
type_render = " or ".join(
_render_type_str(arg, depth + 1, indent, seen_models) for arg in non_none_args
)
# Add "or null" if None was part of the union
if len(non_none_args) < len(args):
return f"{type_render} or null"
Expand All @@ -73,11 +75,14 @@ def _render_type_str(
current_indent = " " * indent
return f"[\n{inner_schema}\n{current_indent}]"
else:
return f"{_render_type_str(inner_type, depth + 1, indent)}[]"
return f"{_render_type_str(inner_type, depth + 1, indent, seen_models)}[]"

# dict[T1, T2]
if origin is dict:
return f"dict[{_render_type_str(args[0], depth + 1, indent)}, {_render_type_str(args[1], depth + 1, indent)}]"
return (
f"dict[{_render_type_str(args[0], depth + 1, indent, seen_models)}, "
f"{_render_type_str(args[1], depth + 1, indent, seen_models)}]"
)

# fallback
if hasattr(annotation, "__name__"):
Expand All @@ -102,32 +107,35 @@ def _build_simplified_schema(
if pydantic_model in seen_models:
raise ValueError("BAMLAdapter cannot handle recursive pydantic models, please use a different adapter.")

# Add `pydantic_model` to `seen_models` with a placeholder value to avoid infinite recursion.
# Track models currently being traversed so we only flag true recursion.
seen_models.add(pydantic_model)

lines = []
current_indent = " " * indent
next_indent = " " * (indent + 1)

lines.append(f"{current_indent}{{")

fields = pydantic_model.model_fields
if not fields:
lines.append(f"{next_indent}{COMMENT_SYMBOL} No fields defined")
for name, field in fields.items():
if field.description:
lines.append(f"{next_indent}{COMMENT_SYMBOL} {field.description}")
elif field.alias and field.alias != name:
# If there's an alias but no description, show the alias as a comment
lines.append(f"{next_indent}{COMMENT_SYMBOL} alias: {field.alias}")

rendered_type = _render_type_str(field.annotation, indent=indent + 1, seen_models=seen_models)
line = f"{next_indent}{name}: {rendered_type},"

lines.append(line)

lines.append(f"{current_indent}}}")
return "\n".join(lines)
try:
lines = []
current_indent = " " * indent
next_indent = " " * (indent + 1)

lines.append(f"{current_indent}{{")

fields = pydantic_model.model_fields
if not fields:
lines.append(f"{next_indent}{COMMENT_SYMBOL} No fields defined")
for name, field in fields.items():
if field.description:
lines.append(f"{next_indent}{COMMENT_SYMBOL} {field.description}")
elif field.alias and field.alias != name:
# If there's an alias but no description, show the alias as a comment
lines.append(f"{next_indent}{COMMENT_SYMBOL} alias: {field.alias}")

rendered_type = _render_type_str(field.annotation, indent=indent + 1, seen_models=seen_models)
line = f"{next_indent}{name}: {rendered_type},"

lines.append(line)

lines.append(f"{current_indent}}}")
return "\n".join(lines)
finally:
# Remove the model so other branches can reuse it without triggering a false recursion.
seen_models.remove(pydantic_model)


class BAMLAdapter(JSONAdapter):
Expand Down
69 changes: 69 additions & 0 deletions tests/adapters/test_baml_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,48 @@ class TestSignature(dspy.Signature):
assert "BAMLAdapter cannot handle recursive pydantic models" in str(error.value)


def test_baml_adapter_detects_recursive_models_with_union_annotations():
"""Recursive models wrapped in union types should still be rejected."""

class RecursiveNode(pydantic.BaseModel):
label: str
child: "RecursiveNode | None" = None

RecursiveNode.model_rebuild()

class TestSignature(dspy.Signature):
prompt: str = dspy.InputField()
node: RecursiveNode = dspy.OutputField()

adapter = BAMLAdapter()
with pytest.raises(ValueError, match="BAMLAdapter cannot handle recursive pydantic models"):
adapter.format_field_structure(TestSignature)


def test_baml_adapter_detects_deep_recursive_cycles():
"""Cycles spanning multiple models should be caught even through unions."""

class StepThree(pydantic.BaseModel):
back: "StepTwo | None" = None

class StepTwo(pydantic.BaseModel):
next_step: StepThree | None = None

StepTwo.model_rebuild()
StepThree.model_rebuild()

class Wrapper(pydantic.BaseModel):
branch: StepTwo

class TestSignature(dspy.Signature):
query: str = dspy.InputField()
wrapper: Wrapper = dspy.OutputField()

adapter = BAMLAdapter()
with pytest.raises(ValueError, match="BAMLAdapter cannot handle recursive pydantic models"):
adapter.format_field_structure(TestSignature)


def test_baml_adapter_formats_pydantic_inputs_as_clean_json():
"""Test that Pydantic input instances are formatted as clean JSON."""

Expand Down Expand Up @@ -220,6 +262,33 @@ class TestSignature(dspy.Signature):
pass


def test_baml_adapter_handles_model_reuse_in_schema():
"""Test that the same model can be used multiple times within a schema without recursion issues."""

class CommonFields(pydantic.BaseModel):
name: str = pydantic.Field(description="A descriptive name")
metadata: list[str] = pydantic.Field(description="Associated metadata tags")

class DocumentInfo(pydantic.BaseModel):
source: CommonFields = pydantic.Field(description="Source information")
target: CommonFields = pydantic.Field(description="Target information")
version: int

class TestSignature(dspy.Signature):
query: str = dspy.InputField()
document: DocumentInfo = dspy.OutputField()

adapter = BAMLAdapter()
schema = adapter.format_field_structure(TestSignature)

# Should include both instances of CommonFields without treating as recursion
assert schema.count("name: string,") == 2 # One for source, one for target
assert schema.count("metadata: string[],") == 2
assert schema.count(f"{COMMENT_SYMBOL} A descriptive name") == 2 # Field description appears twice
assert schema.count(f"{COMMENT_SYMBOL} Associated metadata tags") == 2 # Field description appears twice
assert "version: int," in schema # Top level field exists once


def test_baml_adapter_raises_on_missing_fields():
"""Test that missing required fields raise appropriate errors."""

Expand Down