Skip to content

Commit 3fb05eb

Browse files
ankaneNgalstyan4dqii
committed
Added casts for arrays to sparsevec - pgvector#604
Co-authored-by: Narek Galstyan <[email protected]> Co-authored-by: Di Qi <[email protected]>
1 parent b738ffe commit 3fb05eb

File tree

6 files changed

+245
-0
lines changed

6 files changed

+245
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## 0.8.0 (unreleased)
22

3+
- Added casts for arrays to `sparsevec`
34
- Reduced memory usage for HNSW index scans
45
- Dropped support for Postgres 12
56

sql/vector--0.7.4--0.8.0.sql

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
-- complain if script is sourced in psql, rather than via CREATE EXTENSION
2+
\echo Use "ALTER EXTENSION vector UPDATE TO '0.8.0'" to load this file. \quit
3+
4+
CREATE FUNCTION array_to_sparsevec(integer[], integer, boolean) RETURNS sparsevec
5+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
6+
7+
CREATE FUNCTION array_to_sparsevec(real[], integer, boolean) RETURNS sparsevec
8+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
9+
10+
CREATE FUNCTION array_to_sparsevec(double precision[], integer, boolean) RETURNS sparsevec
11+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
12+
13+
CREATE FUNCTION array_to_sparsevec(numeric[], integer, boolean) RETURNS sparsevec
14+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
15+
16+
CREATE CAST (integer[] AS sparsevec)
17+
WITH FUNCTION array_to_sparsevec(integer[], integer, boolean) AS ASSIGNMENT;
18+
19+
CREATE CAST (real[] AS sparsevec)
20+
WITH FUNCTION array_to_sparsevec(real[], integer, boolean) AS ASSIGNMENT;
21+
22+
CREATE CAST (double precision[] AS sparsevec)
23+
WITH FUNCTION array_to_sparsevec(double precision[], integer, boolean) AS ASSIGNMENT;
24+
25+
CREATE CAST (numeric[] AS sparsevec)
26+
WITH FUNCTION array_to_sparsevec(numeric[], integer, boolean) AS ASSIGNMENT;

sql/vector.sql

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,18 @@ CREATE FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) RETURNS sparseve
782782
CREATE FUNCTION sparsevec_to_halfvec(sparsevec, integer, boolean) RETURNS halfvec
783783
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
784784

785+
CREATE FUNCTION array_to_sparsevec(integer[], integer, boolean) RETURNS sparsevec
786+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
787+
788+
CREATE FUNCTION array_to_sparsevec(real[], integer, boolean) RETURNS sparsevec
789+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
790+
791+
CREATE FUNCTION array_to_sparsevec(double precision[], integer, boolean) RETURNS sparsevec
792+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
793+
794+
CREATE FUNCTION array_to_sparsevec(numeric[], integer, boolean) RETURNS sparsevec
795+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
796+
785797
-- sparsevec casts
786798

787799
CREATE CAST (sparsevec AS sparsevec)
@@ -799,6 +811,18 @@ CREATE CAST (sparsevec AS halfvec)
799811
CREATE CAST (halfvec AS sparsevec)
800812
WITH FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) AS IMPLICIT;
801813

814+
CREATE CAST (integer[] AS sparsevec)
815+
WITH FUNCTION array_to_sparsevec(integer[], integer, boolean) AS ASSIGNMENT;
816+
817+
CREATE CAST (real[] AS sparsevec)
818+
WITH FUNCTION array_to_sparsevec(real[], integer, boolean) AS ASSIGNMENT;
819+
820+
CREATE CAST (double precision[] AS sparsevec)
821+
WITH FUNCTION array_to_sparsevec(double precision[], integer, boolean) AS ASSIGNMENT;
822+
823+
CREATE CAST (numeric[] AS sparsevec)
824+
WITH FUNCTION array_to_sparsevec(numeric[], integer, boolean) AS ASSIGNMENT;
825+
802826
-- sparsevec operators
803827

804828
CREATE OPERATOR <-> (

src/sparsevec.c

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <limits.h>
44
#include <math.h>
55

6+
#include "catalog/pg_type.h"
67
#include "common/string.h"
78
#include "fmgr.h"
89
#include "halfutils.h"
@@ -11,6 +12,7 @@
1112
#include "sparsevec.h"
1213
#include "utils/array.h"
1314
#include "utils/builtins.h"
15+
#include "utils/lsyscache.h"
1416
#include "vector.h"
1517

1618
#if PG_VERSION_NUM >= 120000
@@ -670,6 +672,126 @@ halfvec_to_sparsevec(PG_FUNCTION_ARGS)
670672
PG_RETURN_POINTER(result);
671673
}
672674

675+
/*
676+
* Convert array to sparse vector
677+
*/
678+
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(array_to_sparsevec);
679+
Datum
680+
array_to_sparsevec(PG_FUNCTION_ARGS)
681+
{
682+
ArrayType *array = PG_GETARG_ARRAYTYPE_P(0);
683+
int32 typmod = PG_GETARG_INT32(1);
684+
SparseVector *result;
685+
int16 typlen;
686+
bool typbyval;
687+
char typalign;
688+
Datum *elemsp;
689+
int nelemsp;
690+
int nnz = 0;
691+
float *values;
692+
int j = 0;
693+
694+
if (ARR_NDIM(array) > 1)
695+
ereport(ERROR,
696+
(errcode(ERRCODE_DATA_EXCEPTION),
697+
errmsg("array must be 1-D")));
698+
699+
if (ARR_HASNULL(array) && array_contains_nulls(array))
700+
ereport(ERROR,
701+
(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
702+
errmsg("array must not contain nulls")));
703+
704+
get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign);
705+
deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, NULL, &nelemsp);
706+
707+
CheckDim(nelemsp);
708+
CheckExpectedDim(typmod, nelemsp);
709+
710+
if (ARR_ELEMTYPE(array) == INT4OID)
711+
{
712+
for (int i = 0; i < nelemsp; i++)
713+
nnz += ((float) DatumGetInt32(elemsp[i])) != 0;
714+
}
715+
else if (ARR_ELEMTYPE(array) == FLOAT8OID)
716+
{
717+
for (int i = 0; i < nelemsp; i++)
718+
nnz += ((float) DatumGetFloat8(elemsp[i])) != 0;
719+
}
720+
else if (ARR_ELEMTYPE(array) == FLOAT4OID)
721+
{
722+
for (int i = 0; i < nelemsp; i++)
723+
nnz += (DatumGetFloat4(elemsp[i]) != 0);
724+
}
725+
else if (ARR_ELEMTYPE(array) == NUMERICOID)
726+
{
727+
for (int i = 0; i < nelemsp; i++)
728+
nnz += (DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i])) != 0);
729+
}
730+
else
731+
{
732+
ereport(ERROR,
733+
(errcode(ERRCODE_DATA_EXCEPTION),
734+
errmsg("unsupported array type")));
735+
}
736+
737+
result = InitSparseVector(nelemsp, nnz);
738+
values = SPARSEVEC_VALUES(result);
739+
740+
#define PROCESS_ARRAY_ELEM(elem) \
741+
do { \
742+
float v = (float) (elem); \
743+
if (v != 0) { \
744+
/* Safety check */ \
745+
if (j >= result->nnz) \
746+
elog(ERROR, "safety check failed"); \
747+
result->indices[j] = i; \
748+
values[j] = v; \
749+
j++; \
750+
} \
751+
} while (0)
752+
753+
if (ARR_ELEMTYPE(array) == INT4OID)
754+
{
755+
for (int i = 0; i < nelemsp; i++)
756+
PROCESS_ARRAY_ELEM(DatumGetInt32(elemsp[i]));
757+
}
758+
else if (ARR_ELEMTYPE(array) == FLOAT8OID)
759+
{
760+
for (int i = 0; i < nelemsp; i++)
761+
PROCESS_ARRAY_ELEM(DatumGetFloat8(elemsp[i]));
762+
}
763+
else if (ARR_ELEMTYPE(array) == FLOAT4OID)
764+
{
765+
for (int i = 0; i < nelemsp; i++)
766+
PROCESS_ARRAY_ELEM(DatumGetFloat4(elemsp[i]));
767+
}
768+
else if (ARR_ELEMTYPE(array) == NUMERICOID)
769+
{
770+
for (int i = 0; i < nelemsp; i++)
771+
PROCESS_ARRAY_ELEM(DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i])));
772+
}
773+
else
774+
{
775+
ereport(ERROR,
776+
(errcode(ERRCODE_DATA_EXCEPTION),
777+
errmsg("unsupported array type")));
778+
}
779+
780+
#undef PROCESS_ARRAY_ELEM
781+
782+
/*
783+
* Free allocation from deconstruct_array. Do not free individual elements
784+
* when pass-by-reference since they point to original array.
785+
*/
786+
pfree(elemsp);
787+
788+
/* Check elements */
789+
for (int i = 0; i < result->nnz; i++)
790+
CheckElement(values[i]);
791+
792+
PG_RETURN_POINTER(result);
793+
}
794+
673795
/*
674796
* Get the L2 squared distance between sparse vectors
675797
*/

test/expected/cast.out

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,62 @@ SELECT '{1:1e-8}/1'::sparsevec::halfvec;
208208
[0]
209209
(1 row)
210210

211+
SELECT ARRAY[1,0,2,0,3,0]::sparsevec;
212+
array
213+
-----------------
214+
{1:1,3:2,5:3}/6
215+
(1 row)
216+
217+
SELECT ARRAY[1.0,0.0,2.0,0.0,3.0,0.0]::sparsevec;
218+
array
219+
-----------------
220+
{1:1,3:2,5:3}/6
221+
(1 row)
222+
223+
SELECT ARRAY[1,0,2,0,3,0]::float4[]::sparsevec;
224+
array
225+
-----------------
226+
{1:1,3:2,5:3}/6
227+
(1 row)
228+
229+
SELECT ARRAY[1,0,2,0,3,0]::float8[]::sparsevec;
230+
array
231+
-----------------
232+
{1:1,3:2,5:3}/6
233+
(1 row)
234+
235+
SELECT ARRAY[1,0,2,0,3,0]::numeric[]::sparsevec;
236+
array
237+
-----------------
238+
{1:1,3:2,5:3}/6
239+
(1 row)
240+
241+
SELECT '{1,0,2,0,3,0}'::real[]::sparsevec;
242+
sparsevec
243+
-----------------
244+
{1:1,3:2,5:3}/6
245+
(1 row)
246+
247+
SELECT '{1,0,2,0,3,0}'::real[]::sparsevec(6);
248+
sparsevec
249+
-----------------
250+
{1:1,3:2,5:3}/6
251+
(1 row)
252+
253+
SELECT '{1,0,2,0,3,0}'::real[]::sparsevec(5);
254+
ERROR: expected 5 dimensions, not 6
255+
SELECT '{NULL}'::real[]::sparsevec;
256+
ERROR: array must not contain nulls
257+
SELECT '{NaN}'::real[]::sparsevec;
258+
ERROR: NaN not allowed in sparsevec
259+
SELECT '{Infinity}'::real[]::sparsevec;
260+
ERROR: infinite value not allowed in sparsevec
261+
SELECT '{-Infinity}'::real[]::sparsevec;
262+
ERROR: infinite value not allowed in sparsevec
263+
SELECT '{}'::real[]::sparsevec;
264+
ERROR: sparsevec must have at least 1 dimension
265+
SELECT '{{1}}'::real[]::sparsevec;
266+
ERROR: array must be 1-D
211267
SELECT array_agg(n)::vector FROM generate_series(1, 16001) n;
212268
ERROR: vector cannot have more than 16000 dimensions
213269
SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n;

test/sql/cast.sql

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,22 @@ SELECT '{}/16001'::sparsevec::halfvec;
5858
SELECT '{1:65520}/1'::sparsevec::halfvec;
5959
SELECT '{1:1e-8}/1'::sparsevec::halfvec;
6060

61+
SELECT ARRAY[1,0,2,0,3,0]::sparsevec;
62+
SELECT ARRAY[1.0,0.0,2.0,0.0,3.0,0.0]::sparsevec;
63+
SELECT ARRAY[1,0,2,0,3,0]::float4[]::sparsevec;
64+
SELECT ARRAY[1,0,2,0,3,0]::float8[]::sparsevec;
65+
SELECT ARRAY[1,0,2,0,3,0]::numeric[]::sparsevec;
66+
67+
SELECT '{1,0,2,0,3,0}'::real[]::sparsevec;
68+
SELECT '{1,0,2,0,3,0}'::real[]::sparsevec(6);
69+
SELECT '{1,0,2,0,3,0}'::real[]::sparsevec(5);
70+
SELECT '{NULL}'::real[]::sparsevec;
71+
SELECT '{NaN}'::real[]::sparsevec;
72+
SELECT '{Infinity}'::real[]::sparsevec;
73+
SELECT '{-Infinity}'::real[]::sparsevec;
74+
SELECT '{}'::real[]::sparsevec;
75+
SELECT '{{1}}'::real[]::sparsevec;
76+
6177
SELECT array_agg(n)::vector FROM generate_series(1, 16001) n;
6278
SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n;
6379

0 commit comments

Comments
 (0)