Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions packages/ai/src/microsoft/teams/ai/chat_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,12 @@ async def on_chunk_fn(chunk: str):

return ChatSendResult(response=current_response)

def _wrap_function_handler(self, original_handler: FunctionHandlers, function_name: str) -> FunctionHandlers:
def _wrap_function_handler(
self,
original_handler: FunctionHandlers,
function_name: str,
parameter_schema: type[BaseModel] | Dict[str, Any] | None,
) -> FunctionHandlers:
"""
Wrap a function handler with plugin before/after hooks.

Expand Down Expand Up @@ -231,6 +236,9 @@ async def wrapped_handler(params: Optional[BaseModel]) -> str:

return current_result

if parameter_schema is None:
return lambda: wrapped_handler(None)

return wrapped_handler

async def _run_before_send_hooks(self, input: Message) -> Message:
Expand Down Expand Up @@ -287,7 +295,7 @@ async def _build_wrapped_functions(self) -> dict[str, Function[BaseModel]] | Non
name=func.name,
description=func.description,
parameter_schema=func.parameter_schema,
handler=self._wrap_function_handler(func.handler, func.name),
handler=self._wrap_function_handler(func.handler, func.name, func.parameter_schema),
)

return wrapped_functions
Expand Down
55 changes: 55 additions & 0 deletions packages/ai/tests/test_chat_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,61 @@ def handler_no_params() -> str:
result = await prompt.send("Test message")
assert result.response.content == "GENERATED - Test message"

@pytest.mark.asyncio
async def test_function_with_no_parameters_wrapped_with_plugins(self) -> None:
"""Test that functions with parameter_schema=None work correctly when called by the model"""

class MockModelThatCallsFunction:
"""Mock model that simulates calling a function with no parameters"""

async def generate_text(
self,
input: Any,
*,
system: SystemMessage | None = None,
memory: Memory | None = None,
functions: dict[str, Function[BaseModel]] | None = None,
on_chunk: Callable[[str], Awaitable[None]] | None = None,
) -> ModelMessage:
# Simulate model deciding to call a function
if functions and "no_param_func" in functions:
function = functions["no_param_func"]

# Call the function handler the way the model would
# When parameter_schema is None, handler should be callable with no args
handler = cast(Callable[[], str | Awaitable[str]], function.handler)
result = handler()
if isawaitable(result):
result = await result

return ModelMessage(
content=f"Function returned: {result}",
function_calls=None,
)

return ModelMessage(content="No function called", function_calls=None)

plugin = MockPlugin("test_plugin")
handler_called = False

def handler_no_params() -> str:
nonlocal handler_called
handler_called = True
return "Success"

no_param_function = Function(
name="no_param_func",
description="Function with no parameters",
parameter_schema=None,
handler=handler_no_params,
)

prompt = ChatPrompt(MockModelThatCallsFunction(), functions=[no_param_function], plugins=[plugin])
result = await prompt.send("Call the function")

assert handler_called
assert result.response.content == "Function returned: Success"


class MockPlugin(BaseAIPlugin):
"""Mock plugin for testing that tracks all hook calls"""
Expand Down
2 changes: 1 addition & 1 deletion packages/apps/src/microsoft/teams/apps/http_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def custom_server_factory(app: FastAPI) -> uvicorn.Server:
return uvicorn.Server(config=uvicorn.Config(app, host="0.0.0.0", port=8000))


http_plugin = HttpPlugin(app_id="your-app-id", server_factory=custom_server_factory)
http_plugin = HttpPlugin(server_factory=custom_server_factory)
```
"""
super().__init__()
Expand Down
Loading