Skip to content

Commit 31b5e48

Browse files
ishahrozmdrxy
andauthored
feat(deepseek): support strict beta structured output (#32727)
**Description:** This PR adds support for DeepSeek's beta strict mode feature for structured outputs and tool calling. It overrides `bind_tools()` and `with_structured_output()` to automatically use DeepSeek's beta endpoint (https://api.deepseek.com/beta) when `strict=True`. Both methods need overriding because they're independent entry points and user can call either directly. When DeepSeek's strict mode graduates from beta, we can just remove both overriden methods. You can read more about the beta feature here: https://api-docs.deepseek.com/guides/function_calling#strict-mode-beta **Issue:** Implements #32670 **Dependencies:** None **Sample Code** ```python from langchain_deepseek import ChatDeepSeek from pydantic import BaseModel, Field from typing import Optional import os # Enter your DeepSeek API Key here API_KEY = "YOUR_API_KEY" # location, temperature, condition are required fields # humidity is optional field with default value class WeatherInfo(BaseModel): location: str = Field(description="City name") temperature: int = Field(description="Temperature in Celsius") condition: str = Field(description="Weather condition (sunny, cloudy, rainy)") humidity: Optional[int] = Field(default=None, description="Humidity percentage") llm = ChatDeepSeek( model="deepseek-chat", api_key=API_KEY, ) # just to confirm that a new instance will use the default base url (instead of beta) print(f"Default API base: {llm.api_base}") # Test 1: bind_tools with strict=True shoud list all the tools calls print("\nTest 1: bind_tools with strict=True") llm_with_tools = llm.bind_tools([WeatherInfo], strict=True) response = llm_with_tools.invoke("Tell me the weather in New York. It's 22 degrees, sunny.") print(response.tool_calls) # Test 2: with_structured_output with strict=True print("\nTest 2: with_structured_output with strict=True") structured_llm = llm.with_structured_output(WeatherInfo, strict=True) result = structured_llm.invoke("Tell me the weather in New York.") print(f" Result: {result}") assert isinstance(result, WeatherInfo), "Result should be a WeatherInfo instance" ``` --------- Co-authored-by: Mason Daugherty <[email protected]> Co-authored-by: Mason Daugherty <[email protected]>
1 parent c6801fe commit 31b5e48

File tree

3 files changed

+132
-72
lines changed

3 files changed

+132
-72
lines changed

libs/partners/deepseek/langchain_deepseek/chat_models.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import json
6-
from collections.abc import Iterator
6+
from collections.abc import Callable, Iterator, Sequence
77
from json import JSONDecodeError
88
from typing import Any, Literal, TypeAlias
99

@@ -12,15 +12,17 @@
1212
CallbackManagerForLLMRun,
1313
)
1414
from langchain_core.language_models import LangSmithParams, LanguageModelInput
15-
from langchain_core.messages import AIMessageChunk, BaseMessage
15+
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
1616
from langchain_core.outputs import ChatGenerationChunk, ChatResult
1717
from langchain_core.runnables import Runnable
18+
from langchain_core.tools import BaseTool
1819
from langchain_core.utils import from_env, secret_from_env
1920
from langchain_openai.chat_models.base import BaseChatOpenAI
2021
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
2122
from typing_extensions import Self
2223

2324
DEFAULT_API_BASE = "https://api.deepseek.com/v1"
25+
DEFAULT_BETA_API_BASE = "https://api.deepseek.com/beta"
2426

2527
_DictOrPydanticClass: TypeAlias = dict[str, Any] | type[BaseModel]
2628
_DictOrPydantic: TypeAlias = dict[str, Any] | BaseModel
@@ -39,7 +41,7 @@ class ChatDeepSeek(BaseChatOpenAI):
3941
4042
Key init args — completion params:
4143
model:
42-
Name of DeepSeek model to use, e.g. `"deepseek-chat"`.
44+
Name of DeepSeek model to use, e.g. `'deepseek-chat'`.
4345
temperature:
4446
Sampling temperature.
4547
max_tokens:
@@ -368,6 +370,50 @@ def _generate(
368370
e.pos,
369371
) from e
370372

373+
def bind_tools(
374+
self,
375+
tools: Sequence[dict[str, Any] | type | Callable | BaseTool],
376+
*,
377+
tool_choice: dict | str | bool | None = None,
378+
strict: bool | None = None,
379+
parallel_tool_calls: bool | None = None,
380+
**kwargs: Any,
381+
) -> Runnable[LanguageModelInput, AIMessage]:
382+
"""Bind tool-like objects to this chat model.
383+
384+
Overrides parent to use beta endpoint when `strict=True`.
385+
386+
Args:
387+
tools: A list of tool definitions to bind to this chat model.
388+
tool_choice: Which tool to require the model to call.
389+
strict: If True, uses beta API for strict schema validation.
390+
parallel_tool_calls: Set to `False` to disable parallel tool use.
391+
**kwargs: Additional parameters passed to parent `bind_tools`.
392+
393+
Returns:
394+
A Runnable that takes same inputs as a chat model.
395+
"""
396+
# If strict mode is enabled and using default API base, switch to beta endpoint
397+
if strict is True and self.api_base == DEFAULT_API_BASE:
398+
# Create a new instance with beta endpoint
399+
beta_model = self.model_copy(update={"api_base": DEFAULT_BETA_API_BASE})
400+
return beta_model.bind_tools(
401+
tools,
402+
tool_choice=tool_choice,
403+
strict=strict,
404+
parallel_tool_calls=parallel_tool_calls,
405+
**kwargs,
406+
)
407+
408+
# Otherwise use parent implementation
409+
return super().bind_tools(
410+
tools,
411+
tool_choice=tool_choice,
412+
strict=strict,
413+
parallel_tool_calls=parallel_tool_calls,
414+
**kwargs,
415+
)
416+
371417
def with_structured_output(
372418
self,
373419
schema: _DictOrPydanticClass | None = None,
@@ -423,10 +469,14 @@ def with_structured_output(
423469
424470
strict:
425471
Whether to enable strict schema adherence when generating the function
426-
call. This parameter is included for compatibility with other chat
427-
models, and if specified will be passed to the Chat Completions API
428-
in accordance with the OpenAI API specification. However, the DeepSeek
429-
API may ignore the parameter.
472+
call. When set to `True`, DeepSeek will use the beta API endpoint
473+
(`https://api.deepseek.com/beta`) for strict schema validation.
474+
This ensures model outputs exactly match the defined schema.
475+
476+
!!! note
477+
478+
DeepSeek's strict mode requires all object properties to be marked
479+
as required in the schema.
430480
431481
kwargs: Additional keyword args aren't supported.
432482
@@ -448,6 +498,19 @@ def with_structured_output(
448498
# methods) be handled.
449499
if method == "json_schema":
450500
method = "function_calling"
501+
502+
# If strict mode is enabled and using default API base, switch to beta endpoint
503+
if strict is True and self.api_base == DEFAULT_API_BASE:
504+
# Create a new instance with beta endpoint
505+
beta_model = self.model_copy(update={"api_base": DEFAULT_BETA_API_BASE})
506+
return beta_model.with_structured_output(
507+
schema,
508+
method=method,
509+
include_raw=include_raw,
510+
strict=strict,
511+
**kwargs,
512+
)
513+
451514
return super().with_structured_output(
452515
schema,
453516
method=method,

libs/partners/deepseek/tests/unit_tests/test_chat_models.py

Lines changed: 57 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
from langchain_tests.unit_tests import ChatModelUnitTests
1010
from openai import BaseModel
1111
from openai.types.chat import ChatCompletionMessage
12-
from pydantic import SecretStr
12+
from pydantic import BaseModel as PydanticBaseModel
13+
from pydantic import Field, SecretStr
1314

14-
from langchain_deepseek.chat_models import ChatDeepSeek
15+
from langchain_deepseek.chat_models import DEFAULT_API_BASE, ChatDeepSeek
1516

1617
MODEL_NAME = "deepseek-chat"
1718

@@ -243,71 +244,66 @@ def test_get_request_payload(self) -> None:
243244
payload = chat_model._get_request_payload([tool_message])
244245
assert payload["messages"][0]["content"] == "test string"
245246

246-
def test_create_chat_result_with_model_provider(self) -> None:
247-
"""Test that `model_provider` is added to `response_metadata`."""
248-
chat_model = ChatDeepSeek(model=MODEL_NAME, api_key=SecretStr("api_key"))
249-
mock_message = MagicMock()
250-
mock_message.content = "Main content"
251-
mock_message.role = "assistant"
252-
mock_response = MockOpenAIResponse(
253-
choices=[MagicMock(message=mock_message)],
254-
error=None,
255-
)
256247

257-
result = chat_model._create_chat_result(mock_response)
258-
assert (
259-
result.generations[0].message.response_metadata.get("model_provider")
260-
== "deepseek"
261-
)
248+
class SampleTool(PydanticBaseModel):
249+
"""Sample tool schema for testing."""
262250

263-
def test_convert_chunk_with_model_provider(self) -> None:
264-
"""Test that `model_provider` is added to `response_metadata` for chunks."""
265-
chat_model = ChatDeepSeek(model=MODEL_NAME, api_key=SecretStr("api_key"))
266-
chunk: dict[str, Any] = {
267-
"choices": [
268-
{
269-
"delta": {
270-
"content": "Main content",
271-
},
272-
},
273-
],
274-
}
251+
value: str = Field(description="A test value")
275252

276-
chunk_result = chat_model._convert_chunk_to_generation_chunk(
277-
chunk,
278-
AIMessageChunk,
279-
None,
253+
254+
class TestChatDeepSeekStrictMode:
255+
"""Tests for DeepSeek strict mode support.
256+
257+
This tests the experimental beta feature that uses the beta API endpoint
258+
when `strict=True` is used. These tests can be removed when strict mode
259+
becomes stable in the default base API.
260+
"""
261+
262+
def test_bind_tools_with_strict_mode_uses_beta_endpoint(self) -> None:
263+
"""Test that bind_tools with strict=True uses the beta endpoint."""
264+
llm = ChatDeepSeek(
265+
model="deepseek-chat",
266+
api_key=SecretStr("test_key"),
280267
)
281-
if chunk_result is None:
282-
msg = "Expected chunk_result not to be None"
283-
raise AssertionError(msg)
284-
assert (
285-
chunk_result.message.response_metadata.get("model_provider") == "deepseek"
268+
269+
# Verify default endpoint
270+
assert llm.api_base == DEFAULT_API_BASE
271+
272+
# Bind tools with strict=True
273+
bound_model = llm.bind_tools([SampleTool], strict=True)
274+
275+
# The bound model should have its internal model using beta endpoint
276+
# We can't directly access the internal model, but we can verify the behavior
277+
# by checking that the binding operation succeeds
278+
assert bound_model is not None
279+
280+
def test_bind_tools_without_strict_mode_uses_default_endpoint(self) -> None:
281+
"""Test bind_tools without strict or with strict=False uses default endpoint."""
282+
llm = ChatDeepSeek(
283+
model="deepseek-chat",
284+
api_key=SecretStr("test_key"),
286285
)
287286

288-
def test_create_chat_result_with_model_provider_multiple_generations(
289-
self,
290-
) -> None:
291-
"""Test that `model_provider` is added to all generations when `n > 1`."""
292-
chat_model = ChatDeepSeek(model=MODEL_NAME, api_key=SecretStr("api_key"))
293-
mock_message_1 = MagicMock()
294-
mock_message_1.content = "First response"
295-
mock_message_1.role = "assistant"
296-
mock_message_2 = MagicMock()
297-
mock_message_2.content = "Second response"
298-
mock_message_2.role = "assistant"
287+
# Test with strict=False
288+
bound_model_false = llm.bind_tools([SampleTool], strict=False)
289+
assert bound_model_false is not None
299290

300-
mock_response = MockOpenAIResponse(
301-
choices=[
302-
MagicMock(message=mock_message_1),
303-
MagicMock(message=mock_message_2),
304-
],
305-
error=None,
291+
# Test with strict=None (default)
292+
bound_model_none = llm.bind_tools([SampleTool])
293+
assert bound_model_none is not None
294+
295+
def test_with_structured_output_strict_mode_uses_beta_endpoint(self) -> None:
296+
"""Test that with_structured_output with strict=True uses beta endpoint."""
297+
llm = ChatDeepSeek(
298+
model="deepseek-chat",
299+
api_key=SecretStr("test_key"),
306300
)
307301

308-
result = chat_model._create_chat_result(mock_response)
309-
assert len(result.generations) == 2 # noqa: PLR2004
310-
for generation in result.generations:
311-
assert (
312-
generation.message.response_metadata.get("model_provider") == "deepseek"
313-
)
302+
# Verify default endpoint
303+
assert llm.api_base == DEFAULT_API_BASE
304+
305+
# Create structured output with strict=True
306+
structured_model = llm.with_structured_output(SampleTool, strict=True)
307+
308+
# The structured model should work with beta endpoint
309+
assert structured_model is not None

libs/partners/deepseek/uv.lock

Lines changed: 5 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)