diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index a434c0e..cbdfcc8 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -361,4 +361,4 @@ jobs: - run: cargo publish --no-verify working-directory: ./bindings/rust env: - CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/upload-deno-assets.js b/.github/workflows/upload-deno-assets.js index 9ab67af..dee6e23 100644 --- a/.github/workflows/upload-deno-assets.js +++ b/.github/workflows/upload-deno-assets.js @@ -55,4 +55,4 @@ module.exports = async ({ github, context }) => { }) ); return outputAssetChecksums.map((d) => `${d.checksum} ${d.name}`).join("\n"); -}; +}; \ No newline at end of file diff --git a/.github/workflows/upload.js b/.github/workflows/upload.js index 56eb4ec..6ac55d5 100644 --- a/.github/workflows/upload.js +++ b/.github/workflows/upload.js @@ -144,4 +144,4 @@ module.exports = async ({ github, context }) => { name: "spm.json", data: JSON.stringify(spm_json), }); -}; +}; \ No newline at end of file diff --git a/src/sqlite-vector.cpp b/src/sqlite-vector.cpp index 897ec09..a780fcf 100644 --- a/src/sqlite-vector.cpp +++ b/src/sqlite-vector.cpp @@ -1,235 +1,289 @@ + +#pragma region Includes and typedefs + +#include "sqlite-vector.h" +#include "sqlite-vss.h" #include #include +#include #include #include #include -#include "sqlite-vector.h" -#include "sqlite-vss.h" #include "sqlite3ext.h" SQLITE_EXTENSION_INIT1 +using namespace std; +typedef unique_ptr> vec_ptr; // https://github.com/sqlite/sqlite/blob/master/src/json.c#L88-L89 -#define JSON_SUBTYPE 74 /* Ascii for "J" */ +#define JSON_SUBTYPE 74 /* Ascii for "J" */ #include using json = nlohmann::json; char VECTOR_BLOB_HEADER_BYTE = 'v'; char VECTOR_BLOB_HEADER_TYPE = 1; -const char * VECTOR_FLOAT_POINTER_NAME = "vectorf32v0"; +const char *VECTOR_FLOAT_POINTER_NAME = "vectorf32v0"; -struct VecX { - int64_t size; - float * data; -}; +#pragma endregion +#pragma region Generic -void del(void*p) { - VectorFloat * vx = (VectorFloat *)p; - sqlite3_free(vx->data); - delete vx; +void delVectorFloat(void *p) { + + auto vx = static_cast(p); + sqlite3_free(vx->data); + delete vx; } -struct Vector0Global { - vector0_api api; - sqlite3 *db; -}; +void resultVector(sqlite3_context *context, vector *vecIn) { -static void resultVector(sqlite3_context * context, std::vector* v) { - VectorFloat * vx = new VectorFloat(); - vx->size = v->size(); - vx->data = (float *) sqlite3_malloc(v->size()*sizeof(float)); - memcpy(vx->data, v->data(), v->size()*sizeof(float)); - sqlite3_result_pointer(context, vx, VECTOR_FLOAT_POINTER_NAME, del); -} + auto vecRes = new VectorFloat(); + + vecRes->size = vecIn->size(); + vecRes->data = (float *)sqlite3_malloc(vecIn->size() * sizeof(float)); + + memcpy(vecRes->data, vecIn->data(), vecIn->size() * sizeof(float)); -#pragma region generic -std::vector * vectorFromBlobValue(sqlite3_value*value, const char ** pzErrMsg) { - int n = sqlite3_value_bytes(value); - const void * b; - char header; - char type; - - if(n < (2)) { - *pzErrMsg = "Vector blob size less than header length"; - return NULL; - } - b = sqlite3_value_blob(value); - memcpy(&header, ((char *) b + 0), sizeof(char)); - memcpy(&type, ((char *) b + 1), sizeof(char)); - - if(header != VECTOR_BLOB_HEADER_BYTE) { - *pzErrMsg = "Blob not well-formatted vector blob"; - return NULL; - } - if(type != VECTOR_BLOB_HEADER_TYPE) { - *pzErrMsg = "Blob type not right"; - return NULL; - } - int numElements = (n - 2)/sizeof(float); - float * v = (float *) ((char *)b + 2); - return new std::vector(v, v+numElements); + sqlite3_result_pointer(context, vecRes, VECTOR_FLOAT_POINTER_NAME, delVectorFloat); } -std::vector * vectorFromRawBlobValue(sqlite3_value*value, const char ** pzErrMsg) { - int n = sqlite3_value_bytes(value); - // must be divisible by 4 - if(n%4) { - *pzErrMsg = "Invalid raw blob length, must be divisible by 4"; - return NULL; - } - const void * b = sqlite3_value_blob(value); - - float * v = (float *) ((char *)b); - return new std::vector(v, v+ (n / 4)); +vec_ptr vectorFromBlobValue(sqlite3_value *value, const char **pzErrMsg) { + + int size = sqlite3_value_bytes(value); + char header; + char type; + + if (size < (2)) { + *pzErrMsg = "Vector blob size less than header length"; + return nullptr; + } + + const void *pBlob = sqlite3_value_blob(value); + memcpy(&header, ((char *)pBlob + 0), sizeof(char)); + memcpy(&type, ((char *)pBlob + 1), sizeof(char)); + + if (header != VECTOR_BLOB_HEADER_BYTE) { + *pzErrMsg = "Blob not well-formatted vector blob"; + return nullptr; + } + + if (type != VECTOR_BLOB_HEADER_TYPE) { + *pzErrMsg = "Blob type not right"; + return nullptr; + } + + int numElements = (size - 2) / sizeof(float); + float *vec = (float *)((char *)pBlob + 2); + return vec_ptr(new vector(vec, vec + numElements)); } -std::vector * vectorFromTextValue(sqlite3_value*value) { - std::vector v; +vec_ptr vectorFromRawBlobValue(sqlite3_value *value, const char **pzErrMsg) { - try { - json data = json::parse(sqlite3_value_text(value)); - data.get_to(v); - } catch (const json::exception&) { - return NULL; - } + int size = sqlite3_value_bytes(value); + + // Must be divisible by 4 + if (size % 4) { + *pzErrMsg = "Invalid raw blob length, blob must be divisible by 4"; + return nullptr; + } + const void *pBlob = sqlite3_value_blob(value); - return new std::vector(v); + float *vec = (float *)((char *)pBlob); + return vec_ptr(new vector(vec, vec + (size / 4))); } -// Returns vector pointer MUST be deleted -static std::vector* valueAsVector(sqlite3_value*value) { - // Option 1: If the value is a "vectorf32v0" pointer, create vector from that - VectorFloat* v = (VectorFloat*) sqlite3_value_pointer(value, VECTOR_FLOAT_POINTER_NAME); - if (v!=NULL) return new std::vector(v->data, v->data + v->size); - std::vector * vec; - - // Option 2: value is a blob in vector format - if(sqlite3_value_type(value) == SQLITE_BLOB) { - const char * pzErrMsg = 0; - if((vec = vectorFromBlobValue(value, &pzErrMsg)) != NULL) { - return vec; +vec_ptr vectorFromTextValue(sqlite3_value *value) { + + try { + + json json = json::parse(sqlite3_value_text(value)); + vec_ptr pVec(new vector()); + json.get_to(*pVec); + return pVec; + + } catch (const json::exception &) { + return nullptr; } - if((vec = vectorFromRawBlobValue(value, &pzErrMsg)) != NULL) { - return vec; + + return nullptr; +} + +static vec_ptr valueAsVector(sqlite3_value *value) { + + // Option 1: If the value is a "vectorf32v0" pointer, create vector from + // that + auto vec = (VectorFloat *)sqlite3_value_pointer(value, VECTOR_FLOAT_POINTER_NAME); + + if (vec != nullptr) + return vec_ptr(new vector(vec->data, vec->data + vec->size)); + + vec_ptr pVec; + + // Option 2: value is a blob in vector format + if (sqlite3_value_type(value) == SQLITE_BLOB) { + + const char *pzErrMsg = nullptr; + + if ((pVec = vectorFromBlobValue(value, &pzErrMsg)) != nullptr) + return pVec; + + if ((pVec = vectorFromRawBlobValue(value, &pzErrMsg)) != nullptr) + return pVec; } - } - // Option 3: if value is a JSON array coercible to float vector, use that - //if(sqlite3_value_subtype(value) == JSON_SUBTYPE) { - if(sqlite3_value_type(value) == SQLITE_TEXT) { - if((vec = vectorFromTextValue(value)) != NULL) { - return vec; - }else { - return NULL; + + // Option 3: if value is a JSON array coercible to float vector, use that + if (sqlite3_value_type(value) == SQLITE_TEXT) { + + if ((pVec = vectorFromTextValue(value)) != nullptr) + return pVec; + else + return nullptr; } - } - // else, value isn't a vector - return NULL; + // Else, value isn't a vector + return nullptr; } + #pragma endregion -#pragma region meta -static void vector_version(sqlite3_context *context, int argc, sqlite3_value **argv) { - sqlite3_result_text(context, SQLITE_VSS_VERSION, -1, SQLITE_STATIC); +#pragma region Meta + +static void vector_version(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + sqlite3_result_text(context, SQLITE_VSS_VERSION, -1, SQLITE_STATIC); } -static void vector_debug(sqlite3_context *context, int argc, sqlite3_value **argv) { - if(argc){ - std::vector* v = valueAsVector(argv[0]); - if(v==NULL) { - sqlite3_result_error(context, "value not a vector", -1); - return; + +static void vector_debug(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + vec_ptr pVec = valueAsVector(argv[0]); + + if (pVec == nullptr) { + + sqlite3_result_error(context, "Value not a vector", -1); + return; } - sqlite3_str * str = sqlite3_str_new(0); - sqlite3_str_appendf(str, "size: %lld [", v->size()); - for(int i = 0; i < v->size(); i++) { - if(i==0) sqlite3_str_appendf(str, "%f", v->at(i)); - else sqlite3_str_appendf(str, ", %f", v->at(i)); + + sqlite3_str *str = sqlite3_str_new(0); + sqlite3_str_appendf(str, "size: %lld [", pVec->size()); + + for (int i = 0; i < pVec->size(); i++) { + + if (i == 0) + sqlite3_str_appendf(str, "%f", pVec->at(i)); + else + sqlite3_str_appendf(str, ", %f", pVec->at(i)); } + sqlite3_str_appendchar(str, 1, ']'); sqlite3_result_text(context, sqlite3_str_finish(str), -1, sqlite3_free); - delete v; - - }else { - sqlite3_result_text(context, "yo", -1, SQLITE_STATIC); - } } + #pragma endregion +#pragma region Vector generation + +// TODO should return fvec, ivec, or bvec depending on input. How do bvec, +// though? +static void vector_from(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + vector vec; + vec.reserve(argc); + for (int i = 0; i < argc; i++) { + vec.push_back(sqlite3_value_double(argv[i])); + } -#pragma region vector generation -// TODO should return fvec, ivec, or bvec depending on input. How do bvec, though? -static void vector_from(sqlite3_context *context, int argc, sqlite3_value **argv) { - std::vector * v = new std::vector(); - v->reserve(argc); - for(int i = 0; i < argc; i++) { - v->push_back(sqlite3_value_double(argv[i])); - } - resultVector(context, v); - delete v; + resultVector(context, &vec); } + #pragma endregion -#pragma region vector general -static void vector_value_at(sqlite3_context *context, int argc, sqlite3_value **argv) { - std::vector*v = valueAsVector(argv[0]); - if(v == NULL) return; - int at = sqlite3_value_int(argv[1]); - try { - float result = v->at(at); - sqlite3_result_double(context, result); - } - catch (const std::out_of_range& oor) { - char * errmsg = sqlite3_mprintf("%d out of range: %s", at, oor.what()); - if(errmsg != NULL){ - sqlite3_result_error(context, errmsg, -1); - sqlite3_free(errmsg); +#pragma region Vector general + +static void vector_value_at(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + vec_ptr pVec = valueAsVector(argv[0]); + + if (pVec == nullptr) + return; + + int pos = sqlite3_value_int(argv[1]); + + try { + + float result = pVec->at(pos); + sqlite3_result_double(context, result); + + } catch (const out_of_range &oor) { + + char *errmsg = sqlite3_mprintf("%d out of range: %s", pos, oor.what()); + + if (errmsg != nullptr) { + sqlite3_result_error(context, errmsg, -1); + sqlite3_free(errmsg); + } else { + sqlite3_result_error_nomem(context); + } } - else sqlite3_result_error_nomem(context); - } } -static void vector_length(sqlite3_context *context, int argc, sqlite3_value **argv) { - VectorFloat* v = (VectorFloat*) sqlite3_value_pointer(argv[0], VECTOR_FLOAT_POINTER_NAME); - if(v==NULL) return; - sqlite3_result_int64(context, v->size); +static void vector_length(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + auto pVec = (VectorFloat *)sqlite3_value_pointer(argv[0], VECTOR_FLOAT_POINTER_NAME); + if (pVec == nullptr) + return; + + sqlite3_result_int64(context, pVec->size); } + #pragma endregion +#pragma region Json +static void vector_to_json(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + vec_ptr pVec = valueAsVector(argv[0]); + if (pVec == nullptr) + return; -#pragma region json -static void vector_to_json(sqlite3_context *context, int argc, sqlite3_value **argv) { - std::vector*v = valueAsVector(argv[0]); - if(v == NULL) return; + json j = json(*pVec); - json j = json(*v); - sqlite3_result_text(context, j.dump().c_str(), -1, SQLITE_TRANSIENT); - sqlite3_result_subtype(context, JSON_SUBTYPE); + sqlite3_result_text(context, j.dump().c_str(), -1, SQLITE_TRANSIENT); + sqlite3_result_subtype(context, JSON_SUBTYPE); } -static void vector_from_json(sqlite3_context *context, int argc, sqlite3_value **argv) { - const char * text = (const char *) sqlite3_value_text(argv[0]); - std::vector* v =vectorFromTextValue(argv[0]); - if(v == NULL) { - sqlite3_result_error(context, "input not valid json, or contains non-float data", -1); - }else { - resultVector(context, v); - delete v; - } - //json j = json::parse(text); - //std::vector *v = new std::vector(); - //j.get_to(*v); +static void vector_from_json(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + const char *text = (const char *)sqlite3_value_text(argv[0]); + vec_ptr pVec = vectorFromTextValue(argv[0]); + if (pVec == nullptr) { + sqlite3_result_error( + context, "input not valid json, or contains non-float data", -1); + } else { + resultVector(context, pVec.get()); + } } + #pragma endregion -#pragma region blob +#pragma region Blob /* @@ -237,359 +291,410 @@ static void vector_from_json(sqlite3_context *context, int argc, sqlite3_value * |-|-|- |a|a|A */ -static void vector_to_blob(sqlite3_context *context, int argc, sqlite3_value **argv) { - std::vector*v = valueAsVector(argv[0]); - if(v == NULL) return; +static void vector_to_blob(sqlite3_context *context, + int argc, + sqlite3_value **argv) { - int sz = v->size(); - int n = (sizeof(char)) + (sizeof(char)) + (sz * 4); - void * b = sqlite3_malloc(n); - memset(b, 0, n); + vec_ptr pVec = valueAsVector(argv[0]); + if (pVec == nullptr) + return; - memcpy((void *) ((char *) b+0), (void *) &VECTOR_BLOB_HEADER_BYTE, sizeof(char)); - memcpy((void *) ((char *) b+1), (void *) &VECTOR_BLOB_HEADER_TYPE, sizeof(char)); - memcpy((void *) ((char *) b+2), (void *) v->data(), sz*4); - sqlite3_result_blob64(context, b, n, sqlite3_free); - delete v; + int size = pVec->size(); + int memSize = (sizeof(char)) + (sizeof(char)) + (size * 4); + void *pBlob = sqlite3_malloc(memSize); + memset(pBlob, 0, memSize); -} + memcpy((void *)((char *)pBlob + 0), (void *)&VECTOR_BLOB_HEADER_BYTE, sizeof(char)); + memcpy((void *)((char *)pBlob + 1), (void *)&VECTOR_BLOB_HEADER_TYPE, sizeof(char)); + memcpy((void *)((char *)pBlob + 2), (void *)pVec->data(), size * 4); -static void vector_from_blob(sqlite3_context *context, int argc, sqlite3_value **argv) { - const char * pzErrMsg; - std::vector * vec = vectorFromBlobValue(argv[0], &pzErrMsg); - if(vec == NULL) { - sqlite3_result_error(context, pzErrMsg, -1); - } else { - resultVector(context, vec); - delete vec; - } + sqlite3_result_blob64(context, pBlob, memSize, sqlite3_free); } +static void vector_from_blob(sqlite3_context *context, + int argc, + sqlite3_value **argv) { -static void vector_to_raw(sqlite3_context *context, int argc, sqlite3_value **argv) { - std::vector*v = valueAsVector(argv[0]); - if(v == NULL) return; + const char *pzErrMsg; - int sz = v->size(); - int n = sz * sizeof(float); - void * b = sqlite3_malloc(n); - memset(b, 0, n); - memcpy((void *) ((char *) b), (void *) v->data(), n); - sqlite3_result_blob64(context, b, n, sqlite3_free); - delete v; + vec_ptr pVec = vectorFromBlobValue(argv[0], &pzErrMsg); + if (pVec == nullptr) + sqlite3_result_error(context, pzErrMsg, -1); + else + resultVector(context, pVec.get()); +} +static void vector_to_raw(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + vec_ptr pVec = valueAsVector(argv[0]); + if (pVec == nullptr) + return; + + int size = pVec->size(); + int n = size * sizeof(float); + void *pBlob = sqlite3_malloc(n); + memset(pBlob, 0, n); + memcpy((void *)((char *)pBlob), (void *)pVec->data(), n); + sqlite3_result_blob64(context, pBlob, n, sqlite3_free); } -static void vector_from_raw(sqlite3_context *context, int argc, sqlite3_value **argv) { - const char * pzErrMsg; - std::vector * vec = vectorFromRawBlobValue(argv[0], &pzErrMsg); - if(vec == NULL) { - sqlite3_result_error(context, pzErrMsg, -1); - } else { - resultVector(context, vec); - delete vec; - } + +static void vector_from_raw(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + const char *pzErrMsg; // TODO: Shouldn't we have like error messages here? + + vec_ptr pVec = vectorFromRawBlobValue(argv[0], &pzErrMsg); + if (pVec == nullptr) + sqlite3_result_error(context, pzErrMsg, -1); + else + resultVector(context, pVec.get()); } + #pragma endregion +#pragma region fvecs vtab +struct fvecsEach_vtab : public sqlite3_vtab { + + fvecsEach_vtab() { + pModule = nullptr; + nRef = 0; + zErrMsg = nullptr; + } -#pragma region fvecs vtab + ~fvecsEach_vtab() { -typedef struct fvecsEach_vtab fvecsEach_vtab; -struct fvecsEach_vtab { - sqlite3_vtab base; /* Base class - must be first */ + if (zErrMsg != nullptr) { + sqlite3_free(zErrMsg); + } + } }; -typedef struct fvecsEach_cursor fvecsEach_cursor; -struct fvecsEach_cursor { - sqlite3_vtab_cursor base; /* Base class - must be first */ - sqlite3_int64 iRowid; - // malloc'ed copy of fvecs input blob - void * pBlob; - // total size of pBlob in bytes - sqlite3_int64 iBlobN; - sqlite3_int64 p; - - // current dimensions - int iCurrentD; - // pointer to current vector being read in - std::vector* pCurrentVector; +struct fvecsEach_cursor : public sqlite3_vtab_cursor { + + fvecsEach_cursor(sqlite3_vtab *pVtab) { + + this->pVtab = pVtab; + iRowid = 0; + pBlob = nullptr; + iBlobN = 0; + p = 0; + iCurrentD = 0; + } + + ~fvecsEach_cursor() { + if (pBlob != nullptr) + sqlite3_free(pBlob); + } + + sqlite3_int64 iRowid; + + // Copy of fvecs input blob + void *pBlob; + + // Total size of pBlob in bytes + sqlite3_int64 iBlobN; + sqlite3_int64 p; + + // Current dimensions + int iCurrentD; + + // Pointer to current vector being read in + vec_ptr pCurrentVector; }; -static int fvecsEachConnect( - sqlite3 *db, - void *pAux, - int argc, const char *const*argv, - sqlite3_vtab **ppVtab, - char **pzErr -){ - fvecsEach_vtab *pNew; - int rc; - - rc = sqlite3_declare_vtab(db, - "CREATE TABLE x(dimensions, vector, input hidden)" - ); -#define FVECS_EACH_DIMENSIONS 0 -#define FVECS_EACH_VECTOR 1 -#define FVECS_EACH_INPUT 2 - if( rc==SQLITE_OK ){ - pNew = (fvecsEach_vtab *) sqlite3_malloc( sizeof(*pNew) ); - *ppVtab = (sqlite3_vtab*)pNew; - if( pNew==0 ) return SQLITE_NOMEM; - memset(pNew, 0, sizeof(*pNew)); - } - return rc; +static int fvecsEachConnect(sqlite3 *db, + void *pAux, + int argc, + const char *const *argv, + sqlite3_vtab **ppVtab, + char **pzErr) { + + int rc; + + rc = sqlite3_declare_vtab(db, "create table x(dimensions, vector, input hidden)"); + +#define FVECS_EACH_DIMENSIONS 0 +#define FVECS_EACH_VECTOR 1 +#define FVECS_EACH_INPUT 2 + + if (rc == SQLITE_OK) { + + auto pNew = new fvecsEach_vtab(); + if (pNew == 0) + return SQLITE_NOMEM; + + *ppVtab = pNew; + } + return rc; } -static int fvecsEachDisconnect(sqlite3_vtab *pVtab){ - fvecsEach_vtab *p = (fvecsEach_vtab*)pVtab; - sqlite3_free(p); - return SQLITE_OK; +static int fvecsEachDisconnect(sqlite3_vtab *pVtab) { + + auto pTable = static_cast(pVtab); + delete pTable; + return SQLITE_OK; } -static int fvecsEachOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor){ - fvecsEach_cursor *pCur; - pCur = (fvecsEach_cursor *)sqlite3_malloc( sizeof(*pCur) ); - if( pCur==0 ) return SQLITE_NOMEM; - memset(pCur, 0, sizeof(*pCur)); - *ppCursor = &pCur->base; - return SQLITE_OK; +static int fvecsEachOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { + + auto pCur = new fvecsEach_cursor(p); + if (pCur == nullptr) + return SQLITE_NOMEM; + + *ppCursor = pCur; + return SQLITE_OK; } -static int fvecsEachClose(sqlite3_vtab_cursor *cur){ - fvecsEach_cursor *pCur = (fvecsEach_cursor*)cur; - sqlite3_free(pCur); - return SQLITE_OK; +static int fvecsEachClose(sqlite3_vtab_cursor *cur) { + + auto pCur = static_cast(cur); + delete pCur; + return SQLITE_OK; } +static int fvecsEachBestIndex(sqlite3_vtab *tab, sqlite3_index_info *pIdxInfo) { + + for (int i = 0; i < pIdxInfo->nConstraint; i++) { + + auto pCons = pIdxInfo->aConstraint[i]; -static int fvecsEachBestIndex( - sqlite3_vtab *tab, - sqlite3_index_info *pIdxInfo -){ - for (int i = 0; i < pIdxInfo->nConstraint; i++) { - auto pCons = pIdxInfo->aConstraint[i]; - switch (pCons.iColumn) { - case FVECS_EACH_INPUT: { - if (pCons.op == SQLITE_INDEX_CONSTRAINT_EQ && pCons.usable) { - pIdxInfo->aConstraintUsage[i].argvIndex = 1; - pIdxInfo->aConstraintUsage[i].omit = 1; + switch (pCons.iColumn) { + + case FVECS_EACH_INPUT: + if (pCons.op == SQLITE_INDEX_CONSTRAINT_EQ && pCons.usable) { + + pIdxInfo->aConstraintUsage[i].argvIndex = 1; + pIdxInfo->aConstraintUsage[i].omit = 1; + } + break; } - break; - } - } } - pIdxInfo->estimatedCost = (double)10; - pIdxInfo->estimatedRows = 10; - return SQLITE_OK; + + pIdxInfo->estimatedCost = (double)10; + pIdxInfo->estimatedRows = 10; + return SQLITE_OK; } -static int fvecsEachFilter( - sqlite3_vtab_cursor *pVtabCursor, - int idxNum, const char *idxStr, - int argc, sqlite3_value **argv -){ - fvecsEach_cursor *pCur = (fvecsEach_cursor *)pVtabCursor; +static int fvecsEachFilter(sqlite3_vtab_cursor *pVtabCursor, + int idxNum, + const char *idxStr, + int argc, + sqlite3_value **argv) { + + auto pCur = static_cast(pVtabCursor); + + int size = sqlite3_value_bytes(argv[0]); + const void *blob = sqlite3_value_blob(argv[0]); + + if (pCur->pBlob) + sqlite3_free(pCur->pBlob); + + pCur->pBlob = sqlite3_malloc(size); + pCur->iBlobN = size; + pCur->iRowid = 1; + memcpy(pCur->pBlob, blob, size); - int n = sqlite3_value_bytes(argv[0]); - const void * b = sqlite3_value_blob(argv[0]); + memcpy(&pCur->iCurrentD, pCur->pBlob, sizeof(int)); + float *vecBegin = (float *)((char *)pCur->pBlob + sizeof(int)); - pCur->pBlob = sqlite3_malloc(n); - pCur->iBlobN = n; - pCur->iRowid = 1; - memcpy(pCur->pBlob, b, n); + // TODO: Shouldn't this multiply by sizeof(float)? + pCur->pCurrentVector = vec_ptr(new vector(vecBegin, vecBegin + pCur->iCurrentD)); - memcpy(&pCur->iCurrentD, pCur->pBlob, sizeof(int)); - float * v = (float *) ((char *)pCur->pBlob + sizeof(int)); - pCur->pCurrentVector = new std::vector(v, v+pCur->iCurrentD); - pCur->p = sizeof(int) + (pCur->iCurrentD*sizeof(float)); + pCur->p = sizeof(int) + (pCur->iCurrentD * sizeof(float)); - return SQLITE_OK; + return SQLITE_OK; } -static int fvecsEachNext(sqlite3_vtab_cursor *cur){ - fvecsEach_cursor *pCur = (fvecsEach_cursor*)cur; +static int fvecsEachNext(sqlite3_vtab_cursor *cur) { + + auto pCur = static_cast(cur); + + // TODO: Shouldn't this multiply by sizeof(float)? + memcpy(&pCur->iCurrentD, ((char *)pCur->pBlob + pCur->p), sizeof(int)); + float *vecBegin = (float *)(((char *)pCur->pBlob + pCur->p) + sizeof(int)); - memcpy(&pCur->iCurrentD, ((char *)pCur->pBlob + pCur->p), sizeof(int)); - float * v = (float *) (((char *)pCur->pBlob + pCur->p) + sizeof(int)); - pCur->pCurrentVector->clear(); - pCur->pCurrentVector->reserve(pCur->iCurrentD);// = new std::vector(v, v+pCur->iCurrentD); - pCur->pCurrentVector->insert(pCur->pCurrentVector->begin(), v, v+pCur->iCurrentD); + pCur->pCurrentVector->clear(); + pCur->pCurrentVector->shrink_to_fit(); + pCur->pCurrentVector->reserve(pCur->iCurrentD); + pCur->pCurrentVector->insert(pCur->pCurrentVector->begin(), + vecBegin, + vecBegin + pCur->iCurrentD); - pCur->p += (sizeof(int) + (pCur->iCurrentD*sizeof(float))); - pCur->iRowid++; - return SQLITE_OK; + pCur->p += (sizeof(int) + (pCur->iCurrentD * sizeof(float))); + pCur->iRowid++; + return SQLITE_OK; } -static int fvecsEachEof(sqlite3_vtab_cursor *cur){ - fvecsEach_cursor *pCur = (fvecsEach_cursor*)cur; +static int fvecsEachEof(sqlite3_vtab_cursor *cur) { - return pCur->p > pCur->iBlobN; + auto pCur = (fvecsEach_cursor *)cur; + return pCur->p > pCur->iBlobN; } -static int fvecsEachRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid){ - fvecsEach_cursor *pCur = (fvecsEach_cursor*)cur; - *pRowid = pCur->iRowid; - return SQLITE_OK; +static int fvecsEachRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { + + fvecsEach_cursor *pCur = (fvecsEach_cursor *)cur; + *pRowid = pCur->iRowid; + return SQLITE_OK; } -static int fvecsEachColumn( - sqlite3_vtab_cursor *cur, /* The cursor */ - sqlite3_context *context, /* First argument to sqlite3_result_...() */ - int i /* Which column to return */ -){ - fvecsEach_cursor *pCur = (fvecsEach_cursor*)cur; - switch( i ){ - case FVECS_EACH_DIMENSIONS: - sqlite3_result_int(context, pCur->iCurrentD); - break; - case FVECS_EACH_VECTOR: - resultVector(context, pCur->pCurrentVector); - break; - case FVECS_EACH_INPUT: - sqlite3_result_null(context); - break; - } - return SQLITE_OK; +static int fvecsEachColumn(sqlite3_vtab_cursor *cur, + sqlite3_context *context, + int i) { + + auto pCur = static_cast(cur); + + switch (i) { + + case FVECS_EACH_DIMENSIONS: + sqlite3_result_int(context, pCur->iCurrentD); + break; + + case FVECS_EACH_VECTOR: + resultVector(context, pCur->pCurrentVector.get()); + break; + + case FVECS_EACH_INPUT: + sqlite3_result_null(context); + break; + } + return SQLITE_OK; } /* -** This following structure defines all the methods for the -** virtual table. -*/ + * This following structure defines all the methods for the + * virtual table. + */ static sqlite3_module fvecsEachModule = { - /* iVersion */ 0, - /* xCreate */ fvecsEachConnect, - /* xConnect */ fvecsEachConnect, - /* xBestIndex */ fvecsEachBestIndex, - /* xDisconnect */ fvecsEachDisconnect, - /* xDestroy */ 0, - /* xOpen */ fvecsEachOpen, - /* xClose */ fvecsEachClose, - /* xFilter */ fvecsEachFilter, - /* xNext */ fvecsEachNext, - /* xEof */ fvecsEachEof, - /* xColumn */ fvecsEachColumn, - /* xRowid */ fvecsEachRowid, - /* xUpdate */ 0, - /* xBegin */ 0, - /* xSync */ 0, - /* xCommit */ 0, - /* xRollback */ 0, - /* xFindMethod */ 0, - /* xRename */ 0, - /* xSavepoint */ 0, - /* xRelease */ 0, - /* xRollbackTo */ 0, - /* xShadowName */ 0 -}; + /* iVersion */ 0, + /* xCreate */ fvecsEachConnect, + /* xConnect */ fvecsEachConnect, + /* xBestIndex */ fvecsEachBestIndex, + /* xDisconnect */ fvecsEachDisconnect, + /* xDestroy */ 0, + /* xOpen */ fvecsEachOpen, + /* xClose */ fvecsEachClose, + /* xFilter */ fvecsEachFilter, + /* xNext */ fvecsEachNext, + /* xEof */ fvecsEachEof, + /* xColumn */ fvecsEachColumn, + /* xRowid */ fvecsEachRowid, + /* xUpdate */ 0, + /* xBegin */ 0, + /* xSync */ 0, + /* xCommit */ 0, + /* xRollback */ 0, + /* xFindMethod */ 0, + /* xRename */ 0, + /* xSavepoint */ 0, + /* xRelease */ 0, + /* xRollbackTo */ 0, + /* xShadowName */ 0}; #pragma endregion +#pragma region Entrypoint + +static void vector0(sqlite3_context *context, int argc, sqlite3_value **argv) { -#pragma region fvecs - -static void vector_fvecs(sqlite3_context *context, int argc, sqlite3_value **argv) { - sqlite3_int64 sz = sqlite3_value_bytes(argv[0]); - const void * blob = sqlite3_value_blob(argv[0]); - int d; - memcpy((void *) &d, (void *) blob, sizeof(int)); - if(d <= 0 || d >= 1000000) { - sqlite3_result_error(context, "unreasonable dimensions size", -1); - return; - } - if( sz % ((d + 1) * 4) != 0) { - sqlite3_result_error(context, "wrong blob size", -1); - return; - } - size_t n = sz / ((d + 1) * 4); - printf("sz=%lld, d=%d n=%zu\n", sz, d, n); - - float* x = new float[n * (d + 1)]; - memcpy(x, ((char*) blob) + sizeof(int), (sizeof(float)) * (n * (d + 1))); - //for (size_t i = 0; i < n; i++) - // memmove(x + i * d, x + 1 + i * (d + 1), d * sizeof(*x)); - - printf("x[0]=%f \n", x[0]); - printf("x[1]=%f \n", x[1]); - //sqlite3_result_text(context, "yo", -1, SQLITE_STATIC); + auto api = (vector0_api *)sqlite3_user_data(context); + vector0_api **ppApi = (vector0_api **)sqlite3_value_pointer(argv[0], "vector0_api_ptr"); + if (ppApi) + *ppApi = api; } -#pragma endregion +static void delete_api(void * pApi) { -#pragma region entrypoint -static void vector0(sqlite3_context *context, int argc, sqlite3_value **argv) { - Vector0Global *pGlobal = (Vector0Global*)sqlite3_user_data(context); - vector0_api **ppApi; - ppApi = (vector0_api**)sqlite3_value_pointer(argv[0], "vector0_api_ptr"); - if( ppApi ) *ppApi = &pGlobal->api; + auto api = static_cast(pApi); + delete api; } extern "C" { - #ifdef _WIN32 - __declspec(dllexport) - #endif - int sqlite3_vector_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi) { - int rc = SQLITE_OK; - SQLITE_EXTENSION_INIT2(pApi); - Vector0Global *pGlobal = 0; - pGlobal = (Vector0Global*)sqlite3_malloc(sizeof(Vector0Global)); - if( pGlobal==0 ){ - return SQLITE_NOMEM; - } - void *p = (void*)pGlobal; - memset(pGlobal, 0, sizeof(Vector0Global)); - pGlobal->db = db; - pGlobal->api.iVersion = 0; - pGlobal->api.xValueAsVector = valueAsVector; - pGlobal->api.xResultVector = resultVector; - rc = sqlite3_create_function_v2(db, "vector0", 1, - SQLITE_UTF8, - p, - vector0, 0, 0, sqlite3_free); - static const struct { - char *zFName; - int nArg; - void* pAux; - void (*xFunc)(sqlite3_context*,int,sqlite3_value**); - int flags; - } aFunc[] = { - //{ (char*) "vector0", 1, p, vector0, SQLITE_UTF8 }, - { (char*) "vector_version", 0, NULL, vector_version, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS }, - { (char*) "vector_debug", 0, NULL, vector_debug, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS }, - { (char*) "vector_debug", 1, NULL, vector_debug, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS }, - { (char*) "vector_length", 1, NULL, vector_length, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS }, - { (char*) "vector_value_at", 2, NULL, vector_value_at, SQLITE_UTF8|SQLITE_INNOCUOUS}, - { (char*) "vector_from_json", 1, NULL, vector_from_json, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS}, - { (char*) "vector_to_json", 1, NULL, vector_to_json, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS}, - { (char*) "vector_from_blob", 1, NULL, vector_from_blob, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS}, - { (char*) "vector_to_blob", 1, NULL, vector_to_blob, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS}, - { (char*) "vector_from_raw", 1, NULL, vector_from_raw, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS}, - { (char*) "vector_to_raw", 1, NULL, vector_to_raw, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS}, - }; - for(int i=0; iiVersion = 0; + api->xValueAsVector = valueAsVector; + api->xResultVector = resultVector; + + rc = sqlite3_create_function_v2(db, + "vector0", + 1, + SQLITE_UTF8, + api, + vector0, + 0, + 0, + delete_api); + + if (rc != SQLITE_OK) { + + *pzErrMsg = sqlite3_mprintf("%s: %s", "vector0", sqlite3_errmsg(db)); + return rc; + } - rc = sqlite3_create_module(db, "vector_fvecs_each", &fvecsEachModule, 0); - if(rc != SQLITE_OK) goto fail; + static const struct { + + char *zFName; + int nArg; + void *pAux; + void (*xFunc)(sqlite3_context *, int, sqlite3_value **); + int flags; + + } aFunc[] = { + + { (char *)"vector_version", 0, nullptr, vector_version, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS }, + { (char *)"vector_debug", 1, nullptr, vector_debug, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS }, + { (char *)"vector_length", 1, nullptr, vector_length, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS }, + { (char *)"vector_value_at", 2, nullptr, vector_value_at, SQLITE_UTF8 | SQLITE_INNOCUOUS }, + { (char *)"vector_from_json", 1, nullptr, vector_from_json, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS }, + { (char *)"vector_to_json", 1, nullptr, vector_to_json, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS }, + { (char *)"vector_from_blob", 1, nullptr, vector_from_blob, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS }, + { (char *)"vector_to_blob", 1, nullptr, vector_to_blob, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS }, + { (char *)"vector_from_raw", 1, nullptr, vector_from_raw, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS }, + { (char *)"vector_to_raw", 1, nullptr, vector_to_raw, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS }, + }; + + for (int i = 0; i < sizeof(aFunc) / sizeof(aFunc[0]) && rc == SQLITE_OK; i++) { + + rc = sqlite3_create_function_v2(db, + aFunc[i].zFName, + aFunc[i].nArg, + aFunc[i].flags, + aFunc[i].pAux, + aFunc[i].xFunc, + 0, + 0, + 0); + + if (rc != SQLITE_OK) { + + *pzErrMsg = sqlite3_mprintf("%s: %s", aFunc[i].zFName, sqlite3_errmsg(db)); + return rc; + } + } + rc = sqlite3_create_module_v2(db, "vector_fvecs_each", &fvecsEachModule, nullptr, nullptr); + if (rc != SQLITE_OK) { - return SQLITE_OK; + *pzErrMsg = sqlite3_mprintf("%s", sqlite3_errmsg(db)); + return rc; + } - fail: - *pzErrMsg = sqlite3_mprintf("%s", sqlite3_errmsg(db)); - return rc; - } + return SQLITE_OK; + } } #pragma endregion diff --git a/src/sqlite-vector.h.in b/src/sqlite-vector.h.in index 9a468f7..3249334 100644 --- a/src/sqlite-vector.h.in +++ b/src/sqlite-vector.h.in @@ -4,29 +4,32 @@ #include "sqlite3ext.h" #ifdef __cplusplus +#include #include #include struct VectorFloat { - int64_t size; - float * data; + int64_t size; + float *data; }; -typedef struct vector0_api vector0_api; struct vector0_api { - int iVersion; - std::vector* (*xValueAsVector)(sqlite3_value*value); - void (*xResultVector)(sqlite3_context*context, std::vector*); + + int iVersion; + std::unique_ptr> (*xValueAsVector)(sqlite3_value *value); + void (*xResultVector)(sqlite3_context *context, std::vector *); }; + #endif /* end of C++ specific APIs*/ #ifdef __cplusplus extern "C" { #endif -int sqlite3_vector_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi); +int sqlite3_vector_init(sqlite3 *db, char **pzErrMsg, + const sqlite3_api_routines *pApi); #ifdef __cplusplus -} /* end of the 'extern "C"' block */ +} /* end of the 'extern "C"' block */ #endif #endif /* ifndef _SQLITE_VECTOR_H */ diff --git a/src/sqlite-vss.cpp b/src/sqlite-vss.cpp index 5a363b7..6baa0f4 100644 --- a/src/sqlite-vss.cpp +++ b/src/sqlite-vss.cpp @@ -1,6 +1,6 @@ +#include "sqlite-vss.h" #include #include -#include "sqlite-vss.h" #include "sqlite3ext.h" SQLITE_EXTENSION_INIT1 @@ -10,1308 +10,1643 @@ SQLITE_EXTENSION_INIT1 #include #include - #include #include -#include -#include -#include -#include #include #include +#include +#include +#include #include +#include #include "sqlite-vector.h" -// https://github.com/sqlite/sqlite/blob/master/src/json.c#L88-L89 -//#define JSON_SUBTYPE 74 /* Ascii for "J" */ +using namespace std; -#pragma region work +typedef unique_ptr> vec_ptr; -static void vss_version(sqlite3_context *context, int argc, sqlite3_value **argv) { - sqlite3_result_text(context, SQLITE_VSS_VERSION, -1, SQLITE_STATIC); -} -static void vss_debug(sqlite3_context *context, int argc, sqlite3_value **argv) { - const char * debug = sqlite3_mprintf("version: %s\nfaiss version: %d.%d.%d\nfaiss compile options: %s", - SQLITE_VSS_VERSION, FAISS_VERSION_MAJOR, FAISS_VERSION_MINOR, FAISS_VERSION_PATCH, - faiss::get_compile_options().c_str()); - sqlite3_result_text(context, debug, -1, SQLITE_TRANSIENT); - sqlite3_free((void *) debug); +#pragma region Meta + +static void vss_version(sqlite3_context *context, int argc, + sqlite3_value **argv) { + + sqlite3_result_text(context, SQLITE_VSS_VERSION, -1, SQLITE_STATIC); } +static void vss_debug(sqlite3_context *context, + int argc, + sqlite3_value **argv) { -#pragma endregion + auto resTxt = sqlite3_mprintf( + "version: %s\nfaiss version: %d.%d.%d\nfaiss compile options: %s", + SQLITE_VSS_VERSION, + FAISS_VERSION_MAJOR, + FAISS_VERSION_MINOR, + FAISS_VERSION_PATCH, + faiss::get_compile_options().c_str()); -#pragma region distances - -static void vss_distance_l1(sqlite3_context *context, int argc, sqlite3_value **argv) { - std::vector* a; - std::vector* b; - vector0_api * vector_api = (vector0_api*) sqlite3_user_data(context); - - a = vector_api->xValueAsVector(argv[0]); - if(a==NULL) { - sqlite3_result_error(context, "a is not a vector", -1); - return; - } - b = vector_api->xValueAsVector(argv[1]); - if(b==NULL) { - sqlite3_result_error(context, "b is not a vector", -1); - delete a; - return; - } - if(a->size() != b->size()) { - sqlite3_result_error(context, "b is not a vector", -1); - delete a; - delete b; - return; - } - int d = a->size(); - sqlite3_result_double(context, faiss::fvec_L1(a->data(), b->data(), d)); - delete a; - delete b; + sqlite3_result_text(context, resTxt, -1, SQLITE_TRANSIENT); + sqlite3_free(resTxt); } -static void vss_distance_l2(sqlite3_context *context, int argc, sqlite3_value **argv) { - std::vector* a; - std::vector* b; - vector0_api * vector_api = (vector0_api*) sqlite3_user_data(context); - - a = vector_api->xValueAsVector(argv[0]); - if(a==NULL) { - sqlite3_result_error(context, "a is not a vector", -1); - return; - } - b = vector_api->xValueAsVector(argv[1]); - if(b==NULL) { - sqlite3_result_error(context, "b is not a vector", -1); - delete a; - return; - } - if(a->size() != b->size()) { - sqlite3_result_error(context, "TODO", -1); - delete a; - delete b; - return; - } - int d = a->size(); - sqlite3_result_double(context, faiss::fvec_L2sqr(a->data(), b->data(), d)); - delete a; - delete b; +#pragma endregion + +#pragma region Distances + +static void vss_distance_l1(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + auto vector_api = (vector0_api *)sqlite3_user_data(context); + + vec_ptr lhs = vector_api->xValueAsVector(argv[0]); + if (lhs == nullptr) { + sqlite3_result_error(context, "LHS is not a vector", -1); + return; + } + + vec_ptr rhs = vector_api->xValueAsVector(argv[1]); + if (rhs == nullptr) { + sqlite3_result_error(context, "RHS is not a vector", -1); + return; + } + + if (lhs->size() != rhs->size()) { + sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", + -1); + return; + } + + sqlite3_result_double(context, faiss::fvec_L1(lhs->data(), rhs->data(), lhs->size())); } -static void vss_distance_linf(sqlite3_context *context, int argc, sqlite3_value **argv) { - std::vector* a; - std::vector* b; - vector0_api * vector_api = (vector0_api*) sqlite3_user_data(context); - - a = vector_api->xValueAsVector(argv[0]); - if(a==NULL) { - sqlite3_result_error(context, "a is not a vector", -1); - return; - } - b = vector_api->xValueAsVector(argv[1]); - if(b==NULL) { - sqlite3_result_error(context, "b is not a vector", -1); - delete a; - return; - } - if(a->size() != b->size()) { - sqlite3_result_error(context, "b is not a vector", -1); - delete a; - delete b; - return; - } - int d = a->size(); - sqlite3_result_double(context, faiss::fvec_Linf(a->data(), b->data(), d)); - delete a; - delete b; +static void vss_distance_l2(sqlite3_context *context, int argc, + sqlite3_value **argv) { + + auto vector_api = (vector0_api *)sqlite3_user_data(context); + + vec_ptr lhs = vector_api->xValueAsVector(argv[0]); + if (lhs == nullptr) { + sqlite3_result_error(context, "LHS is not a vector", -1); + return; + } + + vec_ptr rhs = vector_api->xValueAsVector(argv[1]); + if (rhs == nullptr) { + sqlite3_result_error(context, "RHS is not a vector", -1); + return; + } + + if (lhs->size() != rhs->size()) { + sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", + -1); + return; + } + + sqlite3_result_double(context, faiss::fvec_L2sqr(lhs->data(), rhs->data(), lhs->size())); } -static void vss_inner_product(sqlite3_context *context, int argc, sqlite3_value **argv) { - std::vector* a; - std::vector* b; - vector0_api * vector_api = (vector0_api*) sqlite3_user_data(context); - - a = vector_api->xValueAsVector(argv[0]); - if(a==NULL) { - sqlite3_result_error(context, "a is not a vector", -1); - return; - } - b = vector_api->xValueAsVector(argv[1]); - if(b==NULL) { - sqlite3_result_error(context, "b is not a vector", -1); - delete a; - return; - } - if(a->size() != b->size()) { - sqlite3_result_error(context, "b is not a vector", -1); - delete a; - delete b; - return; - } - int d = a->size(); - sqlite3_result_double(context, faiss::fvec_inner_product(a->data(), b->data(), d)); - delete a; - delete b; +static void vss_distance_linf(sqlite3_context *context, int argc, + sqlite3_value **argv) { + + auto vector_api = (vector0_api *)sqlite3_user_data(context); + + vec_ptr lhs = vector_api->xValueAsVector(argv[0]); + if (lhs == nullptr) { + sqlite3_result_error(context, "LHS is not a vector", -1); + return; + } + + vec_ptr rhs = vector_api->xValueAsVector(argv[1]); + if (rhs == nullptr) { + sqlite3_result_error(context, "RHS is not a vector", -1); + return; + } + + if (lhs->size() != rhs->size()) { + sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", + -1); + return; + } + + sqlite3_result_double(context, faiss::fvec_Linf(lhs->data(), rhs->data(), lhs->size())); } -static void vss_fvec_add(sqlite3_context *context, int argc, sqlite3_value **argv) { - std::vector* a; - std::vector* b; - std::vector* c; - vector0_api * vector_api = (vector0_api*) sqlite3_user_data(context); - - a = vector_api->xValueAsVector(argv[0]); - if(a==NULL) { - sqlite3_result_error(context, "a is not a vector", -1); - return; - } - b = vector_api->xValueAsVector(argv[1]); - if(b==NULL) { - sqlite3_result_error(context, "b is not a vector", -1); - delete a; - return; - } - if(a->size() != b->size()) { - sqlite3_result_error(context, "b is not a vector", -1); - delete a; - delete b; - return; - } - int d = a->size(); - c = new std::vector(d); - faiss::fvec_add(d, a->data(), b->data(), c->data()); - vector_api->xResultVector(context, c); - delete a; - delete b; - delete c; +static void vss_inner_product(sqlite3_context *context, int argc, + sqlite3_value **argv) { + + auto vector_api = (vector0_api *)sqlite3_user_data(context); + + vec_ptr lhs = vector_api->xValueAsVector(argv[0]); + if (lhs == nullptr) { + sqlite3_result_error(context, "LHS is not a vector", -1); + return; + } + + vec_ptr rhs = vector_api->xValueAsVector(argv[1]); + if (rhs == nullptr) { + sqlite3_result_error(context, "RHS is not a vector", -1); + return; + } + + if (lhs->size() != rhs->size()) { + sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", + -1); + return; + } + + sqlite3_result_double(context, + faiss::fvec_inner_product(lhs->data(), rhs->data(), lhs->size())); } -static void vss_fvec_sub(sqlite3_context *context, int argc, sqlite3_value **argv) { - std::vector* a; - std::vector* b; - std::vector* c; - vector0_api * vector_api = (vector0_api*) sqlite3_user_data(context); - - a = vector_api->xValueAsVector(argv[0]); - if(a==NULL) { - sqlite3_result_error(context, "a is not a vector", -1); - return; - } - b = vector_api->xValueAsVector(argv[1]); - if(b==NULL) { - sqlite3_result_error(context, "b is not a vector", -1); - delete a; - return; - } - if(a->size() != b->size()) { - sqlite3_result_error(context, "b is not a vector", -1); - delete a; - delete b; - return; - } - int d = a->size(); - c = new std::vector(d); - faiss::fvec_sub(d, a->data(), b->data(), c->data()); - vector_api->xResultVector(context, c); - delete a; - delete b; - delete c; +static void vss_fvec_add(sqlite3_context *context, int argc, + sqlite3_value **argv) { + + auto vector_api = (vector0_api *)sqlite3_user_data(context); + + vec_ptr lhs = vector_api->xValueAsVector(argv[0]); + if (lhs == nullptr) { + sqlite3_result_error(context, "LHS is not a vector", -1); + return; + } + + vec_ptr rhs = vector_api->xValueAsVector(argv[1]); + if (rhs == nullptr) { + sqlite3_result_error(context, "RHS is not a vector", -1); + return; + } + + if (lhs->size() != rhs->size()) { + sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", + -1); + return; + } + + auto size = lhs->size(); + vec_ptr c(new vector(size)); + faiss::fvec_add(size, lhs->data(), rhs->data(), c->data()); + + vector_api->xResultVector(context, c.get()); } +static void vss_fvec_sub(sqlite3_context *context, int argc, + sqlite3_value **argv) { -#pragma endregion + auto vector_api = (vector0_api *)sqlite3_user_data(context); + + vec_ptr lhs = vector_api->xValueAsVector(argv[0]); + if (lhs == nullptr) { + sqlite3_result_error(context, "LHS is not a vector", -1); + return; + } + + vec_ptr rhs = vector_api->xValueAsVector(argv[1]); + if (rhs == nullptr) { + sqlite3_result_error(context, "RHS is not a vector", -1); + return; + } -#pragma region vtab + if (lhs->size() != rhs->size()) { + sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", -1); + return; + } + + int size = lhs->size(); + vec_ptr c = vec_ptr(new vector(size)); + faiss::fvec_sub(size, lhs->data(), rhs->data(), c->data()); + vector_api->xResultVector(context, c.get()); +} +#pragma endregion +#pragma region Structs and cleanup functions struct VssSearchParams { - std::vector * vector; - sqlite3_int64 k; + + vec_ptr vector; + sqlite3_int64 k; }; -static void VssSearchParamsFunc( - sqlite3_context *context, - int argc, - sqlite3_value **argv -){ - vector0_api * vector_api = (vector0_api*) sqlite3_user_data(context); - - std::vector * vector = vector_api->xValueAsVector(argv[0]); - if(vector==NULL) { - sqlite3_result_error(context, "1st argument is not a vector", -1); - return; - } - sqlite3_int64 k = sqlite3_value_int64(argv[1]); - VssSearchParams* params = new VssSearchParams(); - params->vector = vector; - params->k = k; - sqlite3_result_pointer(context, params, "vss0_searchparams", 0); +void delVssSearchParams(void *p) { + + VssSearchParams *self = (VssSearchParams *)p; + delete self; } struct VssRangeSearchParams { - std::vector * vector; - float distance; + + vec_ptr vector; + float distance; }; -static void VssRangeSearchParamsFunc( - sqlite3_context *context, - int argc, - sqlite3_value **argv -){ - vector0_api * vector_api = (vector0_api*) sqlite3_user_data(context); - std::vector * vector = vector_api->xValueAsVector(argv[0]); - if(vector==NULL) { - sqlite3_result_error(context, "1st argument is not a vector", -1); - return; - } - float distance = sqlite3_value_double(argv[1]); - VssRangeSearchParams* params = new VssRangeSearchParams(); - params->vector = vector; - params->distance = distance; - sqlite3_result_pointer(context, params, "vss0_rangesearchparams", 0); +void delVssRangeSearchParams(void *p) { + + auto self = (VssRangeSearchParams *)p; + delete self; } -static int write_index_insert(faiss::Index * index, sqlite3*db, char * schema, char * name, int i) { - faiss::VectorIOWriter * w = new faiss::VectorIOWriter(); - faiss::write_index(index, w); - sqlite3_int64 nIn = w->data.size(); - - // First try to insert into xyz_index. If that fails with a rowid constraint error, - // that means the index is already on disk, we just have to UPDATE instead. - - sqlite3_stmt *stmt; - char * q = sqlite3_mprintf("insert into \"%w\".\"%w_index\"(rowid, idx) values (?, ?)", schema, name); - int rc = sqlite3_prepare_v2(db, q, -1, &stmt, 0); - if (rc != SQLITE_OK || stmt==0) { - //printf("error prepping stmt\n"); - return SQLITE_ERROR; - } - rc = sqlite3_bind_int64(stmt, 1, i); - rc = sqlite3_bind_blob64(stmt, 2, w->data.data(), nIn, SQLITE_TRANSIENT); - if (rc != SQLITE_OK) { - //printf("error binding blob: %s\n", sqlite3_errmsg(db)); - sqlite3_free(q); - return SQLITE_ERROR; - } - int result = sqlite3_step(stmt); - sqlite3_finalize(stmt); - sqlite3_free(q); - - - // INSERT was success, index wasn't written yet, all good to exit - if(result == SQLITE_DONE) { - return SQLITE_OK; - } - // INSERT failed for another unknown reason, bad, return error - else if(sqlite3_extended_errcode(db) != SQLITE_CONSTRAINT_ROWID) { - return SQLITE_ERROR; - } - - // INSERT failed because index already is on disk, so do UPDATE instead - - q = sqlite3_mprintf("UPDATE \"%w\".\"%w_index\" SET idx = ? WHERE rowid = ?", schema, name); - rc = sqlite3_prepare_v2(db, q, -1, &stmt, 0); - if (rc != SQLITE_OK || stmt==0) { - return SQLITE_ERROR; - } - rc = sqlite3_bind_blob64(stmt, 1, w->data.data(), nIn, SQLITE_TRANSIENT); - if (rc != SQLITE_OK) { - sqlite3_free(q); - return SQLITE_ERROR; - } - rc = sqlite3_bind_int64(stmt, 2, i); - if (rc != SQLITE_OK) { - sqlite3_free(q); - return SQLITE_ERROR; - } - result = sqlite3_step(stmt); - sqlite3_finalize(stmt); - sqlite3_free(q); - - if(result == SQLITE_DONE) { - return SQLITE_OK; - } - delete w; - return result; +#pragma endregion + +#pragma region Vtab + +static void vssSearchParamsFunc(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + auto vector_api = (vector0_api *)sqlite3_user_data(context); + + vec_ptr vector = vector_api->xValueAsVector(argv[0]); + if (vector == nullptr) { + sqlite3_result_error(context, "1st argument is not a vector", -1); + return; + } + + auto limit = sqlite3_value_int64(argv[1]); + auto params = new VssSearchParams(); + params->vector = vec_ptr(vector.release()); + params->k = limit; + sqlite3_result_pointer(context, params, "vss0_searchparams", delVssSearchParams); } +static void vssRangeSearchParamsFunc(sqlite3_context *context, int argc, + sqlite3_value **argv) { -static int shadow_data_insert(sqlite3*db, char * schema, char * name, sqlite3_int64 *rowid, sqlite3_int64 *retRowid) { - sqlite3_stmt *stmt; - if(rowid == NULL) { - char * q = sqlite3_mprintf("insert into \"%w\".\"%w_data\"(x) values (?)", schema, name); - int rc = sqlite3_prepare_v2(db, q, -1, &stmt, 0); - sqlite3_free(q); + auto vector_api = (vector0_api *)sqlite3_user_data(context); - if (rc != SQLITE_OK || stmt==0) { - //printf("error prepping stmt: %s \n", sqlite3_errmsg(db)); - return SQLITE_ERROR; + vec_ptr vector = vector_api->xValueAsVector(argv[0]); + if (vector == nullptr) { + sqlite3_result_error(context, "1st argument is not a vector", -1); + return; } - sqlite3_bind_null(stmt, 1); - if(sqlite3_step(stmt) != SQLITE_DONE) { - //printf("error inserting?\n"); - sqlite3_finalize(stmt); - return SQLITE_ERROR; + + auto params = new VssRangeSearchParams(); + + params->vector = vec_ptr(vector.release()); + params->distance = sqlite3_value_double(argv[1]); + + sqlite3_result_pointer(context, params, "vss0_rangesearchparams", delVssRangeSearchParams); +} + +static int write_index_insert(faiss::Index *index, + sqlite3 *db, + char *schema, + char *name, + int rowId) { + + faiss::VectorIOWriter writer; + faiss::write_index(index, &writer); + sqlite3_int64 indexSize = writer.data.size(); + + // First try to insert into xyz_index. If that fails with a rowid constraint + // error, that means the index is already on disk, we just have to UPDATE + // instead. + + sqlite3_stmt *stmt; + char *sql = sqlite3_mprintf( + "insert into \"%w\".\"%w_index\"(rowid, idx) values (?, ?)", + schema, + name); + + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, 0); + if (rc != SQLITE_OK || stmt == nullptr) { + sqlite3_free(sql); + return SQLITE_ERROR; } - } else { - char * q = sqlite3_mprintf("insert into \"%w\".\"%w_data\"(rowid, x) values (?, ?);", schema, name); - int rc = sqlite3_prepare_v2(db, q, -1, &stmt, 0); - sqlite3_free(q); - - if (rc != SQLITE_OK || stmt==0) { - //printf("error prepping stmt: %s \n", sqlite3_errmsg(db)); - return SQLITE_ERROR; + + rc = sqlite3_bind_int64(stmt, 1, rowId); + if (rc != SQLITE_OK) { + sqlite3_finalize(stmt); + sqlite3_free(sql); + return SQLITE_ERROR; } - sqlite3_bind_int64(stmt, 1, *rowid); - sqlite3_bind_null(stmt, 2); - if(sqlite3_step(stmt) != SQLITE_DONE) { - //printf("error inserting: %s\n", sqlite3_errmsg(db)); - sqlite3_finalize(stmt); - return SQLITE_ERROR; + + rc = sqlite3_bind_blob64(stmt, 2, writer.data.data(), indexSize, SQLITE_TRANSIENT); + if (rc != SQLITE_OK) { + sqlite3_finalize(stmt); + sqlite3_free(sql); + return SQLITE_ERROR; } - if(retRowid != NULL) - *retRowid = sqlite3_last_insert_rowid(db); - } - sqlite3_finalize(stmt); - return SQLITE_OK; -} -static int shadow_data_delete(sqlite3*db, char * schema, char * name, sqlite3_int64 rowid) { - sqlite3_stmt *stmt; - sqlite3_str *query = sqlite3_str_new(0); - sqlite3_str_appendf(query, "delete from \"%w\".\"%w_data\" where rowid = ?", schema, name); - char * q = sqlite3_str_finish(query); - int rc = sqlite3_prepare_v2(db, q, -1, &stmt, 0); - if (rc != SQLITE_OK || stmt==0) { - return SQLITE_ERROR; - } - sqlite3_bind_int64(stmt, 1, rowid); - if(sqlite3_step(stmt) != SQLITE_DONE) { + int result = sqlite3_step(stmt); sqlite3_finalize(stmt); - return SQLITE_ERROR; - } - sqlite3_free(q); - sqlite3_finalize(stmt); - return SQLITE_OK; -} + sqlite3_free(sql); + + if (result == SQLITE_DONE) { + + // INSERT was success, index wasn't written yet, all good to exit + return SQLITE_OK; + + } else if (sqlite3_extended_errcode(db) != SQLITE_CONSTRAINT_ROWID) { + + // INSERT failed for another unknown reason, bad, return error + return SQLITE_ERROR; + } + + // INSERT failed because index already is on disk, so we do an UPDATE instead + + sql = sqlite3_mprintf( + "update \"%w\".\"%w_index\" set idx = ? where rowid = ?", schema, name); + + rc = sqlite3_prepare_v2(db, sql, -1, &stmt, 0); + if (rc != SQLITE_OK || stmt == nullptr) { + sqlite3_free(sql); + return SQLITE_ERROR; + } + + rc = sqlite3_bind_blob64(stmt, 1, writer.data.data(), indexSize, SQLITE_TRANSIENT); + if (rc != SQLITE_OK) { + sqlite3_finalize(stmt); + sqlite3_free(sql); + return SQLITE_ERROR; + } + + rc = sqlite3_bind_int64(stmt, 2, rowId); + if (rc != SQLITE_OK) { + sqlite3_finalize(stmt); + sqlite3_free(sql); + return SQLITE_ERROR; + } -static faiss::Index * read_index_select(sqlite3 * db, const char * name, int i) { - sqlite3_stmt *stmt; - char * q = sqlite3_mprintf("select idx from \"%w_index\" where rowid = ?", name); - int rc = sqlite3_prepare_v2(db, q, -1, &stmt, 0); - if (rc != SQLITE_OK || stmt==0) { - //printf("error prepping stmt: %s\n", sqlite3_errmsg(db)); + result = sqlite3_step(stmt); sqlite3_finalize(stmt); - return 0; - } - sqlite3_bind_int64(stmt, 1, i); - if(sqlite3_step(stmt) != SQLITE_ROW) { - //printf("connect no row??\n"); + sqlite3_free(sql); + + if (result == SQLITE_DONE) { + return SQLITE_OK; + } + + return result; +} + +static int shadow_data_insert(sqlite3 *db, + char *schema, + char *name, + sqlite3_int64 *rowid, + sqlite3_int64 *retRowid) { + + sqlite3_stmt *stmt; + + if (rowid == nullptr) { + + auto sql = sqlite3_mprintf( + "insert into \"%w\".\"%w_data\"(x) values (?)", schema, name); + + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, 0); + sqlite3_free(sql); + + if (rc != SQLITE_OK || stmt == nullptr) { + return SQLITE_ERROR; + } + + sqlite3_bind_null(stmt, 1); + if (sqlite3_step(stmt) != SQLITE_DONE) { + sqlite3_finalize(stmt); + return SQLITE_ERROR; + } + + } else { + + auto sql = sqlite3_mprintf( + "insert into \"%w\".\"%w_data\"(rowid, x) values (?, ?);", schema, + name); + + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, 0); + sqlite3_free(sql); + + if (rc != SQLITE_OK || stmt == nullptr) + return SQLITE_ERROR; + + sqlite3_bind_int64(stmt, 1, *rowid); + sqlite3_bind_null(stmt, 2); + if (sqlite3_step(stmt) != SQLITE_DONE) { + sqlite3_finalize(stmt); + return SQLITE_ERROR; + } + + if (retRowid != nullptr) + *retRowid = sqlite3_last_insert_rowid(db); + } + sqlite3_finalize(stmt); - return 0; - } - const void * index_data = sqlite3_column_blob(stmt, 0); - int64_t n = sqlite3_column_bytes(stmt, 0); - faiss::VectorIOReader * r = new faiss::VectorIOReader(); - std::copy((const uint8_t*) index_data, ((const uint8_t*)index_data) + n, std::back_inserter(r->data)); - sqlite3_free(q); - sqlite3_finalize(stmt); - return faiss::read_index(r); + return SQLITE_OK; } -static int create_shadow_tables(sqlite3 * db, const char * schema, const char * name, int n) { - /*sqlite3_str *create_index_str = sqlite3_str_new(db); - sqlite3_str_appendf(create_index_str, "CREATE TABLE \"%w\".\"%w_index\"(", schema, name); - for(int i = 0; i < n; i++) { - const char * format; - if(i==0) { - format = "c%d"; - }else { - format = ", c%d"; +static int shadow_data_delete(sqlite3 *db, + char *schema, + char *name, + sqlite3_int64 rowid) { + sqlite3_stmt *stmt; + + // TODO: We should strive to use only one concept and idea while creating + // SQL statements. + auto query = sqlite3_str_new(0); + + sqlite3_str_appendf(query, "delete from \"%w\".\"%w_data\" where rowid = ?", + schema, name); + + auto sql = sqlite3_str_finish(query); + + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, 0); + if (rc != SQLITE_OK || stmt == nullptr) + return SQLITE_ERROR; + + sqlite3_bind_int64(stmt, 1, rowid); + if (sqlite3_step(stmt) != SQLITE_DONE) { + sqlite3_finalize(stmt); + return SQLITE_ERROR; } - sqlite3_str_appendf(create_index_str, format, i); - } - sqlite3_str_appendall(create_index_str, ")");*/ - const char * zCreateIndex = sqlite3_mprintf("CREATE TABLE \"%w\".\"%w_index\"(idx)", schema, name);//sqlite3_str_finish(create_index_str); - int rc = sqlite3_exec(db, zCreateIndex, 0, 0, 0); - sqlite3_free((void *) zCreateIndex); - if (rc != SQLITE_OK) return rc; - - const char * zCreateData = sqlite3_mprintf("CREATE TABLE \"%w\".\"%w_data\"(x);", schema, name); - rc = sqlite3_exec(db, zCreateData, 0, 0, 0); - sqlite3_free((void *) zCreateData); - return rc; + sqlite3_free(sql); + sqlite3_finalize(stmt); + return SQLITE_OK; } -static int drop_shadow_tables(sqlite3*db, char * name) { - const char* drops[2] = {"drop table \"%w_index\";", "drop table \"%w_data\";"}; - for(int i = 0; i < 2; i++) { - const char * s = drops[i]; +static faiss::Index *read_index_select(sqlite3 *db, const char *name, int indexId) { sqlite3_stmt *stmt; + auto sql = sqlite3_mprintf("select idx from \"%w_index\" where rowid = ?", name); - sqlite3_str *query = sqlite3_str_new(0); - sqlite3_str_appendf(query, s, name); - char * q = sqlite3_str_finish(query); - - int rc = sqlite3_prepare_v2(db, q, -1, &stmt, 0); - if (rc != SQLITE_OK || stmt==0) { - //printf("error prepping stmt\n"); - return SQLITE_ERROR; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr); + if (rc != SQLITE_OK || stmt == nullptr) { + sqlite3_finalize(stmt); + sqlite3_free(sql); + return nullptr; } - if(sqlite3_step(stmt) != SQLITE_DONE) { - //printf("error dropping?\n"); - sqlite3_finalize(stmt); - return SQLITE_ERROR; + + sqlite3_bind_int64(stmt, 1, indexId); + if (sqlite3_step(stmt) != SQLITE_ROW) { + sqlite3_finalize(stmt); + sqlite3_free(sql); + return nullptr; } - sqlite3_free(q); + + auto index_data = sqlite3_column_blob(stmt, 0); + int64_t size = sqlite3_column_bytes(stmt, 0); + + faiss::VectorIOReader reader; + copy((const uint8_t *)index_data, + ((const uint8_t *)index_data) + size, + back_inserter(reader.data)); + + sqlite3_free(sql); sqlite3_finalize(stmt); - } - return SQLITE_OK; + return faiss::read_index(&reader); } -#define VSS_SEARCH_FUNCTION SQLITE_INDEX_CONSTRAINT_FUNCTION +static int create_shadow_tables(sqlite3 *db, + const char *schema, + const char *name, + int n) { + + auto sql = sqlite3_mprintf("create table \"%w\".\"%w_index\"(idx)", + schema, + name); + + auto rc = sqlite3_exec(db, sql, 0, 0, 0); + sqlite3_free(sql); + if (rc != SQLITE_OK) + return rc; + + sql = sqlite3_mprintf("create table \"%w\".\"%w_data\"(x);", + schema, + name); + + rc = sqlite3_exec(db, sql, nullptr, nullptr, nullptr); + sqlite3_free(sql); + return rc; +} + +static int drop_shadow_tables(sqlite3 *db, char *name) { + + const char *drops[2] = {"drop table \"%w_index\";", + "drop table \"%w_data\";"}; + + for (int i = 0; i < 2; i++) { + + auto curSql = drops[i]; + + sqlite3_stmt *stmt; + + // TODO: Use of one construct to create SQL statements. + sqlite3_str *query = sqlite3_str_new(0); + sqlite3_str_appendf(query, curSql, name); + char *sql = sqlite3_str_finish(query); + + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, 0); + if (rc != SQLITE_OK || stmt == nullptr) { + sqlite3_free(sql); + return SQLITE_ERROR; + } + + if (sqlite3_step(stmt) != SQLITE_DONE) { + sqlite3_free(sql); + sqlite3_finalize(stmt); + return SQLITE_ERROR; + } + + sqlite3_free(sql); + sqlite3_finalize(stmt); + } + return SQLITE_OK; +} + +#define VSS_SEARCH_FUNCTION SQLITE_INDEX_CONSTRAINT_FUNCTION #define VSS_RANGE_SEARCH_FUNCTION SQLITE_INDEX_CONSTRAINT_FUNCTION + 1 -typedef struct vss_index_vtab vss_index_vtab; -struct vss_index_vtab { - sqlite3_vtab base; /* Base class - must be first */ - sqlite3 * db; - vector0_api * vector_api; - // name of the virtual table. Must be freed during disconnect - char * name; - // name of the schema the virtual table exists in. Must be freed during disconnect - char * schema; - // number of index columns in the virtual table - sqlite3_int64 indexCount; - // vector holding all the faiss Indices the vtab uses. This, and the elements, must be freed during disconnect. - std::vector *indexes; - - // float vector that holds training vectors for indices that require it. - // This, and the elements, must be freed at disconnect. - std::vector*> * trainings; - std::vector*> * insert_to_add_data; - std::vector*> * insert_to_add_ids; - std::vector*> * delete_to_delete_ids; - - // whether the current transaction is inserting training data for at least 1 column - bool isTraining; - // whether the current transaction is inserting data for at least 1 column - bool isInsertData; +// Wrapper around a single faiss index, with training data, insert records, and +// delete records. +struct vss_index { + + explicit vss_index(faiss::Index *index) : index(index) {} + + ~vss_index() { + if (index != nullptr) { + delete index; + } + } + + faiss::Index *index; + vector trainings; + vector insert_data; + vector insert_ids; + vector delete_ids; }; -enum QueryType {search, range_search, fullscan}; +struct vss_index_vtab : public sqlite3_vtab { + + vss_index_vtab(sqlite3 *db, vector0_api *vector_api, char *schema, char *name) + : db(db), + vector_api(vector_api), + schema(schema), + name(name) { } -typedef struct vss_index_cursor vss_index_cursor; -struct vss_index_cursor { - sqlite3_vtab_cursor base; /* Base class - must be first */ - vss_index_vtab * table; + ~vss_index_vtab() { - sqlite3_int64 iCurrent; - sqlite3_int64 iRowid; + if (name) + sqlite3_free(name); + if (schema) + sqlite3_free(schema); + for (auto iter = indexes.begin(); iter != indexes.end(); ++iter) { + delete (*iter); + } + } - QueryType query_type; + sqlite3 *db; + vector0_api *vector_api; - // for query_type == QueryType::search - sqlite3_int64 search_k; - std::vector *search_ids; - std::vector *search_distances; + // Name of the virtual table. Must be freed during disconnect + char *name; - // for query_type == QueryType::range_search - faiss::RangeSearchResult * range_search_result; + // Name of the schema the virtual table exists in. Must be freed during + // disconnect + char *schema; - // for query_type == QueryType::fullscan - sqlite3_stmt * stmt; - int step_result; + // Vector holding all the faiss Indices the vtab uses, and their state, + // implying which items are to be deleted and inserted. + vector indexes; +}; + +enum QueryType { search, range_search, fullscan }; + +struct vss_index_cursor : public sqlite3_vtab_cursor { + + explicit vss_index_cursor(vss_index_vtab *table) + : table(table), + sqlite3_vtab_cursor({0}), + stmt(nullptr) { } + + ~vss_index_cursor() { + if (stmt != nullptr) + sqlite3_finalize(stmt); + } + + vss_index_vtab *table; + + sqlite3_int64 iCurrent; + sqlite3_int64 iRowid; + + QueryType query_type; + + // For query_type == QueryType::search + sqlite3_int64 limit; + vector search_ids; + vector search_distances; + + // For query_type == QueryType::range_search + unique_ptr range_search_result; + + // For query_type == QueryType::fullscan + sqlite3_stmt *stmt; + int step_result; }; struct VssIndexColumn { - std::string name; - sqlite3_int64 dimensions; - std::string factory; + + string name; + sqlite3_int64 dimensions; + string factory; }; -std::vector* parse_constructor(int argc, const char *const*argv) { - std::vector* columns = new std::vector(); - for(int i = 3; i < argc; i++) { - // ' xyz(123) ' - // ' xyz(123) factory="abx" ' - std::string arg = std::string(argv[i]); - - std::size_t lparen = arg.find("("); - std::size_t rparen = arg.find(")"); - if(lparen==std::string::npos || rparen==std::string::npos || lparen >= rparen) { - return NULL; - } - std::string name = arg.substr(0, lparen); - std::string sDimensions = arg.substr(lparen+1, rparen - lparen-1); - //= std::string("x"); - sqlite3_int64 dimensions = std::atoi(sDimensions.c_str()); - - std::size_t factoryStart, factoryStringStartFrom; - std::string factory; - if( - (factoryStart = arg.find("factory", rparen)) != std::string::npos - && (factoryStringStartFrom = arg.find("=", factoryStart)) != std::string::npos - ) { - std::size_t lquote = arg.find("\"", factoryStringStartFrom); - std::size_t rquote = arg.find_last_of("\""); - if(lquote==std::string::npos || rquote==std::string::npos || lquote >= rquote) { - delete columns; - return NULL; - } - factory = arg.substr(lquote+1, rquote - lquote-1); +unique_ptr> parse_constructor(int argc, + const char *const *argv) { - } else { - factory = std::string("Flat,IDMap2"); + auto columns = unique_ptr>(new vector()); + + for (int i = 3; i < argc; i++) { + + string arg = string(argv[i]); + + size_t lparen = arg.find("("); + size_t rparen = arg.find(")"); + + if (lparen == string::npos || rparen == string::npos || + lparen >= rparen) { + return nullptr; + } + + string name = arg.substr(0, lparen); + string sDimensions = arg.substr(lparen + 1, rparen - lparen - 1); + + sqlite3_int64 dimensions = atoi(sDimensions.c_str()); + + size_t factoryStart, factoryStringStartFrom; + string factory; + + if ((factoryStart = arg.find("factory", rparen)) != string::npos && + (factoryStringStartFrom = arg.find("=", factoryStart)) != + string::npos) { + + size_t lquote = arg.find("\"", factoryStringStartFrom); + size_t rquote = arg.find_last_of("\""); + + if (lquote == string::npos || rquote == string::npos || + lquote >= rquote) { + return nullptr; + } + factory = arg.substr(lquote + 1, rquote - lquote - 1); + + } else { + factory = string("Flat,IDMap2"); + } + columns->push_back(VssIndexColumn{name, dimensions, factory}); } - columns->push_back(VssIndexColumn{name, dimensions, factory}); - } - // TODO - return columns; + + return columns; } -static int init( - sqlite3 *db, - void *pAux, - int argc, const char *const*argv, - sqlite3_vtab **ppVtab, - char **pzErr, - bool isCreate -) { - sqlite3_vtab_config(db, SQLITE_VTAB_CONSTRAINT_SUPPORT, 1); - int rc; - - sqlite3_str *str = sqlite3_str_new(NULL); - sqlite3_str_appendall(str, "CREATE TABLE x(distance hidden, operation hidden"); - std::vector * columns = parse_constructor(argc, argv); - if(columns == NULL) { - *pzErr = sqlite3_mprintf("Error parsing constructor"); - return rc; - } - for (auto column = columns->begin(); column != columns->end(); ++column) { - sqlite3_str_appendf(str, ", \"%w\"", column->name.c_str()); - } - sqlite3_str_appendall(str, ")"); - const char * query = sqlite3_str_finish(str); - rc = sqlite3_declare_vtab(db, query); - sqlite3_free((void*) query); - - #define VSS_INDEX_COLUMN_DISTANCE 0 - #define VSS_INDEX_COLUMN_OPERATION 1 - #define VSS_INDEX_COLUMN_VECTORS 2 - - if( rc!=SQLITE_OK ) return rc; - - vss_index_vtab *pNew; - pNew = (vss_index_vtab *) sqlite3_malloc( sizeof(*pNew) ); - *ppVtab = (sqlite3_vtab*)pNew; - memset(pNew, 0, sizeof(*pNew)); - - pNew->vector_api = (vector0_api *) pAux; - pNew->indexCount = columns->size(); - pNew->indexes = new std::vector (); - - pNew->schema = sqlite3_mprintf("%s", argv[1]); - pNew->name = sqlite3_mprintf("%s", argv[2]); - pNew->db = db; - - if (isCreate) { - int i; - for (auto column = columns->begin(); column != columns->end(); ++column) { - try { - faiss::Index *index = faiss::index_factory(column->dimensions, column->factory.c_str()); - pNew->indexes->push_back(index); - } catch(faiss::FaissException& e) { - *pzErr = sqlite3_mprintf("Error building index factory for %s: %s", column->name.c_str(), e.msg.c_str()); - return SQLITE_ERROR; - } - } - create_shadow_tables(db, argv[1], argv[2], pNew->indexCount); - - // after shadow tables are created, write the initial index state to shadow _index - for (int i = 0; i < pNew->indexes->size(); i++) { - auto index = pNew->indexes->at(i); - try { - int rc = write_index_insert(index, pNew->db, pNew->schema, pNew->name, i); - if(rc != SQLITE_OK) return rc; - } catch(faiss::FaissException& e) { - return SQLITE_ERROR; - } +static int init(sqlite3 *db, + void *pAux, + int argc, + const char *const *argv, + sqlite3_vtab **ppVtab, + char **pzErr, + bool isCreate) { + + sqlite3_vtab_config(db, SQLITE_VTAB_CONSTRAINT_SUPPORT, 1); + int rc; + + sqlite3_str *str = sqlite3_str_new(nullptr); + sqlite3_str_appendall(str, + "create table x(distance hidden, operation hidden"); + + auto columns = parse_constructor(argc, argv); + + if (columns == nullptr) { + *pzErr = sqlite3_mprintf("Error parsing constructor"); + return rc; } - } else { - for(int i = 0; i < pNew->indexCount; i++) { - auto index = read_index_select(db, argv[2], i); - // index in shadow table should always be available, integrity check to avoid null pointer - if(index == NULL) { - *pzErr = sqlite3_mprintf("Could not read index at position %d", i); - return SQLITE_ERROR; - } - pNew->indexes->push_back(index); + for (auto column = columns->begin(); column != columns->end(); ++column) { + sqlite3_str_appendf(str, ", \"%w\"", column->name.c_str()); } - } - pNew->trainings = new std::vector*>(); - pNew->insert_to_add_data = new std::vector*>(); - pNew->insert_to_add_ids = new std::vector*>(); - pNew->delete_to_delete_ids = new std::vector*>(); + sqlite3_str_appendall(str, ")"); + auto sql = sqlite3_str_finish(str); + rc = sqlite3_declare_vtab(db, sql); + sqlite3_free(sql); + +#define VSS_INDEX_COLUMN_DISTANCE 0 +#define VSS_INDEX_COLUMN_OPERATION 1 +#define VSS_INDEX_COLUMN_VECTORS 2 + + if (rc != SQLITE_OK) + return rc; + + auto pTable = new vss_index_vtab(db, + (vector0_api *)pAux, + sqlite3_mprintf("%s", argv[1]), + sqlite3_mprintf("%s", argv[2])); + *ppVtab = pTable; + + if (isCreate) { + + for (auto iter = columns->begin(); iter != columns->end(); ++iter) { + + try { + + auto index = faiss::index_factory(iter->dimensions, iter->factory.c_str()); + pTable->indexes.push_back(new vss_index(index)); + + } catch (faiss::FaissException &e) { + + *pzErr = sqlite3_mprintf("Error building index factory for %s: %s", + iter->name.c_str(), + e.msg.c_str()); + + return SQLITE_ERROR; + } + } + + rc = create_shadow_tables(db, argv[1], argv[2], columns->size()); + if (rc != SQLITE_OK) + return rc; + + // Shadow tables were successully created. + // After shadow tables are created, write the initial index state to + // shadow _index. + auto i = 0; + for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter, i++) { + + try { + + int rc = write_index_insert((*iter)->index, + pTable->db, + pTable->schema, + pTable->name, + i); + + if (rc != SQLITE_OK) + return rc; + + } catch (faiss::FaissException &e) { - for(int i = 0; i < pNew->indexCount; i++) { - pNew->trainings->push_back(new std::vector()); - pNew->insert_to_add_data->push_back(new std::vector()); - pNew->insert_to_add_ids->push_back(new std::vector()); - pNew->delete_to_delete_ids->push_back(new std::vector()); - } + return SQLITE_ERROR; + } + } + + } else { + for (int i = 0; i < columns->size(); i++) { - pNew->isTraining = false; - pNew->isInsertData = false; + auto index = read_index_select(db, argv[2], i); + + // Index in shadow table should always be available, integrity check + // to avoid null pointer + if (index == nullptr) { + *pzErr = sqlite3_mprintf("Could not read index at position %d", i); + return SQLITE_ERROR; + } + pTable->indexes.push_back(new vss_index(index)); + } + } - return SQLITE_OK; + return SQLITE_OK; } -static int vssIndexCreate( - sqlite3 *db, - void *pAux, - int argc, const char *const*argv, - sqlite3_vtab **ppVtab, - char **pzErr -){ - return init(db, pAux, argc, argv, ppVtab, pzErr, true); +static int vssIndexCreate(sqlite3 *db, void *pAux, + int argc, + const char *const *argv, + sqlite3_vtab **ppVtab, + char **pzErr) { + + return init(db, pAux, argc, argv, ppVtab, pzErr, true); } -static int vssIndexConnect( - sqlite3 *db, - void *pAux, - int argc, const char *const*argv, - sqlite3_vtab **ppVtab, - char **pzErr -){ - return init(db, pAux, argc, argv, ppVtab, pzErr, false); +static int vssIndexConnect(sqlite3 *db, + void *pAux, int argc, + const char *const *argv, + sqlite3_vtab **ppVtab, + char **pzErr) { + + return init(db, pAux, argc, argv, ppVtab, pzErr, false); } -static int vssIndexDisconnect(sqlite3_vtab *pVtab){ - vss_index_vtab *p = (vss_index_vtab*)pVtab; - //printf("disconnect\n"); - sqlite3_free(p->name); - sqlite3_free(p->schema); - - for(int i = 0; i < p->indexCount; i++) { - faiss::Index * index = p->indexes->at(i); - delete index; - } - delete p->indexes; - - for(int i = 0; i < p->indexCount; i++) { - delete p->trainings->at(i); - delete p->insert_to_add_data->at(i); - delete p->insert_to_add_ids->at(i); - delete p->delete_to_delete_ids->at(i); - - } - delete p->trainings; - delete p->insert_to_add_data; - delete p->insert_to_add_ids; - delete p->delete_to_delete_ids; - - sqlite3_free(p); - return SQLITE_OK; +static int vssIndexDisconnect(sqlite3_vtab *pVtab) { + + auto pTable = static_cast(pVtab); + delete pTable; + return SQLITE_OK; } -static int vssIndexDestroy(sqlite3_vtab *pVtab){ - vss_index_vtab *p = (vss_index_vtab*)pVtab; - //printf("destroy\n"); - drop_shadow_tables(p->db, p->name); - vssIndexDisconnect(pVtab); - return SQLITE_OK; +static int vssIndexDestroy(sqlite3_vtab *pVtab) { + + auto pTable = static_cast(pVtab); + drop_shadow_tables(pTable->db, pTable->name); + vssIndexDisconnect(pVtab); + return SQLITE_OK; } -static int vssIndexOpen(sqlite3_vtab *pVtab, sqlite3_vtab_cursor **ppCursor){ - vss_index_cursor *pCur; - vss_index_vtab *p = (vss_index_vtab*)pVtab; - pCur = (vss_index_cursor *) sqlite3_malloc( sizeof(*pCur) ); - if( pCur==0 ) return SQLITE_NOMEM; - memset(pCur, 0, sizeof(*pCur)); +static int vssIndexOpen(sqlite3_vtab *pVtab, sqlite3_vtab_cursor **ppCursor) { + + auto pTable = static_cast(pVtab); - *ppCursor = &pCur->base; - pCur->table = p; - //pCur->nns = new std::vector(5); - //pCur->dis = new std::vector(5); + auto pCursor = new vss_index_cursor(pTable); + if (pCursor == nullptr) + return SQLITE_NOMEM; - return SQLITE_OK; + *ppCursor = pCursor; + + return SQLITE_OK; } -static int vssIndexClose(sqlite3_vtab_cursor *cur){ - vss_index_cursor *pCur = (vss_index_cursor*)cur; - //printf("close\n"); - if(pCur->range_search_result) { - delete pCur->range_search_result; - pCur->range_search_result = 0; - } - if(pCur->stmt) sqlite3_finalize(pCur->stmt); - sqlite3_free(pCur); - //printf("b\n"); - return SQLITE_OK; +static int vssIndexClose(sqlite3_vtab_cursor *cur) { + + auto pCursor = static_cast(cur); + delete pCursor; + return SQLITE_OK; } -static int vssIndexBestIndex( - sqlite3_vtab *tab, - sqlite3_index_info *pIdxInfo -){ - int iSearchTerm = -1; - int iRangeSearchTerm = -1; - int iXSearchColumn = -1; - int iLimit = -1; - //printf("best index, %d\n", pIdxInfo->nConstraint); - - for(int i = 0; i < pIdxInfo->nConstraint; i++) { - auto constraint = pIdxInfo->aConstraint[i]; - //printf("\t[%d] col=%d, op=%d \n", i, pIdxInfo->aConstraint[i].iColumn, pIdxInfo->aConstraint[i].op); - - if(!constraint.usable) continue; - if(constraint.op == VSS_SEARCH_FUNCTION) { - iSearchTerm = i; - iXSearchColumn = constraint.iColumn; - } - else if(constraint.op == VSS_RANGE_SEARCH_FUNCTION) { - iRangeSearchTerm = i; - iXSearchColumn = constraint.iColumn; +static int vssIndexBestIndex(sqlite3_vtab *tab, sqlite3_index_info *pIdxInfo) { + + int iSearchTerm = -1; + int iRangeSearchTerm = -1; + int iXSearchColumn = -1; + int iLimit = -1; + + for (int i = 0; i < pIdxInfo->nConstraint; i++) { + + auto constraint = pIdxInfo->aConstraint[i]; + + if (!constraint.usable) + continue; + + if (constraint.op == VSS_SEARCH_FUNCTION) { + + iSearchTerm = i; + iXSearchColumn = constraint.iColumn; + + } else if (constraint.op == VSS_RANGE_SEARCH_FUNCTION) { + + iRangeSearchTerm = i; + iXSearchColumn = constraint.iColumn; + + } else if (constraint.op == SQLITE_INDEX_CONSTRAINT_LIMIT) { + iLimit = i; + } } - else if(constraint.op == SQLITE_INDEX_CONSTRAINT_LIMIT) { - iLimit = i; + + if (iSearchTerm >= 0) { + + pIdxInfo->idxNum = iXSearchColumn - VSS_INDEX_COLUMN_VECTORS; + pIdxInfo->idxStr = (char *)"search"; + pIdxInfo->aConstraintUsage[iSearchTerm].argvIndex = 1; + pIdxInfo->aConstraintUsage[iSearchTerm].omit = 1; + if (iLimit >= 0) { + pIdxInfo->aConstraintUsage[iLimit].argvIndex = 2; + pIdxInfo->aConstraintUsage[iLimit].omit = 1; + } + pIdxInfo->estimatedCost = 300.0; + pIdxInfo->estimatedRows = 10; + + return SQLITE_OK; } - } - if(iSearchTerm >=0) { - pIdxInfo->idxNum = iXSearchColumn - VSS_INDEX_COLUMN_VECTORS; - pIdxInfo->idxStr = (char*) "search"; - pIdxInfo->aConstraintUsage[iSearchTerm].argvIndex = 1; - pIdxInfo->aConstraintUsage[iSearchTerm].omit = 1; - if(iLimit >= 0) { - pIdxInfo->aConstraintUsage[iLimit].argvIndex = 2; - pIdxInfo->aConstraintUsage[iLimit].omit = 1; + + if (iRangeSearchTerm >= 0) { + + pIdxInfo->idxNum = iXSearchColumn - VSS_INDEX_COLUMN_VECTORS; + pIdxInfo->idxStr = (char *)"range_search"; + pIdxInfo->aConstraintUsage[iRangeSearchTerm].argvIndex = 1; + pIdxInfo->aConstraintUsage[iRangeSearchTerm].omit = 1; + pIdxInfo->estimatedCost = 300.0; + pIdxInfo->estimatedRows = 10; + return SQLITE_OK; } - pIdxInfo->estimatedCost = 300.0; - pIdxInfo->estimatedRows = 10; - return SQLITE_OK; - } - if(iRangeSearchTerm >=0) { - pIdxInfo->idxNum = iXSearchColumn - VSS_INDEX_COLUMN_VECTORS; - pIdxInfo->idxStr = (char*) "range_search"; - pIdxInfo->aConstraintUsage[iRangeSearchTerm].argvIndex = 1; - pIdxInfo->aConstraintUsage[iRangeSearchTerm].omit = 1; - pIdxInfo->estimatedCost = 300.0; - pIdxInfo->estimatedRows = 10; + + pIdxInfo->idxNum = -1; + pIdxInfo->idxStr = (char *)"fullscan"; + pIdxInfo->estimatedCost = 3000000.0; + pIdxInfo->estimatedRows = 100000; return SQLITE_OK; - } - pIdxInfo->idxNum = -1; - pIdxInfo->idxStr = (char*)"fullscan"; - pIdxInfo->estimatedCost = 3000000.0; - pIdxInfo->estimatedRows = 100000; - return SQLITE_OK; } -static int vssIndexFilter( - sqlite3_vtab_cursor *pVtabCursor, - int idxNum, const char *idxStr, - int argc, sqlite3_value **argv -){ - //printf("filter argc=%d, idxStr='%s', idxNum=%d\n", argc, idxStr, idxNum); - vss_index_cursor *pCur = (vss_index_cursor *)pVtabCursor; - if (strcmp(idxStr, "search")==0) { - pCur->query_type = QueryType::search; - std::vector * query_vector; - - VssSearchParams* params; - if ( (params = (VssSearchParams*) sqlite3_value_pointer(argv[0], "vss0_searchparams")) != NULL) { - pCur->search_k = params->k; - query_vector = new std::vector(*params->vector); - } - // https://sqlite.org/forum/info/6b32f818ba1d97ef - else if(sqlite3_libversion_number() < 3041000) { - sqlite3_free(pVtabCursor->pVtab->zErrMsg); - pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf("vss_search() only support vss_search_params() as a 2nd parameter for SQLite versions below 3.41.0"); - return SQLITE_ERROR; - } - else if ((query_vector = pCur->table->vector_api->xValueAsVector(argv[0])) != NULL) { - if(argc > 1) - pCur->search_k = sqlite3_value_int(argv[1]); - else { - sqlite3_free(pVtabCursor->pVtab->zErrMsg); - pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf("LIMIT required on vss_search() queries"); +static int vssIndexFilter(sqlite3_vtab_cursor *pVtabCursor, + int idxNum, + const char *idxStr, + int argc, + sqlite3_value **argv) { + + auto pCursor = static_cast(pVtabCursor); + + if (strcmp(idxStr, "search") == 0) { + + pCursor->query_type = QueryType::search; + vec_ptr query_vector; + + auto params = static_cast(sqlite3_value_pointer(argv[0], "vss0_searchparams")); + if (params != nullptr) { + + pCursor->limit = params->k; + query_vector = vec_ptr(new vector(*params->vector)); + + } else if (sqlite3_libversion_number() < 3041000) { + + // https://sqlite.org/forum/info/6b32f818ba1d97ef + sqlite3_free(pVtabCursor->pVtab->zErrMsg); + pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf( + "vss_search() only support vss_search_params() as a " + "2nd parameter for SQLite versions below 3.41.0"); + return SQLITE_ERROR; + + } else if ((query_vector = pCursor->table->vector_api->xValueAsVector( + argv[0])) != nullptr) { + + if (argc > 1) { + pCursor->limit = sqlite3_value_int(argv[1]); + } else { + sqlite3_free(pVtabCursor->pVtab->zErrMsg); + pVtabCursor->pVtab->zErrMsg = + sqlite3_mprintf("LIMIT required on vss_search() queries"); + return SQLITE_ERROR; + } + + } else { + + if (pVtabCursor->pVtab->zErrMsg != nullptr) + sqlite3_free(pVtabCursor->pVtab->zErrMsg); + + pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf( + "2nd argument to vss_search() must be a vector"); + return SQLITE_ERROR; + } + + int nq = 1; + auto index = pCursor->table->indexes.at(idxNum)->index; + + if (query_vector->size() != index->d) { + + // TODO: To support index that transforms vectors + // (to conserve spage, eg?), we should probably + // have some logic in place that transforms the vectors here? + sqlite3_free(pVtabCursor->pVtab->zErrMsg); + pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf( + "Input query size doesn't match index dimensions: %ld != %ld", + query_vector->size(), + index->d); + return SQLITE_ERROR; + } + + if (pCursor->limit <= 0) { + + sqlite3_free(pVtabCursor->pVtab->zErrMsg); + pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf( + "Limit must be greater than 0, got %ld", pCursor->limit); + return SQLITE_ERROR; + } + + // To avoid trying to select more records than number of records in index. + auto searchMax = min(static_cast(pCursor->limit) * nq, index->ntotal * nq); + + pCursor->search_distances = vector(searchMax, 0); + pCursor->search_ids = vector(searchMax, 0); + + index->search(nq, + query_vector->data(), + pCursor->limit, + pCursor->search_distances.data(), + pCursor->search_ids.data()); + + } else if (strcmp(idxStr, "range_search") == 0) { + + pCursor->query_type = QueryType::range_search; + + auto params = static_cast( + sqlite3_value_pointer(argv[0], "vss0_rangesearchparams")); + + int nq = 1; + + vector nns(params->distance * nq); + pCursor->range_search_result = unique_ptr(new faiss::RangeSearchResult(nq, true)); + + auto index = pCursor->table->indexes.at(idxNum)->index; + + index->range_search(nq, + params->vector->data(), + params->distance, + pCursor->range_search_result.get()); + + } else if (strcmp(idxStr, "fullscan") == 0) { + + pCursor->query_type = QueryType::fullscan; + sqlite3_stmt *stmt; + + int res = sqlite3_prepare_v2( + pCursor->table->db, + sqlite3_mprintf("select rowid from \"%w_data\"", pCursor->table->name), + -1, &pCursor->stmt, nullptr); + + if (res != SQLITE_OK) + return res; + + pCursor->step_result = sqlite3_step(pCursor->stmt); + + } else { + + if (pVtabCursor->pVtab->zErrMsg != 0) + sqlite3_free(pVtabCursor->pVtab->zErrMsg); + + pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf( + "%s %s", "vssIndexFilter error: unhandled idxStr", idxStr); return SQLITE_ERROR; - } - } - else { - sqlite3_free(pVtabCursor->pVtab->zErrMsg); - pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf("2nd argument to vss_search() must be a vector"); - return SQLITE_ERROR; - } - int nq = 1; - faiss::Index* index = pCur->table->indexes->at(idxNum); - if(query_vector->size() != index->d) { - sqlite3_free(pVtabCursor->pVtab->zErrMsg); - pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf("input query size doesn't match index dimensions: %ld != %ld", query_vector->size(), index->d); - delete query_vector; - return SQLITE_ERROR; - } - if(pCur->search_k <= 0) { - sqlite3_free(pVtabCursor->pVtab->zErrMsg); - pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf("k must be greater than 0, got %ld", pCur->search_k); - return SQLITE_ERROR; } - pCur->search_distances = new std::vector(pCur->search_k * nq); - pCur->search_ids = new std::vector(pCur->search_k * nq); - /*printf("d %p\n", index); - index->verbose = true; - //printf("k=%d\n", pCur->k); - //printf("v=%p\n", params->vector); - //printf("vsize=%p\n", params->vector->size()); - //printf("index->d=%d\n", index->d); - //printf("pCur->dis->size()=%ld\n", pCur->dis->size()); - //printf("pCur->nns->size()=%ld\n", pCur->nns->size()); - //printf("pls k=%d vsize=%lld index=%lld %lld %lld\n", params->k, params->vector->size(), index->d, pCur->dis->size(), pCur->nns->size());*/ - index->search(nq, query_vector->data(), pCur->search_k, pCur->search_distances->data(), pCur->search_ids->data()); - } - else if (strcmp(idxStr, "range_search")==0) { - pCur->query_type = QueryType::range_search; - VssRangeSearchParams* params = (VssRangeSearchParams*) sqlite3_value_pointer(argv[0], "vss0_rangesearchparams"); - int nq = 1; - std::vector nns(params->distance * nq); - faiss::RangeSearchResult * result = new faiss::RangeSearchResult(nq, true); - //pCur->k = params->k; - faiss::Index* index = pCur->table->indexes->at(idxNum); - index->range_search(nq, params->vector->data(), params->distance, result); - pCur->range_search_result = result; - - }else if (strcmp(idxStr, "fullscan")==0) { - pCur->query_type = QueryType::fullscan; - sqlite3_stmt* stmt; - int res = sqlite3_prepare_v2(pCur->table->db, sqlite3_mprintf("select rowid from \"%w_data\"", pCur->table->name), -1, &pCur->stmt, NULL); - if(res != SQLITE_OK) return res; - pCur->step_result = sqlite3_step(pCur->stmt); - } - else { - if(pVtabCursor->pVtab->zErrMsg != 0) sqlite3_free(pVtabCursor->pVtab->zErrMsg); - pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf("%s %s", "vssIndexFilter error: unhandled idxStr", idxStr); - return SQLITE_ERROR; - } - - pCur->iCurrent = 0; - return SQLITE_OK; + + pCursor->iCurrent = 0; + return SQLITE_OK; } -static int vssIndexNext(sqlite3_vtab_cursor *cur){ - vss_index_cursor *pCur = (vss_index_cursor*)cur; - switch(pCur->query_type) { - case QueryType::search: - case QueryType::range_search: { - pCur->iCurrent++; - break; - } - case QueryType::fullscan: { - pCur->step_result = sqlite3_step(pCur->stmt); +static int vssIndexNext(sqlite3_vtab_cursor *cur) { + + auto pCursor = static_cast(cur); + + switch (pCursor->query_type) { + + case QueryType::search: + case QueryType::range_search: + pCursor->iCurrent++; + break; + + case QueryType::fullscan: + pCursor->step_result = sqlite3_step(pCursor->stmt); } - } - return SQLITE_OK; + return SQLITE_OK; } -static int vssIndexRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid){ - vss_index_cursor *pCur = (vss_index_cursor*)cur; - switch(pCur->query_type) { - case QueryType::search: { - *pRowid = pCur->search_ids->at(pCur->iCurrent); - break; - } - case QueryType::range_search: { - *pRowid = pCur->range_search_result->labels[pCur->iCurrent]; - break; - } - case QueryType::fullscan: { - *pRowid = sqlite3_column_int64(pCur->stmt, 0); - break; + +static int vssIndexRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { + + auto pCursor = static_cast(cur); + + switch (pCursor->query_type) { + + case QueryType::search: + *pRowid = pCursor->search_ids.at(pCursor->iCurrent); + break; + + case QueryType::range_search: + *pRowid = pCursor->range_search_result->labels[pCursor->iCurrent]; + break; + + case QueryType::fullscan: + *pRowid = sqlite3_column_int64(pCursor->stmt, 0); + break; } - } - return SQLITE_OK; + return SQLITE_OK; } -static int vssIndexEof(sqlite3_vtab_cursor *cur){ - vss_index_cursor *pCur = (vss_index_cursor*)cur; - switch(pCur->query_type) { - case QueryType::search: { - return pCur->iCurrent >= pCur->search_k || (pCur->search_ids->at(pCur->iCurrent) == -1); - } - case QueryType::range_search: { - return pCur->iCurrent >= pCur->range_search_result->lims[1]; - } - case QueryType::fullscan: { - return pCur->step_result != SQLITE_ROW; +static int vssIndexEof(sqlite3_vtab_cursor *cur) { + + auto pCursor = static_cast(cur); + + switch (pCursor->query_type) { + + case QueryType::search: + return pCursor->iCurrent >= pCursor->limit || + pCursor->iCurrent >= pCursor->search_ids.size(); + + case QueryType::range_search: + return pCursor->iCurrent >= pCursor->range_search_result->lims[1]; + + case QueryType::fullscan: + return pCursor->step_result != SQLITE_ROW; } - } - return 1; + return 1; } -static int vssIndexColumn( - sqlite3_vtab_cursor *cur, - sqlite3_context *ctx, - int i -){ - vss_index_cursor *pCur = (vss_index_cursor*)cur; - if(i == VSS_INDEX_COLUMN_DISTANCE) { - switch(pCur->query_type) { - case QueryType::search: { - sqlite3_result_double(ctx, pCur->search_distances->at(pCur->iCurrent)); - break; - } - case QueryType::range_search: { - sqlite3_result_double(ctx, pCur->range_search_result->distances[pCur->iCurrent]); - break; - } - case QueryType::fullscan: { - break; - } - } - } - else if ( i >= VSS_INDEX_COLUMN_VECTORS) { - faiss::Index * index = pCur->table->indexes->at(i-VSS_INDEX_COLUMN_VECTORS); +static int vssIndexColumn(sqlite3_vtab_cursor *cur, + sqlite3_context *ctx, + int i) { - std::vector *v = new std::vector(index->d); - sqlite3_int64 rowid; - vssIndexRowid(cur, &rowid); + auto pCursor = static_cast(cur); - try { - index->reconstruct(rowid, v->data()); + if (i == VSS_INDEX_COLUMN_DISTANCE) { + switch (pCursor->query_type) { + + case QueryType::search: + sqlite3_result_double(ctx, + pCursor->search_distances.at(pCursor->iCurrent)); + break; + + case QueryType::range_search: + sqlite3_result_double(ctx, + pCursor->range_search_result->distances[pCursor->iCurrent]); + break; + + case QueryType::fullscan: + break; + } + + } else if (i >= VSS_INDEX_COLUMN_VECTORS) { + + auto index = + pCursor->table->indexes.at(i - VSS_INDEX_COLUMN_VECTORS)->index; + + vector vec(index->d); + sqlite3_int64 rowId; + vssIndexRowid(cur, &rowId); + + try { + index->reconstruct(rowId, vec.data()); + + } catch (faiss::FaissException &e) { + + char *errmsg = + (char *)sqlite3_mprintf("Error reconstructing vector - Does " + "the column factory string end " + "with IDMap2? Full error: %s", + e.msg.c_str()); + + sqlite3_result_error(ctx, errmsg, -1); + sqlite3_free(errmsg); + return SQLITE_ERROR; + } + pCursor->table->vector_api->xResultVector(ctx, &vec); } - catch(faiss::FaissException& e) { - //printf("%s\n", ex.what()); - char * errmsg = (char *) sqlite3_mprintf("Error reconstructing vector - Does the column factory string end with IDMap2? Full error: %s", e.msg.c_str()); - sqlite3_result_error(ctx, errmsg, -1); - sqlite3_free(errmsg); - return SQLITE_ERROR; - } - pCur->table->vector_api->xResultVector(ctx, v); - } - return SQLITE_OK; + return SQLITE_OK; } - static int vssIndexBegin(sqlite3_vtab *tab) { - //printf("BEGIN\n"); - return SQLITE_OK; + + return SQLITE_OK; } static int vssIndexSync(sqlite3_vtab *pVTab) { - vss_index_vtab *p = (vss_index_vtab*)pVTab; - bool needsWriting = false; - if(p->isTraining) { - //printf("TRAINING %lu\n", p->training->size()); - for (std::size_t i = 0; i != p->trainings->size(); ++i) { - auto training = p->trainings->at(i); - if(!training->empty()) { - faiss::Index * index = p->indexes->at(i); - index->train(training->size() / index->d, training->data()); - training->clear(); - } - } - p->isTraining = false; - } - - for (std::size_t i = 0; i < p->indexCount; ++i) { - auto delete_ids = p->delete_to_delete_ids->at(i); - if(!delete_ids->empty()) { - faiss::IDSelectorBatch * selector = new faiss::IDSelectorBatch(delete_ids->size(), delete_ids->data()); - size_t numRemoved = p->indexes->at(i)->remove_ids(*selector); - delete selector; - needsWriting= true; - delete_ids->clear(); - } - } - - if(p->isInsertData) { - //printf("WRITING INDEX\n"); - p->isInsertData = false; - - for (std::size_t i = 0; i < p->indexCount; ++i) { - auto insert_data = p->insert_to_add_data->at(i); - auto insert_ids = p->insert_to_add_ids->at(i); - if(!insert_data->empty()) { - try { - p->indexes->at(i)->add_with_ids(insert_ids->size(), insert_data->data(), (faiss::idx_t *) insert_ids->data()); + + auto pTable = static_cast(pVTab); + + try { + + bool needsWriting = false; + + auto idxCol = 0; + for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter, idxCol++) { + + // Checking if index needs training. + if (!(*iter)->trainings.empty()) { + + (*iter)->index->train( + (*iter)->trainings.size() / (*iter)->index->d, + (*iter)->trainings.data()); + + (*iter)->trainings.clear(); + (*iter)->trainings.shrink_to_fit(); + + needsWriting = true; + } + + // Checking if we're deleting records from the index. + if (!(*iter)->delete_ids.empty()) { + + faiss::IDSelectorBatch selector((*iter)->delete_ids.size(), + (*iter)->delete_ids.data()); + + (*iter)->index->remove_ids(selector); + (*iter)->delete_ids.clear(); + (*iter)->delete_ids.shrink_to_fit(); + + needsWriting = true; + } + + // Checking if we're inserting records to the index. + if (!(*iter)->insert_data.empty()) { + + (*iter)->index->add_with_ids( + (*iter)->insert_ids.size(), + (*iter)->insert_data.data(), + (faiss::idx_t *)(*iter)->insert_ids.data()); + + (*iter)->insert_ids.clear(); + (*iter)->insert_ids.shrink_to_fit(); + + (*iter)->insert_data.clear(); + (*iter)->insert_data.shrink_to_fit(); + + needsWriting = true; + } } - catch(faiss::FaissException& e) { - sqlite3_free(pVTab->zErrMsg); - pVTab->zErrMsg = sqlite3_mprintf("Error adding vector to index at column index %d. Full error: %s", i, e.msg.c_str()); - insert_ids->clear(); - insert_data->clear(); - return SQLITE_ERROR; + + if (needsWriting) { + + int i = 0; + for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter, i++) { + + int rc = write_index_insert((*iter)->index, + pTable->db, + pTable->schema, + pTable->name, + i); + + if (rc != SQLITE_OK) { + + sqlite3_free(pVTab->zErrMsg); + pVTab->zErrMsg = sqlite3_mprintf("Error saving index (%d): %s", + rc, sqlite3_errmsg(pTable->db)); + return rc; + } + } } - insert_ids->clear(); - insert_data->clear(); - insert_ids->shrink_to_fit(); - insert_data->shrink_to_fit(); - needsWriting= true; - } - } - } - if(needsWriting) { - for(int i = 0; i < p->indexCount; i++) { - int rc = write_index_insert(p->indexes->at(i), p->db, p->schema, p->name, i); - if(rc != SQLITE_OK) { + + return SQLITE_OK; + + } catch (faiss::FaissException &e) { + sqlite3_free(pVTab->zErrMsg); - pVTab->zErrMsg = sqlite3_mprintf("Error saving index (%d): %s", rc, sqlite3_errmsg(p->db)); - return rc; - } + pVTab->zErrMsg = + sqlite3_mprintf("Error during synchroning index. Full error: %s", + e.msg.c_str()); + + for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter) { + + (*iter)->insert_ids.clear(); + (*iter)->insert_ids.shrink_to_fit(); + + (*iter)->insert_data.clear(); + (*iter)->insert_data.shrink_to_fit(); + + (*iter)->delete_ids.clear(); + (*iter)->delete_ids.shrink_to_fit(); + + (*iter)->trainings.clear(); + (*iter)->trainings.shrink_to_fit(); + } + + return SQLITE_ERROR; } - } - return SQLITE_OK; } -static int vssIndexCommit(sqlite3_vtab *pVTab) { - return SQLITE_OK; -} +static int vssIndexCommit(sqlite3_vtab *pVTab) { return SQLITE_OK; } static int vssIndexRollback(sqlite3_vtab *pVTab) { - vss_index_vtab *p = (vss_index_vtab*)pVTab; - for (std::size_t i = 0; i != p->trainings->size(); ++i) { - auto training = p->trainings->at(i); - training->clear(); - } - for (std::size_t i = 0; i < p->indexCount; ++i) { - auto insert_data = p->insert_to_add_data->at(i); - auto insert_ids = p->insert_to_add_ids->at(i); - insert_ids->clear(); - insert_data->clear(); - auto delete_ids = p->delete_to_delete_ids->at(i); - delete_ids->clear(); - } - return SQLITE_OK; -} -static int vssIndexUpdate( - sqlite3_vtab *pVTab, - int argc, - sqlite3_value **argv, - sqlite_int64 *pRowid -) { - vss_index_vtab *p = (vss_index_vtab*)pVTab; - // DELETE operation - if (argc ==1 && sqlite3_value_type(argv[0]) != SQLITE_NULL) { - sqlite3_int64 rowid_to_delete = sqlite3_value_int64(argv[0]); - int rc; - if((rc = shadow_data_delete(p->db, p->schema, p->name, rowid_to_delete)) != SQLITE_OK) { - return rc; - } - for(int i = 0; i < p->indexCount; i++) { - p->delete_to_delete_ids->at(i)->push_back(rowid_to_delete); + auto pTable = static_cast(pVTab); + + for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter) { + + (*iter)->trainings.clear(); + (*iter)->trainings.shrink_to_fit(); + + (*iter)->insert_data.clear(); + (*iter)->insert_data.shrink_to_fit(); + + (*iter)->insert_ids.clear(); + (*iter)->insert_ids.shrink_to_fit(); + + (*iter)->delete_ids.clear(); + (*iter)->delete_ids.shrink_to_fit(); } + return SQLITE_OK; +} - } - // INSERT operation - else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) { - // if no operation, we adding it to the index - bool noOperation = sqlite3_value_type(argv[2+VSS_INDEX_COLUMN_OPERATION]) == SQLITE_NULL; - if (noOperation) { - std::vector* vec; - sqlite3_int64 rowid = sqlite3_value_int64(argv[1]); - bool inserted_rowid = false; - for(int i = 0; i < p->indexCount; i++) { - if ( (vec = p->vector_api->xValueAsVector(argv[2+VSS_INDEX_COLUMN_VECTORS + i])) != NULL ) { - // make sure the index is already trained, if it's needed - if(!p->indexes->at(i)->is_trained) { - sqlite3_free(pVTab->zErrMsg); - pVTab->zErrMsg = sqlite3_mprintf("Index at i=%d requires training before inserting data.", i); - delete vec; - return SQLITE_ERROR; - } - - if(!inserted_rowid) { - sqlite_int64 retrowid; - int rc = shadow_data_insert(p->db, p->schema, p->name, &rowid, &retrowid); - if (rc != SQLITE_OK) { - delete vec; - return rc; - } - inserted_rowid = true; - } - /*try { - p->indexes->at(i)->add_with_ids(1, vec->data(), (faiss::idx_t *) &rowid); - }*/ - p->insert_to_add_data->at(i)->reserve(vec->size() + distance(vec->begin(), vec->end())); - p->insert_to_add_data->at(i)->insert(p->insert_to_add_data->at(i)->end(), vec->begin(), vec->end()); - p->insert_to_add_ids->at(i)->push_back(rowid); - - p->isInsertData = true; - *pRowid = rowid; - delete vec; +static int vssIndexUpdate(sqlite3_vtab *pVTab, + int argc, + sqlite3_value **argv, + sqlite_int64 *pRowid) { + + auto pTable = static_cast(pVTab); + + if (argc == 1 && sqlite3_value_type(argv[0]) != SQLITE_NULL) { + + // DELETE operation + sqlite3_int64 rowid_to_delete = sqlite3_value_int64(argv[0]); + + auto rc = shadow_data_delete(pTable->db, + pTable->schema, + pTable->name, + rowid_to_delete); + if (rc != SQLITE_OK) + return rc; + + for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter) { + (*iter)->delete_ids.push_back(rowid_to_delete); } - } + } else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) { - } else { - std::string operation ((char *) sqlite3_value_text(argv[2+VSS_INDEX_COLUMN_OPERATION])); - if(operation.compare("training") == 0) { - std::vector* vec; - for(int i = 0; i < p->indexCount; i++) { - if ( (vec = p->vector_api->xValueAsVector(argv[2+VSS_INDEX_COLUMN_VECTORS+i])) != NULL ) { - p->trainings->at(i)->reserve(vec->size() + distance(vec->begin(), vec->end())); - p->trainings->at(i)->insert(p->trainings->at(i)->end(), vec->begin(), vec->end()); - p->isTraining = true; - delete vec; - } + // INSERT operation + // if no operation, we add it to the index + bool noOperation = + sqlite3_value_type(argv[2 + VSS_INDEX_COLUMN_OPERATION]) == + SQLITE_NULL; + + if (noOperation) { + + vec_ptr vec; + sqlite3_int64 rowid = sqlite3_value_int64(argv[1]); + bool inserted_rowid = false; + + auto i = 0; + for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter, i++) { + + if ((vec = pTable->vector_api->xValueAsVector( + argv[2 + VSS_INDEX_COLUMN_VECTORS + i])) != nullptr) { + + // Make sure the index is already trained, if it's needed + if (!(*iter)->index->is_trained) { + + sqlite3_free(pVTab->zErrMsg); + pVTab->zErrMsg = + sqlite3_mprintf("Index at i=%d requires training " + "before inserting data.", + i); + + return SQLITE_ERROR; + } + + if (!inserted_rowid) { + + sqlite_int64 retrowid; + auto rc = shadow_data_insert(pTable->db, pTable->schema, pTable->name, + &rowid, &retrowid); + if (rc != SQLITE_OK) + return rc; + + inserted_rowid = true; + } + + (*iter)->insert_data.reserve((*iter)->insert_data.size() + vec->size()); + (*iter)->insert_data.insert( + (*iter)->insert_data.end(), + vec->begin(), + vec->end()); + + (*iter)->insert_ids.push_back(rowid); + + *pRowid = rowid; + } + } + + } else { + + string operation((char *)sqlite3_value_text(argv[2 + VSS_INDEX_COLUMN_OPERATION])); + + if (operation.compare("training") == 0) { + + auto i = 0; + for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter, i++) { + + vec_ptr vec = pTable->vector_api->xValueAsVector(argv[2 + VSS_INDEX_COLUMN_VECTORS + i]); + if (vec != nullptr) { + + (*iter)->trainings.reserve((*iter)->trainings.size() + vec->size()); + (*iter)->trainings.insert( + (*iter)->trainings.end(), + vec->begin(), + vec->end()); + } + } + + } else { + + return SQLITE_ERROR; + } } - } - else { - //printf("unknown operation\n"); + + } else { + + // TODO: Implement - UPDATE operation + sqlite3_free(pVTab->zErrMsg); + + pVTab->zErrMsg = + sqlite3_mprintf("update on vss0 virtual tables not supported yet."); + return SQLITE_ERROR; - } } - } - // some UPDATE operations - else { - sqlite3_free(pVTab->zErrMsg); - pVTab->zErrMsg = sqlite3_mprintf("UPDATE on vss0 virtual tables not supported yet."); - return SQLITE_ERROR; - } - return SQLITE_OK; + + return SQLITE_OK; } +static void vssSearchFunc(sqlite3_context *context, + int argc, + sqlite3_value **argv) { } -static void vssSearchFunc( - sqlite3_context *context, - int argc, - sqlite3_value **argv -){ - //printf("search?\n"); -} -static void faissMemoryUsageFunc( - sqlite3_context *context, - int argc, - sqlite3_value **argv -){ - sqlite3_result_int64(context, faiss::get_mem_usage_kb()); -} -static void vssRangeSearchFunc( - sqlite3_context *context, - int argc, - sqlite3_value **argv -){ - //printf("range search?\n"); +static void faissMemoryUsageFunc(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + sqlite3_result_int64(context, faiss::get_mem_usage_kb()); } +static void vssRangeSearchFunc(sqlite3_context *context, + int argc, + sqlite3_value **argv) { } + static int vssIndexFindFunction( - sqlite3_vtab *pVtab, - int nArg, - const char *zName, - void (**pxFunc)(sqlite3_context*,int,sqlite3_value**), - void **ppArg -) { - //printf("find function. %d %s %s \n", nArg, zName, (char *) sqlite3_version); - if( sqlite3_stricmp(zName, "vss_search")==0 ){ - *pxFunc = vssSearchFunc; - *ppArg = 0; - return VSS_SEARCH_FUNCTION; - } - if( sqlite3_stricmp(zName, "vss_range_search")==0 ){ - *pxFunc = vssRangeSearchFunc; - *ppArg = 0; - return VSS_RANGE_SEARCH_FUNCTION; - } - return 0; + sqlite3_vtab *pVtab, + int nArg, + const char *zName, + void (**pxFunc)(sqlite3_context *, int, sqlite3_value **), + void **ppArg) { + + if (sqlite3_stricmp(zName, "vss_search") == 0) { + *pxFunc = vssSearchFunc; + *ppArg = 0; + return VSS_SEARCH_FUNCTION; + } + + if (sqlite3_stricmp(zName, "vss_range_search") == 0) { + *pxFunc = vssRangeSearchFunc; + *ppArg = 0; + return VSS_RANGE_SEARCH_FUNCTION; + } + return 0; }; -static int vssIndexShadowName(const char *zName){ - static const char *azName[] = { - "index", - "data" - }; - unsigned int i; - for(i=0; iiVersion); - sqlite3_create_function_v2(db, "vss_version", 0, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS, 0, vss_version, 0, 0, 0); - sqlite3_create_function_v2(db, "vss_debug", 0, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS, 0, vss_debug, 0, 0, 0); - - sqlite3_create_function_v2(db, "vss_distance_l1", 2, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS, vector_api, vss_distance_l1, 0, 0, 0); - sqlite3_create_function_v2(db, "vss_distance_l2", 2, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS, vector_api, vss_distance_l2, 0, 0, 0); - sqlite3_create_function_v2(db, "vss_distance_linf", 2, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS, vector_api, vss_distance_linf, 0, 0, 0); - sqlite3_create_function_v2(db, "vss_inner_product", 2, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS, vector_api, vss_inner_product, 0, 0, 0); - sqlite3_create_function_v2(db, "vss_fvec_add", 2, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS, vector_api, vss_fvec_add, 0, 0, 0); - sqlite3_create_function_v2(db, "vss_fvec_sub", 2, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS, vector_api, vss_fvec_sub, 0, 0, 0); - - sqlite3_create_function_v2(db, "vss_search", 2, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS, vector_api, vssSearchFunc, 0, 0, 0); - sqlite3_create_function_v2(db, "vss_search_params", 2, 0, vector_api, VssSearchParamsFunc, 0, 0, 0); - sqlite3_create_function_v2(db, "vss_range_search", 2, SQLITE_UTF8|SQLITE_DETERMINISTIC|SQLITE_INNOCUOUS, vector_api, vssRangeSearchFunc, 0, 0, 0); - sqlite3_create_function_v2(db, "vss_range_search_params", 2, 0, vector_api, VssRangeSearchParamsFunc, 0, 0, 0); - - sqlite3_create_module_v2 (db, "vss0", &vssIndexModule, vector_api, 0); - return 0; - } } #pragma endregion diff --git a/src/sqlite-vss.h.in b/src/sqlite-vss.h.in index 80cc639..63889ed 100644 --- a/src/sqlite-vss.h.in +++ b/src/sqlite-vss.h.in @@ -5,20 +5,21 @@ #include "sqlite3ext.h" #define SQLITE_VSS_VERSION "@SQLITE_VSS_VERSION@" -#define SQLITE_VSS_VERSION_MAJOR @sqlite-vss_VERSION_MAJOR@ -#define SQLITE_VSS_VERSION_MINOR @sqlite-vss_VERSION_MINOR@ -#define SQLITE_VSS_VERSION_PATCH @sqlite-vss_VERSION_PATCH@ -#define SQLITE_VSS_VERSION_TWEAK @sqlite-vss_VERSION_TWEAK@ - +#define SQLITE_VSS_VERSION_MAJOR @sqlite - vss_VERSION_MAJOR @ +#define SQLITE_VSS_VERSION_MINOR @sqlite - vss_VERSION_MINOR @ +#define SQLITE_VSS_VERSION_PATCH @sqlite - vss_VERSION_PATCH @ +#define SQLITE_VSS_VERSION_TWEAK @sqlite - vss_VERSION_TWEAK @ #ifdef __cplusplus extern "C" { #endif -int sqlite3_vss_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi); +int sqlite3_vss_init(sqlite3 *db, + char **pzErrMsg, + const sqlite3_api_routines *pApi); #ifdef __cplusplus -} /* end of the 'extern "C"' block */ +} /* end of the 'extern "C"' block */ #endif #endif /* ifndef _SQLITE_VSS_H */ diff --git a/tests/test-loadable.py b/tests/test-loadable.py index 556bf8d..8cdd374 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -226,16 +226,16 @@ def range_search(column, v, d): {'rowid': 1002, 'distance': 9.0}, ]) - with self.assertRaisesRegex(sqlite3.OperationalError, 'input query size doesn\'t match index dimensions: 0 != 1'): + with self.assertRaisesRegex(sqlite3.OperationalError, 'Input query size doesn\'t match index dimensions: 0 != 1'): search('b', '[]', 2) - with self.assertRaisesRegex(sqlite3.OperationalError, 'input query size doesn\'t match index dimensions: 3 != 1'): + with self.assertRaisesRegex(sqlite3.OperationalError, 'Input query size doesn\'t match index dimensions: 3 != 1'): search('b', '[0.1, 0.2, 0.3]', 2) - with self.assertRaisesRegex(sqlite3.OperationalError, 'k must be greater than 0, got -1'): + with self.assertRaisesRegex(sqlite3.OperationalError, 'Limit must be greater than 0, got -1'): search('b', '[6]', -1) - with self.assertRaisesRegex(sqlite3.OperationalError, 'k must be greater than 0, got 0'): + with self.assertRaisesRegex(sqlite3.OperationalError, 'Limit must be greater than 0, got 0'): search('b', '[6]', 0) self.assertEqual(range_search('a', '[0.5, 0.5]', 1), [ @@ -499,7 +499,6 @@ def test_vss0_issue_29_upsert(self): VECTOR_FUNCTIONS = [ 'vector0', 'vector_debug', - 'vector_debug', 'vector_from_blob', 'vector_from_json', 'vector_from_raw', @@ -526,13 +525,11 @@ def test_vector_version(self): self.assertEqual(db.execute("select vector_version()").fetchone()[0][0], "v") def test_vector_debug(self): - debug = db.execute("select vector_debug()").fetchone()[0].split('\n') - self.assertEqual(len(debug), 1) self.assertEqual( db.execute("select vector_debug(json('[]'))").fetchone()[0], "size: 0 []" ) - with self.assertRaisesRegex(sqlite3.OperationalError, "value not a vector"): + with self.assertRaisesRegex(sqlite3.OperationalError, "Value not a vector"): db.execute("select vector_debug(']')").fetchone() def test_vector0(self): @@ -583,7 +580,7 @@ def test_vector_from_raw(self): 'size: 2 [0.000000, 0.100000]' ) - with self.assertRaisesRegex(sqlite3.OperationalError, "Invalid raw blob length, must be divisible by 4"): + with self.assertRaisesRegex(sqlite3.OperationalError, "Invalid raw blob length, blob must be divisible by 4"): vector_from_raw_blob(b"abc") def test_vector_from_json(self):