diff --git a/packages/ai/src/microsoft/teams/ai/chat_prompt.py b/packages/ai/src/microsoft/teams/ai/chat_prompt.py index eed8313d..ffbd3c6f 100644 --- a/packages/ai/src/microsoft/teams/ai/chat_prompt.py +++ b/packages/ai/src/microsoft/teams/ai/chat_prompt.py @@ -189,7 +189,12 @@ async def on_chunk_fn(chunk: str): return ChatSendResult(response=current_response) - def _wrap_function_handler(self, original_handler: FunctionHandlers, function_name: str) -> FunctionHandlers: + def _wrap_function_handler( + self, + original_handler: FunctionHandlers, + function_name: str, + parameter_schema: type[BaseModel] | Dict[str, Any] | None, + ) -> FunctionHandlers: """ Wrap a function handler with plugin before/after hooks. @@ -231,6 +236,9 @@ async def wrapped_handler(params: Optional[BaseModel]) -> str: return current_result + if parameter_schema is None: + return lambda: wrapped_handler(None) + return wrapped_handler async def _run_before_send_hooks(self, input: Message) -> Message: @@ -287,7 +295,7 @@ async def _build_wrapped_functions(self) -> dict[str, Function[BaseModel]] | Non name=func.name, description=func.description, parameter_schema=func.parameter_schema, - handler=self._wrap_function_handler(func.handler, func.name), + handler=self._wrap_function_handler(func.handler, func.name, func.parameter_schema), ) return wrapped_functions diff --git a/packages/ai/tests/test_chat_prompt.py b/packages/ai/tests/test_chat_prompt.py index 0f39dab7..eed73792 100644 --- a/packages/ai/tests/test_chat_prompt.py +++ b/packages/ai/tests/test_chat_prompt.py @@ -336,6 +336,61 @@ def handler_no_params() -> str: result = await prompt.send("Test message") assert result.response.content == "GENERATED - Test message" + @pytest.mark.asyncio + async def test_function_with_no_parameters_wrapped_with_plugins(self) -> None: + """Test that functions with parameter_schema=None work correctly when called by the model""" + + class MockModelThatCallsFunction: + """Mock model that simulates calling a function with no parameters""" + + async def generate_text( + self, + input: Any, + *, + system: SystemMessage | None = None, + memory: Memory | None = None, + functions: dict[str, Function[BaseModel]] | None = None, + on_chunk: Callable[[str], Awaitable[None]] | None = None, + ) -> ModelMessage: + # Simulate model deciding to call a function + if functions and "no_param_func" in functions: + function = functions["no_param_func"] + + # Call the function handler the way the model would + # When parameter_schema is None, handler should be callable with no args + handler = cast(Callable[[], str | Awaitable[str]], function.handler) + result = handler() + if isawaitable(result): + result = await result + + return ModelMessage( + content=f"Function returned: {result}", + function_calls=None, + ) + + return ModelMessage(content="No function called", function_calls=None) + + plugin = MockPlugin("test_plugin") + handler_called = False + + def handler_no_params() -> str: + nonlocal handler_called + handler_called = True + return "Success" + + no_param_function = Function( + name="no_param_func", + description="Function with no parameters", + parameter_schema=None, + handler=handler_no_params, + ) + + prompt = ChatPrompt(MockModelThatCallsFunction(), functions=[no_param_function], plugins=[plugin]) + result = await prompt.send("Call the function") + + assert handler_called + assert result.response.content == "Function returned: Success" + class MockPlugin(BaseAIPlugin): """Mock plugin for testing that tracks all hook calls""" diff --git a/packages/apps/src/microsoft/teams/apps/http_plugin.py b/packages/apps/src/microsoft/teams/apps/http_plugin.py index d4f18625..347a7258 100644 --- a/packages/apps/src/microsoft/teams/apps/http_plugin.py +++ b/packages/apps/src/microsoft/teams/apps/http_plugin.py @@ -99,7 +99,7 @@ def custom_server_factory(app: FastAPI) -> uvicorn.Server: return uvicorn.Server(config=uvicorn.Config(app, host="0.0.0.0", port=8000)) - http_plugin = HttpPlugin(app_id="your-app-id", server_factory=custom_server_factory) + http_plugin = HttpPlugin(server_factory=custom_server_factory) ``` """ super().__init__()