diff --git a/.github/workflows/ci-dgraph-vector-tests.yml b/.github/workflows/ci-dgraph-vector-tests.yml index 94e6cbc1139..d1650870924 100644 --- a/.github/workflows/ci-dgraph-vector-tests.yml +++ b/.github/workflows/ci-dgraph-vector-tests.yml @@ -26,7 +26,7 @@ jobs: dgraph-vector-tests: if: github.event.pull_request.draft == false runs-on: warp-ubuntu-latest-x64-4x - timeout-minutes: 30 + timeout-minutes: 120 steps: - uses: actions/checkout@v5 - name: Set up Go diff --git a/.vscode/launch.json b/.vscode/launch.json index f4b79b21747..d703d5d6911 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -15,7 +15,7 @@ "--security", "whitelist=0.0.0.0/0;" ], - "showLog": true + "showLog": false }, { "name": "Zero", @@ -25,7 +25,7 @@ "program": "${workspaceRoot}/dgraph/", "env": {}, "args": ["zero"], - "showLog": true + "showLog": false }, { "name": "AlphaACL", diff --git a/posting/index.go b/posting/index.go index f17d5e6a57a..f30831b0b8b 100644 --- a/posting/index.go +++ b/posting/index.go @@ -33,6 +33,9 @@ import ( "github.com/hypermodeinc/dgraph/v25/schema" "github.com/hypermodeinc/dgraph/v25/tok" "github.com/hypermodeinc/dgraph/v25/tok/hnsw" + tokIndex "github.com/hypermodeinc/dgraph/v25/tok/index" + "github.com/hypermodeinc/dgraph/v25/tok/kmeans" + "github.com/hypermodeinc/dgraph/v25/types" "github.com/hypermodeinc/dgraph/v25/x" ) @@ -1412,6 +1415,284 @@ func (rb *indexRebuildInfo) prefixesForTokIndexes() ([][]byte, error) { return prefixes, nil } +func rebuildVectorIndex(ctx context.Context, factorySpecs []*tok.FactoryCreateSpec, rb *IndexRebuild) error { + pk := x.ParsedKey{Attr: rb.Attr} + + indexer, err := factorySpecs[0].CreateIndex(pk.Attr) + if err != nil { + return err + } + + dimension := indexer.Dimension() + // If dimension is -1, it means that the dimension is not set through options in case of partitioned hnsw. + if dimension == -1 { + numVectorsToCheck := 100 + lenFreq := make(map[int]int, numVectorsToCheck) + maxFreq := 0 + MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{ + Prefix: pk.DataPrefix(), + ReadTs: rb.StartTs, + AllVersions: false, + Reverse: false, + CheckInclusion: func(uid uint64) error { + return nil + }, + Function: func(l *List, pk x.ParsedKey) error { + val, err := l.Value(rb.StartTs) + if err != nil { + return err + } + inVec := types.BytesAsFloatArray(val.Value.([]byte)) + lenFreq[len(inVec)] += 1 + if lenFreq[len(inVec)] > maxFreq { + maxFreq = lenFreq[len(inVec)] + dimension = len(inVec) + } + numVectorsToCheck -= 1 + if numVectorsToCheck <= 0 { + return ErrStopIteration + } + return nil + }, + StartKey: x.DataKey(rb.Attr, 0), + }) + + indexer.SetDimension(rb.CurrentSchema, dimension) + } + + fmt.Println("Selecting vector dimension to be:", dimension) + + norm := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} + norm.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { + val, err := pl.Value(rb.StartTs) + if err != nil { + return nil, err + } + if val.Tid == types.VFloatID { + return nil, nil + } + + // Convert to VFloatID and persist as binary bytes. + sv, err := types.Convert(val, types.VFloatID) + if err != nil { + return nil, err + } + b := types.ValueForType(types.BinaryID) + if err = types.Marshal(sv, &b); err != nil { + return nil, err + } + + edge := &pb.DirectedEdge{ + Attr: rb.Attr, + Entity: uid, + Value: b.Value.([]byte), + ValueType: types.VFloatID.Enum(), + } + inKey := x.DataKey(edge.Attr, uid) + p, err := txn.Get(inKey) + if err != nil { + return []*pb.DirectedEdge{}, err + } + + if err := p.addMutation(ctx, txn, edge); err != nil { + return []*pb.DirectedEdge{}, err + } + return nil, nil + } + + if err := norm.RunWithoutTemp(ctx); err != nil { + return err + } + + count := 0 + + if indexer.NumSeedVectors() > 0 { + err := MemLayerInstance.IterateDisk(ctx, IterateDiskArgs{ + Prefix: pk.DataPrefix(), + ReadTs: rb.StartTs, + AllVersions: false, + Reverse: false, + CheckInclusion: func(uid uint64) error { + return nil + }, + Function: func(l *List, pk x.ParsedKey) error { + val, err := l.Value(rb.StartTs) + if err != nil { + return err + } + + if val.Tid != types.VFloatID { + // Here, we convert the defaultID type vector into vfloat. + sv, err := types.Convert(val, types.VFloatID) + if err != nil { + return err + } + b := types.ValueForType(types.BinaryID) + if err = types.Marshal(sv, &b); err != nil { + return err + } + + val.Value = b.Value + val.Tid = types.VFloatID + } + + inVec := types.BytesAsFloatArray(val.Value.([]byte)) + if len(inVec) != dimension { + return fmt.Errorf("vector dimension mismatch expected dimension %d but got %d", dimension, len(inVec)) + } + count += 1 + indexer.AddSeedVector(inVec) + if count == indexer.NumSeedVectors() { + return ErrStopIteration + } + return nil + }, + StartKey: x.DataKey(rb.Attr, 0), + }) + if err != nil { + return err + } + } + + txns := make([]*Txn, indexer.NumThreads()) + for i := range txns { + txns[i] = NewTxn(rb.StartTs) + } + caches := make([]tokIndex.CacheType, indexer.NumThreads()) + for i := range caches { + caches[i] = hnsw.NewTxnCache(NewViTxn(txns[i]), rb.StartTs) + } + + if count < indexer.NumSeedVectors() { + indexer.SetNumPasses(0) + } + + for pass_idx := range indexer.NumBuildPasses() { + fmt.Println("Building pass", pass_idx) + + indexer.StartBuild(caches) + + builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} + builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { + val, err := pl.Value(rb.StartTs) + if err != nil { + return []*pb.DirectedEdge{}, err + } + + inVec := types.BytesAsFloatArray(val.Value.([]byte)) + if len(inVec) != dimension { + return []*pb.DirectedEdge{}, nil + } + indexer.BuildInsert(ctx, uid, inVec) + return []*pb.DirectedEdge{}, nil + } + + err := builder.RunWithoutTemp(ctx) + if err != nil { + return err + } + + indexer.EndBuild() + } + + centroids := indexer.GetCentroids() + + if centroids != nil { + txn := NewTxn(rb.StartTs) + + bCentroids, err := json.Marshal(centroids) + if err != nil { + return err + } + + if err := addCentroidInDB(ctx, rb.Attr, bCentroids, txn); err != nil { + return err + } + txn.Update() + writer := NewTxnWriter(pstore) + if err := txn.CommitToDisk(writer, rb.StartTs); err != nil { + return err + } + } + + numIndexPasses := indexer.NumIndexPasses() + + if count < indexer.NumSeedVectors() { + numIndexPasses = 1 + } + + for pass_idx := range numIndexPasses { + fmt.Println("Indexing pass", pass_idx) + + indexer.StartBuild(caches) + + builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} + builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { + val, err := pl.Value(rb.StartTs) + if err != nil { + return []*pb.DirectedEdge{}, err + } + + inVec := types.BytesAsFloatArray(val.Value.([]byte)) + if len(inVec) != dimension && centroids != nil { + if pass_idx == 0 { + glog.Warningf("Skipping vector with invalid dimension uid: %d, dimension: %d", uid, len(inVec)) + } + return []*pb.DirectedEdge{}, nil + } + + indexer.BuildInsert(ctx, uid, inVec) + + return []*pb.DirectedEdge{}, nil + } + + err := builder.RunWithoutTemp(ctx) + if err != nil { + return err + } + + for _, idx := range indexer.EndBuild() { + txns[idx].Update() + writer := NewTxnWriter(pstore) + + x.ExponentialRetry(int(x.Config.MaxRetries), + 20*time.Millisecond, func() error { + err := txns[idx].CommitToDisk(writer, rb.StartTs) + if err == badger.ErrBannedKey { + glog.Errorf("Error while writing to banned namespace.") + return nil + } + return err + }) + + txns[idx].cache.plists = nil + txns[idx] = nil + } + } + + return nil +} + +func addCentroidInDB(ctx context.Context, attr string, vec []byte, txn *Txn) error { + indexCountAttr := hnsw.ConcatStrings(attr, kmeans.CentroidPrefix) + countKey := x.DataKey(indexCountAttr, 1) + pl, err := txn.Get(countKey) + if err != nil { + return err + } + + edge := &pb.DirectedEdge{ + Entity: 1, + Attr: indexCountAttr, + Value: vec, + ValueType: pb.Posting_ValType(12), + } + if err := pl.addMutation(ctx, txn, edge); err != nil { + return err + } + return nil +} + // rebuildTokIndex rebuilds index for a given attribute. // We commit mutations with startTs and ignore the errors. func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { @@ -1443,6 +1724,9 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { } runForVectors := (len(factorySpecs) != 0) + if runForVectors { + return rebuildVectorIndex(ctx, factorySpecs, rb) + } pk := x.ParsedKey{Attr: rb.Attr} builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} diff --git a/systest/vector/backup_test.go b/systest/vector/backup_test.go index 2effaa171be..020ed5b0026 100644 --- a/systest/vector/backup_test.go +++ b/systest/vector/backup_test.go @@ -12,7 +12,6 @@ import ( "fmt" "slices" "strings" - "testing" "time" "github.com/stretchr/testify/require" @@ -23,7 +22,12 @@ import ( "github.com/hypermodeinc/dgraph/v25/x" ) -func TestVectorIncrBackupRestore(t *testing.T) { +func (vsuite *VectorTestSuite) TestVectorIncrBackupRestore() { + t := vsuite.T() + if vsuite.isForPartitionedIndex { + t.Skip("Skipping TestVectorIncrBackupRestore for partitioned index") + } + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) c, err := dgraphtest.NewLocalCluster(conf) require.NoError(t, err) @@ -41,21 +45,19 @@ func TestVectorIncrBackupRestore(t *testing.T) { require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.RootNamespace)) - require.NoError(t, gc.SetupSchema(testSchema)) - - numVectors := 500 - pred := "project_description_v" + numVectors := 1500 allVectors := make([][][]float32, 0, 5) allRdfs := make([]string, 0, 5) for i := 1; i <= 5; i++ { var rdfs string var vectors [][]float32 - rdfs, vectors = dgraphapi.GenerateRandomVectors(numVectors*(i-1), numVectors*i, 1, pred) + rdfs, vectors = dgraphapi.GenerateRandomVectors(numVectors*(i-1), numVectors*i, 10, pred) allVectors = append(allVectors, vectors) allRdfs = append(allRdfs, rdfs) mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} _, err := gc.Mutate(mu) require.NoError(t, err) + require.NoError(t, gc.SetupSchema(vsuite.schemaVecDimesion10)) t.Logf("taking backup #%v\n", i) require.NoError(t, hc.Backup(c, i == 1, dgraphtest.DefaultBackupDir)) @@ -77,10 +79,8 @@ func TestVectorIncrBackupRestore(t *testing.T) { require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors*i), string(result.GetJson())) var allSpredVec [][]float32 - for i, vecArr := range allVectors { - if i <= i { - allSpredVec = append(allSpredVec, vecArr...) - } + for _, vecArr := range allVectors { + allSpredVec = append(allSpredVec, vecArr...) } for p, vector := range allVectors[i-1] { triple := strings.Split(allRdfs[i-1], "\n")[p] @@ -89,7 +89,6 @@ func TestVectorIncrBackupRestore(t *testing.T) { require.NoError(t, err) require.Equal(t, allVectors[i-1][p], queriedVector[0]) - similarVectors, err := gc.QueryMultipleVectorsUsingSimilarTo(vector, pred, numVectors) require.NoError(t, err) require.GreaterOrEqual(t, len(similarVectors), 10) @@ -100,7 +99,8 @@ func TestVectorIncrBackupRestore(t *testing.T) { } } -func TestVectorBackupRestore(t *testing.T) { +func (vsuite *VectorTestSuite) TestVectorBackupRestore() { + t := vsuite.T() conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) c, err := dgraphtest.NewLocalCluster(conf) require.NoError(t, err) @@ -118,15 +118,13 @@ func TestVectorBackupRestore(t *testing.T) { require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.RootNamespace)) - require.NoError(t, gc.SetupSchema(testSchema)) - - numVectors := 1000 - pred := "project_description_v" - rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 10, pred) + numVectors := 1001 + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} _, err = gc.Mutate(mu) require.NoError(t, err) + require.NoError(t, gc.SetupSchema(vsuite.schema)) t.Log("taking backup \n") require.NoError(t, hc.Backup(c, false, dgraphtest.DefaultBackupDir)) @@ -135,10 +133,22 @@ func TestVectorBackupRestore(t *testing.T) { require.NoError(t, hc.Restore(c, dgraphtest.DefaultBackupDir, "", 0, 0)) require.NoError(t, dgraphapi.WaitForRestore(c)) - testVectorQuery(t, gc, vectors, rdfs, pred, numVectors) + for _, vector := range vectors { + similarVectors, err := gc.QueryMultipleVectorsUsingSimilarTo(vector, pred, 100) + require.NoError(t, err) + require.GreaterOrEqual(t, len(similarVectors), 100) + for _, similarVector := range similarVectors { + require.Contains(t, vectors, similarVector) + } + } } -func TestVectorBackupRestoreDropIndex(t *testing.T) { +func (vsuite *VectorTestSuite) TestVectorBackupRestoreDropIndex() { + + t := vsuite.T() + if vsuite.isForPartitionedIndex { + t.Skip("Skipping TestVectorBackupRestoreDropIndex for partitioned index") + } // setup cluster conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) c, err := dgraphtest.NewLocalCluster(conf) @@ -158,11 +168,11 @@ func TestVectorBackupRestoreDropIndex(t *testing.T) { dgraphapi.DefaultPassword, x.RootNamespace)) // add vector predicate + index - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) // add data to the vector predicate - numVectors := 3 + numVectors := 1000 pred := "project_description_v" - rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 1, pred) + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} _, err = gc.Mutate(mu) require.NoError(t, err) @@ -174,7 +184,7 @@ func TestVectorBackupRestoreDropIndex(t *testing.T) { require.NoError(t, gc.SetupSchema(testSchemaWithoutIndex)) // add more data to the vector predicate - rdfs, vectors2 := dgraphapi.GenerateRandomVectors(3, numVectors+3, 1, pred) + rdfs, vectors2 := dgraphapi.GenerateRandomVectors(numVectors, numVectors+3, 100, pred) mu = &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} _, err = gc.Mutate(mu) require.NoError(t, err) @@ -195,7 +205,7 @@ func TestVectorBackupRestoreDropIndex(t *testing.T) { require.NoError(t, hc.Backup(c, false, dgraphtest.DefaultBackupDir)) // add index - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) t.Log("taking second incr backup \n") require.NoError(t, hc.Backup(c, false, dgraphtest.DefaultBackupDir)) @@ -212,7 +222,7 @@ func TestVectorBackupRestoreDropIndex(t *testing.T) { }` resp, err := gc.Query(query) require.NoError(t, err) - require.JSONEq(t, `{"vectors":[{"count":4}]}`, string(resp.GetJson())) + require.JSONEq(t, `{"vectors":[{"count":1001}]}`, string(resp.GetJson())) require.NoError(t, err) allVec := append(vectors, vectors2...) @@ -227,7 +237,11 @@ func TestVectorBackupRestoreDropIndex(t *testing.T) { } } -func TestVectorBackupRestoreReIndexing(t *testing.T) { +func (vsuite *VectorTestSuite) TestVectorBackupRestoreReIndexing() { + t := vsuite.T() + if vsuite.isForPartitionedIndex { + t.Skip("Skipping TestVectorBackupRestoreReIndexing for partitioned index") + } conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) c, err := dgraphtest.NewLocalCluster(conf) require.NoError(t, err) @@ -245,11 +259,11 @@ func TestVectorBackupRestoreReIndexing(t *testing.T) { require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.RootNamespace)) - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) numVectors := 1000 pred := "project_description_v" - rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 10, pred) + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} _, err = gc.Mutate(mu) @@ -258,7 +272,7 @@ func TestVectorBackupRestoreReIndexing(t *testing.T) { t.Log("taking backup \n") require.NoError(t, hc.Backup(c, false, dgraphtest.DefaultBackupDir)) - rdfs2, vectors2 := dgraphapi.GenerateRandomVectors(numVectors, numVectors+300, 10, pred) + rdfs2, vectors2 := dgraphapi.GenerateRandomVectors(numVectors, numVectors+300, 100, pred) mu = &api.Mutation{SetNquads: []byte(rdfs2), CommitNow: true} _, err = gc.Mutate(mu) @@ -271,7 +285,7 @@ func TestVectorBackupRestoreReIndexing(t *testing.T) { // drop index require.NoError(t, gc.SetupSchema(testSchemaWithoutIndex)) // add index - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) } vectors = append(vectors, vectors2...) rdfs = rdfs + rdfs2 diff --git a/systest/vector/load_test.go b/systest/vector/load_test.go index 8bb83e9ecb4..e521e4f21ab 100644 --- a/systest/vector/load_test.go +++ b/systest/vector/load_test.go @@ -27,17 +27,18 @@ type Node struct { Vtest []float32 `json:"vtest"` } -func TestLiveLoadAndExportRDFFormat(t *testing.T) { +func (vsuite *VectorTestSuite) TestLiveLoadAndExportRDFFormat() { + t := vsuite.T() conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) c, err := dgraphtest.NewLocalCluster(conf) require.NoError(t, err) defer func() { c.Cleanup(t.Failed()) }() require.NoError(t, c.Start()) - testExportAndLiveLoad(t, c, "rdf") + testExportAndLiveLoad(t, c, "rdf", vsuite.schemaVecDimesion10) } -func testExportAndLiveLoad(t *testing.T, c *dgraphtest.LocalCluster, exportFormat string) { +func testExportAndLiveLoad(t *testing.T, c *dgraphtest.LocalCluster, exportFormat string, schema string) { gc, cleanup, err := c.Client() require.NoError(t, err) defer cleanup() @@ -49,9 +50,9 @@ func testExportAndLiveLoad(t *testing.T, c *dgraphtest.LocalCluster, exportForma require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.RootNamespace)) - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(schema)) - numVectors := 100 + numVectors := 1000 pred := "project_description_v" rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 10, pred) diff --git a/systest/vector/vector_test.go b/systest/vector/vector_test.go index 5fd0b991d68..92aeff24b31 100644 --- a/systest/vector/vector_test.go +++ b/systest/vector/vector_test.go @@ -9,12 +9,14 @@ package main import ( "context" + "errors" "fmt" "strings" "testing" "time" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/dgraph-io/dgo/v250/protos/api" "github.com/hypermodeinc/dgraph/v25/dgraphapi" @@ -23,11 +25,16 @@ import ( ) const ( - testSchema = `project_description_v: float32vector @index(hnsw(exponent: "5", metric: "euclidean")) .` testSchemaWithoutIndex = `project_description_v: float32vector .` pred = "project_description_v" + schemaVecDimension10 = `project_description_v: float32vector @index(partionedhnsw(numClusters: "1000", partitionStratOpt: "kmeans", vectorDimension: "10", metric: "euclidean")) .` ) +var schemas = map[string]string{ + "hnsw": `project_description_v: float32vector @index(hnsw(exponent: "5", metric: "euclidean")) .`, + "partitionedhnsw": `project_description_v: float32vector @index(partionedhnsw(numClusters: "1000", partitionStratOpt: "kmeans", vectorDimension: "100", metric: "euclidean")) .`, +} + func testVectorQuery(t *testing.T, gc *dgraphapi.GrpcClient, vectors [][]float32, rdfs, pred string, topk int) { for i, vector := range vectors { triple := strings.Split(rdfs, "\n")[i] @@ -44,7 +51,11 @@ func testVectorQuery(t *testing.T, gc *dgraphapi.GrpcClient, vectors [][]float32 } } -func TestVectorDropAll(t *testing.T) { +func (vsuite *VectorTestSuite) TestVectorDropAll() { + t := vsuite.T() + if vsuite.isForPartitionedIndex { + t.Skip("Skipping TestVectorDropAll for partitioned index") + } conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) c, err := dgraphtest.NewLocalCluster(conf) require.NoError(t, err) @@ -62,7 +73,7 @@ func TestVectorDropAll(t *testing.T) { require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.RootNamespace)) - numVectors := 100 + numVectors := 10 testVectorSimilarTo := func(vectors [][]float32) { for _, vector := range vectors { @@ -73,7 +84,7 @@ func TestVectorDropAll(t *testing.T) { } for i := 0; i < 10; i++ { - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} _, err = gc.Mutate(mu) @@ -100,7 +111,11 @@ func TestVectorDropAll(t *testing.T) { } } -func TestVectorSnapshot(t *testing.T) { +func (vsuite *VectorTestSuite) TestVectorSnapshot() { + t := vsuite.T() + if vsuite.isForPartitionedIndex { + t.Skip("Skipping TestVectorSnapshot for partitioned index") + } conf := dgraphtest.NewClusterConfig().WithNumAlphas(3).WithNumZeros(3).WithReplicas(3).WithACL(time.Hour) c, err := dgraphtest.NewLocalCluster(conf) require.NoError(t, err) @@ -131,7 +146,7 @@ func TestVectorSnapshot(t *testing.T) { require.NoError(t, gc.LoginIntoNamespace(context.Background(), dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.RootNamespace)) - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) prevSnapshotTs, err := hc.GetCurrentSnapshotTs(1) require.NoError(t, err) @@ -173,7 +188,11 @@ func TestVectorSnapshot(t *testing.T) { testVectorQuery(t, gc, vectors, rdfs, pred, numVectors) } -func TestVectorDropNamespace(t *testing.T) { +func (vsuite *VectorTestSuite) TestVectorDropNamespace() { + t := vsuite.T() + if vsuite.isForPartitionedIndex { + t.Skip("Skipping TestVectorDropNamespace for partitioned index") + } conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) c, err := dgraphtest.NewLocalCluster(conf) require.NoError(t, err) @@ -195,7 +214,7 @@ func TestVectorDropNamespace(t *testing.T) { for i := 0; i < 6; i++ { ns, err := hc.AddNamespace() require.NoError(t, err) - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} _, err = gc.Mutate(mu) @@ -223,7 +242,8 @@ func TestVectorDropNamespace(t *testing.T) { } } -func TestVectorIndexRebuilding(t *testing.T) { +func (vsuite *VectorTestSuite) TestVectorIndexRebuilding() { + t := vsuite.T() conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) c, err := dgraphtest.NewLocalCluster(conf) require.NoError(t, err) @@ -241,7 +261,7 @@ func TestVectorIndexRebuilding(t *testing.T) { require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.RootNamespace)) - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) numVectors := 1000 rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) @@ -265,7 +285,7 @@ func TestVectorIndexRebuilding(t *testing.T) { time.Sleep(5 * time.Second) // rebuild index - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) time.Sleep(5 * time.Second) result, err = gc.Query(query) @@ -275,7 +295,8 @@ func TestVectorIndexRebuilding(t *testing.T) { testVectorQuery(t, gc, vectors, rdfs, pred, numVectors) } -func TestVectorIndexOnVectorPredWithoutData(t *testing.T) { +func (vsuite *VectorTestSuite) TestVectorIndexOnVectorPredWithoutData() { + t := vsuite.T() conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) c, err := dgraphtest.NewLocalCluster(conf) require.NoError(t, err) @@ -293,14 +314,15 @@ func TestVectorIndexOnVectorPredWithoutData(t *testing.T) { require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.RootNamespace)) - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) vector := []float32{1.0, 2.0, 3.0} _, err = gc.QueryMultipleVectorsUsingSimilarTo(vector, pred, 10) require.NoError(t, err) } -func TestVectorIndexDropPredicate(t *testing.T) { +func (vsuite *VectorTestSuite) TestVectorIndexDropPredicate() { + t := vsuite.T() conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) c, err := dgraphtest.NewLocalCluster(conf) @@ -320,7 +342,6 @@ func TestVectorIndexDropPredicate(t *testing.T) { require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.RootNamespace)) - require.NoError(t, gc.SetupSchema(testSchema)) numVectors := 1000 // add vectors @@ -329,7 +350,7 @@ func TestVectorIndexDropPredicate(t *testing.T) { _, err = gc.Mutate(mu) require.NoError(t, err) - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) for _, vect := range vectors { similarVects, err := gc.QueryMultipleVectorsUsingSimilarTo(vect, pred, 2) @@ -363,7 +384,7 @@ func TestVectorIndexDropPredicate(t *testing.T) { require.NoError(t, err) // add index back - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) result, err = gc.Query(query) require.NoError(t, err) @@ -376,7 +397,8 @@ func TestVectorIndexDropPredicate(t *testing.T) { } } -func TestVectorIndexWithoutSchema(t *testing.T) { +func (vsuite *VectorTestSuite) TestVectorIndexWithoutSchema() { + t := vsuite.T() conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) c, err := dgraphtest.NewLocalCluster(conf) @@ -399,7 +421,7 @@ func TestVectorIndexWithoutSchema(t *testing.T) { _, err = gc.Mutate(mu) require.NoError(t, err) - require.NoError(t, gc.SetupSchema(testSchema)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) for _, vect := range vectors { similarVects, err := gc.QueryMultipleVectorsUsingSimilarTo(vect, pred, 100) @@ -418,7 +440,50 @@ func TestVectorIndexWithoutSchema(t *testing.T) { require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson())) } -func TestVectorIndexWithoutSchemaWithoutIndex(t *testing.T) { +func (vsuite *VectorTestSuite) TestIndexRebuildingWithoutSchema() { + t := vsuite.T() + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) + c, err := dgraphtest.NewLocalCluster(conf) + require.NoError(t, c.Start()) + + defer func() { c.Cleanup(t.Failed()) }() + + gc, cleanup, err := c.Client() + require.NoError(t, err) + defer cleanup() + + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.RootNamespace)) + + require.NoError(t, gc.DropAll()) + require.NoError(t, gc.SetupSchema(testSchemaWithoutIndex)) + + numVectors := 1000 + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + require.NoError(t, gc.SetupSchema(vsuite.schema)) + + query := `{ + vector(func: has(project_description_v)) { + count(uid) + } + }` + + result, err := gc.Query(query) + require.NoError(t, err) + require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson())) + + for _, vect := range vectors { + similarVects, err := gc.QueryMultipleVectorsUsingSimilarTo(vect, pred, 100) + require.NoError(t, err) + require.Equal(t, 100, len(similarVects)) + } +} + +func (vsuite *VectorTestSuite) TestVectorIndexWithoutSchemaWithoutIndex() { + t := vsuite.T() conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) c, err := dgraphtest.NewLocalCluster(conf) @@ -441,7 +506,7 @@ func TestVectorIndexWithoutSchemaWithoutIndex(t *testing.T) { _, err = gc.Mutate(mu) require.NoError(t, err) - require.NoError(t, gc.SetupSchema(testSchemaWithoutIndex)) + require.NoError(t, gc.SetupSchema(vsuite.schema)) for i, vect := range vectors { triple := strings.Split(rdfs, "\n")[i] @@ -461,3 +526,124 @@ func TestVectorIndexWithoutSchemaWithoutIndex(t *testing.T) { require.NoError(t, err) require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson())) } + +func (vsuite *VectorTestSuite) TestPartitionedHNSWIndex() { + t := vsuite.T() + + if !vsuite.isForPartitionedIndex { + t.Skip("Skipping TestPartitionedHNSWIndex for non partitioned index") + } + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1) + c, err := dgraphtest.NewLocalCluster(conf) + + require.NoError(t, err) + defer func() { c.Cleanup(t.Failed()) }() + require.NoError(t, c.Start()) + + gc, cleanup, err := c.Client() + defer cleanup() + require.NoError(t, err) + + schemaWithoutIndex := `project_description_v: float32vector .` + + t.Run("with more than 1000 vectors", func(t *testing.T) { + require.NoError(t, gc.DropAll()) + + numVectors := 5000 + + require.NoError(t, gc.SetupSchema(schemaWithoutIndex)) + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + err = gc.SetupSchema(vsuite.schema) + require.NoError(t, err) + + testVectorQuery(t, gc, vectors, rdfs, pred, 5) + }) + + t.Run("without providing vector dimension", func(t *testing.T) { + require.NoError(t, gc.DropAll()) + + numVectors := 1001 + + require.NoError(t, gc.SetupSchema(schemaWithoutIndex)) + + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + s := `project_description_v: float32vector @index(partionedhnsw` + + `(numClusters:"1000", partitionStratOpt: "kmeans",metric: "euclidean")) .` + err = gc.SetupSchema(s) + require.NoError(t, err) + + testVectorQuery(t, gc, vectors, rdfs, pred, 1000) + }) + + t.Run("with less than 1000 vectors", func(t *testing.T) { + require.NoError(t, gc.DropAll()) + numVectors := 100 + require.NoError(t, gc.SetupSchema(schemaWithoutIndex)) + + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + err = gc.SetupSchema(vsuite.schema) + require.NoError(t, err) + + testVectorQuery(t, gc, vectors, rdfs, pred, numVectors) + }) + + t.Run("with different length of vectors", func(t *testing.T) { + require.NoError(t, gc.DropAll()) + numVectors := 1100 + require.NoError(t, gc.SetupSchema(schemaWithoutIndex)) + + q := `schema {}` + result, err := gc.Query(q) + require.NoError(t, err) + + rdfs, _ := dgraphapi.GenerateRandomVectors(0, numVectors, 8, pred) + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + err = gc.SetupSchema(vsuite.schema) + require.NoError(t, err) + + // here check schema it should not be changed + q = `schema {}` + result1, err := gc.Query(q) + require.NoError(t, err) + require.JSONEq(t, string(result.GetJson()), string(result1.GetJson())) + }) +} + +type VectorTestSuite struct { + suite.Suite + schema string + schemaVecDimesion10 string + isForPartitionedIndex bool +} + +func TestVectorSuite(t *testing.T) { + for _, schema := range schemas { + var ssuite VectorTestSuite + ssuite.schema = schema + if strings.Contains(schema, "partionedhnsw") { + ssuite.schemaVecDimesion10 = schemaVecDimension10 + ssuite.isForPartitionedIndex = true + } else { + ssuite.schemaVecDimesion10 = schema + } + suite.Run(t, &ssuite) + if t.Failed() { + x.Panic(errors.New("vector tests failed")) + } + } +} diff --git a/t/t.go b/t/t.go index 47bc7e427fd..f5f92bdb4a9 100644 --- a/t/t.go +++ b/t/t.go @@ -376,7 +376,7 @@ func runTestsFor(ctx context.Context, pkg, prefix string, xmlFile string) error // Todo: There are few race errors in tests itself. Enable this once that is fixed. // args = append(args, "-race") } else { - args = append(args, "-timeout", "30m") + args = append(args, "-timeout", "90m") } if *count > 0 { diff --git a/tok/hnsw/helper.go b/tok/hnsw/helper.go index 477f5bc9b27..c3b78c7a488 100644 --- a/tok/hnsw/helper.go +++ b/tok/hnsw/helper.go @@ -114,6 +114,10 @@ func euclideanDistanceSq[T c.Float](a, b []T, floatBits int) (T, error) { return applyDistanceFunction(a, b, floatBits, "euclidean distance", vek32.Distance, vek.Distance) } +func EuclideanDistanceSq[T c.Float](a, b []T, floatBits int) (T, error) { + return applyDistanceFunction(a, b, floatBits, "euclidean distance", vek32.Distance, vek.Distance) +} + // Used for distance, since shorter distance is better func insortPersistentHeapAscending[T c.Float]( slice []minPersistentHeapElement[T], diff --git a/tok/hnsw/persistent_factory.go b/tok/hnsw/persistent_factory.go index ff4c622f218..4bc13b48ea6 100644 --- a/tok/hnsw/persistent_factory.go +++ b/tok/hnsw/persistent_factory.go @@ -78,6 +78,17 @@ func (hf *persistentIndexFactory[T]) AllowedOptions() opt.AllowedOptions { return retVal } +func UpdateIndexSplit[T c.Float](vi index.VectorIndex[T], split int) error { + hnsw, ok := vi.(*persistentHNSW[T]) + if !ok { + return errors.New("index is not a persistent HNSW index") + } + hnsw.vecEntryKey = ConcatStrings(hnsw.pred, fmt.Sprintf("%s_%d", VecEntry, split)) + hnsw.vecKey = ConcatStrings(hnsw.pred, fmt.Sprintf("%s_%d", VecKeyword, split)) + hnsw.vecDead = ConcatStrings(hnsw.pred, fmt.Sprintf("%s_%d", VecDead, split)) + return nil +} + // Create is an implementation of the IndexFactory interface function, invoked by an HNSWIndexFactory // instance. It takes in a string name and a VectorSource implementation, and returns a VectorIndex and error // flag. It creates an HNSW instance using the index name and populates other parts of the HNSW struct such as diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index e13ddddaf89..a01b19fa011 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -8,10 +8,12 @@ package hnsw import ( "context" "fmt" + "sort" "strings" "time" "github.com/golang/glog" + "github.com/hypermodeinc/dgraph/v25/protos/pb" c "github.com/hypermodeinc/dgraph/v25/tok/constraints" "github.com/hypermodeinc/dgraph/v25/tok/index" opt "github.com/hypermodeinc/dgraph/v25/tok/options" @@ -32,6 +34,7 @@ type persistentHNSW[T c.Float] struct { // layer for uuid 65443. The result will be a neighboring uuid. nodeAllEdges map[uint64][][]uint64 deadNodes map[uint64]struct{} + cache index.CacheType } func GetPersistantOptions[T c.Float](o opt.Options) string { @@ -111,6 +114,66 @@ func (ph *persistentHNSW[T]) applyOptions(o opt.Options) error { return nil } +func (ph *persistentHNSW[T]) NumBuildPasses() int { + return 0 +} + +func (ph *persistentHNSW[T]) SetNumPasses(int) { + return +} + +func (ph *persistentHNSW[T]) Dimension() int { + return 0 +} + +func (ph *persistentHNSW[T]) SetDimension(schema *pb.SchemaUpdate, dimension int) { + glog.Info("not implemented") +} + +func (ph *persistentHNSW[T]) NumIndexPasses() int { + return 1 +} + +func (ph *persistentHNSW[T]) NumSeedVectors() int { + return 0 +} + +func (ph *persistentHNSW[T]) StartBuild(caches []index.CacheType) { + ph.nodeAllEdges = make(map[uint64][][]uint64) + ph.cache = caches[0] +} + +func (ph *persistentHNSW[T]) EndBuild() []int { + ph.nodeAllEdges = nil + ph.cache = nil + return []int{0} +} + +func (ph *persistentHNSW[T]) NumThreads() int { + return 1 +} + +func (ph *persistentHNSW[T]) BuildInsert(ctx context.Context, uid uint64, vec []T) error { + newPh := &persistentHNSW[T]{ + maxLevels: ph.maxLevels, + efConstruction: ph.efConstruction, + efSearch: ph.efSearch, + pred: ph.pred, + vecEntryKey: ph.vecEntryKey, + vecKey: ph.vecKey, + vecDead: ph.vecDead, + simType: ph.simType, + floatBits: ph.floatBits, + nodeAllEdges: make(map[uint64][][]uint64), + cache: ph.cache, + } + _, err := newPh.Insert(ctx, ph.cache, uid, vec) + return err +} + +func (ph *persistentHNSW[T]) AddSeedVector(vec []T) { +} + func (ph *persistentHNSW[T]) emptyFinalResultWithError(e error) ( *index.SearchPathResult, error) { return index.NewSearchPathResult(), e @@ -254,6 +317,53 @@ func (ph *persistentHNSW[T]) Search(ctx context.Context, c index.CacheType, quer return r.Neighbors, err } +type resultRow[T c.Float] struct { + uid uint64 + dist T +} + +// MergeResults takes a list of UIDs and returns the maxResults nearest neighbors +// in order of increasing distance. It returns an error if any of the UIDs are +// not present in the index. +// +// The filter parameter is not used by this method. +// +// This method is part of the index.MultipleIndex interface. +func (ph *persistentHNSW[T]) MergeResults(ctx context.Context, c index.CacheType, list []uint64, query []T, maxResults int, filter index.SearchFilter[T]) ([]uint64, error) { + var result []resultRow[T] + + for i := range list { + var vec []T + err := ph.getVecFromUid(list[i], c, &vec) + if err != nil { + return nil, err + } + + dist, err := ph.simType.distanceScore(vec, query, ph.floatBits) + if err != nil { + return nil, err + } + result = append(result, resultRow[T]{ + uid: list[i], + dist: dist, + }) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].dist < result[j].dist + }) + + uids := []uint64{} + for i := range maxResults { + if i > len(result) { + break + } + uids = append(uids, result[i].uid) + } + + return uids, nil +} + // SearchWithUid searches the hnsw graph for the nearest neighbors of the query uid // and returns the traversal path and the nearest neighbors func (ph *persistentHNSW[T]) SearchWithUid(_ context.Context, c index.CacheType, queryUid uint64, @@ -400,6 +510,9 @@ func (ph *persistentHNSW[T]) Insert(ctx context.Context, c index.CacheType, _, edges, err := ph.insertHelper(ctx, tc, inUuid, inVec) return edges, err } +func (ph *persistentHNSW[T]) GetCentroids() [][]T { + return nil +} // InsertToPersistentStorage inserts a node into the hnsw graph and returns the // traversal path and the edges created diff --git a/tok/index/index.go b/tok/index/index.go index e0a62255ce1..2e8506e6cc6 100644 --- a/tok/index/index.go +++ b/tok/index/index.go @@ -8,6 +8,7 @@ package index import ( "context" + "github.com/hypermodeinc/dgraph/v25/protos/pb" c "github.com/hypermodeinc/dgraph/v25/tok/constraints" opts "github.com/hypermodeinc/dgraph/v25/tok/options" ) @@ -89,10 +90,26 @@ type OptionalIndexSupport[T c.Float] interface { filter SearchFilter[T]) (*SearchPathResult, error) } +type VectorPartitionStrat[T c.Float] interface { + FindIndexForSearch(vec []T) ([]int, error) + FindIndexForInsert(vec []T) (int, error) + NumPasses() int + SetNumPasses(int) + NumSeedVectors() int + StartBuildPass() + EndBuildPass() + AddSeedVector(vec []T) + AddVector(vec []T) error + GetCentroids() [][]T +} + // A VectorIndex can be used to Search for vectors and add vectors to an index. type VectorIndex[T c.Float] interface { OptionalIndexSupport[T] + MergeResults(ctx context.Context, c CacheType, list []uint64, query []T, maxResults int, + filter SearchFilter[T]) ([]uint64, error) + // Search will find the uids for a given set of vectors based on the // input query, limiting to the specified maximum number of results. // The filter parameter indicates that we might discard certain parameters @@ -116,6 +133,19 @@ type VectorIndex[T c.Float] interface { // Insert will add a vector and uuid into the existing VectorIndex. If // uuid already exists, it should throw an error to not insert duplicate uuids Insert(ctx context.Context, c CacheType, uuid uint64, vec []T) ([]*KeyValue, error) + + BuildInsert(ctx context.Context, uuid uint64, vec []T) error + GetCentroids() [][]T + AddSeedVector(vec []T) + NumBuildPasses() int + SetNumPasses(int) + NumIndexPasses() int + NumSeedVectors() int + StartBuild(caches []CacheType) + EndBuild() []int + NumThreads() int + Dimension() int + SetDimension(schema *pb.SchemaUpdate, dimension int) } // A Txn is an interface representation of a persistent storage transaction, diff --git a/tok/kmeans/kmeans.go b/tok/kmeans/kmeans.go new file mode 100644 index 00000000000..81fc07c537a --- /dev/null +++ b/tok/kmeans/kmeans.go @@ -0,0 +1,233 @@ +package kmeans + +import ( + "encoding/json" + "fmt" + "math" + "sync" + + c "github.com/hypermodeinc/dgraph/v25/tok/constraints" + "github.com/hypermodeinc/dgraph/v25/tok/hnsw" + "github.com/hypermodeinc/dgraph/v25/tok/index" + "github.com/hypermodeinc/dgraph/v25/x" +) + +const ( + CentroidPrefix = "__centroid_" +) + +type Kmeans[T c.Float] struct { + floatBits int + numPasses int + centroids *vectorCentroids[T] +} + +func CreateKMeans[T c.Float](floatBits int, pred string, distFunc func(a, b []T, floatBits int) (T, error)) index.VectorPartitionStrat[T] { + return &Kmeans[T]{ + floatBits: floatBits, + numPasses: 5, + centroids: &vectorCentroids[T]{ + distFunc: distFunc, + floatBits: floatBits, + pred: pred, + }, + } +} + +func (km *Kmeans[T]) AddSeedVector(vec []T) { + km.centroids.addSeedCentroid(vec) +} + +func (km *Kmeans[T]) AddVector(vec []T) error { + return km.centroids.addVector(vec) +} + +func (km *Kmeans[T]) GetCentroids() [][]T { + return km.centroids.centroids +} + +func (km *Kmeans[T]) FindIndexForSearch(vec []T) ([]int, error) { + if km.NumPasses() == 0 { + return []int{0}, nil + } + res := make([]int, km.NumSeedVectors()) + for i := range res { + res[i] = i + } + return res, nil +} + +func (km *Kmeans[T]) FindIndexForInsert(vec []T) (int, error) { + if km.NumPasses() == 0 { + return 0, nil + } + return km.centroids.findCentroid(vec) +} + +func (km *Kmeans[T]) NumPasses() int { + return km.numPasses +} + +func (km *Kmeans[T]) SetNumPasses(n int) { + km.numPasses = n +} + +func (km *Kmeans[T]) NumSeedVectors() int { + return 1000 +} + +func (km *Kmeans[T]) StartBuildPass() { + if km.centroids.weights == nil { + km.centroids.randomInit() + } +} + +func (km *Kmeans[T]) EndBuildPass() { + km.centroids.updateCentroids() +} + +type vectorCentroids[T c.Float] struct { + dimension int + numCenters int + pred string + + distFunc func(a, b []T, floatBits int) (T, error) + + centroids [][]T + counts []int64 + weights [][]T + mutexs []*sync.Mutex + floatBits int +} + +func (vc *vectorCentroids[T]) findCentroid(input []T) (int, error) { + minIdx := 0 + minDist := math.MaxFloat32 + for i, centroid := range vc.centroids { + dist, err := vc.distFunc(centroid, input, vc.floatBits) + if err != nil { + return 0, err + } + if float64(dist) < minDist { + minDist = float64(dist) + minIdx = i + } + } + return minIdx, nil +} + +func (vc *vectorCentroids[T]) getCentroids(txn index.CacheType) ([][]T, error) { + if len(vc.centroids) > 0 { + return vc.centroids, nil + } + indexCountAttr := hnsw.ConcatStrings(vc.pred, CentroidPrefix) + key := x.DataKey(indexCountAttr, 1) + centroidsMarshalled, err := txn.Get(key) + if err != nil { + return nil, err + } + + centroids := [][]T{} + err = json.Unmarshal(centroidsMarshalled, ¢roids) + if err != nil { + return nil, err + } + + vc.centroids = centroids + return vc.centroids, nil +} + +func (vc *vectorCentroids[T]) findNClosestCentroids(input []T, n int, txn index.CacheType) ([]int, error) { + cNS, err := vc.getCentroids(txn) + if err != nil { + return nil, err + } + if n <= 0 || len(vc.centroids) == 0 { + return []int{}, nil + } + if n >= len(vc.centroids) { + res := make([]int, len(vc.centroids)) + for i := range res { + res[i] = i + } + return res, nil + } + res := []int{} + resDist := []float64{} + // get centroids + + for i, centroid := range cNS { + dist, err := vc.distFunc(centroid, input, vc.floatBits) + if err != nil { + return nil, err + } + if len(res) < n { + res = append(res, i) + resDist = append(resDist, float64(dist)) + } else { + // Find the farthest in current top-n + maxIdx, maxDist := 0, resDist[0] + for j, d := range resDist { + if d > maxDist { + maxIdx, maxDist = j, d + } + } + if float64(dist) < maxDist { + res[maxIdx] = i + resDist[maxIdx] = float64(dist) + } + } + } + return res, nil +} + +func (vc *vectorCentroids[T]) addVector(vec []T) error { + idx, err := vc.findCentroid(vec) + if err != nil { + return err + } + vc.mutexs[idx].Lock() + defer vc.mutexs[idx].Unlock() + for i := 0; i < vc.dimension; i++ { + vc.weights[idx][i] += vec[i] + } + vc.counts[idx]++ + return nil +} + +func (vc *vectorCentroids[T]) updateCentroids() { + x.AssertTrue(len(vc.centroids) == vc.numCenters) + x.AssertTrue(len(vc.counts) == vc.numCenters) + x.AssertTrue(len(vc.weights) == vc.numCenters) + for i := 0; i < vc.numCenters; i++ { + for j := 0; j < vc.dimension; j++ { + x.AssertTrue(len(vc.centroids[i]) == vc.dimension) + x.AssertTrue(len(vc.weights[i]) == vc.dimension) + vc.centroids[i][j] = vc.weights[i][j] / T(vc.counts[i]) + vc.weights[i][j] = 0 + } + fmt.Printf("%d, ", vc.counts[i]) + vc.counts[i] = 0 + } + fmt.Println() +} + +func (vc *vectorCentroids[T]) randomInit() { + vc.dimension = len(vc.centroids[0]) + for i := range vc.centroids { + x.AssertTrue(len(vc.centroids[i]) == vc.dimension) + } + vc.numCenters = len(vc.centroids) + vc.counts = make([]int64, vc.numCenters) + vc.weights = make([][]T, vc.numCenters) + vc.mutexs = make([]*sync.Mutex, vc.numCenters) + for i := 0; i < vc.numCenters; i++ { + vc.weights[i] = make([]T, vc.dimension) + vc.counts[i] = 0 + vc.mutexs[i] = &sync.Mutex{} + } +} + +func (vc *vectorCentroids[T]) addSeedCentroid(vec []T) { + vc.centroids = append(vc.centroids, vec) +} diff --git a/tok/partitioned_hnsw/partitioned_factory.go b/tok/partitioned_hnsw/partitioned_factory.go new file mode 100644 index 00000000000..f8f46e87f8d --- /dev/null +++ b/tok/partitioned_hnsw/partitioned_factory.go @@ -0,0 +1,162 @@ +/* + * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package partitioned_hnsw + +import ( + "errors" + "fmt" + "sync" + + c "github.com/hypermodeinc/dgraph/v25/tok/constraints" + "github.com/hypermodeinc/dgraph/v25/tok/hnsw" + "github.com/hypermodeinc/dgraph/v25/tok/index" + opt "github.com/hypermodeinc/dgraph/v25/tok/options" +) + +const ( + NumClustersOpt string = "numClusters" + vectorDimension string = "vectorDimension" + PartitionStratOpt string = "partitionStratOpt" + PartitionedHNSW string = "partionedhnsw" +) + +type partitionedHNSWIndexFactory[T c.Float] struct { + indexMap map[string]index.VectorIndex[T] + floatBits int + mu sync.RWMutex +} + +// CreateFactory creates an instance of the private struct persistentIndexFactory. +// NOTE: if T and floatBits do not match in # of bits, there will be consequences. +func CreateFactory[T c.Float](floatBits int) index.IndexFactory[T] { + return &partitionedHNSWIndexFactory[T]{ + indexMap: map[string]index.VectorIndex[T]{}, + floatBits: floatBits, + } +} + +// Implements NamedFactory interface for use as a plugin. +func (hf *partitionedHNSWIndexFactory[T]) Name() string { return PartitionedHNSW } + +func (hf *partitionedHNSWIndexFactory[T]) GetOptions(o opt.Options) string { + return hnsw.GetPersistantOptions[T](o) +} + +func (hf *partitionedHNSWIndexFactory[T]) isNameAvailableWithLock(name string) bool { + _, nameUsed := hf.indexMap[name] + return !nameUsed +} + +// hf.AllowedOptions() allows persistentIndexFactory to implement the +// IndexFactory interface (see vector-indexer/index/index.go for details). +// We define here options for exponent, maxLevels, efSearch, efConstruction, +// and metric. +func (hf *partitionedHNSWIndexFactory[T]) AllowedOptions() opt.AllowedOptions { + retVal := opt.NewAllowedOptions() + retVal.AddIntOption(hnsw.ExponentOpt). + AddIntOption(hnsw.MaxLevelsOpt). + AddIntOption(hnsw.EfConstructionOpt). + AddIntOption(hnsw.EfSearchOpt). + AddIntOption(NumClustersOpt). + AddStringOption(PartitionStratOpt).AddIntOption(vectorDimension) + getSimFunc := func(optValue string) (any, error) { + if optValue != hnsw.Euclidean && optValue != hnsw.Cosine && optValue != hnsw.DotProd { + return nil, fmt.Errorf("Can't create a vector index for %s", optValue) + } + return hnsw.GetSimType[T](optValue, hf.floatBits), nil + } + + retVal.AddCustomOption(hnsw.MetricOpt, getSimFunc) + return retVal +} + +// Create is an implementation of the IndexFactory interface function, invoked by an HNSWIndexFactory +// instance. It takes in a string name and a VectorSource implementation, and returns a VectorIndex and error +// flag. It creates an HNSW instance using the index name and populates other parts of the HNSW struct such as +// multFactor, maxLevels, efConstruction, maxNeighbors, and efSearch using struct parameters. +// It then populates the HNSW graphs using the InsertChunk function until there are no more items to populate. +// Finally, the function adds the name and hnsw object to the in memory map and returns the object. +func (hf *partitionedHNSWIndexFactory[T]) Create( + name string, + o opt.Options, + floatBits int) (index.VectorIndex[T], error) { + hf.mu.Lock() + defer hf.mu.Unlock() + return hf.createWithLock(name, o, floatBits) +} + +func (hf *partitionedHNSWIndexFactory[T]) createWithLock( + name string, + o opt.Options, + floatBits int) (index.VectorIndex[T], error) { + if !hf.isNameAvailableWithLock(name) { + err := errors.New("index with name " + name + " already exists") + return nil, err + } + retVal := &partitionedHNSW[T]{ + pred: name, + floatBits: floatBits, + clusterMap: map[int]index.VectorIndex[T]{}, + buildSyncMaps: map[int]*sync.Mutex{}, + } + err := retVal.applyOptions(o) + if err != nil { + return nil, err + } + hf.indexMap[name] = retVal + return retVal, nil +} + +// Find is an implementation of the IndexFactory interface function, invoked by an persistentIndexFactory +// instance. It returns the VectorIndex corresponding with a string name using the in memory map. +func (hf *partitionedHNSWIndexFactory[T]) Find(name string) (index.VectorIndex[T], error) { + hf.mu.RLock() + defer hf.mu.RUnlock() + return hf.findWithLock(name) +} + +func (hf *partitionedHNSWIndexFactory[T]) findWithLock(name string) (index.VectorIndex[T], error) { + vecInd := hf.indexMap[name] + return vecInd, nil +} + +// Remove is an implementation of the IndexFactory interface function, invoked by an persistentIndexFactory +// instance. It removes the VectorIndex corresponding with a string name using the in memory map. +func (hf *partitionedHNSWIndexFactory[T]) Remove(name string) error { + hf.mu.Lock() + defer hf.mu.Unlock() + return hf.removeWithLock(name) +} + +func (hf *partitionedHNSWIndexFactory[T]) removeWithLock(name string) error { + delete(hf.indexMap, name) + return nil +} + +// CreateOrReplace is an implementation of the IndexFactory interface funciton, +// invoked by an persistentIndexFactory. It checks if a VectorIndex +// correpsonding with name exists. If it does, it removes it, and replaces it +// via the Create function using the passed VectorSource. If the VectorIndex +// does not exist, it creates that VectorIndex corresponding with the name using +// the VectorSource. +func (hf *partitionedHNSWIndexFactory[T]) CreateOrReplace( + name string, + o opt.Options, + floatBits int) (index.VectorIndex[T], error) { + hf.mu.Lock() + defer hf.mu.Unlock() + vi, err := hf.findWithLock(name) + if err != nil { + return nil, err + } + if vi != nil { + err = hf.removeWithLock(name) + if err != nil { + return nil, err + } + } + return hf.createWithLock(name, o, floatBits) +} diff --git a/tok/partitioned_hnsw/partitioned_hnsw.go b/tok/partitioned_hnsw/partitioned_hnsw.go new file mode 100644 index 00000000000..e68f41283d4 --- /dev/null +++ b/tok/partitioned_hnsw/partitioned_hnsw.go @@ -0,0 +1,235 @@ +// CreateFactory creates an instance of the private struct persistentIndexFactory. +// NOTE: if T and floatBits do not match in # of bits, there will be consequences. + +package partitioned_hnsw + +import ( + "context" + "errors" + "fmt" + "strconv" + "sync" + + "github.com/hypermodeinc/dgraph/v25/protos/pb" + c "github.com/hypermodeinc/dgraph/v25/tok/constraints" + hnsw "github.com/hypermodeinc/dgraph/v25/tok/hnsw" + "github.com/hypermodeinc/dgraph/v25/tok/index" + "github.com/hypermodeinc/dgraph/v25/tok/kmeans" + opt "github.com/hypermodeinc/dgraph/v25/tok/options" +) + +type partitionedHNSW[T c.Float] struct { + floatBits int + pred string + + clusterMap map[int]index.VectorIndex[T] + numClusters int + vectorDimension int + vecCount int + numPasses int + partition index.VectorPartitionStrat[T] + + hnswOptions opt.Options + partitionStrat string + + caches []index.CacheType + buildPass int + buildSyncMaps map[int]*sync.Mutex +} + +func (ph *partitionedHNSW[T]) applyOptions(o opt.Options) error { + ph.numClusters, _, _ = opt.GetOpt(o, NumClustersOpt, 1000) + ph.vectorDimension, _, _ = opt.GetOpt(o, vectorDimension, -1) + ph.partitionStrat, _, _ = opt.GetOpt(o, PartitionStratOpt, "kmeans") + + if ph.partitionStrat != "kmeans" && ph.partitionStrat != "query" { + return errors.New("partition strategy must be kmeans or query") + } + + if ph.partitionStrat == "kmeans" { + ph.partition = kmeans.CreateKMeans(ph.floatBits, ph.pred, hnsw.EuclideanDistanceSq[T]) + } + + ph.buildPass = 0 + ph.numPasses = 10 + ph.hnswOptions = o + for i := range ph.numClusters { + factory := hnsw.CreateFactory[T](ph.floatBits) + vi, err := factory.Create(ph.pred, ph.hnswOptions, ph.floatBits) + if err != nil { + return err + } + err = hnsw.UpdateIndexSplit(vi, i) + if err != nil { + return err + } + ph.clusterMap[i] = vi + } + return nil +} + +func (ph *partitionedHNSW[T]) AddSeedVector(vec []T) { + ph.partition.AddSeedVector(vec) +} + +func (ph *partitionedHNSW[T]) BuildInsert(ctx context.Context, uuid uint64, vec []T) error { + passIdx := ph.buildPass - ph.partition.NumPasses() + if passIdx < 0 { + return ph.partition.AddVector(vec) + } + index, err := ph.partition.FindIndexForInsert(vec) + if err != nil { + return err + } + if index%ph.numPasses != passIdx { + return nil + } + ph.buildSyncMaps[index].Lock() + defer ph.buildSyncMaps[index].Unlock() + _, err = ph.clusterMap[index].Insert(ctx, ph.caches[index], uuid, vec) + return err +} + +func (ph *partitionedHNSW[T]) GetCentroids() [][]T { + return ph.partition.GetCentroids() +} + +func (ph *partitionedHNSW[T]) NumBuildPasses() int { + return ph.partition.NumPasses() +} + +func (ph *partitionedHNSW[T]) SetNumPasses(n int) { + ph.partition.SetNumPasses(n) +} + +func (ph *partitionedHNSW[T]) Dimension() int { + return ph.vectorDimension +} + +func (ph *partitionedHNSW[T]) SetDimension(schema *pb.SchemaUpdate, dimension int) { + ph.vectorDimension = dimension + for _, vs := range schema.IndexSpecs { + if vs.Name == "partionedhnsw" { + vs.Options = append(vs.Options, &pb.OptionPair{ + Key: "vectorDimension", + Value: strconv.Itoa(dimension), + }) + } + } +} + +func (ph *partitionedHNSW[T]) NumIndexPasses() int { + return ph.numPasses +} + +func (ph *partitionedHNSW[T]) NumThreads() int { + return ph.numClusters +} + +func (ph *partitionedHNSW[T]) NumSeedVectors() int { + return ph.partition.NumSeedVectors() +} + +func (ph *partitionedHNSW[T]) StartBuild(caches []index.CacheType) { + ph.caches = caches + if ph.buildPass < ph.partition.NumPasses() { + ph.partition.StartBuildPass() + return + } + + for i := range ph.clusterMap { + ph.buildSyncMaps[i] = &sync.Mutex{} + if i%ph.numPasses != (ph.buildPass - ph.partition.NumPasses()) { + continue + } + ph.clusterMap[i].StartBuild([]index.CacheType{ph.caches[i]}) + } +} + +func (ph *partitionedHNSW[T]) EndBuild() []int { + res := []int{} + + if ph.buildPass >= ph.partition.NumPasses() { + for i := range ph.clusterMap { + if i%ph.numPasses != (ph.buildPass - ph.partition.NumPasses()) { + continue + } + ph.clusterMap[i].EndBuild() + res = append(res, i) + } + } + + ph.buildPass += 1 + + if len(res) > 0 { + return res + } + + if ph.buildPass < ph.partition.NumPasses() { + ph.partition.EndBuildPass() + } + return []int{} +} + +func (ph *partitionedHNSW[T]) Insert(ctx context.Context, txn index.CacheType, uid uint64, vec []T) ([]*index.KeyValue, error) { + if len(vec) == 0 { + ph.vectorDimension = len(vec) + } + + if len(vec) != ph.vectorDimension { + return nil, fmt.Errorf("connot insert vector length of %d vector lenth should be %d", len(vec), ph.vectorDimension) + } + + index, err := ph.partition.FindIndexForInsert(vec) + if err != nil { + return nil, err + } + return ph.clusterMap[index].Insert(ctx, txn, uid, vec) +} + +func (ph *partitionedHNSW[T]) Search(ctx context.Context, txn index.CacheType, query []T, maxResults int, filter index.SearchFilter[T]) ([]uint64, error) { + indexes, err := ph.partition.FindIndexForSearch(query) + if err != nil { + return nil, err + } + res := []uint64{} + mutex := &sync.Mutex{} + var wg sync.WaitGroup + for _, index := range indexes { + wg.Add(1) + go func(i int) { + defer wg.Done() + ids, err := ph.clusterMap[i].Search(ctx, txn, query, maxResults, filter) + if err != nil { + return + } + mutex.Lock() + res = append(res, ids...) + mutex.Unlock() + }(index) + } + wg.Wait() + + if len(res) == 0 { + return res, nil + } + + return ph.clusterMap[0].MergeResults(ctx, txn, res, query, maxResults, filter) +} + +func (ph *partitionedHNSW[T]) SearchWithPath(ctx context.Context, txn index.CacheType, query []T, maxResults int, filter index.SearchFilter[T]) (*index.SearchPathResult, error) { + indexes, err := ph.partition.FindIndexForSearch(query) + if err != nil { + return nil, err + } + return ph.clusterMap[indexes[0]].SearchWithPath(ctx, txn, query, maxResults, filter) +} + +func (ph *partitionedHNSW[T]) SearchWithUid(ctx context.Context, txn index.CacheType, uid uint64, maxResults int, filter index.SearchFilter[T]) ([]uint64, error) { + // #TODO + return ph.clusterMap[0].SearchWithUid(ctx, txn, uid, maxResults, filter) +} + +func (ph *partitionedHNSW[T]) MergeResults(ctx context.Context, txn index.CacheType, list []uint64, query []T, maxResults int, filter index.SearchFilter[T]) ([]uint64, error) { + return ph.clusterMap[0].MergeResults(ctx, txn, list, query, maxResults, filter) +} diff --git a/tok/tok.go b/tok/tok.go index 216f3ae7b06..3d598b9af16 100644 --- a/tok/tok.go +++ b/tok/tok.go @@ -22,6 +22,7 @@ import ( "github.com/hypermodeinc/dgraph/v25/protos/pb" "github.com/hypermodeinc/dgraph/v25/tok/hnsw" opts "github.com/hypermodeinc/dgraph/v25/tok/options" + "github.com/hypermodeinc/dgraph/v25/tok/partitioned_hnsw" "github.com/hypermodeinc/dgraph/v25/types" "github.com/hypermodeinc/dgraph/v25/x" ) @@ -87,6 +88,7 @@ var indexFactories = make(map[string]IndexFactory) func init() { registerTokenizer(BigFloatTokenizer{}) registerIndexFactory(createIndexFactory(hnsw.CreateFactory[float32](32))) + registerIndexFactory(createIndexFactory(partitioned_hnsw.CreateFactory[float32](32))) registerTokenizer(GeoTokenizer{}) registerTokenizer(IntTokenizer{}) registerTokenizer(FloatTokenizer{}) diff --git a/worker/backup.go b/worker/backup.go index 0d803f47c06..e2096f3be70 100644 --- a/worker/backup.go +++ b/worker/backup.go @@ -14,6 +14,7 @@ import ( "math" "net/url" "reflect" + "strconv" "strings" "sync" "time" @@ -32,6 +33,8 @@ import ( "github.com/hypermodeinc/dgraph/v25/posting" "github.com/hypermodeinc/dgraph/v25/protos/pb" "github.com/hypermodeinc/dgraph/v25/tok/hnsw" + "github.com/hypermodeinc/dgraph/v25/tok/kmeans" + "github.com/hypermodeinc/dgraph/v25/tok/partitioned_hnsw" "github.com/hypermodeinc/dgraph/v25/x" ) @@ -298,8 +301,29 @@ func ProcessBackupRequest(ctx context.Context, req *pb.BackupRequest) error { for _, pred := range schema { if pred.Type == "float32vector" && len(pred.IndexSpecs) != 0 { - vecPredMap[gid] = append(predMap[gid], pred.Predicate+hnsw.VecEntry, pred.Predicate+hnsw.VecKeyword, - pred.Predicate+hnsw.VecDead) + for _, spec := range pred.IndexSpecs { + if spec.Name == partitioned_hnsw.PartitionedHNSW { + vecPredMap[gid] = append(predMap[gid], pred.Predicate+kmeans.CentroidPrefix) + for _, opt := range spec.Options { + if opt.Key == partitioned_hnsw.NumClustersOpt { + numClusters, err := strconv.Atoi(opt.Value) + if err != nil { + return fmt.Errorf(`unable to parse number of clusters %s for predicate %s: %w`, + opt.Value, pred.Predicate, err) + } + for i := range numClusters { + vecEntryKey := hnsw.ConcatStrings(pred.Predicate, fmt.Sprintf("%s_%d", hnsw.VecEntry, i)) + vecKey := hnsw.ConcatStrings(pred.Predicate, fmt.Sprintf("%s_%d", hnsw.VecKeyword, i)) + vecDead := hnsw.ConcatStrings(pred.Predicate, fmt.Sprintf("%s_%d", hnsw.VecDead, i)) + vecPredMap[gid] = append(vecPredMap[gid], vecEntryKey, vecKey, vecDead) + } + } + } + } else { + vecPredMap[gid] = append(predMap[gid], pred.Predicate+hnsw.VecEntry, pred.Predicate+hnsw.VecKeyword, + pred.Predicate+hnsw.VecDead) + } + } } } } @@ -602,6 +626,9 @@ func (pr *BackupProcessor) WriteBackup(ctx context.Context) (*pb.BackupResponse, err, hex.EncodeToString(item.Key())) continue } + + fmt.Println("Backup key:", parsedKey.Attr, "isType:", parsedKey.IsType()) + // This check makes sense only for the schema keys. The types are not stored in it. if _, ok := predMap[parsedKey.Attr]; !parsedKey.IsType() && !ok { continue diff --git a/worker/task.go b/worker/task.go index ba7e859572f..8df1f663db7 100644 --- a/worker/task.go +++ b/worker/task.go @@ -363,22 +363,18 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er posting.NewViLocalCache(qs.cache), args.q.ReadTs, ) + indexer, err := cspec.CreateIndex(args.q.Attr) if err != nil { return err } - var nnUids []uint64 - if srcFn.vectorInfo != nil { - nnUids, err = indexer.Search(ctx, qc, srcFn.vectorInfo, - int(numNeighbors), index.AcceptAll[float32]) - } else { - nnUids, err = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, - int(numNeighbors), index.AcceptAll[float32]) - } + nnUids, err := indexer.Search(ctx, qc, srcFn.vectorInfo, + int(numNeighbors), index.AcceptAll[float32]) if err != nil && !strings.Contains(err.Error(), hnsw.EmptyHNSWTreeError+": "+badger.ErrKeyNotFound.Error()) { return err } + sort.Slice(nnUids, func(i, j int) bool { return nnUids[i] < nnUids[j] }) args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: nnUids}) return nil