Skip to content

Commit 57e502e

Browse files
committed
feat: introduction completion chat and RAG chat
1 parent 21c732b commit 57e502e

12 files changed

+229
-20
lines changed

src/Model/Completion/Chat/History.php

+10-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
namespace Devscast\Lugha\Model\Completion\Chat;
1515

16+
use Devscast\Lugha\Assert;
17+
1618
/**
1719
* Class History.
1820
*
@@ -23,22 +25,26 @@ class History
2325
/**
2426
* @param Message[] $messages
2527
*/
26-
private array $messages = [];
28+
private function __construct(
29+
private array $messages = []
30+
) {
31+
Assert::allIsInstanceOf($messages, Message::class);
32+
}
2733

2834
/**
2935
* @param Message[] $messages
3036
*/
31-
public function fromMessages(array $messages): void
37+
public static function fromMessages(array $messages): self
3238
{
33-
$this->messages = $messages;
39+
return new self($messages);
3440
}
3541

3642
public function getHistory(bool $excludeSystemInstruction = false): array
3743
{
3844
return \array_map(
3945
callback: fn (Message $message) => $message->toArray(),
4046
array: $excludeSystemInstruction ?
41-
array_filter($this->messages, fn (Message $message) => $message->role !== Role::SYSTEM) :
47+
\array_filter($this->messages, fn (Message $message) => $message->role !== Role::SYSTEM) :
4248
$this->messages
4349
);
4450
}

src/Model/Completion/Chat/Message.php

+5-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
namespace Devscast\Lugha\Model\Completion\Chat;
1515

16+
use Devscast\Lugha\Assert;
17+
1618
/**
1719
* Class Message.
1820
*
@@ -21,17 +23,18 @@
2123
final readonly class Message implements \Stringable
2224
{
2325
public function __construct(
24-
public ?string $content,
26+
public string $content,
2527
public Role $role = Role::USER,
2628
public ?string $toolCallId = null,
2729
public ?array $toolCalls = null
2830
) {
31+
Assert::notEmpty($content);
2932
}
3033

3134
#[\Override]
3235
public function __toString(): string
3336
{
34-
return (string) $this->content;
37+
return $this->content;
3538
}
3639

3740
public static function fromResponse(array $message): self

src/Model/Completion/Chat/ToolCalled.php

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public static function fromResponse(array $data): self
3737
{
3838
try {
3939
/** @var array<string, mixed> $arguments */
40-
$arguments = json_decode($data['function']['arguments'], true, flags: JSON_THROW_ON_ERROR);
40+
$arguments = \json_decode($data['function']['arguments'], true, flags: \JSON_THROW_ON_ERROR);
4141

4242
return new self($data['id'], $data['type'], $data['function']['name'], $arguments);
4343
} catch (\JsonException $e) {

src/Model/Completion/ChatInterface.php

+2-4
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
namespace Devscast\Lugha\Model\Completion;
1515

1616
use Devscast\Lugha\Model\Completion\Chat\History;
17-
use Devscast\Lugha\Model\Completion\Chat\Message;
18-
use Devscast\Lugha\Model\Completion\Prompt\PromptTemplate;
1917

2018
/**
2119
* Interface ChatInterface.
@@ -24,7 +22,7 @@
2422
*/
2523
interface ChatInterface
2624
{
27-
public function setSystemMessage(PromptTemplate|Message $message): void;
25+
public function completion(string $input, array $tools = []): string;
2826

29-
public function completion(PromptTemplate|History|string $input): string;
27+
public function completionWithHistory(string $input, History $history, array $tools = []): string;
3028
}
+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Lugha package.
5+
*
6+
* (c) Bernard Ngandu <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
declare(strict_types=1);
13+
14+
namespace Devscast\Lugha\Model\Completion;
15+
16+
use Devscast\Lugha\Exception\ServiceIntegrationException;
17+
use Devscast\Lugha\Model\Completion\Chat\History;
18+
use Devscast\Lugha\Model\Completion\Chat\Message;
19+
use Devscast\Lugha\Model\Completion\Chat\Role;
20+
use Devscast\Lugha\Provider\Service\HasCompletionSupport;
21+
22+
/**
23+
* Class Chatter.
24+
*
25+
* @author bernard-ng <[email protected]>
26+
*/
27+
final readonly class CompletionChat implements ChatInterface
28+
{
29+
public function __construct(
30+
private HasCompletionSupport $client,
31+
private CompletionConfig $completionConfig,
32+
) {
33+
}
34+
35+
/**
36+
* @throws ServiceIntegrationException
37+
*/
38+
#[\Override]
39+
public function completion(string $input, array $tools = []): string
40+
{
41+
return $this->client->completion($input, $this->completionConfig, $tools)->completion;
42+
}
43+
44+
/**
45+
* @throws ServiceIntegrationException
46+
*/
47+
#[\Override]
48+
public function completionWithHistory(string $input, History $history, array $tools = []): string
49+
{
50+
$history->append(new Message($input, Role::USER));
51+
52+
return $this->client->completion($history, $this->completionConfig, $tools)->completion;
53+
}
54+
}

src/Model/Completion/CompletionConfig.php

+6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
namespace Devscast\Lugha\Model\Completion;
1515

1616
use Devscast\Lugha\Assert;
17+
use Devscast\Lugha\Model\Embedding\Distance;
1718

1819
/**
1920
* Class CompletionConfig.
@@ -37,6 +38,8 @@
3738
* @param float|null $frequencyPenalty The value used to penalize new tokens based on their frequency in the training data.
3839
* @param float|null $presencePenalty The value used to penalize new tokens based on whether they are already present in the text.
3940
* @param array|null $stopSequences A list of sequences where the model should stop generating the text.
41+
* @param int $similarityK The number of similar examples to use for retrieval-augmented generation.
42+
* @param Distance $similarityDistance The distance metric to use for retrieval-augmented generation.
4043
* @param array $additionalParameters Additional parameters to pass to the API.
4144
*/
4245
public function __construct(
@@ -48,6 +51,8 @@ public function __construct(
4851
public ?float $frequencyPenalty = null,
4952
public ?float $presencePenalty = null,
5053
public ?array $stopSequences = null,
54+
public int $similarityK = 4,
55+
public Distance $similarityDistance = Distance::L2,
5156
public array $additionalParameters = []
5257
) {
5358
Assert::notEmpty($this->model);
@@ -57,5 +62,6 @@ public function __construct(
5762
Assert::nullOrPositiveInteger($this->topK);
5863
Assert::nullOrRange($this->frequencyPenalty, -2, 2);
5964
Assert::nullOrRange($this->presencePenalty, -2, 2);
65+
Assert::greaterThan($this->similarityK, 0);
6066
}
6167
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Lugha package.
5+
*
6+
* (c) Bernard Ngandu <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
declare(strict_types=1);
13+
14+
namespace Devscast\Lugha\Model\Completion;
15+
16+
use Devscast\Lugha\Exception\ServiceIntegrationException;
17+
use Devscast\Lugha\Model\Completion\Chat\History;
18+
use Devscast\Lugha\Model\Completion\Chat\Message;
19+
use Devscast\Lugha\Model\Completion\Chat\Role;
20+
use Devscast\Lugha\Model\Completion\Prompt\PromptTemplate;
21+
use Devscast\Lugha\Provider\Service\HasCompletionSupport;
22+
use Devscast\Lugha\Retrieval\VectorStore\VectorStoreInterface;
23+
24+
/**
25+
* Class RetrievalAugmentedChatter.
26+
*
27+
* @author bernard-ng <[email protected]>
28+
*/
29+
final readonly class RetrievalAugmentedChat implements RetrievalAugmentedInterface
30+
{
31+
public function __construct(
32+
private HasCompletionSupport $client,
33+
private CompletionConfig $config,
34+
private VectorStoreInterface $vectorStore,
35+
) {
36+
}
37+
38+
/**
39+
* @throws ServiceIntegrationException
40+
*/
41+
#[\Override]
42+
public function augmentedCompletion(string $query, PromptTemplate $prompt): string
43+
{
44+
$prompt->setParameter(':CONTEXT', $this->createContext($query));
45+
46+
$history = History::fromMessages([
47+
new Message((string) $prompt, Role::SYSTEM),
48+
new Message($query, Role::USER),
49+
]);
50+
51+
return $this->client->completion($history, $this->config)->completion;
52+
}
53+
54+
/**
55+
* @throws ServiceIntegrationException
56+
*/
57+
#[\Override]
58+
public function augmentedCompletionWithHistory(string $query, PromptTemplate $prompt, History $history): string
59+
{
60+
$prompt->setParameter(':CONTEXT', $this->createContext($query));
61+
62+
// TODO: not sure if a chat history can have multiple system messages
63+
// TODO: further investigation needed
64+
$history->append(new Message((string) $prompt, Role::SYSTEM));
65+
$history->append(new Message($query, Role::USER));
66+
67+
return $this->client->completion($history, $this->config)->completion;
68+
}
69+
70+
private function createContext(string $query): string
71+
{
72+
$documents = $this->vectorStore->similaritySearch(
73+
query: $query,
74+
k: $this->config->similarityK,
75+
distance: $this->config->similarityDistance
76+
);
77+
78+
$context = '';
79+
foreach ($documents as $document) {
80+
$context .= $document->content . "\n";
81+
}
82+
83+
return $context;
84+
}
85+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Lugha package.
5+
*
6+
* (c) Bernard Ngandu <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
declare(strict_types=1);
13+
14+
namespace Devscast\Lugha\Model\Completion;
15+
16+
use Devscast\Lugha\Model\Completion\Chat\History;
17+
use Devscast\Lugha\Model\Completion\Prompt\PromptTemplate;
18+
19+
/**
20+
* Interface RagInterface.
21+
*
22+
* @author bernard-ng <[email protected]>
23+
*/
24+
interface RetrievalAugmentedInterface
25+
{
26+
public function augmentedCompletion(string $query, PromptTemplate $prompt): string;
27+
28+
public function augmentedCompletionWithHistory(string $query, PromptTemplate $prompt, History $history): string;
29+
}

src/Provider/Service/Common/OpenAICompatibilitySupport.php

+1-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ public function handleToolCalls(array $response, CompletionConfig $config): Comp
2727
return $this->handleCompletion($response, $config);
2828
}
2929

30-
$history = new History();
31-
$history->append(Message::fromResponse($message));
30+
$history = History::fromMessages([Message::fromResponse($message)]);
3231
$history->merge($this->callTools($message['tool_calls']));
3332

3433
return $this->completion($history, $config);

src/Provider/Service/Common/ToolCallingSupport.php

+7-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
use Devscast\Lugha\Assert;
1717
use Devscast\Lugha\Model\Completion\Chat\History;
18+
use Devscast\Lugha\Model\Completion\Chat\Message;
1819
use Devscast\Lugha\Model\Completion\Chat\ToolCalled;
1920
use Devscast\Lugha\Model\Completion\Tools\ToolReference;
2021
use Devscast\Lugha\Model\Completion\Tools\ToolRunner;
@@ -55,11 +56,12 @@ public function callTools(array $message): History
5556
$message['tool_calls']
5657
);
5758

58-
$history = new History();
59-
foreach ($tools as $tool) {
60-
$history->append(ToolRunner::run($tool, $this->references));
61-
}
59+
$messages = \array_map(
60+
fn (ToolCalled $tool): ?Message => ToolRunner::run($tool, $this->references),
61+
$tools
62+
);
63+
$messages = \array_filter($message, fn (?Message $message): bool => $message !== null);
6264

63-
return $history;
65+
return History::fromMessages($messages);
6466
}
6567
}

src/Retrieval/Document.php

+14-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
*
2020
* @author bernard-ng <[email protected]>
2121
*/
22-
class Document implements \Stringable
22+
class Document implements \Stringable, \JsonSerializable
2323
{
2424
public function __construct(
2525
public string $content,
@@ -47,4 +47,17 @@ public function hasEmbeddings(): bool
4747
{
4848
return \count($this->embeddings) !== 0;
4949
}
50+
51+
/**
52+
* @throws \JsonException
53+
*/
54+
#[\Override]
55+
public function jsonSerialize(): string
56+
{
57+
return \json_encode([
58+
'content' => $this->content,
59+
'embeddings' => $this->embeddings,
60+
'metadata' => $this->metadata,
61+
], \JSON_THROW_ON_ERROR);
62+
}
5063
}

src/Retrieval/Metadata.php

+15-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
*
2121
* @author bernard-ng <[email protected]>
2222
*/
23-
class Metadata
23+
class Metadata implements \JsonSerializable
2424
{
2525
public function __construct(
2626
public ?string $hash = null,
@@ -39,4 +39,18 @@ public static function from(array $metadata): self
3939
$metadata['chunkNumber'],
4040
);
4141
}
42+
43+
/**
44+
* @throws \JsonException
45+
*/
46+
#[\Override]
47+
public function jsonSerialize(): string
48+
{
49+
return \json_encode([
50+
'hash' => $this->hash,
51+
'sourceType' => $this->sourceType,
52+
'sourceName' => $this->sourceName,
53+
'chunkNumber' => $this->chunkNumber,
54+
], \JSON_THROW_ON_ERROR);
55+
}
4256
}

0 commit comments

Comments
 (0)