Skip to content

Commit

Permalink
Merge pull request #14 from gregpriday/feature/multiple-vector-support
Browse files Browse the repository at this point in the history
Multiple Vector Support
  • Loading branch information
hkulekci authored Aug 6, 2023
2 parents 8e62cb1 + cda9f1f commit 5b69412
Show file tree
Hide file tree
Showing 16 changed files with 415 additions and 28 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@ vendor/
.phpunit.result.cache
.phpunit.cache
coverage
/qdrant_storage
/composer.lock
/.idea
2 changes: 1 addition & 1 deletion composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"require": {
"php": "^8.1",
"psr/http-client": "^1.0",
"psr/http-message": "^1.0",
"psr/http-message": "^1.0|^2.0",
"psr/log": "^1.0|^2.0|^3.0",
"guzzlehttp/guzzle": "^7.5",
"guzzlehttp/psr7": "^2.0",
Expand Down
2 changes: 1 addition & 1 deletion src/Exception/InvalidArgumentException.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

use Qdrant\Response;

class InvalidArgumentException extends \Exception
class InvalidArgumentException extends \InvalidArgumentException
{
protected Response $response;

Expand Down
58 changes: 58 additions & 0 deletions src/Models/MultiVectorStruct.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<?php

namespace Qdrant\Models;

use Qdrant\Exception\InvalidArgumentException;

class MultiVectorStruct implements VectorStructInterface
{
protected array $vectors = [];

public function __construct(array $vectors = [])
{
foreach ($vectors as $name => $vector) {
$this->addVector($name, $vector);
}
}

public function addVector(string $name, array $vector): void
{
$this->vectors[$name] = $vector;
}

public function getName(): string
{
if(empty($this->vectors)) {
throw new InvalidArgumentException("No vectors added yet");
}

return array_key_first($this->vectors);
}

public function toSearchArray(string $name = null): array
{
// Throw an error if no name is given
if ($name === null) {
throw new InvalidArgumentException("Must provide a name to search");
}

if(!isset($this->vectors[$name])) {
throw new InvalidArgumentException("Vector with name $name not found");
}

return [
'name' => $name,
'vector' => $this->vectors[$name],
];
}

public function toArray(): array
{
return $this->vectors;
}

public function count(): int
{
return count($this->vectors);
}
}
28 changes: 13 additions & 15 deletions src/Models/PointStruct.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ class PointStruct
{
use ProtectedPropertyAccessor;

// TODO: we need a solution for point with uuid
protected int|string $id;
protected ?array $payload = null;
protected VectorStruct $vector;
protected VectorStructInterface $vector;

public function __construct(int|string $id, VectorStruct $vector, array $payload = null)
public function __construct(int|string $id, VectorStructInterface $vector, array $payload = null)
{
$this->id = $id;
$this->vector = $vector;
Expand All @@ -33,9 +32,17 @@ public static function createFromArray(array $pointArray): PointStruct
if (count(array_intersect_key(array_flip($required), $pointArray)) !== count($required)) {
throw new InvalidArgumentException('Missing point keys');
}

$vector = $pointArray['vector'];
if (is_array($pointArray['vector'])) {
$vector = new VectorStruct($pointArray['vector'], $pointArray['name'] ?? null);

// Check if it's an array and convert it to a VectorStruct
if (is_array($vector)) {
$vector = new VectorStruct($vector, $pointArray['name'] ?? null);
}

// Check if it's already a VectorStruct or MultiVectorStruct
if (!($vector instanceof VectorStructInterface)) {
throw new InvalidArgumentException('Invalid vector type');
}

return new PointStruct($pointArray['id'], $vector, $pointArray['payload'] ?? null);
Expand All @@ -55,26 +62,17 @@ public function toArray(): array
return $point;
}

/**
* @return int
*/
public function getId(): int|string
{
return $this->id;
}

/**
* @return array|null
*/
public function getPayload(): ?array
{
return $this->payload;
}

/**
* @return VectorStruct
*/
public function getVector(): VectorStruct
public function getVector(): VectorStructInterface
{
return $this->vector;
}
Expand Down
5 changes: 5 additions & 0 deletions src/Models/PointsStruct.php
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,9 @@ public function toArray(): array

return $points;
}

public function count(): int
{
return count($this->points);
}
}
14 changes: 11 additions & 3 deletions src/Models/Request/Point.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,26 @@

namespace Qdrant\Models\Request;

use Qdrant\Models\MultiVectorStruct;
use Qdrant\Models\VectorStruct;
use Qdrant\Models\VectorStructInterface;

class Point implements RequestModel
{
protected string $id;
protected array $vector;
protected VectorStructInterface $vector;

/**
* @var array|null Payload values (optional)
*/
protected ?array $payload = null;

public function __construct(string $id, array $vector, array $payload = null)
public function __construct(string $id, VectorStructInterface|array $vector, array $payload = null)
{
if(is_array($vector)) {
$vector = new VectorStruct($vector);
}

$this->id = $id;
$this->vector = $vector;
$this->payload = $payload;
Expand All @@ -29,7 +37,7 @@ public function toArray(): array
{
$data = [
'id' => $this->id,
'vector' => $this->vector,
'vector' => $this->vector->toArray(),
];

if ($this->payload) {
Expand Down
18 changes: 14 additions & 4 deletions src/Models/Request/SearchRequest.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

use Qdrant\Models\Filter\Filter;
use Qdrant\Models\Traits\ProtectedPropertyAccessor;
use Qdrant\Models\VectorStruct;
use Qdrant\Models\VectorStructInterface;

class SearchRequest
{
Expand All @@ -19,7 +19,7 @@ class SearchRequest

protected array $params = [];

protected VectorStruct $vector;
protected VectorStructInterface $vector;

protected ?int $limit = null;

Expand All @@ -31,11 +31,20 @@ class SearchRequest

protected ?float $scoreThreshold = null;

public function __construct(VectorStruct $vector)
protected ?string $name = null;

public function __construct(VectorStructInterface $vector)
{
$this->vector = $vector;
}

public function setName(string $name): static
{
$this->name = $name;

return $this;
}

public function setFilter(Filter $filter): static
{
$this->filter = $filter;
Expand Down Expand Up @@ -88,8 +97,9 @@ public function setWithVector($withVector): static
public function toArray(): array
{
$body = [
'vector' => $this->vector->toSearch(),
'vector' => $this->vector->toSearchArray($this->name ?? $this->vector->getName()),
];

if ($this->filter !== null && $this->filter->toArray()) {
$body['filter'] = $this->filter->toArray();
}
Expand Down
9 changes: 7 additions & 2 deletions src/Models/VectorStruct.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

use Qdrant\Models\Traits\ProtectedPropertyAccessor;

class VectorStruct
class VectorStruct implements VectorStructInterface
{
use ProtectedPropertyAccessor;

Expand All @@ -26,7 +26,12 @@ public function isNamed(): bool
return $this->name !== null;
}

public function toSearch(): array
public function getName(): string
{
return $this->name;
}

public function toSearchArray(string $name = null): array
{
if ($this->isNamed()) {
return [
Expand Down
28 changes: 28 additions & 0 deletions src/Models/VectorStructInterface.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<?php

namespace Qdrant\Models;

interface VectorStructInterface
{
/**
* Get the name of the vector
*
* @return string
*/
public function getName(): string;

/**
* Convert this vector to a search array.
*
* @param string|null $name
* @return array
*/
public function toSearchArray(string $name = null): array;

/**
* Convert this vector an array for Point and PointsBatch.
*
* @return array
*/
public function toArray(): array;
}
3 changes: 2 additions & 1 deletion tests/Integration/AbstractIntegration.php
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ protected function setUp(): void
private static function sampleCollectionOption(): CreateCollection
{
return (new CreateCollection())
->addVector(new VectorParams(3, VectorParams::DISTANCE_COSINE), 'image');
->addVector(new VectorParams(3, VectorParams::DISTANCE_COSINE), 'image')
->addVector(new VectorParams(3, VectorParams::DISTANCE_COSINE), 'text');
}

protected function createCollections($name, CreateCollection $withConfiguration = null): void
Expand Down
Loading

0 comments on commit 5b69412

Please sign in to comment.