diff --git a/examples/google/toolcall.php b/examples/google/toolcall.php new file mode 100644 index 000000000..b10517aeb --- /dev/null +++ b/examples/google/toolcall.php @@ -0,0 +1,40 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Agent\Agent; +use Symfony\AI\Agent\Toolbox\AgentProcessor; +use Symfony\AI\Agent\Toolbox\Tool\Clock; +use Symfony\AI\Agent\Toolbox\Toolbox; +use Symfony\AI\Platform\Bridge\Google\Gemini; +use Symfony\AI\Platform\Bridge\Google\PlatformFactory; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\Component\Dotenv\Dotenv; + +require_once dirname(__DIR__, 2).'/vendor/autoload.php'; +(new Dotenv())->loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['GOOGLE_API_KEY'])) { + echo 'Please set the GOOGLE_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['GOOGLE_API_KEY']); +$llm = new Gemini(Gemini::GEMINI_2_FLASH); + +$toolbox = Toolbox::create(new Clock()); +$processor = new AgentProcessor($toolbox); +$chain = new Agent($platform, $llm, [$processor], [$processor]); + +$messages = new MessageBag(Message::ofUser('What time is it?')); +$response = $chain->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/src/platform/src/Bridge/Google/Contract/AssistantMessageNormalizer.php b/src/platform/src/Bridge/Google/Contract/AssistantMessageNormalizer.php index 11663a747..81346942b 100644 --- a/src/platform/src/Bridge/Google/Contract/AssistantMessageNormalizer.php +++ b/src/platform/src/Bridge/Google/Contract/AssistantMessageNormalizer.php @@ -15,16 +15,12 @@ use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; use Symfony\AI\Platform\Message\AssistantMessage; use Symfony\AI\Platform\Model; -use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface; -use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait; /** * @author Christopher Hertel */ -final class AssistantMessageNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface +final class AssistantMessageNormalizer extends ModelContractNormalizer { - use NormalizerAwareTrait; - protected function supportedDataClass(): string { return AssistantMessage::class; @@ -42,8 +38,23 @@ protected function supportsModel(Model $model): bool */ public function normalize(mixed $data, ?string $format = null, array $context = []): array { - return [ - ['text' => $data->content], - ]; + $normalized = []; + + if (isset($data->content)) { + $normalized['text'] = $data->content; + } + + if (isset($data->toolCalls[0])) { + $normalized['functionCall'] = [ + 'id' => $data->toolCalls[0]->id, + 'name' => $data->toolCalls[0]->name, + ]; + + if ($data->toolCalls[0]->arguments) { + $normalized['functionCall']['args'] = $data->toolCalls[0]->arguments; + } + } + + return [$normalized]; } } diff --git a/src/platform/src/Bridge/Google/Contract/ToolCallMessageNormalizer.php b/src/platform/src/Bridge/Google/Contract/ToolCallMessageNormalizer.php new file mode 100644 index 000000000..1bb4d375d --- /dev/null +++ b/src/platform/src/Bridge/Google/Contract/ToolCallMessageNormalizer.php @@ -0,0 +1,59 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Google\Contract; + +use Symfony\AI\Platform\Bridge\Google\Gemini; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Message\ToolCallMessage; +use Symfony\AI\Platform\Model; + +/** + * @author Valtteri R + */ +final class ToolCallMessageNormalizer extends ModelContractNormalizer +{ + protected function supportedDataClass(): string + { + return ToolCallMessage::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Gemini; + } + + /** + * @param ToolCallMessage $data + * + * @return array{ + * functionResponse: array{ + * id: string, + * name: string, + * response: array + * } + * }[] + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $responseContent = json_validate($data->content) ? json_decode($data->content, true) : $data->content; + + return [[ + 'functionResponse' => array_filter([ + 'id' => $data->toolCall->id, + 'name' => $data->toolCall->name, + 'response' => \is_array($responseContent) ? $responseContent : [ + 'rawResponse' => $responseContent, // Gemini expects the response to be an object, but not everyone uses objects as their responses. + ], + ]), + ]]; + } +} diff --git a/src/platform/src/Bridge/Google/Contract/ToolNormalizer.php b/src/platform/src/Bridge/Google/Contract/ToolNormalizer.php new file mode 100644 index 000000000..d8fb94b3d --- /dev/null +++ b/src/platform/src/Bridge/Google/Contract/ToolNormalizer.php @@ -0,0 +1,63 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Google\Contract; + +use Symfony\AI\Platform\Bridge\Google\Gemini; +use Symfony\AI\Platform\Contract\JsonSchema\Factory; +use Symfony\AI\Platform\Contract\Normalizer\ModelContractNormalizer; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Tool\Tool; + +/** + * @author Valtteri R + * + * @phpstan-import-type JsonSchema from Factory + */ +final class ToolNormalizer extends ModelContractNormalizer +{ + protected function supportedDataClass(): string + { + return Tool::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Gemini; + } + + /** + * @param Tool $data + * + * @return array{ + * functionDeclarations: array{ + * name: string, + * description: string, + * parameters: JsonSchema|array{type: 'object'} + * }[] + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $parameters = $data->parameters; + unset($parameters['additionalProperties']); + + return [ + 'functionDeclarations' => [ + [ + 'description' => $data->description, + 'name' => $data->name, + 'parameters' => $parameters, + ], + ], + ]; + } +} diff --git a/src/platform/src/Bridge/Google/Gemini.php b/src/platform/src/Bridge/Google/Gemini.php index ec52fc787..50473ceaf 100644 --- a/src/platform/src/Bridge/Google/Gemini.php +++ b/src/platform/src/Bridge/Google/Gemini.php @@ -34,6 +34,7 @@ public function __construct(string $name = self::GEMINI_2_PRO, array $options = Capability::INPUT_MESSAGES, Capability::INPUT_IMAGE, Capability::OUTPUT_STREAMING, + Capability::TOOL_CALLING, ]; parent::__construct($name, $capabilities, $options); diff --git a/src/platform/src/Bridge/Google/ModelHandler.php b/src/platform/src/Bridge/Google/ModelHandler.php index 57bc4a0ed..e807a97d8 100644 --- a/src/platform/src/Bridge/Google/ModelHandler.php +++ b/src/platform/src/Bridge/Google/ModelHandler.php @@ -14,9 +14,13 @@ use Symfony\AI\Platform\Exception\RuntimeException; use Symfony\AI\Platform\Model; use Symfony\AI\Platform\ModelClientInterface; +use Symfony\AI\Platform\Response\Choice; +use Symfony\AI\Platform\Response\ChoiceResponse; use Symfony\AI\Platform\Response\ResponseInterface as LlmResponse; use Symfony\AI\Platform\Response\StreamResponse; use Symfony\AI\Platform\Response\TextResponse; +use Symfony\AI\Platform\Response\ToolCall; +use Symfony\AI\Platform\Response\ToolCallResponse; use Symfony\AI\Platform\ResponseConverterInterface; use Symfony\Component\HttpClient\EventSourceHttpClient; use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface; @@ -59,6 +63,12 @@ public function request(Model $model, array|string $payload, array $options = [] $generationConfig = ['generationConfig' => $options]; unset($generationConfig['generationConfig']['stream']); + unset($generationConfig['generationConfig']['tools']); + + if (isset($options['tools'])) { + $generationConfig['tools'] = $options['tools']; + unset($options['tools']); + } return $this->httpClient->request('POST', $url, [ 'headers' => [ @@ -83,11 +93,22 @@ public function convert(ResponseInterface $response, array $options = []): LlmRe $data = $response->toArray(); - if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) { + if (!isset($data['candidates'][0]['content']['parts'][0])) { throw new RuntimeException('Response does not contain any content'); } - return new TextResponse($data['candidates'][0]['content']['parts'][0]['text']); + /** @var Choice[] $choices */ + $choices = array_map($this->convertChoice(...), $data['candidates']); + + if (1 !== \count($choices)) { + return new ChoiceResponse(...$choices); + } + + if ($choices[0]->hasToolCall()) { + return new ToolCallResponse(...$choices[0]->getToolCalls()); + } + + return new TextResponse($choices[0]->getContent()); } private function convertStream(ResponseInterface $response): \Generator @@ -121,12 +142,68 @@ private function convertStream(ResponseInterface $response): \Generator throw new RuntimeException('Failed to decode JSON response', 0, $e); } - if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) { + /** @var Choice[] $choices */ + $choices = array_map($this->convertChoice(...), $data['candidates'] ?? []); + + if (!$choices) { continue; } - yield $data['candidates'][0]['content']['parts'][0]['text']; + if (1 !== \count($choices)) { + yield new ChoiceResponse(...$choices); + continue; + } + + if ($choices[0]->hasToolCall()) { + yield new ToolCallResponse(...$choices[0]->getToolCalls()); + } + + if ($choices[0]->hasContent()) { + yield $choices[0]->getContent(); + } } } } + + /** + * @param array{ + * finishReason?: string, + * content: array{ + * parts: array{ + * functionCall?: array{ + * id: string, + * name: string, + * args: mixed[] + * }, + * text?: string + * }[] + * } + * } $choice + */ + private function convertChoice(array $choice): Choice + { + $contentPart = $choice['content']['parts'][0] ?? []; + + if (isset($contentPart['functionCall'])) { + return new Choice(toolCalls: [$this->convertToolCall($contentPart['functionCall'])]); + } + + if (isset($contentPart['text'])) { + return new Choice($contentPart['text']); + } + + throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $choice['finishReason'])); + } + + /** + * @param array{ + * id: string, + * name: string, + * args: mixed[] + * } $toolCall + */ + private function convertToolCall(array $toolCall): ToolCall + { + return new ToolCall($toolCall['id'], $toolCall['name'], $toolCall['args']); + } } diff --git a/src/platform/src/Bridge/Google/PlatformFactory.php b/src/platform/src/Bridge/Google/PlatformFactory.php index cab1c3b3d..49fe8e326 100644 --- a/src/platform/src/Bridge/Google/PlatformFactory.php +++ b/src/platform/src/Bridge/Google/PlatformFactory.php @@ -13,6 +13,8 @@ use Symfony\AI\Platform\Bridge\Google\Contract\AssistantMessageNormalizer; use Symfony\AI\Platform\Bridge\Google\Contract\MessageBagNormalizer; +use Symfony\AI\Platform\Bridge\Google\Contract\ToolCallMessageNormalizer; +use Symfony\AI\Platform\Bridge\Google\Contract\ToolNormalizer; use Symfony\AI\Platform\Bridge\Google\Contract\UserMessageNormalizer; use Symfony\AI\Platform\Contract; use Symfony\AI\Platform\Platform; @@ -35,6 +37,8 @@ public static function create( return new Platform([$responseHandler], [$responseHandler], Contract::create( new AssistantMessageNormalizer(), new MessageBagNormalizer(), + new ToolNormalizer(), + new ToolCallMessageNormalizer(), new UserMessageNormalizer(), )); } diff --git a/src/platform/tests/Bridge/Google/Contract/AssistantMessageNormalizerTest.php b/src/platform/tests/Bridge/Google/Contract/AssistantMessageNormalizerTest.php index 27101c318..a79ab4242 100644 --- a/src/platform/tests/Bridge/Google/Contract/AssistantMessageNormalizerTest.php +++ b/src/platform/tests/Bridge/Google/Contract/AssistantMessageNormalizerTest.php @@ -12,6 +12,7 @@ namespace Symfony\AI\Platform\Tests\Bridge\Google\Contract; use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; use PHPUnit\Framework\Attributes\Small; use PHPUnit\Framework\Attributes\Test; use PHPUnit\Framework\Attributes\UsesClass; @@ -21,12 +22,14 @@ use Symfony\AI\Platform\Contract; use Symfony\AI\Platform\Message\AssistantMessage; use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\ToolCall; #[Small] #[CoversClass(AssistantMessageNormalizer::class)] #[UsesClass(Gemini::class)] #[UsesClass(AssistantMessage::class)] #[UsesClass(Model::class)] +#[UsesClass(ToolCall::class)] final class AssistantMessageNormalizerTest extends TestCase { #[Test] @@ -49,13 +52,32 @@ public function getSupportedTypes(): void } #[Test] - public function normalize(): void + #[DataProvider('normalizeDataProvider')] + public function normalize(AssistantMessage $message, array $expectedOutput): void { $normalizer = new AssistantMessageNormalizer(); - $message = new AssistantMessage('Great to meet you. What would you like to know?'); $normalized = $normalizer->normalize($message); - self::assertSame([['text' => 'Great to meet you. What would you like to know?']], $normalized); + self::assertSame($expectedOutput, $normalized); + } + + /** + * @return iterable + */ + public static function normalizeDataProvider(): iterable + { + yield 'assistant message' => [ + new AssistantMessage('Great to meet you. What would you like to know?'), + [['text' => 'Great to meet you. What would you like to know?']], + ]; + yield 'function call' => [ + new AssistantMessage(toolCalls: [new ToolCall('id1', 'name1', ['arg1' => '123'])]), + [['functionCall' => ['id' => 'id1', 'name' => 'name1', 'args' => ['arg1' => '123']]]], + ]; + yield 'function call without parameters' => [ + new AssistantMessage(toolCalls: [new ToolCall('id1', 'name1')]), + [['functionCall' => ['id' => 'id1', 'name' => 'name1']]], + ]; } } diff --git a/src/platform/tests/Bridge/Google/Contract/ToolCallMessageNormalizerTest.php b/src/platform/tests/Bridge/Google/Contract/ToolCallMessageNormalizerTest.php new file mode 100644 index 000000000..35700d6d4 --- /dev/null +++ b/src/platform/tests/Bridge/Google/Contract/ToolCallMessageNormalizerTest.php @@ -0,0 +1,102 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\Google\Contract; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Google\Contract\ToolCallMessageNormalizer; +use Symfony\AI\Platform\Bridge\Google\Gemini; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Message\ToolCallMessage; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Response\ToolCall; + +#[Small] +#[CoversClass(ToolCallMessageNormalizer::class)] +#[UsesClass(Model::class)] +#[UsesClass(Gemini::class)] +#[UsesClass(ToolCallMessage::class)] +#[UsesClass(ToolCall::class)] +final class ToolCallMessageNormalizerTest extends TestCase +{ + #[Test] + public function supportsNormalization(): void + { + $normalizer = new ToolCallMessageNormalizer(); + + self::assertTrue($normalizer->supportsNormalization(new ToolCallMessage(new ToolCall('', '', []), ''), context: [ + Contract::CONTEXT_MODEL => new Gemini(), + ])); + self::assertFalse($normalizer->supportsNormalization('not a tool call')); + } + + #[Test] + public function getSupportedTypes(): void + { + $normalizer = new ToolCallMessageNormalizer(); + + $expected = [ + ToolCallMessage::class => true, + ]; + + self::assertSame($expected, $normalizer->getSupportedTypes(null)); + } + + #[Test] + #[DataProvider('normalizeDataProvider')] + public function normalize(ToolCallMessage $message, array $expected): void + { + $normalizer = new ToolCallMessageNormalizer(); + + $normalized = $normalizer->normalize($message); + + self::assertEquals($expected, $normalized); + } + + /** + * @return iterable + */ + public static function normalizeDataProvider(): iterable + { + yield 'scalar' => [ + new ToolCallMessage( + new ToolCall('id1', 'name1', ['foo' => 'bar']), + 'true', + ), + [[ + 'functionResponse' => [ + 'id' => 'id1', + 'name' => 'name1', + 'response' => ['rawResponse' => 'true'], + ], + ]], + ]; + + yield 'structured response' => [ + new ToolCallMessage( + new ToolCall('id1', 'name1', ['foo' => 'bar']), + '{"structured":"response"}', + ), + [[ + 'functionResponse' => [ + 'id' => 'id1', + 'name' => 'name1', + 'response' => ['structured' => 'response'], + ], + ]], + ]; + } +} diff --git a/src/platform/tests/Bridge/Google/Contract/ToolNormalizerTest.php b/src/platform/tests/Bridge/Google/Contract/ToolNormalizerTest.php new file mode 100644 index 000000000..ec20bc53c --- /dev/null +++ b/src/platform/tests/Bridge/Google/Contract/ToolNormalizerTest.php @@ -0,0 +1,138 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\Google\Contract; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\Test; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Fixtures\Tool\ToolNoParams; +use Symfony\AI\Fixtures\Tool\ToolRequiredParams; +use Symfony\AI\Platform\Bridge\Google\Contract\ToolNormalizer; +use Symfony\AI\Platform\Bridge\Google\Gemini; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Tool\ExecutionReference; +use Symfony\AI\Platform\Tool\Tool; + +#[Small] +#[CoversClass(ToolNormalizer::class)] +#[UsesClass(Model::class)] +#[UsesClass(Gemini::class)] +#[UsesClass(Tool::class)] +final class ToolNormalizerTest extends TestCase +{ + #[Test] + public function supportsNormalization(): void + { + $normalizer = new ToolNormalizer(); + + self::assertTrue($normalizer->supportsNormalization(new Tool(new ExecutionReference(ToolNoParams::class), 'test', 'test'), context: [ + Contract::CONTEXT_MODEL => new Gemini(), + ])); + self::assertFalse($normalizer->supportsNormalization('not a tool')); + } + + #[Test] + public function getSupportedTypes(): void + { + $normalizer = new ToolNormalizer(); + + $expected = [ + Tool::class => true, + ]; + + self::assertSame($expected, $normalizer->getSupportedTypes(null)); + } + + #[Test] + #[DataProvider('normalizeDataProvider')] + public function normalize(Tool $tool, array $expected): void + { + $normalizer = new ToolNormalizer(); + + $normalized = $normalizer->normalize($tool); + + self::assertEquals($expected, $normalized); + } + + /** + * @return iterable + */ + public static function normalizeDataProvider(): iterable + { + yield 'call with params' => [ + new Tool( + new ExecutionReference(ToolRequiredParams::class, 'bar'), + 'tool_required_params', + 'A tool with required parameters', + [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'Text parameter', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'Number parameter', + ], + ], + 'required' => ['text', 'number'], + 'additionalProperties' => false, + ], + ), + [ + 'functionDeclarations' => [ + [ + 'description' => 'A tool with required parameters', + 'name' => 'tool_required_params', + 'parameters' => [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'Text parameter', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'Number parameter', + ], + ], + 'required' => ['text', 'number'], + ], + ], + ], + ], + ]; + + yield 'call without params' => [ + new Tool( + new ExecutionReference(ToolNoParams::class, 'bar'), + 'tool_no_params', + 'A tool without parameters', + null, + ), + [ + 'functionDeclarations' => [ + [ + 'description' => 'A tool without parameters', + 'name' => 'tool_no_params', + 'parameters' => null, + ], + ], + ], + ]; + } +}