From e951c6d6233dfd5dfcb8d859525ddd5ff9f19388 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Wed, 16 Aug 2023 18:02:45 -0700 Subject: [PATCH] new tokenizer for vss0 column defs, test updates, edit faiss_ondisk --- CMakeLists.txt | 2 +- src/sqlite-vss.cpp | 349 ++++++--- tests/test-loadable.py | 1666 +++++++++++++++++++++++++--------------- 3 files changed, 1302 insertions(+), 715 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c59d993..5100d5f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,7 @@ endif() configure_file(src/sqlite-vss.h.in sqlite-vss.h) configure_file(src/sqlite-vector.h.in sqlite-vector.h) -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) option(FAISS_ENABLE_GPU "" OFF) diff --git a/src/sqlite-vss.cpp b/src/sqlite-vss.cpp index 765e111..3e17f57 100644 --- a/src/sqlite-vss.cpp +++ b/src/sqlite-vss.cpp @@ -11,6 +11,7 @@ SQLITE_EXTENSION_INIT1 #include #include #include +#include #include #include @@ -184,7 +185,7 @@ static void vss_cosine_similarity(sqlite3_context *context, int argc, -1); return; } - + float inner_product = faiss::fvec_inner_product(lhs->data(), rhs->data(), lhs->size()); float lhs_norm = faiss::fvec_norm_L2sqr(lhs->data(), lhs->size()); float rhs_norm = faiss::fvec_norm_L2sqr(rhs->data(), rhs->size()); @@ -645,14 +646,12 @@ static faiss::Index *read_index_select(sqlite3 *db, const char *schema, const ch int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr); if (rc != SQLITE_OK || stmt == nullptr) { finalize_and_free(stmt, sql); - printf("zz prepare error\n"); return nullptr; } sqlite3_bind_int64(stmt, 1, indexId); if ((rc = sqlite3_step(stmt)) != SQLITE_ROW) { finalize_and_free(stmt, sql); - printf("zz step error %d\n", rc); return nullptr; } @@ -676,10 +675,11 @@ static int create_shadow_tables(sqlite3 *db, vector indices) { - bool skip_shadow_index = false; + // make the _index shadow tables if there's at least 1 column that uses the default faiss_shadow + bool skip_shadow_index = true; for (auto i : indices) { - if (i->storage_type == StorageType::faiss_ondisk) { - skip_shadow_index = true; + if (i->storage_type == StorageType::faiss_shadow) { + skip_shadow_index = false; } } @@ -741,109 +741,275 @@ static int drop_shadow_tables(sqlite3 *db, char *name) { #define VSS_SEARCH_FUNCTION SQLITE_INDEX_CONSTRAINT_FUNCTION #define VSS_RANGE_SEARCH_FUNCTION SQLITE_INDEX_CONSTRAINT_FUNCTION + 1 +// Tokens types when parsing vss0 column definitions +enum TokenType { + LPAREN = 1, + RPAREN = 2, + EQUAL = 3, + IDENTIFIER = 50, + STRING = 51, + INTEGER = 52 +}; -faiss::MetricType parse_metric_type(const std::string& metric_type) { - static const std::unordered_map metric_type_map = { - {"L1", faiss::METRIC_L1}, - {"L2", faiss::METRIC_L2}, - {"INNER_PRODUCT", faiss::METRIC_INNER_PRODUCT}, - {"Linf", faiss::METRIC_Linf}, - // {"Lp", faiss::METRIC_Lp}, unavailable until metric arg is added - {"Canberra", faiss::METRIC_Canberra}, - {"BrayCurtis", faiss::METRIC_BrayCurtis}, - {"JensenShannon", faiss::METRIC_JensenShannon} - }; +// Tokens types when parsing vss0 column definitions +struct Token { +public: + TokenType token_type; + // Only definied when token_type == TokenType::IDENTIFIER + string identifier_value; + // Only definied when token_type == TokenType::STRING + string string_value; + // Only definied when token_type == TokenType::INTEGER + int int_value; + + explicit Token(TokenType token_type) : token_type(token_type) {} + // TODO: maybe these should just be different classes that inherit Token? idk C++ man + static Token IdentifierToken(string value){ + Token token(TokenType::IDENTIFIER); + token.identifier_value = value; + return token; + } + static Token StringToken(string value){ + Token token(TokenType::STRING); + token.string_value = value; + return token; + } + static Token IntToken(int value){ + Token token(TokenType::INTEGER); + token.int_value = value; + return token; + } +}; - auto it = metric_type_map.find(metric_type); - if (it == metric_type_map.end()) { - throw invalid_argument("unknown metric type: " + metric_type); +// scans through a vss0 column definition string +struct Scanner { +public: + explicit Scanner(string source): source(source), idx(0), len(source.length()) {} + char advance() { + return source[idx++]; + } + optional peek() { + if(idx >= len) { + return {}; } + return source[idx]; + } + bool eof() { + return idx >= len; + } - return it->second; -} +private: + string source; + int idx; + int len; + +}; -StorageType parse_storage_type(const string& storage_type) { - if (storage_type == "faiss_ondisk") return StorageType::faiss_ondisk; - else if (storage_type == "faiss_shadow") return StorageType::faiss_shadow; - throw invalid_argument("unknown option for on storage type: " + storage_type); +// Valid chatacters allowed in an identifier, except the 1st character (most be a-zA-Z_) +bool is_identifier_char(optional c) { + if(!c.has_value()) { + return false; + } + return (c >= 'a' && c <= 'z') + || (c >= 'A' && c <= 'Z') + || (c >= '0' && c <= '9') + || (c == '_'); } -string parse_factory(const string& s) { - size_t lquote = s.find_first_of("\""); - size_t rquote = s.find_last_of("\""); +// sus out the tokens in a vss0 column definition +vector tokenize(string source) { + vector tokens; + Scanner scanner(source); + while (!scanner.eof()) { + char c = scanner.advance(); + if(c == ' ' || c == '\n' || c == '\t') { + continue; + } + else if(c=='(') { + tokens.push_back(Token(TokenType::LPAREN)); + } + else if(c==')') { + tokens.push_back(Token(TokenType::RPAREN)); + } + else if(c=='=') { + tokens.push_back(Token(TokenType::EQUAL)); + } + else if((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) { + string identifier; + identifier.push_back(c); + while (is_identifier_char(scanner.peek())) { + identifier.push_back(scanner.advance()); + } + tokens.push_back(Token::IdentifierToken(identifier)); - if (lquote == string::npos || rquote == string::npos || - lquote >= rquote) { - return nullptr; } + else if(c >= '0' && c <= '9') { + string number_literal; + number_literal.push_back(c); + optional next; + while((next = scanner.peek()) && next.has_value() && next >= '0' && next <= '9') { + number_literal.push_back(*next); + scanner.advance(); + } + tokens.push_back(Token::IntToken(atoi(number_literal.c_str()))); + } + else if (c == '"') { + string string_literal; + optional next; + while(!scanner.eof() && (next = scanner.peek()) && next != '"') { + string_literal.push_back(*next); + scanner.advance(); + } + if (next == '"') { + scanner.advance(); + tokens.push_back(Token::StringToken(string_literal)); + }else { + throw invalid_argument("unterminated string"); + } - return s.substr(lquote + 1, rquote - lquote - 1); + } + } + return tokens; } -template -T parse_attribute(const string& arg, const string& keyword, T default_value, - size_t rparen, function parse_func) { - - size_t keywordStart, keywordStringStartFrom; - - if ((keywordStart = arg.find(keyword, rparen)) != string::npos && - (keywordStringStartFrom = arg.find("=", keywordStart)) != string::npos) { - - size_t l = arg.find_first_not_of(" ", keywordStringStartFrom + 1); - if (l == string::npos) { - throw invalid_argument("Invalid " + keyword + " value"); - } +static const std::unordered_map metric_type_map = { + {"L1", faiss::METRIC_L1}, + {"L2", faiss::METRIC_L2}, + {"INNER_PRODUCT", faiss::METRIC_INNER_PRODUCT}, + {"Linf", faiss::METRIC_Linf}, + // {"Lp", faiss::METRIC_Lp}, unavailable until metric arg is added + {"Canberra", faiss::METRIC_Canberra}, + {"BrayCurtis", faiss::METRIC_BrayCurtis}, + {"JensenShannon", faiss::METRIC_JensenShannon} + }; - size_t r; - if ((r = arg.find(' ', l)) && r != string::npos) { - return parse_func(arg.substr(l, r - l)); - } +// parse a vss0 column definition. Throws on errors +VssIndexColumn parse_vss0_column_definition(string source) { + string name; + sqlite3_int64 dimensions; + string factory = "Flat,IDMap2"; + faiss::MetricType metric_type = faiss::MetricType::METRIC_L2; + StorageType storage_type = StorageType::faiss_shadow; + + vector tokens = tokenize(source); + std::vector::iterator it = tokens.begin(); + + // STEP 1: expect a column name as a first token (identifier) + if(it == tokens.end()) { + throw invalid_argument("No tokens available."); + } + if((*it).token_type != TokenType::IDENTIFIER) { + throw invalid_argument("Expected identifier as first token, got "); + } + name = (*it).identifier_value; + + // STEP 2: ensure a '(' immediately follows + it++; + if(it == tokens.end()) { + throw invalid_argument("Expected dimensions after column name declaration"); + } + if((*it).token_type != TokenType::LPAREN) { + throw invalid_argument("Expected left parethesis '('"); + } + + // STEP 3: ensure an integer token (# dimensions) immediately follows + it++; + if(it == tokens.end()) { + throw invalid_argument("Expected dimensions after column name declaration"); + } + if((*it).token_type != TokenType::INTEGER) { + throw invalid_argument("Expected integer"); + } + dimensions = (*it).int_value; + + // STEP 4: ensure a ')' token immediately follows + it++; + if(it == tokens.end()) { + throw invalid_argument("Expected dimensions after column name declaration"); + } + if((*it).token_type != TokenType::RPAREN) { + throw invalid_argument("Expected right parethesis ')'"); + } + + // STEP 5: parse any column options + it++; + while(it != tokens.end()) { + // every column option must start with an identifier + if((*it).token_type != TokenType::IDENTIFIER) { + throw invalid_argument("Expected an identifier for column arguments"); + } + string key = (*it).identifier_value; + if(key != "factory" && key != "metric_type" && key != "storage_type") { + throw invalid_argument("Unknown vss0 column option '" + key + "'"); + } - r = min(arg.find(',', l), arg.size()); + // ensure it is followed by an '=' + it++; + if(it == tokens.end() || (*it).token_type != TokenType::EQUAL) { + throw invalid_argument("Expected '=' "); + } - return parse_func(arg.substr(l, r - l)); + // now parse the value - different legal values for each key + it++; + if(it == tokens.end()) { + throw invalid_argument("Expected column option value for " + key); + } + if (key == "factory") { + if((*it).token_type != TokenType::STRING) { + throw invalid_argument("Expected string value for factory column option, got "); + } + factory = (*it).string_value; } - else { - return default_value; + else if (key == "metric_type") { + if((*it).token_type != TokenType::IDENTIFIER) { + throw invalid_argument("Expected an identifier value for the 'metric_type' column option"); + } + string value = (*it).identifier_value; + auto it2 = metric_type_map.find(value); + + if( it2 == metric_type_map.end()) { + throw invalid_argument("Unknown metric type: " + value); + } + + metric_type = it2->second; } + else if (key == "storage_type") { + if((*it).token_type != TokenType::IDENTIFIER) { + throw invalid_argument("Expected an identifier value for the 'storage_type' column option"); + } + string value = (*it).identifier_value; + if(value == "faiss_shadow") { + storage_type = StorageType::faiss_shadow; + } + else if(value == "faiss_ondisk") { + storage_type = StorageType::faiss_ondisk; + }else { + throw invalid_argument("storage_type value must be one of faiss_shadow or faiss_ondisk"); + } + } + + it++; + } + return VssIndexColumn { + name, + dimensions, + factory, + metric_type, + storage_type + }; } unique_ptr> parse_constructor(int argc, const char* const* argv, sqlite3 *db) { - auto columns = - unique_ptr>(new vector()); + auto columns = unique_ptr>(new vector()); for (int i = 3; i < argc; i++) { - string arg = string(argv[i]); - - size_t lparen = arg.find("("); - size_t rparen = arg.find(")"); - - if (lparen == string::npos || rparen == string::npos || lparen >= rparen) { - return nullptr; - } - - string name = arg.substr(0, lparen); - string sDimensions = arg.substr(lparen + 1, rparen - lparen - 1); - sqlite3_int64 dimensions = atoi(sDimensions.c_str()); - - string factory = - parse_attribute(arg, "factory", "Flat,IDMap2", rparen, - parse_factory); - - faiss::MetricType metric_type = parse_attribute( - arg, "metric_type", faiss::METRIC_L2, rparen, parse_metric_type); - - StorageType storage_type = - parse_attribute(arg, "storage_type", StorageType::faiss_shadow, rparen, parse_storage_type); - - if (storage_type == StorageType::faiss_ondisk && sqlite3_db_filename(db, "main") == nullptr) { + auto column = parse_vss0_column_definition(string(argv[i])); + if (column.storage_type == StorageType::faiss_ondisk && sqlite3_db_filename(db, "main")[0] == '\0') { throw invalid_argument("Cannot use on disk storage for in memory db"); } - - columns->push_back( - VssIndexColumn{ name, dimensions, factory, metric_type, storage_type }); + columns->push_back(column); } return columns; @@ -868,7 +1034,7 @@ static int init(sqlite3 *db, try { columns = parse_constructor(argc, argv, db); } catch (const invalid_argument& e) { - *pzErr = sqlite3_mprintf(e.what()); + *pzErr = sqlite3_mprintf("Error parsing constructor: %s", e.what()); return SQLITE_ERROR; } @@ -919,8 +1085,11 @@ static int init(sqlite3 *db, } rc = create_shadow_tables(db, argv[1], argv[2], pTable->indexes); - if (rc != SQLITE_OK) - return rc; + if (rc != SQLITE_OK){ + *pzErr = sqlite3_mprintf("Error creating shadow tables"); + return rc; + } + // Shadow tables were successully created. // After shadow tables are created, write the initial index state to @@ -938,11 +1107,13 @@ static int init(sqlite3 *db, (*iter)->name, (*iter)->storage_type); - if (rc != SQLITE_OK) - return rc; + if (rc != SQLITE_OK) { + *pzErr = sqlite3_mprintf("Error initializing _index shadow tables"); + return rc; + } } catch (faiss::FaissException &e) { - + *pzErr = sqlite3_mprintf("Faiss error when initializing shadow tables: %s", e.what()); return SQLITE_ERROR; } } diff --git a/tests/test-loadable.py b/tests/test-loadable.py index bec894a..52f5a18 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -4,164 +4,224 @@ import os import tempfile -EXT_VSS_PATH="./dist/debug/vss0" -EXT_VECTOR_PATH="./dist/debug/vector0" +EXT_VSS_PATH = "./dist/debug/vss0" +EXT_VECTOR_PATH = "./dist/debug/vector0" + def connect(path=":memory:"): - db = sqlite3.connect(path) + db = sqlite3.connect(path) - db.enable_load_extension(True) + db.enable_load_extension(True) - db.execute("create temp table base_functions as select name from pragma_function_list") - db.execute("create temp table base_modules as select name from pragma_module_list") - db.load_extension(EXT_VECTOR_PATH) - db.execute("create temp table vector_loaded_functions as select name from pragma_function_list where name not in (select name from base_functions) order by name") - db.execute("create temp table vector_loaded_modules as select name from pragma_module_list where name not in (select name from base_modules) order by name") + db.execute( + "create temp table base_functions as select name from pragma_function_list" + ) + db.execute("create temp table base_modules as select name from pragma_module_list") + db.load_extension(EXT_VECTOR_PATH) + db.execute( + "create temp table vector_loaded_functions as select name from pragma_function_list where name not in (select name from base_functions) order by name" + ) + db.execute( + "create temp table vector_loaded_modules as select name from pragma_module_list where name not in (select name from base_modules) order by name" + ) - db.execute("drop table base_functions") - db.execute("drop table base_modules") + db.execute("drop table base_functions") + db.execute("drop table base_modules") - db.execute("create temp table base_functions as select name from pragma_function_list") - db.execute("create temp table base_modules as select name from pragma_module_list") + db.execute( + "create temp table base_functions as select name from pragma_function_list" + ) + db.execute("create temp table base_modules as select name from pragma_module_list") - db.load_extension(EXT_VSS_PATH) + db.load_extension(EXT_VSS_PATH) - db.execute("create temp table vss_loaded_functions as select name from pragma_function_list where name not in (select name from base_functions) order by name") - db.execute("create temp table vss_loaded_modules as select name from pragma_module_list where name not in (select name from base_modules) order by name") + db.execute( + "create temp table vss_loaded_functions as select name from pragma_function_list where name not in (select name from base_functions) order by name" + ) + db.execute( + "create temp table vss_loaded_modules as select name from pragma_module_list where name not in (select name from base_modules) order by name" + ) - db.row_factory = sqlite3.Row - return db + db.row_factory = sqlite3.Row + return db db = connect() + def explain_query_plan(sql): - return db.execute("explain query plan " + sql).fetchone()["detail"] + return db.execute("explain query plan " + sql).fetchone()["detail"] + def execute_all(cursor, sql, args=None): - if args is None: args = [] - results = cursor.execute(sql, args).fetchall() - return list(map(lambda x: dict(x), results)) + if args is None: + args = [] + results = cursor.execute(sql, args).fetchall() + return list(map(lambda x: dict(x), results)) + VSS_FUNCTIONS = [ - 'vss_cosine_similarity', - 'vss_debug', - 'vss_distance_l1', - 'vss_distance_l2', - 'vss_distance_linf', - 'vss_fvec_add', - 'vss_fvec_sub', - 'vss_inner_product', - 'vss_memory_usage', - 'vss_range_search', - 'vss_range_search_params', - 'vss_search', - 'vss_search_params', - 'vss_version', + "vss_cosine_similarity", + "vss_debug", + "vss_distance_l1", + "vss_distance_l2", + "vss_distance_linf", + "vss_fvec_add", + "vss_fvec_sub", + "vss_inner_product", + "vss_memory_usage", + "vss_range_search", + "vss_range_search_params", + "vss_search", + "vss_search_params", + "vss_version", ] VSS_MODULES = [ - "vss0", + "vss0", ] + + class TestVss(unittest.TestCase): - def test_funcs(self): - funcs = list(map(lambda a: a[0], db.execute("select name from vss_loaded_functions").fetchall())) - self.assertEqual(funcs, VSS_FUNCTIONS) - - def test_modules(self): - modules = list(map(lambda a: a[0], db.execute("select name from vss_loaded_modules").fetchall())) - self.assertEqual(modules, VSS_MODULES) - - - def test_vss_version(self): - self.assertEqual(db.execute("select vss_version()").fetchone()[0][0], "v") - - def test_vss_debug(self): - debug = db.execute("select vss_debug()").fetchone()[0].split('\n') - self.assertEqual(len(debug), 3) - - def test_vss_distance_l1(self): - vss_distance_l1 = lambda a, b: db.execute("select vss_distance_l1(json(?), json(?))", [a, b]).fetchone()[0] - self.assertEqual(vss_distance_l1('[0, 0]', '[0, 0]'), 0.0) - self.assertEqual(vss_distance_l1('[0, 0]', '[0, 1]'), 1.0) - - def test_vss_distance_l2(self): - vss_distance_l2 = lambda a, b: db.execute("select vss_distance_l2(json(?), json(?))", [a, b]).fetchone()[0] - self.assertEqual(vss_distance_l2('[0, 0]', '[0, 0]'), 0.0) - self.assertEqual(vss_distance_l2('[0, 0]', '[0, 1]'), 1.0) - - def test_vss_distance_linf(self): - vss_distance_linf = lambda a, b: db.execute("select vss_distance_linf(json(?), json(?))", [a, b]).fetchone()[0] - self.assertEqual(vss_distance_linf('[0, 0]', '[0, 0]'), 0.0) - self.assertEqual(vss_distance_linf('[0, 0]', '[0, 1]'), 1.0) - - def test_vss_inner_product(self): - vss_inner_product = lambda a, b: db.execute("select vss_inner_product(json(?), json(?))", [a, b]).fetchone()[0] - self.assertEqual(vss_inner_product('[0, 0]', '[0, 0]'), 0.0) - self.assertEqual(vss_inner_product('[0, 0]', '[0, 1]'), 0.0) - - def test_vss_cosine_similarity(self): - vss_cosine_similarity = lambda a, b: db.execute("select vss_cosine_similarity(json(?), json(?))", [a, b]).fetchone()[0] - self.assertAlmostEqual(vss_cosine_similarity('[2, 1]', '[1, 2]'), 0.8) - self.assertEqual(vss_cosine_similarity('[1, 1]', '[-1, 1]'), 0.0) - - def test_vss_fvec_add(self): - vss_fvec_add = lambda a, b: db.execute("select vss_fvec_add(json(?), json(?))", [a, b]).fetchone()[0] - self.assertEqual(vss_fvec_add('[0, 0]', '[0, 0]'), b'\x00\x00\x00\x00\x00\x00\x00\x00') - self.assertEqual(vss_fvec_add('[-1, -1]', '[1, 1]'), b'\x00\x00\x00\x00\x00\x00\x00\x00') - self.assertEqual(vss_fvec_add('[0, 0]', '[1, 1]'), b'\x00\x00\x80?\x00\x00\x80?') - self.skipTest("TODO") - - def test_vss_fvec_sub(self): - vss_fvec_sub = lambda a, b: db.execute("select vss_fvec_sub(json(?), json(?))", [a, b]).fetchone()[0] - self.assertEqual(vss_fvec_sub('[0, 0]', '[0, 0]'), b'\x00\x00\x00\x00\x00\x00\x00\x00') - self.assertEqual(vss_fvec_sub('[1, 1]', '[1, 1]'), b'\x00\x00\x00\x00\x00\x00\x00\x00') - self.assertEqual(vss_fvec_sub('[0, 0]', '[1, 1]'), b'\x00\x00\x80\xbf\x00\x00\x80\xbf') - self.skipTest("TODO") - - def test_vss_search(self): - self.skipTest("TODO") - - def test_vss_search_params(self): - self.skipTest("TODO") - - def test_vss_memory_usage(self): - self.skipTest("TODO") - - def test_vss_range_search(self): - self.skipTest("TODO") - - def test_vss_range_search_params(self): - self.skipTest("TODO") - - - def test_vss0(self): - # - # | - # 1000 -> X - # | 1002 - # | / - # | V - # ---X-----------------X-- - # ^1001 | - # | - # | - # X <- 1003 - # | - # - # - cur = db.cursor() - execute_all(cur, """ + def test_funcs(self): + funcs = list( + map( + lambda a: a[0], + db.execute("select name from vss_loaded_functions").fetchall(), + ) + ) + self.assertEqual(funcs, VSS_FUNCTIONS) + + def test_modules(self): + modules = list( + map( + lambda a: a[0], + db.execute("select name from vss_loaded_modules").fetchall(), + ) + ) + self.assertEqual(modules, VSS_MODULES) + + def test_vss_version(self): + self.assertEqual(db.execute("select vss_version()").fetchone()[0][0], "v") + + def test_vss_debug(self): + debug = db.execute("select vss_debug()").fetchone()[0].split("\n") + self.assertEqual(len(debug), 3) + + def test_vss_distance_l1(self): + vss_distance_l1 = lambda a, b: db.execute( + "select vss_distance_l1(json(?), json(?))", [a, b] + ).fetchone()[0] + self.assertEqual(vss_distance_l1("[0, 0]", "[0, 0]"), 0.0) + self.assertEqual(vss_distance_l1("[0, 0]", "[0, 1]"), 1.0) + + def test_vss_distance_l2(self): + vss_distance_l2 = lambda a, b: db.execute( + "select vss_distance_l2(json(?), json(?))", [a, b] + ).fetchone()[0] + self.assertEqual(vss_distance_l2("[0, 0]", "[0, 0]"), 0.0) + self.assertEqual(vss_distance_l2("[0, 0]", "[0, 1]"), 1.0) + + def test_vss_distance_linf(self): + vss_distance_linf = lambda a, b: db.execute( + "select vss_distance_linf(json(?), json(?))", [a, b] + ).fetchone()[0] + self.assertEqual(vss_distance_linf("[0, 0]", "[0, 0]"), 0.0) + self.assertEqual(vss_distance_linf("[0, 0]", "[0, 1]"), 1.0) + + def test_vss_inner_product(self): + vss_inner_product = lambda a, b: db.execute( + "select vss_inner_product(json(?), json(?))", [a, b] + ).fetchone()[0] + self.assertEqual(vss_inner_product("[0, 0]", "[0, 0]"), 0.0) + self.assertEqual(vss_inner_product("[0, 0]", "[0, 1]"), 0.0) + + def test_vss_cosine_similarity(self): + vss_cosine_similarity = lambda a, b: db.execute( + "select vss_cosine_similarity(json(?), json(?))", [a, b] + ).fetchone()[0] + self.assertAlmostEqual(vss_cosine_similarity("[2, 1]", "[1, 2]"), 0.8) + self.assertEqual(vss_cosine_similarity("[1, 1]", "[-1, 1]"), 0.0) + + def test_vss_fvec_add(self): + vss_fvec_add = lambda a, b: db.execute( + "select vss_fvec_add(json(?), json(?))", [a, b] + ).fetchone()[0] + self.assertEqual( + vss_fvec_add("[0, 0]", "[0, 0]"), b"\x00\x00\x00\x00\x00\x00\x00\x00" + ) + self.assertEqual( + vss_fvec_add("[-1, -1]", "[1, 1]"), b"\x00\x00\x00\x00\x00\x00\x00\x00" + ) + self.assertEqual( + vss_fvec_add("[0, 0]", "[1, 1]"), b"\x00\x00\x80?\x00\x00\x80?" + ) + self.skipTest("TODO") + + def test_vss_fvec_sub(self): + vss_fvec_sub = lambda a, b: db.execute( + "select vss_fvec_sub(json(?), json(?))", [a, b] + ).fetchone()[0] + self.assertEqual( + vss_fvec_sub("[0, 0]", "[0, 0]"), b"\x00\x00\x00\x00\x00\x00\x00\x00" + ) + self.assertEqual( + vss_fvec_sub("[1, 1]", "[1, 1]"), b"\x00\x00\x00\x00\x00\x00\x00\x00" + ) + self.assertEqual( + vss_fvec_sub("[0, 0]", "[1, 1]"), b"\x00\x00\x80\xbf\x00\x00\x80\xbf" + ) + self.skipTest("TODO") + + def test_vss_search(self): + self.skipTest("TODO") + + def test_vss_search_params(self): + self.skipTest("TODO") + + def test_vss_memory_usage(self): + self.skipTest("TODO") + + def test_vss_range_search(self): + self.skipTest("TODO") + + def test_vss_range_search_params(self): + self.skipTest("TODO") + + def test_vss0(self): + # + # | + # 1000 -> X + # | 1002 + # | / + # | V + # ---X-----------------X-- + # ^1001 | + # | + # | + # X <- 1003 + # | + # + # + cur = db.cursor() + execute_all( + cur, + """ create virtual table x using vss0( a(2) factory="Flat,IDMap2", b(1) factory="Flat,IDMap2"); - """) - execute_all(cur, """ + """, + ) + execute_all( + cur, + """ insert into x(rowid, a, b) select key + 1000, json_extract(value, '$[0]'), json_extract(value, '$[1]') from json_each(?); - """, [""" + """, + [ + """ [ [[0, 1], [1]], [[0, -1], [2]], @@ -169,293 +229,567 @@ def test_vss0(self): [[-1, 0], [4]], [[0, 0], [5]] ] - """]) - db.commit() - - self.assertEqual(cur.lastrowid, 1004) - self.assertEqual(execute_all(cur, "select rowid, length(idx) from x_index"), [ - {'rowid': 0, 'length(idx)': 170}, - {'rowid': 1, 'length(idx)': 150} - ]) - self.assertEqual(execute_all(cur, "select rowid from x_data"), [ - {"rowid": 1000}, - {"rowid": 1001}, - {"rowid": 1002}, - {"rowid": 1003}, - {"rowid": 1004}, - ]) - - with self.subTest("delete + commit"): - execute_all(cur, "delete from x where rowid = 1004") - db.commit() - - self.assertEqual(execute_all(cur, "select rowid from x_data"), [ - {"rowid": 1000}, - {"rowid": 1001}, - {"rowid": 1002}, - {"rowid": 1003}, - ]) - - self.assertEqual(execute_all(cur, "select rowid, length(idx) from x_index"), [ - {'rowid': 0, 'length(idx)': 154}, - {'rowid': 1, 'length(idx)': 138} - ]) - - with self.subTest("delete + rollback"): - execute_all(cur, "delete from x where rowid = 1003") - self.assertEqual(execute_all(cur, "select rowid from x_data"), [ - {"rowid": 1000}, - {"rowid": 1001}, - {"rowid": 1002}, - ]) - - db.rollback() - - self.assertEqual(execute_all(cur, "select rowid from x_data"), [ - {"rowid": 1000}, - {"rowid": 1001}, - {"rowid": 1002}, - {"rowid": 1003}, - ]) - - self.assertEqual(execute_all(cur, "select rowid, length(idx) from x_index"), [ - {'rowid': 0, 'length(idx)': 154}, - {'rowid': 1, 'length(idx)': 138} - ]) - - def search(column, v, k): - return execute_all(cur, f"select rowid, distance from x where vss_search({column}, vss_search_params(json(?), ?))", [v, k]) - - def range_search(column, v, d): - return execute_all(cur, f"select rowid, distance from x where vss_range_search({column}, vss_range_search_params(json(?), ?))", [v, d]) - - self.assertEqual(search('a', '[0.9, 0]', 5), [ - {'rowid': 1002, 'distance': 0.010000004433095455}, - {'rowid': 1000, 'distance': 1.809999942779541}, - {'rowid': 1001, 'distance': 1.809999942779541}, - {'rowid': 1003, 'distance': 3.609999895095825} - ]) - self.assertEqual(search('b', '[6]', 2), [ - {'rowid': 1003, 'distance': 4.0}, - {'rowid': 1002, 'distance': 9.0}, - ]) - - with self.assertRaisesRegex(sqlite3.OperationalError, 'Input query size doesn\'t match index dimensions: 0 != 1'): - search('b', '[]', 2) - - with self.assertRaisesRegex(sqlite3.OperationalError, 'Input query size doesn\'t match index dimensions: 3 != 1'): - search('b', '[0.1, 0.2, 0.3]', 2) - - with self.assertRaisesRegex(sqlite3.OperationalError, 'Limit must be greater than 0, got -1'): - search('b', '[6]', -1) - - with self.assertRaisesRegex(sqlite3.OperationalError, 'Limit must be greater than 0, got 0'): - search('b', '[6]', 0) - - self.assertEqual(range_search('a', '[0.5, 0.5]', 1), [ - {'rowid': 1000, 'distance': 0.5}, - {'rowid': 1002, 'distance': 0.5}, - ]) - self.assertEqual(range_search('b', '[2.5]', 1), [ - {'rowid': 1001, 'distance': 0.25}, - {'rowid': 1002, 'distance': 0.25}, - ]) - - self.assertEqual(execute_all(cur, 'select rowid, a, b, distance from x'), [ - {'rowid': 1000, "a": b'\x00\x00\x00\x00\x00\x00\x80?', "b": b'\x00\x00\x80?', "distance": None}, - {'rowid': 1001, "a": b'\x00\x00\x00\x00\x00\x00\x80\xbf', "b": b'\x00\x00\x00@', "distance": None}, - {'rowid': 1002, "a": b'\x00\x00\x80?\x00\x00\x00\x00', "b": b'\x00\x00@@', "distance": None}, - {'rowid': 1003, "a": b'\x00\x00\x80\xbf\x00\x00\x00\x00', "b": b'\x00\x00\x80@', "distance": None}, - ]) - self.assertEqual(execute_all(cur, 'select rowid, vector_debug(a) as a, vector_debug(b) as b, distance from x'), [ - {'rowid': 1000, "a": "size: 2 [0.000000, 1.000000]", "b": "size: 1 [1.000000]", "distance": None}, - {'rowid': 1001, "a": "size: 2 [0.000000, -1.000000]", "b": "size: 1 [2.000000]", "distance": None}, - {'rowid': 1002, "a": "size: 2 [1.000000, 0.000000]", "b": "size: 1 [3.000000]", "distance": None}, - {'rowid': 1003, "a": "size: 2 [-1.000000, 0.000000]", "b": "size: 1 [4.000000]", "distance": None}, - ]) - - with self.assertRaisesRegex(sqlite3.OperationalError, "UPDATE statements on vss0 virtual tables not supported yet."): - execute_all(cur, 'update x set b = json(?) where rowid = ?', ['[444]', 1003]) - if sqlite3.sqlite_version_info[1] >= 41: - self.assertEqual( - execute_all( - cur, - f"select rowid, distance from x where vss_search(a, json(?)) limit ?", - ['[0.9, 0]', 2] - ), - [ - {'distance': 0.010000004433095455, 'rowid': 1002}, - {'distance': 1.809999942779541, 'rowid': 1000} - ] - ) - with self.assertRaisesRegex(sqlite3.OperationalError, "2nd argument to vss_search\(\) must be a vector"): + """ + ], + ) + db.commit() + + self.assertEqual(cur.lastrowid, 1004) + self.assertEqual( + execute_all(cur, "select rowid, length(idx) from x_index"), + [{"rowid": 0, "length(idx)": 170}, {"rowid": 1, "length(idx)": 150}], + ) + self.assertEqual( + execute_all(cur, "select rowid from x_data"), + [ + {"rowid": 1000}, + {"rowid": 1001}, + {"rowid": 1002}, + {"rowid": 1003}, + {"rowid": 1004}, + ], + ) + + with self.subTest("delete + commit"): + execute_all(cur, "delete from x where rowid = 1004") + db.commit() + + self.assertEqual( + execute_all(cur, "select rowid from x_data"), + [ + {"rowid": 1000}, + {"rowid": 1001}, + {"rowid": 1002}, + {"rowid": 1003}, + ], + ) + + self.assertEqual( + execute_all(cur, "select rowid, length(idx) from x_index"), + [{"rowid": 0, "length(idx)": 154}, {"rowid": 1, "length(idx)": 138}], + ) + + with self.subTest("delete + rollback"): + execute_all(cur, "delete from x where rowid = 1003") + self.assertEqual( + execute_all(cur, "select rowid from x_data"), + [ + {"rowid": 1000}, + {"rowid": 1001}, + {"rowid": 1002}, + ], + ) + + db.rollback() + + self.assertEqual( + execute_all(cur, "select rowid from x_data"), + [ + {"rowid": 1000}, + {"rowid": 1001}, + {"rowid": 1002}, + {"rowid": 1003}, + ], + ) + + self.assertEqual( + execute_all(cur, "select rowid, length(idx) from x_index"), + [{"rowid": 0, "length(idx)": 154}, {"rowid": 1, "length(idx)": 138}], + ) + + def search(column, v, k): + return execute_all( + cur, + f"select rowid, distance from x where vss_search({column}, vss_search_params(json(?), ?))", + [v, k], + ) + + def range_search(column, v, d): + return execute_all( + cur, + f"select rowid, distance from x where vss_range_search({column}, vss_range_search_params(json(?), ?))", + [v, d], + ) + + self.assertEqual( + search("a", "[0.9, 0]", 5), + [ + {"rowid": 1002, "distance": 0.010000004433095455}, + {"rowid": 1000, "distance": 1.809999942779541}, + {"rowid": 1001, "distance": 1.809999942779541}, + {"rowid": 1003, "distance": 3.609999895095825}, + ], + ) + self.assertEqual( + search("b", "[6]", 2), + [ + {"rowid": 1003, "distance": 4.0}, + {"rowid": 1002, "distance": 9.0}, + ], + ) + + with self.assertRaisesRegex( + sqlite3.OperationalError, + "Input query size doesn't match index dimensions: 0 != 1", + ): + search("b", "[]", 2) + + with self.assertRaisesRegex( + sqlite3.OperationalError, + "Input query size doesn't match index dimensions: 3 != 1", + ): + search("b", "[0.1, 0.2, 0.3]", 2) + + with self.assertRaisesRegex( + sqlite3.OperationalError, "Limit must be greater than 0, got -1" + ): + search("b", "[6]", -1) + + with self.assertRaisesRegex( + sqlite3.OperationalError, "Limit must be greater than 0, got 0" + ): + search("b", "[6]", 0) + + self.assertEqual( + range_search("a", "[0.5, 0.5]", 1), + [ + {"rowid": 1000, "distance": 0.5}, + {"rowid": 1002, "distance": 0.5}, + ], + ) + self.assertEqual( + range_search("b", "[2.5]", 1), + [ + {"rowid": 1001, "distance": 0.25}, + {"rowid": 1002, "distance": 0.25}, + ], + ) + + self.assertEqual( + execute_all(cur, "select rowid, a, b, distance from x"), + [ + { + "rowid": 1000, + "a": b"\x00\x00\x00\x00\x00\x00\x80?", + "b": b"\x00\x00\x80?", + "distance": None, + }, + { + "rowid": 1001, + "a": b"\x00\x00\x00\x00\x00\x00\x80\xbf", + "b": b"\x00\x00\x00@", + "distance": None, + }, + { + "rowid": 1002, + "a": b"\x00\x00\x80?\x00\x00\x00\x00", + "b": b"\x00\x00@@", + "distance": None, + }, + { + "rowid": 1003, + "a": b"\x00\x00\x80\xbf\x00\x00\x00\x00", + "b": b"\x00\x00\x80@", + "distance": None, + }, + ], + ) + self.assertEqual( + execute_all( + cur, + "select rowid, vector_debug(a) as a, vector_debug(b) as b, distance from x", + ), + [ + { + "rowid": 1000, + "a": "size: 2 [0.000000, 1.000000]", + "b": "size: 1 [1.000000]", + "distance": None, + }, + { + "rowid": 1001, + "a": "size: 2 [0.000000, -1.000000]", + "b": "size: 1 [2.000000]", + "distance": None, + }, + { + "rowid": 1002, + "a": "size: 2 [1.000000, 0.000000]", + "b": "size: 1 [3.000000]", + "distance": None, + }, + { + "rowid": 1003, + "a": "size: 2 [-1.000000, 0.000000]", + "b": "size: 1 [4.000000]", + "distance": None, + }, + ], + ) + + with self.assertRaisesRegex( + sqlite3.OperationalError, + "UPDATE statements on vss0 virtual tables not supported yet.", + ): + execute_all( + cur, "update x set b = json(?) where rowid = ?", ["[444]", 1003] + ) + if sqlite3.sqlite_version_info[1] >= 41: + self.assertEqual( + execute_all( + cur, + f"select rowid, distance from x where vss_search(a, json(?)) limit ?", + ["[0.9, 0]", 2], + ), + [ + {"distance": 0.010000004433095455, "rowid": 1002}, + {"distance": 1.809999942779541, "rowid": 1000}, + ], + ) + with self.assertRaisesRegex( + sqlite3.OperationalError, + "2nd argument to vss_search\(\) must be a vector", + ): + execute_all( + cur, f"select rowid, distance from x where vss_search(a, 3) limit 1" + ) + + else: + with self.assertRaisesRegex( + sqlite3.OperationalError, + "vss_search\(\) only support vss_search_params\(\) as a 2nd parameter for SQLite versions below 3.41.0", + ): + execute_all( + cur, + f"select rowid, distance from x where vss_search(a, json(?)) limit ?", + ["[0.9, 0]", 2], + ) + + self.assertRegex( + explain_query_plan("select * from x where vss_search(a, null);"), + r"SCAN (TABLE )?x VIRTUAL TABLE INDEX 0:search", + ) + self.assertRegex( + explain_query_plan("select * from x where vss_search(b, null);"), + r"SCAN (TABLE )?x VIRTUAL TABLE INDEX 1:search", + ) + self.assertRegex( + explain_query_plan("select * from x where vss_range_search(a, null);"), + r"SCAN (TABLE )?x VIRTUAL TABLE INDEX 0:range_search", + ) + self.assertRegex( + explain_query_plan("select * from x where vss_range_search(b, null);"), + r"SCAN (TABLE )?x VIRTUAL TABLE INDEX 1:range_search", + ) + self.assertRegex( + explain_query_plan("select * from x"), + r"SCAN (TABLE )?x VIRTUAL TABLE INDEX -1:fullscan", + ) + + # TODO support rowid point queries + + self.assertEqual(db.execute("select count(*) from x_data").fetchone()[0], 4) + + execute_all(cur, "drop table x;") + with self.assertRaisesRegex(sqlite3.OperationalError, "no such table: x_data"): + self.assertEqual(db.execute("select count(*) from x_data").fetchone()[0], 4) + + with self.assertRaisesRegex(sqlite3.OperationalError, "no such table: x_index"): + self.assertEqual( + db.execute("select count(*) from x_index").fetchone()[0], 2 + ) + + def test_vss0_persistent(self): + tf = tempfile.NamedTemporaryFile(delete=False) + tf.close() + + db = connect(tf.name) + db.execute("create table t as select 1 as a") + cur = db.cursor() execute_all( cur, - f"select rowid, distance from x where vss_search(a, 3) limit 1" - ) - - else: - with self.assertRaisesRegex(sqlite3.OperationalError, "vss_search\(\) only support vss_search_params\(\) as a 2nd parameter for SQLite versions below 3.41.0"): + """ + create virtual table x using vss0( a(2), b(1) factory="Flat,IDMap2"); + """, + ) execute_all( cur, - f"select rowid, distance from x where vss_search(a, json(?)) limit ?", - ['[0.9, 0]', 2] - ) - - self.assertRegex( - explain_query_plan("select * from x where vss_search(a, null);"), - r'SCAN (TABLE )?x VIRTUAL TABLE INDEX 0:search' - ) - self.assertRegex( - explain_query_plan("select * from x where vss_search(b, null);"), - r'SCAN (TABLE )?x VIRTUAL TABLE INDEX 1:search' - ) - self.assertRegex( - explain_query_plan("select * from x where vss_range_search(a, null);"), - r'SCAN (TABLE )?x VIRTUAL TABLE INDEX 0:range_search' - ) - self.assertRegex( - explain_query_plan("select * from x where vss_range_search(b, null);"), - r'SCAN (TABLE )?x VIRTUAL TABLE INDEX 1:range_search' - ) - self.assertRegex( - explain_query_plan("select * from x"), - r'SCAN (TABLE )?x VIRTUAL TABLE INDEX -1:fullscan' - ) - - # TODO support rowid point queries - - self.assertEqual(db.execute("select count(*) from x_data").fetchone()[0], 4) - - execute_all(cur, "drop table x;") - with self.assertRaisesRegex(sqlite3.OperationalError, "no such table: x_data"): - self.assertEqual(db.execute("select count(*) from x_data").fetchone()[0], 4) - - with self.assertRaisesRegex(sqlite3.OperationalError, "no such table: x_index"): - self.assertEqual(db.execute("select count(*) from x_index").fetchone()[0], 2) - - - def test_vss0_persistent(self): - tf = tempfile.NamedTemporaryFile(delete=False) - tf.close() - - db = connect(tf.name) - db.execute("create table t as select 1 as a") - cur = db.cursor() - execute_all(cur, """ - create virtual table x using vss0( a(2), b(1) factory="Flat,IDMap2"); - """) - execute_all(cur, """ + """ insert into x(rowid, a, b) select key + 1000, json_extract(value, '$[0]'), json_extract(value, '$[1]') from json_each(?); - """, [""" + """, + [ + """ [ [[0, 1], [1]], [[0, -1], [2]], [[1, 0], [3]], [[-1, 0], [4]] ] - """]) - - db.commit() - - def search(cur, column, v, k): - return execute_all(cur, f"select rowid, distance from x where vss_search({column}, vss_search_params(json(?), ?))", [v, k]) - - self.assertEqual(search(cur, 'a', '[0.9, 0]', 5), [ - {'rowid': 1002, 'distance': 0.010000004433095455}, - {'rowid': 1000, 'distance': 1.809999942779541}, - {'rowid': 1001, 'distance': 1.809999942779541}, - {'rowid': 1003, 'distance': 3.609999895095825} - ]) - self.assertEqual(search(cur, 'b', '[6]', 2), [ - {'rowid': 1003, 'distance': 4.0}, - {'rowid': 1002, 'distance': 9.0}, - ]) - db.close() - - db = connect(tf.name) - cur = db.cursor() - self.assertEqual(execute_all(db.cursor(), "select a from t"), [{"a": 1}]) - self.assertEqual(search(cur, 'a', '[0.9, 0]', 5), [ - {'rowid': 1002, 'distance': 0.010000004433095455}, - {'rowid': 1000, 'distance': 1.809999942779541}, - {'rowid': 1001, 'distance': 1.809999942779541}, - {'rowid': 1003, 'distance': 3.609999895095825} - ]) - self.assertEqual(search(cur, 'b', '[6]', 2), [ - {'rowid': 1003, 'distance': 4.0}, - {'rowid': 1002, 'distance': 9.0}, - ]) - - db.close() - def test_vss0_persistent_stress(self): - tf = tempfile.NamedTemporaryFile(delete=False) - tf.close() - - # create a vss0 table with no data, then close it. When re-opening it, should work as expected - db = connect(tf.name) - db.execute("create virtual table x using vss0(a(2));") - db.close() - - db = connect(tf.name) - self.assertEqual(execute_all(db, "select rowid, length(idx) as length from x_index"), [ - {'rowid': 0, 'length': 90}, - ]) - db.execute("insert into x(rowid, a) select 1, json_array(1, 2)") - db.commit() - - self.assertEqual(execute_all(db, "select rowid, length(idx) as length from x_index"), [ - {'rowid': 0, 'length': 106}, - ]) - - self.assertEqual(execute_all(db, "select rowid, vector_debug(a) as a from x"), [ - {'rowid': 1, 'a': 'size: 2 [1.000000, 2.000000]'}, - ]) - db.close() - - def test_vss_stress(self): - cur = db.cursor() - - execute_all(cur, 'create virtual table no_id_map using vss0(a(2) factory="Flat");') - with self.assertRaisesRegex(sqlite3.OperationalError, ".*add_with_ids not implemented for this type of index.*"): - execute_all(cur, """insert into no_id_map(rowid, a) select 100, json('[0, 1]')""") - db.commit() - - - execute_all(cur, 'create virtual table no_id_map2 using vss0(a(2) factory="Flat,IDMap");') - execute_all(cur, "insert into no_id_map2(rowid, a) select 100, json('[0, 1]')") - # fails because query references `a`, but cannot reconstruct the vector from the index bc only IDMap - with self.assertRaisesRegex(sqlite3.OperationalError, ".*reconstruct not implemented for this type of index"): - execute_all(cur, "select rowid, a from no_id_map2;") - # but this suceeds, because only the rowid column is referenced - execute_all(cur, "select rowid from no_id_map2;") - - with self.assertRaisesRegex(sqlite3.OperationalError, ".*could not parse index string invalid"): - execute_all(cur, 'create virtual table t1 using vss0(a(2) factory="invalid");') - - with self.assertRaisesRegex(sqlite3.OperationalError, "unknown metric type: L3"): - db.execute("create virtual table xx using vss0( a(2) metric_type=L3)") - - with self.assertRaisesRegex(sqlite3.OperationalError, "invalid metric_type value"): - db.execute("create virtual table xx using vss0( a(2) metric_type=)") - - def test_vss_training(self): - import random - import json - cur = db.cursor() - execute_all( - cur, - 'create virtual table with_training using vss0(a(4) factory="IVF10,Flat,IDMap2", b(4) factory="IVF10,Flat,IDMap2")' - ) - data = list(map(lambda x: [random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)], range(0, 1000))) - execute_all( - cur, - """ + """ + ], + ) + + db.commit() + + def search(cur, column, v, k): + return execute_all( + cur, + f"select rowid, distance from x where vss_search({column}, vss_search_params(json(?), ?))", + [v, k], + ) + + self.assertEqual( + search(cur, "a", "[0.9, 0]", 5), + [ + {"rowid": 1002, "distance": 0.010000004433095455}, + {"rowid": 1000, "distance": 1.809999942779541}, + {"rowid": 1001, "distance": 1.809999942779541}, + {"rowid": 1003, "distance": 3.609999895095825}, + ], + ) + self.assertEqual( + search(cur, "b", "[6]", 2), + [ + {"rowid": 1003, "distance": 4.0}, + {"rowid": 1002, "distance": 9.0}, + ], + ) + db.close() + + db = connect(tf.name) + cur = db.cursor() + self.assertEqual(execute_all(db.cursor(), "select a from t"), [{"a": 1}]) + self.assertEqual( + search(cur, "a", "[0.9, 0]", 5), + [ + {"rowid": 1002, "distance": 0.010000004433095455}, + {"rowid": 1000, "distance": 1.809999942779541}, + {"rowid": 1001, "distance": 1.809999942779541}, + {"rowid": 1003, "distance": 3.609999895095825}, + ], + ) + self.assertEqual( + search(cur, "b", "[6]", 2), + [ + {"rowid": 1003, "distance": 4.0}, + {"rowid": 1002, "distance": 9.0}, + ], + ) + + db.close() + + def test_vss0_persistent_stress(self): + tf = tempfile.NamedTemporaryFile(delete=False) + tf.close() + + # create a vss0 table with no data, then close it. When re-opening it, should work as expected + db = connect(tf.name) + db.execute("create virtual table x using vss0(a(2));") + db.close() + + db = connect(tf.name) + self.assertEqual( + execute_all(db, "select rowid, length(idx) as length from x_index"), + [ + {"rowid": 0, "length": 90}, + ], + ) + db.execute("insert into x(rowid, a) select 1, json_array(1, 2)") + db.commit() + + self.assertEqual( + execute_all(db, "select rowid, length(idx) as length from x_index"), + [ + {"rowid": 0, "length": 106}, + ], + ) + + self.assertEqual( + execute_all(db, "select rowid, vector_debug(a) as a from x"), + [ + {"rowid": 1, "a": "size: 2 [1.000000, 2.000000]"}, + ], + ) + db.close() + + def test_vss_stress(self): + cur = db.cursor() + + execute_all( + cur, 'create virtual table no_id_map using vss0(a(2) factory="Flat");' + ) + with self.assertRaisesRegex( + sqlite3.OperationalError, + ".*add_with_ids not implemented for this type of index.*", + ): + execute_all( + cur, """insert into no_id_map(rowid, a) select 100, json('[0, 1]')""" + ) + db.commit() + + execute_all( + cur, + 'create virtual table no_id_map2 using vss0(a(2) factory="Flat,IDMap");', + ) + execute_all(cur, "insert into no_id_map2(rowid, a) select 100, json('[0, 1]')") + # fails because query references `a`, but cannot reconstruct the vector from the index bc only IDMap + with self.assertRaisesRegex( + sqlite3.OperationalError, + ".*reconstruct not implemented for this type of index", + ): + execute_all(cur, "select rowid, a from no_id_map2;") + # but this suceeds, because only the rowid column is referenced + execute_all(cur, "select rowid from no_id_map2;") + + with self.assertRaisesRegex( + sqlite3.OperationalError, ".*could not parse index string invalid" + ): + execute_all( + cur, 'create virtual table t1 using vss0(a(2) factory="invalid");' + ) + + with self.assertRaisesRegex( + sqlite3.OperationalError, + "Error parsing constructor: Unknown metric type: L3", + ): + db.execute("create virtual table xx using vss0( a(2) metric_type=L3)") + + with self.assertRaisesRegex( + sqlite3.OperationalError, + "Error parsing constructor: Expected column option value for metric_type", + ): + db.execute("create virtual table xx using vss0( a(2) metric_type=)") + with self.assertRaisesRegex( + sqlite3.OperationalError, + "Error parsing constructor: storage_type value must be one of faiss_shadow or faiss_ondisk", + ): + db.execute("create virtual table xx using vss0( a(2) storage_type=xxx)") + + def test_vss0_storage_type_ondisk_only(self): + tf = tempfile.NamedTemporaryFile(delete=False) + tf.close() + + db = connect(tf.name) + db.execute( + "create virtual table vss_ondisk using vss0(a(2) storage_type=faiss_ondisk)" + ) + db.execute("insert into vss_ondisk(rowid, a) select ?1, ?2;", [1, "[0.1, 0.1]"]) + db.commit() + EXPECTED_IDX = b"IxM2\x02\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00IxF2\x02\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\xcd\xcc\xcc=\xcd\xcc\xcc=\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00" + faissindex_path = tf.name + ".main.vss_ondisk.a.faissindex" + with open(faissindex_path, "rb") as f: + bin = f.read() + self.assertEqual(bin, EXPECTED_IDX) + self.assertEqual( + execute_all(db, "select name from sqlite_master order by 1"), + # _index is here because the b colum + [ + { + "name": "sqlite_sequence", + }, + { + "name": "vss_ondisk", + }, + { + "name": "vss_ondisk_data", + }, + ], + ) + os.remove(tf.name) + os.remove(faissindex_path) + + def test_vss0_storage_type_mixed(self): + tf = tempfile.NamedTemporaryFile(delete=False) + tf.close() + + db = connect(tf.name) + db.execute( + """ + create virtual table vss_mixed using vss0( + a(2) storage_type=faiss_ondisk, + b(2), + ) + """ + ) + db.execute( + "insert into vss_mixed(rowid, a, b) select ?1, ?2, ?2;", [1, "[0.1, 0.1]"] + ) + db.commit() + faissindex_path = tf.name + ".main.vss_mixed.a.faissindex" + with open(faissindex_path, "rb") as f: + bin = f.read() + self.assertEqual(len(bin), 106) + self.assertEqual( + bin, + b"IxM2\x02\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00IxF2\x02\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\xcd\xcc\xcc=\xcd\xcc\xcc=\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00", + ) + self.assertEqual( + execute_all(db, "select name from sqlite_master order by 1"), + # _index is here because the b colum + [ + { + "name": "sqlite_sequence", + }, + { + "name": "vss_mixed", + }, + { + "name": "vss_mixed_data", + }, + { + "name": "vss_mixed_index", + }, + ], + ) + self.assertEqual( + execute_all(db, "select * from vss_mixed_index"), + [ + { + "rowid": 1, + "idx": b"IxM2\x02\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00IxF2\x02\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\xcd\xcc\xcc=\xcd\xcc\xcc=\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00", + } + ], + ) + os.remove(tf.name) + os.remove(faissindex_path) + + db = connect() + with self.assertRaisesRegex( + sqlite3.OperationalError, + "Error parsing constructor: Cannot use on disk storage for in memory db", + ): + db.execute( + "create virtual table vss_on_disk using vss0(a(2) storage_type=faiss_ondisk)" + ) + + def test_vss_training(self): + import random + import json + + cur = db.cursor() + execute_all( + cur, + 'create virtual table with_training using vss0(a(4) factory="IVF10,Flat,IDMap2", b(4) factory="IVF10,Flat,IDMap2")', + ) + data = list( + map( + lambda x: [ + random.uniform(0, 1), + random.uniform(0, 1), + random.uniform(0, 1), + random.uniform(0, 1), + ], + range(0, 1000), + ) + ) + execute_all( + cur, + """ insert into with_training(operation, a, b) select 'training', @@ -463,14 +797,16 @@ def test_vss_training(self): value from json_each(?) """, - [json.dumps(data)] - ) - db.commit() - self.assertEqual(cur.execute('select count(*) from with_training').fetchone()[0], 0) + [json.dumps(data)], + ) + db.commit() + self.assertEqual( + cur.execute("select count(*) from with_training").fetchone()[0], 0 + ) - execute_all( - cur, - """ + execute_all( + cur, + """ insert into with_training(rowid, a, b) select key, @@ -478,79 +814,97 @@ def test_vss_training(self): value from json_each(?) """, - [json.dumps(data)] - ) - self.assertEqual(cur.execute('select count(*) from with_training').fetchone()[0], 1000) - db.commit() - self.assertEqual(cur.execute('select count(*) from with_training').fetchone()[0], 1000) - - def test_vss0_issue_29_upsert(self): - db = connect() - execute_all(db , '') - db.executescript(""" + [json.dumps(data)], + ) + self.assertEqual( + cur.execute("select count(*) from with_training").fetchone()[0], 1000 + ) + db.commit() + self.assertEqual( + cur.execute("select count(*) from with_training").fetchone()[0], 1000 + ) + + def test_vss0_issue_29_upsert(self): + db = connect() + execute_all(db, "") + db.executescript( + """ create virtual table demo using vss0(a(2)); insert into demo(rowid, a) values (1, '[1.0, 2.0]'), (2, '[2.0, 3.0]'); - """) - - self.assertEqual( - execute_all(db, "select rowid, vector_debug(a) from demo;"), - [ - {'rowid': 1, 'vector_debug(a)': 'size: 2 [1.000000, 2.000000]'}, - {'rowid': 2, 'vector_debug(a)': 'size: 2 [2.000000, 3.000000]'} - ] - ) - - db.executescript(""" + """ + ) + + self.assertEqual( + execute_all(db, "select rowid, vector_debug(a) from demo;"), + [ + {"rowid": 1, "vector_debug(a)": "size: 2 [1.000000, 2.000000]"}, + {"rowid": 2, "vector_debug(a)": "size: 2 [2.000000, 3.000000]"}, + ], + ) + + db.executescript( + """ delete from demo where rowid = 1; insert into demo(rowid, a) select 1, '[99.0, 99.0]'; delete from demo where rowid = 2; insert into demo(rowid, a) select 2, '[299.0, 299.0]'; - """) - - # This is what used to fail, since we we're clearing deleted IDs properly - self.assertEqual( - execute_all(db, "select rowid, vector_debug(a) from demo;"), - [ - {'rowid': 1, 'vector_debug(a)': 'size: 2 [99.000000, 99.000000]'}, - {'rowid': 2, 'vector_debug(a)': 'size: 2 [299.000000, 299.000000]'} - ] - ) - db.close() - - # Make sure tha VACUUMing a database with vss0 tables still works as expected - def test_vss0_vacuum(self): - cur = db.cursor() - execute_all(cur, "create virtual table x using vss0(a(2));") - execute_all(cur, """ + """ + ) + + # This is what used to fail, since we we're clearing deleted IDs properly + self.assertEqual( + execute_all(db, "select rowid, vector_debug(a) from demo;"), + [ + {"rowid": 1, "vector_debug(a)": "size: 2 [99.000000, 99.000000]"}, + {"rowid": 2, "vector_debug(a)": "size: 2 [299.000000, 299.000000]"}, + ], + ) + db.close() + + # Make sure tha VACUUMing a database with vss0 tables still works as expected + def test_vss0_vacuum(self): + cur = db.cursor() + execute_all(cur, "create virtual table x using vss0(a(2));") + execute_all( + cur, + """ insert into x(rowid, a) select key + 1000, value from json_each(?); - """, [""" + """, + [ + """ [ [1, 1], [2, 2], [3, 3] ] - """]) - db.commit() - - db.execute("VACUUM;") - - self.assertEqual( - execute_all(db, "select rowid, distance from x where vss_search(a, vss_search_params(?, ?))", ['[0, 0]', 1]), - [{'distance': 2.0, 'rowid': 1000}] - ) - - def test_vss0_metric_type(self): - cur = db.cursor() - execute_all( - cur, - '''create virtual table vss_mts using vss0( + """ + ], + ) + db.commit() + + db.execute("VACUUM;") + + self.assertEqual( + execute_all( + db, + "select rowid, distance from x where vss_search(a, vss_search_params(?, ?))", + ["[0, 0]", 1], + ), + [{"distance": 2.0, "rowid": 1000}], + ) + + def test_vss0_metric_type(self): + cur = db.cursor() + execute_all( + cur, + """create virtual table vss_mts using vss0( ip(2) metric_type=INNER_PRODUCT, l1(2) metric_type=L1, l2(2) metric_type=L2, @@ -559,180 +913,242 @@ def test_vss0_metric_type(self): canberra(2) metric_type=Canberra, braycurtis(2) metric_type=BrayCurtis, jensenshannon(2) metric_type=JensenShannon - )''' - ) - idxs = list(map(lambda row: row[0], db.execute("select idx from vss_mts_index").fetchall())) - - # ensure all the indexes are IDMap2 ("IxM2") - for idx in idxs: - idx_type = idx[0:4] - self.assertEqual(idx_type, b"IxM2") - - # manually tested until i ended up at 33 ¯\_(ツ)_/¯ - METRIC_TYPE_OFFSET = 33 - - # values should match https://github.com/facebookresearch/faiss/blob/43d86e30736ede853c384b24667fc3ab897d6ba9/faiss/MetricType.h#L22 - self.assertEqual(idxs[0][METRIC_TYPE_OFFSET], 0) # ip - self.assertEqual(idxs[1][METRIC_TYPE_OFFSET], 2) # l1 - self.assertEqual(idxs[2][METRIC_TYPE_OFFSET], 1) # l2 - self.assertEqual(idxs[3][METRIC_TYPE_OFFSET], 3) # linf - #self.assertEqual(idxs[4][METRIC_TYPE_OFFSET], 4) # lp - self.assertEqual(idxs[4][METRIC_TYPE_OFFSET], 20) # canberra - self.assertEqual(idxs[5][METRIC_TYPE_OFFSET], 21) # braycurtis - self.assertEqual(idxs[6][METRIC_TYPE_OFFSET], 22) # jensenshannon - - - db.execute( - "insert into vss_mts(rowid, ip, l1, l2, linf, /*lp,*/ canberra, braycurtis, jensenshannon) values (1, ?1,?1,?1,?1, /*?1,*/ ?1,?1,?2)", - ["[4,1]", "[0.8, 0.2]"] - ) - db.commit() - - def distance_of(metric_type, query): - return db.execute( - f"select distance from vss_mts where vss_search({metric_type}, vss_search_params(?1, 1))", - [query] - ).fetchone()[0] - - self.assertEqual(distance_of("ip", "[0,0]"), 0.0) - self.assertEqual(distance_of("l1", "[0,0]"), 5.0) - self.assertEqual(distance_of("l2", "[0,0]"), 17.0) - self.assertEqual(distance_of("linf", "[0,0]"), 4.0) - #self.assertEqual(distance_of("lp", "[0,0]"), 2.0) - self.assertEqual(distance_of("canberra", "[0,0]"), 2.0) - self.assertEqual(distance_of("braycurtis", "[0,0]"), 1.0) - - # JS distance assumes L1 normalized input (a valid probability distribution) - # additionally, faiss actually computes JS divergence and not JS distance - self.assertAlmostEqual(distance_of("jensenshannon", "[0.33333333, 0.66666667]"), 0.1157735) - - self.assertEqual(distance_of("ip", "[2,2]"), 10.0) + )""", + ) + idxs = list( + map( + lambda row: row[0], + db.execute("select idx from vss_mts_index").fetchall(), + ) + ) + + # ensure all the indexes are IDMap2 ("IxM2") + for idx in idxs: + idx_type = idx[0:4] + self.assertEqual(idx_type, b"IxM2") + + # manually tested until i ended up at 33 ¯\_(ツ)_/¯ + METRIC_TYPE_OFFSET = 33 + + # values should match https://github.com/facebookresearch/faiss/blob/43d86e30736ede853c384b24667fc3ab897d6ba9/faiss/MetricType.h#L22 + self.assertEqual(idxs[0][METRIC_TYPE_OFFSET], 0) # ip + self.assertEqual(idxs[1][METRIC_TYPE_OFFSET], 2) # l1 + self.assertEqual(idxs[2][METRIC_TYPE_OFFSET], 1) # l2 + self.assertEqual(idxs[3][METRIC_TYPE_OFFSET], 3) # linf + # self.assertEqual(idxs[4][METRIC_TYPE_OFFSET], 4) # lp + self.assertEqual(idxs[4][METRIC_TYPE_OFFSET], 20) # canberra + self.assertEqual(idxs[5][METRIC_TYPE_OFFSET], 21) # braycurtis + self.assertEqual(idxs[6][METRIC_TYPE_OFFSET], 22) # jensenshannon + + db.execute( + "insert into vss_mts(rowid, ip, l1, l2, linf, /*lp,*/ canberra, braycurtis, jensenshannon) values (1, ?1,?1,?1,?1, /*?1,*/ ?1,?1,?2)", + ["[4,1]", "[0.8, 0.2]"], + ) + db.commit() + + def distance_of(metric_type, query): + return db.execute( + f"select distance from vss_mts where vss_search({metric_type}, vss_search_params(?1, 1))", + [query], + ).fetchone()[0] + + self.assertEqual(distance_of("ip", "[0,0]"), 0.0) + self.assertEqual(distance_of("l1", "[0,0]"), 5.0) + self.assertEqual(distance_of("l2", "[0,0]"), 17.0) + self.assertEqual(distance_of("linf", "[0,0]"), 4.0) + # self.assertEqual(distance_of("lp", "[0,0]"), 2.0) + self.assertEqual(distance_of("canberra", "[0,0]"), 2.0) + self.assertEqual(distance_of("braycurtis", "[0,0]"), 1.0) + + # JS distance assumes L1 normalized input (a valid probability distribution) + # additionally, faiss actually computes JS divergence and not JS distance + self.assertAlmostEqual( + distance_of("jensenshannon", "[0.33333333, 0.66666667]"), 0.1157735 + ) + + self.assertEqual(distance_of("ip", "[2,2]"), 10.0) VECTOR_FUNCTIONS = [ - 'vector0', - 'vector_debug', - 'vector_from_blob', - 'vector_from_json', - 'vector_from_raw', - 'vector_length', - 'vector_to_blob', - 'vector_to_json', - 'vector_to_raw', - 'vector_value_at', - 'vector_version', + "vector0", + "vector_debug", + "vector_from_blob", + "vector_from_json", + "vector_from_raw", + "vector_length", + "vector_to_blob", + "vector_to_json", + "vector_to_raw", + "vector_value_at", + "vector_version", ] -VECTOR_MODULES = ['vector_fvecs_each'] -class TestVector(unittest.TestCase): - def test_funcs(self): - funcs = list(map(lambda a: a[0], db.execute("select name from vector_loaded_functions").fetchall())) - self.assertEqual(funcs, VECTOR_FUNCTIONS) - - def test_modules(self): - modules = list(map(lambda a: a[0], db.execute("select name from vector_loaded_modules").fetchall())) - self.assertEqual(modules, VECTOR_MODULES) +VECTOR_MODULES = ["vector_fvecs_each"] - def test_vector_version(self): - self.assertEqual(db.execute("select vector_version()").fetchone()[0][0], "v") - - def test_vector_debug(self): - self.assertEqual( - db.execute("select vector_debug(json('[]'))").fetchone()[0], - "size: 0 []" - ) - with self.assertRaisesRegex(sqlite3.OperationalError, "Value not a vector"): - db.execute("select vector_debug(']')").fetchone() +class TestVector(unittest.TestCase): + def test_funcs(self): + funcs = list( + map( + lambda a: a[0], + db.execute("select name from vector_loaded_functions").fetchall(), + ) + ) + self.assertEqual(funcs, VECTOR_FUNCTIONS) + + def test_modules(self): + modules = list( + map( + lambda a: a[0], + db.execute("select name from vector_loaded_modules").fetchall(), + ) + ) + self.assertEqual(modules, VECTOR_MODULES) + + def test_vector_version(self): + self.assertEqual(db.execute("select vector_version()").fetchone()[0][0], "v") + + def test_vector_debug(self): + self.assertEqual( + db.execute("select vector_debug(json('[]'))").fetchone()[0], "size: 0 []" + ) + with self.assertRaisesRegex(sqlite3.OperationalError, "Value not a vector"): + db.execute("select vector_debug(']')").fetchone() + + def test_vector0(self): + self.assertEqual(db.execute("select vector0(null)").fetchone()[0], None) + + def test_vector_from_blob(self): + self.assertEqual( + db.execute( + "select vector_debug(vector_from_blob(vector_to_blob(vector_from_json(?))))", + ["[0.1,0.2]"], + ).fetchone()[0], + "size: 2 [0.100000, 0.200000]", + ) + + def raises_small_blob_header(input): + with self.assertRaisesRegex( + sqlite3.OperationalError, "Vector blob size less than header length" + ): + db.execute("select vector_from_blob(?)", [input]).fetchall() + + raises_small_blob_header(b"") + raises_small_blob_header(b"v") + + with self.assertRaisesRegex( + sqlite3.OperationalError, "Blob not well-formatted vector blob" + ): + db.execute( + "select vector_from_blob(?)", [b"V\x01\x00\x00\x00\x00"] + ).fetchall() + + with self.assertRaisesRegex(sqlite3.OperationalError, "Blob type not right"): + db.execute( + "select vector_from_blob(?)", [b"v\x00\x00\x00\x00\x00"] + ).fetchall() + + def test_vector_to_blob(self): + vector_to_blob = lambda x: db.execute( + "select vector_to_blob(vector_from_json(json(?)))", [x] + ).fetchone()[0] + self.assertEqual(vector_to_blob("[]"), b"v\x01") + self.assertEqual(vector_to_blob("[0.1]"), b"v\x01\xcd\xcc\xcc=") + self.assertEqual( + vector_to_blob("[0.1, 0]"), b"v\x01\xcd\xcc\xcc=\x00\x00\x00\x00" + ) + + def test_vector_to_raw(self): + vector_to_raw = lambda x: db.execute( + "select vector_to_raw(vector_from_json(json(?)))", [x] + ).fetchone()[0] + self.assertEqual(vector_to_raw("[]"), None) # TODO why not b"" + self.assertEqual(vector_to_raw("[0.1]"), b"\xcd\xcc\xcc=") + self.assertEqual(vector_to_raw("[0.1, 0]"), b"\xcd\xcc\xcc=\x00\x00\x00\x00") + + def test_vector_from_raw(self): + vector_from_raw_blob = lambda x: db.execute( + "select vector_debug(vector_from_raw(?))", [x] + ).fetchone()[0] + self.assertEqual(vector_from_raw_blob(b""), "size: 0 []") + self.assertEqual( + vector_from_raw_blob(b"\x00\x00\x00\x00"), "size: 1 [0.000000]" + ) + self.assertEqual( + vector_from_raw_blob(b"\x00\x00\x00\x00\xcd\xcc\xcc="), + "size: 2 [0.000000, 0.100000]", + ) + + with self.assertRaisesRegex( + sqlite3.OperationalError, + "Invalid raw blob length, blob must be divisible by 4", + ): + vector_from_raw_blob(b"abc") + + def test_vector_from_json(self): + vector_from_json = lambda x: db.execute( + "select vector_debug(vector_from_json(?))", [x] + ).fetchone()[0] + self.assertEqual( + vector_from_json("[0.1, 0.2, 0.3]"), + "size: 3 [0.100000, 0.200000, 0.300000]", + ) + self.assertEqual(vector_from_json("[]"), "size: 0 []") + with self.assertRaisesRegex( + sqlite3.OperationalError, "input not valid json, or contains non-float data" + ): + vector_from_json("") + # db.execute("select vector_from_json(?)", [""]).fetchone()[0] + + def test_vector_to_json(self): + vector_to_json = lambda x: db.execute( + "select vector_debug(vector_to_json(vector_from_json(json(?))))", [x] + ).fetchone()[0] + self.assertEqual( + vector_to_json("[0.1, 0.2, 0.3]"), "size: 3 [0.100000, 0.200000, 0.300000]" + ) + + def test_vector_length(self): + vector_length = lambda x: db.execute( + "select vector_length(vector_from_json(json(?)))", [x] + ).fetchone()[0] + self.assertEqual(vector_length("[0.1, 0.2, 0.3]"), 3) + self.assertEqual(vector_length("[0.1]"), 1) + self.assertEqual(vector_length("[]"), 0) + + def test_vector_value_at(self): + vector_value_at = lambda x, y: db.execute( + "select vector_value_at(vector_from_json(json(?)), ?)", [x, y] + ).fetchone()[0] + self.assertAlmostEqual(vector_value_at("[0.1, 0.2, 0.3]", 0), 0.1) + self.assertAlmostEqual(vector_value_at("[0.1, 0.2, 0.3]", 1), 0.2) + self.assertAlmostEqual(vector_value_at("[0.1, 0.2, 0.3]", 2), 0.3) + # self.assertAlmostEqual(vector_value_at('[0.1, 0.2, 0.3]', 3), 0.3) - def test_vector0(self): - self.assertEqual(db.execute("select vector0(null)").fetchone()[0], None) - def test_vector_from_blob(self): - self.assertEqual( - db.execute("select vector_debug(vector_from_blob(vector_to_blob(vector_from_json(?))))", ["[0.1,0.2]"]).fetchone()[0], - "size: 2 [0.100000, 0.200000]" - ) - def raises_small_blob_header(input): - with self.assertRaisesRegex(sqlite3.OperationalError, "Vector blob size less than header length"): - db.execute("select vector_from_blob(?)", [input]).fetchall() - - raises_small_blob_header(b"") - raises_small_blob_header(b"v") - - with self.assertRaisesRegex(sqlite3.OperationalError, "Blob not well-formatted vector blob"): - db.execute("select vector_from_blob(?)", [b"V\x01\x00\x00\x00\x00"]).fetchall() - - with self.assertRaisesRegex(sqlite3.OperationalError, "Blob type not right"): - db.execute("select vector_from_blob(?)", [b"v\x00\x00\x00\x00\x00"]).fetchall() - - def test_vector_to_blob(self): - vector_to_blob = lambda x: db.execute("select vector_to_blob(vector_from_json(json(?)))", [x]).fetchone()[0] - self.assertEqual(vector_to_blob("[]"), b"v\x01") - self.assertEqual(vector_to_blob("[0.1]"), b"v\x01\xcd\xcc\xcc=") - self.assertEqual(vector_to_blob("[0.1, 0]"), b"v\x01\xcd\xcc\xcc=\x00\x00\x00\x00") - - def test_vector_to_raw(self): - vector_to_raw = lambda x: db.execute("select vector_to_raw(vector_from_json(json(?)))", [x]).fetchone()[0] - self.assertEqual(vector_to_raw("[]"), None) # TODO why not b"" - self.assertEqual(vector_to_raw("[0.1]"), b"\xcd\xcc\xcc=") - self.assertEqual(vector_to_raw("[0.1, 0]"), b"\xcd\xcc\xcc=\x00\x00\x00\x00") - - def test_vector_from_raw(self): - vector_from_raw_blob = lambda x: db.execute("select vector_debug(vector_from_raw(?))", [x]).fetchone()[0] - self.assertEqual( - vector_from_raw_blob(b""), - 'size: 0 []' - ) - self.assertEqual( - vector_from_raw_blob(b"\x00\x00\x00\x00"), - 'size: 1 [0.000000]' - ) - self.assertEqual( - vector_from_raw_blob(b"\x00\x00\x00\x00\xcd\xcc\xcc="), - 'size: 2 [0.000000, 0.100000]' - ) +class TestCoverage(unittest.TestCase): + def test_coverage_vss(self): + test_methods = [ + method for method in dir(TestVss) if method.startswith("test_vss") + ] + funcs_with_tests = set([x.replace("test_", "") for x in test_methods]) + for func in VSS_FUNCTIONS: + self.assertTrue( + func in funcs_with_tests, + f"{func} does not have cooresponding test in {funcs_with_tests}", + ) + + def test_coverage_vector(self): + test_methods = [ + method for method in dir(TestVector) if method.startswith("test_vector") + ] + funcs_with_tests = set([x.replace("test_", "") for x in test_methods]) + for func in VECTOR_FUNCTIONS: + self.assertTrue( + func in funcs_with_tests, + f"{func} does not have cooresponding test in {funcs_with_tests}", + ) - with self.assertRaisesRegex(sqlite3.OperationalError, "Invalid raw blob length, blob must be divisible by 4"): - vector_from_raw_blob(b"abc") - - def test_vector_from_json(self): - vector_from_json = lambda x: db.execute("select vector_debug(vector_from_json(?))", [x]).fetchone()[0] - self.assertEqual(vector_from_json('[0.1, 0.2, 0.3]'), "size: 3 [0.100000, 0.200000, 0.300000]") - self.assertEqual(vector_from_json('[]'), "size: 0 []") - with self.assertRaisesRegex(sqlite3.OperationalError, "input not valid json, or contains non-float data"): - vector_from_json('') - #db.execute("select vector_from_json(?)", [""]).fetchone()[0] - - def test_vector_to_json(self): - vector_to_json = lambda x: db.execute("select vector_debug(vector_to_json(vector_from_json(json(?))))", [x]).fetchone()[0] - self.assertEqual(vector_to_json('[0.1, 0.2, 0.3]'), "size: 3 [0.100000, 0.200000, 0.300000]") - - def test_vector_length(self): - vector_length = lambda x: db.execute("select vector_length(vector_from_json(json(?)))", [x]).fetchone()[0] - self.assertEqual(vector_length('[0.1, 0.2, 0.3]'), 3) - self.assertEqual(vector_length('[0.1]'), 1) - self.assertEqual(vector_length('[]'), 0) - - def test_vector_value_at(self): - vector_value_at = lambda x, y: db.execute("select vector_value_at(vector_from_json(json(?)), ?)", [x, y]).fetchone()[0] - self.assertAlmostEqual(vector_value_at('[0.1, 0.2, 0.3]', 0), 0.1) - self.assertAlmostEqual(vector_value_at('[0.1, 0.2, 0.3]', 1), 0.2) - self.assertAlmostEqual(vector_value_at('[0.1, 0.2, 0.3]', 2), 0.3) - #self.assertAlmostEqual(vector_value_at('[0.1, 0.2, 0.3]', 3), 0.3) -class TestCoverage(unittest.TestCase): - def test_coverage_vss(self): - test_methods = [method for method in dir(TestVss) if method.startswith('test_vss')] - funcs_with_tests = set([x.replace("test_", "") for x in test_methods]) - for func in VSS_FUNCTIONS: - self.assertTrue(func in funcs_with_tests, f"{func} does not have cooresponding test in {funcs_with_tests}") - - def test_coverage_vector(self): - test_methods = [method for method in dir(TestVector) if method.startswith('test_vector')] - funcs_with_tests = set([x.replace("test_", "") for x in test_methods]) - for func in VECTOR_FUNCTIONS: - self.assertTrue(func in funcs_with_tests, f"{func} does not have cooresponding test in {funcs_with_tests}") - -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()