diff --git a/README.md b/README.md index 76ca64f..01aa345 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ llm_chain: model: name: 'Claude' tools: # If undefined, all tools are injected into the chain, use "tools: false" to disable tools. - - 'PhpLlm\LlmChain\Chain\ToolBox\Tool\Wikipedia' + - 'PhpLlm\LlmChain\Chain\Toolbox\Tool\Wikipedia' fault_tolerant_toolbox: false # Disables fault tolerant toolbox, default is true store: # also azure_search, mongodb and pinecone are supported as store type @@ -111,20 +111,21 @@ services: autowire: true autoconfigure: true - PhpLlm\LlmChain\Chain\ToolBox\Tool\Clock: ~ - PhpLlm\LlmChain\Chain\ToolBox\Tool\OpenMeteo: ~ - PhpLlm\LlmChain\Chain\ToolBox\Tool\SerpApi: + PhpLlm\LlmChain\Chain\Toolbox\Tool\Clock: ~ + PhpLlm\LlmChain\Chain\Toolbox\Tool\OpenMeteo: ~ + PhpLlm\LlmChain\Chain\Toolbox\Tool\SerpApi: $apiKey: '%env(SERP_API_KEY)%' - PhpLlm\LlmChain\Chain\ToolBox\Tool\SimilaritySearch: ~ - PhpLlm\LlmChain\Chain\ToolBox\Tool\Tavily: + PhpLlm\LlmChain\Chain\Toolbox\Tool\SimilaritySearch: ~ + PhpLlm\LlmChain\Chain\Toolbox\Tool\Tavily: $apiKey: '%env(TAVILY_API_KEY)%' - PhpLlm\LlmChain\Chain\ToolBox\Tool\Wikipedia: ~ - PhpLlm\LlmChain\Chain\ToolBox\Tool\YouTubeTranscriber: ~ + PhpLlm\LlmChain\Chain\Toolbox\Tool\Wikipedia: ~ + PhpLlm\LlmChain\Chain\Toolbox\Tool\YouTubeTranscriber: ~ ``` Custom tools can be registered by using the `#[AsTool]` attribute: + ```php -use PhpLlm\LlmChain\Chain\ToolBox\Attribute\AsTool; +use PhpLlm\LlmChain\Chain\Toolbox\Attribute\AsTool; #[AsTool('company_name', 'Provides the name of your company')] final class CompanyName @@ -152,7 +153,7 @@ llm_chain: chain: my_chain: tools: - - 'PhpLlm\LlmChain\Chain\ToolBox\Tool\SimilaritySearch' + - 'PhpLlm\LlmChain\Chain\Toolbox\Tool\SimilaritySearch' ``` ### Profiler diff --git a/composer.json b/composer.json index 43e3db1..fd1169e 100644 --- a/composer.json +++ b/composer.json @@ -15,7 +15,7 @@ ], "require": { "php": ">=8.2", - "php-llm/llm-chain": "^0.18", + "php-llm/llm-chain": "^0.19", "symfony/config": "^6.4 || ^7.0", "symfony/dependency-injection": "^6.4 || ^7.0", "symfony/framework-bundle": "^6.4 || ^7.0", diff --git a/src/DependencyInjection/LlmChainExtension.php b/src/DependencyInjection/LlmChainExtension.php index cd0ef6b..eea492e 100644 --- a/src/DependencyInjection/LlmChainExtension.php +++ b/src/DependencyInjection/LlmChainExtension.php @@ -23,9 +23,9 @@ use PhpLlm\LlmChain\Chain\InputProcessor\SystemPromptInputProcessor; use PhpLlm\LlmChain\Chain\OutputProcessor; use PhpLlm\LlmChain\Chain\StructuredOutput\ChainProcessor as StructureOutputProcessor; -use PhpLlm\LlmChain\Chain\ToolBox\Attribute\AsTool; -use PhpLlm\LlmChain\Chain\ToolBox\ChainProcessor as ToolProcessor; -use PhpLlm\LlmChain\Chain\ToolBox\FaultTolerantToolBox; +use PhpLlm\LlmChain\Chain\Toolbox\Attribute\AsTool; +use PhpLlm\LlmChain\Chain\Toolbox\ChainProcessor as ToolProcessor; +use PhpLlm\LlmChain\Chain\Toolbox\FaultTolerantToolbox; use PhpLlm\LlmChain\ChainInterface; use PhpLlm\LlmChain\Embedder; use PhpLlm\LlmChain\Model\EmbeddingsModel; @@ -38,7 +38,7 @@ use PhpLlm\LlmChain\Store\VectorStoreInterface; use PhpLlm\LlmChainBundle\Profiler\DataCollector; use PhpLlm\LlmChainBundle\Profiler\TraceablePlatform; -use PhpLlm\LlmChainBundle\Profiler\TraceableToolBox; +use PhpLlm\LlmChainBundle\Profiler\TraceableToolbox; use Symfony\Component\Config\FileLocator; use Symfony\Component\DependencyInjection\ChildDefinition; use Symfony\Component\DependencyInjection\ContainerBuilder; @@ -126,7 +126,7 @@ public function load(array $configs, ContainerBuilder $container): void if (false === $container->getParameter('kernel.debug')) { $container->removeDefinition(DataCollector::class); - $container->removeDefinition(TraceableToolBox::class); + $container->removeDefinition(TraceableToolbox::class); } } @@ -247,16 +247,17 @@ private function processChainConfig(string $name, array $config, ContainerBuilde // TOOL & PROCESSOR if ($config['tools']['enabled']) { - // Create specific tool box and process if tools are explicitly defined + // Create specific toolbox and process if tools are explicitly defined if (0 !== count($config['tools']['services'])) { $tools = array_map(static fn (string $tool) => new Reference($tool), $config['tools']['services']); + $toolboxDefinition = (new ChildDefinition('llm_chain.toolbox.abstract')) ->replaceArgument('$tools', $tools); $container->setDefinition('llm_chain.toolbox.'.$name, $toolboxDefinition); if ($config['fault_tolerant_toolbox']) { $faultTolerantToolboxDefinition = (new Definition('llm_chain.fault_tolerant_toolbox.'.$name)) - ->setClass(FaultTolerantToolBox::class) + ->setClass(FaultTolerantToolbox::class) ->setAutowired(true) ->setDecoratedService('llm_chain.toolbox.'.$name); $container->setDefinition('llm_chain.fault_tolerant_toolbox.'.$name, $faultTolerantToolboxDefinition); @@ -264,7 +265,7 @@ private function processChainConfig(string $name, array $config, ContainerBuilde if ($container->getParameter('kernel.debug')) { $traceableToolboxDefinition = (new Definition('llm_chain.traceable_toolbox.'.$name)) - ->setClass(TraceableToolBox::class) + ->setClass(TraceableToolbox::class) ->setAutowired(true) ->setDecoratedService('llm_chain.toolbox.'.$name) ->addTag('llm_chain.traceable_toolbox'); @@ -272,7 +273,7 @@ private function processChainConfig(string $name, array $config, ContainerBuilde } $toolProcessorDefinition = (new ChildDefinition('llm_chain.tool.chain_processor.abstract')) - ->replaceArgument('$toolBox', new Reference('llm_chain.toolbox.'.$name)); + ->replaceArgument('$toolbox', new Reference('llm_chain.toolbox.'.$name)); $container->setDefinition('llm_chain.tool.chain_processor.'.$name, $toolProcessorDefinition); $inputProcessors[] = new Reference('llm_chain.tool.chain_processor.'.$name); @@ -298,7 +299,7 @@ private function processChainConfig(string $name, array $config, ContainerBuilde if ($config['include_tools']) { $systemPromptInputProcessorDefinition - ->setArgument('$toolBox', new Reference('llm_chain.toolbox.'.$name)); + ->setArgument('$toolbox', new Reference('llm_chain.toolbox.'.$name)); } $inputProcessors[] = $systemPromptInputProcessorDefinition; @@ -424,10 +425,11 @@ private function processEmbedderConfig(int|string $name, array $config, Containe $modelDefinition->addTag('llm_chain.model.embeddings_model'); $container->setDefinition('llm_chain.embedder.'.$name.'.embeddings', $modelDefinition); - $definition = (new ChildDefinition('llm_chain.embedder.abstract')) - ->replaceArgument('$platform', new Reference($config['platform'])) - ->replaceArgument('$store', new Reference($config['store'])) - ->replaceArgument('$embeddings', new Reference('llm_chain.embedder.'.$name.'.embeddings')); + $definition = new Definition(Embedder::class, [ + '$embeddings' => new Reference('llm_chain.embedder.'.$name.'.embeddings'), + '$platform' => new Reference($config['platform']), + '$store' => new Reference($config['store']), + ]); $container->setDefinition('llm_chain.embedder.'.$name, $definition); } diff --git a/src/Profiler/DataCollector.php b/src/Profiler/DataCollector.php index 6feaa7c..bb4ca99 100644 --- a/src/Profiler/DataCollector.php +++ b/src/Profiler/DataCollector.php @@ -4,8 +4,8 @@ namespace PhpLlm\LlmChainBundle\Profiler; -use PhpLlm\LlmChain\Chain\ToolBox\Metadata; -use PhpLlm\LlmChain\Chain\ToolBox\ToolBoxInterface; +use PhpLlm\LlmChain\Chain\Toolbox\Metadata; +use PhpLlm\LlmChain\Chain\Toolbox\ToolboxInterface; use Symfony\Bundle\FrameworkBundle\DataCollector\AbstractDataCollector; use Symfony\Component\DependencyInjection\Attribute\TaggedIterator; use Symfony\Component\HttpFoundation\Request; @@ -13,7 +13,7 @@ /** * @phpstan-import-type PlatformCallData from TraceablePlatform - * @phpstan-import-type ToolCallData from TraceableToolBox + * @phpstan-import-type ToolCallData from TraceableToolbox */ final class DataCollector extends AbstractDataCollector { @@ -23,23 +23,23 @@ final class DataCollector extends AbstractDataCollector private readonly array $platforms; /** - * @var TraceableToolBox[] + * @var TraceableToolbox[] */ - private readonly array $toolBoxes; + private readonly array $toolboxes; /** * @param TraceablePlatform[] $platforms - * @param TraceableToolBox[] $toolBoxes + * @param TraceableToolbox[] $toolboxes */ public function __construct( #[TaggedIterator('llm_chain.traceable_platform')] iterable $platforms, - private readonly ToolBoxInterface $defaultToolBox, + private readonly ToolboxInterface $defaultToolBox, #[TaggedIterator('llm_chain.traceable_toolbox')] - iterable $toolBoxes, + iterable $toolboxes, ) { $this->platforms = $platforms instanceof \Traversable ? iterator_to_array($platforms) : $platforms; - $this->toolBoxes = $toolBoxes instanceof \Traversable ? iterator_to_array($toolBoxes) : $toolBoxes; + $this->toolboxes = $toolboxes instanceof \Traversable ? iterator_to_array($toolboxes) : $toolboxes; } public function collect(Request $request, Response $response, ?\Throwable $exception = null): void @@ -47,7 +47,7 @@ public function collect(Request $request, Response $response, ?\Throwable $excep $this->data = [ 'tools' => $this->defaultToolBox->getMap(), 'platform_calls' => array_merge(...array_map(fn (TraceablePlatform $platform) => $platform->calls, $this->platforms)), - 'tool_calls' => array_merge(...array_map(fn (TraceableToolBox $toolBox) => $toolBox->calls, $this->toolBoxes)), + 'tool_calls' => array_merge(...array_map(fn (TraceableToolbox $toolbox) => $toolbox->calls, $this->toolboxes)), ]; } diff --git a/src/Profiler/TraceableToolBox.php b/src/Profiler/TraceableToolbox.php similarity index 69% rename from src/Profiler/TraceableToolBox.php rename to src/Profiler/TraceableToolbox.php index 5f0bb3c..ed8d247 100644 --- a/src/Profiler/TraceableToolBox.php +++ b/src/Profiler/TraceableToolbox.php @@ -4,7 +4,7 @@ namespace PhpLlm\LlmChainBundle\Profiler; -use PhpLlm\LlmChain\Chain\ToolBox\ToolBoxInterface; +use PhpLlm\LlmChain\Chain\Toolbox\ToolboxInterface; use PhpLlm\LlmChain\Model\Response\ToolCall; /** @@ -13,7 +13,7 @@ * result: string, * } */ -final class TraceableToolBox implements ToolBoxInterface +final class TraceableToolbox implements ToolboxInterface { /** * @var ToolCallData[] @@ -21,18 +21,18 @@ final class TraceableToolBox implements ToolBoxInterface public array $calls = []; public function __construct( - private readonly ToolBoxInterface $toolBox, + private readonly ToolboxInterface $toolbox, ) { } public function getMap(): array { - return $this->toolBox->getMap(); + return $this->toolbox->getMap(); } public function execute(ToolCall $toolCall): mixed { - $result = $this->toolBox->execute($toolCall); + $result = $this->toolbox->execute($toolCall); $this->calls[] = [ 'call' => $toolCall, diff --git a/src/Resources/config/services.php b/src/Resources/config/services.php index 19052b6..a7bfba1 100644 --- a/src/Resources/config/services.php +++ b/src/Resources/config/services.php @@ -7,27 +7,19 @@ use PhpLlm\LlmChain\Chain\StructuredOutput\ChainProcessor as StructureOutputProcessor; use PhpLlm\LlmChain\Chain\StructuredOutput\ResponseFormatFactory; use PhpLlm\LlmChain\Chain\StructuredOutput\ResponseFormatFactoryInterface; -use PhpLlm\LlmChain\Chain\ToolBox\ChainProcessor as ToolProcessor; -use PhpLlm\LlmChain\Chain\ToolBox\MetadataFactory; -use PhpLlm\LlmChain\Chain\ToolBox\MetadataFactory\ReflectionFactory; -use PhpLlm\LlmChain\Chain\ToolBox\ToolBox; -use PhpLlm\LlmChain\Chain\ToolBox\ToolBoxInterface; -use PhpLlm\LlmChain\Embedder; +use PhpLlm\LlmChain\Chain\Toolbox\ChainProcessor as ToolProcessor; +use PhpLlm\LlmChain\Chain\Toolbox\MetadataFactory; +use PhpLlm\LlmChain\Chain\Toolbox\MetadataFactory\ReflectionFactory; +use PhpLlm\LlmChain\Chain\Toolbox\Toolbox; +use PhpLlm\LlmChain\Chain\Toolbox\ToolboxInterface; use PhpLlm\LlmChainBundle\Profiler\DataCollector; -use PhpLlm\LlmChainBundle\Profiler\TraceableToolBox; +use PhpLlm\LlmChainBundle\Profiler\TraceableToolbox; return static function (ContainerConfigurator $container): void { $container->services() ->defaults() ->autowire() - // high level feature - ->set('llm_chain.embedder.abstract', Embedder::class) - ->abstract() - ->args([ - '$embeddings' => abstract_arg('Embeddings model'), - ]) - // structured output ->set(ResponseFormatFactory::class) ->alias(ResponseFormatFactoryInterface::class, ResponseFormatFactory::class) @@ -37,40 +29,40 @@ // tools ->set('llm_chain.toolbox.abstract') - ->class(ToolBox::class) + ->class(Toolbox::class) ->autowire() ->abstract() ->args([ '$tools' => abstract_arg('Collection of tools'), ]) - ->set(ToolBox::class) + ->set(Toolbox::class) ->parent('llm_chain.toolbox.abstract') ->args([ '$tools' => tagged_iterator('llm_chain.tool'), ]) - ->alias(ToolBoxInterface::class, ToolBox::class) + ->alias(ToolboxInterface::class, Toolbox::class) ->set(ReflectionFactory::class) ->alias(MetadataFactory::class, ReflectionFactory::class) ->set('llm_chain.tool.chain_processor.abstract') ->class(ToolProcessor::class) ->abstract() ->args([ - '$toolBox' => abstract_arg('Tool box'), + '$toolbox' => abstract_arg('Toolbox'), ]) ->set(ToolProcessor::class) ->parent('llm_chain.tool.chain_processor.abstract') ->tag('llm_chain.chain.input_processor') ->tag('llm_chain.chain.output_processor') ->args([ - '$toolBox' => service(ToolBoxInterface::class), + '$toolbox' => service(ToolboxInterface::class), '$eventDispatcher' => service('event_dispatcher')->nullOnInvalid(), ]) // profiler ->set(DataCollector::class) ->tag('data_collector') - ->set(TraceableToolBox::class) - ->decorate(ToolBoxInterface::class) + ->set(TraceableToolbox::class) + ->decorate(ToolboxInterface::class) ->tag('llm_chain.traceable_toolbox') ; }; diff --git a/src/Resources/views/data_collector.html.twig b/src/Resources/views/data_collector.html.twig index ac7183c..524624f 100644 --- a/src/Resources/views/data_collector.html.twig +++ b/src/Resources/views/data_collector.html.twig @@ -194,7 +194,7 @@ {{ tool.name }} {{ tool.description }} - {{ tool.className }}::{{ tool.method }} + {{ tool.reference.class }}::{{ tool.reference.method }} {% if tool.parameters %}