diff --git a/README.md b/README.md index ed00d93..4023467 100644 --- a/README.md +++ b/README.md @@ -32,18 +32,18 @@ go get github.com/coder/hnsw@main ``` ```go -g := hnsw.NewGraph[hnsw.Vector]() +g := hnsw.NewGraph[int]() g.Add( - hnsw.MakeVector("1", []float32{1, 1, 1}), - hnsw.MakeVector("2", []float32{1, -1, 0.999}), - hnsw.MakeVector("3", []float32{1, 0, -0.5}), + hnsw.MakeNode(1, []float32{1, 1, 1}), + hnsw.MakeNode(2, []float32{1, -1, 0.999}), + hnsw.MakeNode(3, []float32{1, 0, -0.5}), ) neighbors := g.Search( []float32{0.5, 0.5, 0.5}, 1, ) -fmt.Printf("best friend: %v\n", neighbors[0].Embedding()) +fmt.Printf("best friend: %v\n", neighbors[0].Vec) // Output: best friend: [1 1 1] ``` @@ -59,13 +59,13 @@ If you're using a single file as the backend, hnsw provides a convenient `SavedG ```go path := "some.graph" -g1, err := LoadSavedGraph[hnsw.Vector](path) +g1, err := LoadSavedGraph[int](path) if err != nil { panic(err) } // Insert some vectors for i := 0; i < 128; i++ { - g1.Add(MakeVector(strconv.Itoa(i), []float32{float32(i)})) + g1.Add(hnsw.MakeNode(i, []float32{float32(i)})) } // Save to disk @@ -76,7 +76,7 @@ if err != nil { // Later... // g2 is a copy of g1 -g2, err := LoadSavedGraph[Vector](path) +g2, err := LoadSavedGraph[int](path) if err != nil { panic(err) } @@ -94,10 +94,10 @@ nearly at disk speed. On my M3 Macbook I get these benchmark results: goos: darwin goarch: arm64 pkg: github.com/coder/hnsw -BenchmarkGraph_Import-16 2733 369803 ns/op 228.65 MB/s 352041 B/op 9880 allocs/op -BenchmarkGraph_Export-16 6046 194441 ns/op 1076.65 MB/s 261854 B/op 3760 allocs/op +BenchmarkGraph_Import-16 4029 259927 ns/op 796.85 MB/s 496022 B/op 3212 allocs/op +BenchmarkGraph_Export-16 7042 168028 ns/op 1232.49 MB/s 239886 B/op 2388 allocs/op PASS -ok github.com/coder/hnsw 2.530s +ok github.com/coder/hnsw 2.624s ``` when saving/loading a graph of 100 vectors with 256 dimensions. @@ -130,18 +130,18 @@ $$ where: * $n$ is the number of vectors in the graph -* $\text{size(id)}$ is the average size of the ID in bytes +* $\text{size(key)}$ is the average size of the key in bytes * $M$ is the maximum number of neighbors each node can have * $d$ is the dimensionality of the vectors * $mem_{graph}$ is the memory used by the graph structure across all layers * $mem_{base}$ is the memory used by the vectors themselves in the base or 0th layer You can infer that: -* Connectivity ($M$) is very expensive if IDs are large -* If $d \cdot 4$ is far larger than $M \cdot \text{size(id)}$, you should expect linear memory usage spent on representing vector data -* If $d \cdot 4$ is far smaller than $M \cdot \text{size(id)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure +* Connectivity ($M$) is very expensive if keys are large +* If $d \cdot 4$ is far larger than $M \cdot \text{size(key)}$, you should expect linear memory usage spent on representing vector data +* If $d \cdot 4$ is far smaller than $M \cdot \text{size(key)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure -In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte IDs, you would see that each vector takes: +In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte keys, you would see that each vector takes: * $256 \cdot 4 = 1024$ data bytes * $16 \cdot 8 = 128$ metadata bytes diff --git a/analyzer.go b/analyzer.go index c527739..c57bb18 100644 --- a/analyzer.go +++ b/analyzer.go @@ -1,11 +1,13 @@ package hnsw +import "cmp" + // Analyzer is a struct that holds a graph and provides // methods for analyzing it. It offers no compatibility guarantee // as the methods of measuring the graph's health with change // with the implementation. -type Analyzer[T Embeddable] struct { - Graph *Graph[T] +type Analyzer[K cmp.Ordered] struct { + Graph *Graph[K] } func (a *Analyzer[T]) Height() int { @@ -17,16 +19,16 @@ func (a *Analyzer[T]) Height() int { func (a *Analyzer[T]) Connectivity() []float64 { var layerConnectivity []float64 for _, layer := range a.Graph.layers { - if len(layer.Nodes) == 0 { + if len(layer.nodes) == 0 { continue } var sum float64 - for _, node := range layer.Nodes { + for _, node := range layer.nodes { sum += float64(len(node.neighbors)) } - layerConnectivity = append(layerConnectivity, sum/float64(len(layer.Nodes))) + layerConnectivity = append(layerConnectivity, sum/float64(len(layer.nodes))) } return layerConnectivity @@ -36,7 +38,7 @@ func (a *Analyzer[T]) Connectivity() []float64 { func (a *Analyzer[T]) Topography() []int { var topography []int for _, layer := range a.Graph.layers { - topography = append(topography, len(layer.Nodes)) + topography = append(topography, len(layer.nodes)) } return topography } diff --git a/encode.go b/encode.go index b161de1..e4aaf9b 100644 --- a/encode.go +++ b/encode.go @@ -2,6 +2,7 @@ package hnsw import ( "bufio" + "cmp" "encoding/binary" "fmt" "io" @@ -43,6 +44,16 @@ func binaryRead(r io.Reader, data interface{}) (int, error) { *v = string(s) return len(s), err + case *[]float32: + var ln int + _, err := binaryRead(r, &ln) + if err != nil { + return 0, err + } + + *v = make([]float32, ln) + return binary.Size(*v), binary.Read(r, byteOrder, *v) + case io.ReaderFrom: n, err := v.ReadFrom(r) return int(n), err @@ -73,6 +84,12 @@ func binaryWrite(w io.Writer, data any) (int, error) { } return n + n2, nil + case []float32: + n, err := binaryWrite(w, len(v)) + if err != nil { + return n, err + } + return n + binary.Size(v), binary.Write(w, byteOrder, v) default: sz := binary.Size(data) @@ -113,7 +130,7 @@ const encodingVersion = 1 // Export writes the graph to a writer. // // T must implement io.WriterTo. -func (h *Graph[T]) Export(w io.Writer) error { +func (h *Graph[K]) Export(w io.Writer) error { distFuncName, ok := distanceFuncToName(h.Distance) if !ok { return fmt.Errorf("distance function %v must be registered with RegisterDistanceFunc", h.Distance) @@ -134,24 +151,20 @@ func (h *Graph[T]) Export(w io.Writer) error { return fmt.Errorf("encode number of layers: %w", err) } for _, layer := range h.layers { - _, err = binaryWrite(w, len(layer.Nodes)) + _, err = binaryWrite(w, len(layer.nodes)) if err != nil { return fmt.Errorf("encode number of nodes: %w", err) } - for _, node := range layer.Nodes { - _, err = binaryWrite(w, node.Point) + for _, node := range layer.nodes { + _, err = multiBinaryWrite(w, node.Key, node.Value, len(node.neighbors)) if err != nil { - return fmt.Errorf("encode node point: %w", err) - } - - if _, err = binaryWrite(w, len(node.neighbors)); err != nil { - return fmt.Errorf("encode number of neighbors: %w", err) + return fmt.Errorf("encode node data: %w", err) } for neighbor := range node.neighbors { _, err = binaryWrite(w, neighbor) if err != nil { - return fmt.Errorf("encode neighbor %q: %w", neighbor, err) + return fmt.Errorf("encode neighbor %v: %w", neighbor, err) } } } @@ -164,7 +177,7 @@ func (h *Graph[T]) Export(w io.Writer) error { // T must implement io.ReaderFrom. // The imported graph does not have to match the exported graph's parameters (except for // dimensionality). The graph will converge onto the new parameters. -func (h *Graph[T]) Import(r io.Reader) error { +func (h *Graph[K]) Import(r io.Reader) error { var ( version int dist string @@ -195,7 +208,7 @@ func (h *Graph[T]) Import(r io.Reader) error { return err } - h.layers = make([]*layer[T], nLayers) + h.layers = make([]*layer[K], nLayers) for i := 0; i < nLayers; i++ { var nNodes int _, err = binaryRead(r, &nNodes) @@ -203,23 +216,19 @@ func (h *Graph[T]) Import(r io.Reader) error { return err } - nodes := make(map[string]*layerNode[T], nNodes) + nodes := make(map[K]*layerNode[K], nNodes) for j := 0; j < nNodes; j++ { - var point T - _, err = binaryRead(r, &point) - if err != nil { - return fmt.Errorf("decoding node %d: %w", j, err) - } - + var key K + var vec Vector var nNeighbors int - _, err = binaryRead(r, &nNeighbors) + _, err = multiBinaryRead(r, &key, &vec, &nNeighbors) if err != nil { - return fmt.Errorf("decoding number of neighbors for node %d: %w", j, err) + return fmt.Errorf("decoding node %d: %w", j, err) } - neighbors := make([]string, nNeighbors) + neighbors := make([]K, nNeighbors) for k := 0; k < nNeighbors; k++ { - var neighbor string + var neighbor K _, err = binaryRead(r, &neighbor) if err != nil { return fmt.Errorf("decoding neighbor %d for node %d: %w", k, j, err) @@ -227,23 +236,26 @@ func (h *Graph[T]) Import(r io.Reader) error { neighbors[k] = neighbor } - node := &layerNode[T]{ - Point: point, - neighbors: make(map[string]*layerNode[T]), + node := &layerNode[K]{ + Node: Node[K]{ + Key: key, + Value: vec, + }, + neighbors: make(map[K]*layerNode[K]), } - nodes[point.ID()] = node + nodes[key] = node for _, neighbor := range neighbors { node.neighbors[neighbor] = nil } } // Fill in neighbor pointers for _, node := range nodes { - for id := range node.neighbors { - node.neighbors[id] = nodes[id] + for key := range node.neighbors { + node.neighbors[key] = nodes[key] } } - h.layers[i] = &layer[T]{Nodes: nodes} + h.layers[i] = &layer[K]{nodes: nodes} } return nil @@ -253,8 +265,8 @@ func (h *Graph[T]) Import(r io.Reader) error { // changes to a file upon calls to Save. It is more convenient // but less powerful than calling Graph.Export and Graph.Import // directly. -type SavedGraph[T Embeddable] struct { - *Graph[T] +type SavedGraph[K cmp.Ordered] struct { + *Graph[K] Path string } @@ -265,7 +277,7 @@ type SavedGraph[T Embeddable] struct { // // It does not hold open a file descriptor, so SavedGraph can be forgotten // without ever calling Save. -func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) { +func LoadSavedGraph[K cmp.Ordered](path string) (*SavedGraph[K], error) { f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) if err != nil { return nil, err @@ -276,7 +288,7 @@ func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) { return nil, err } - g := NewGraph[T]() + g := NewGraph[K]() if info.Size() > 0 { err = g.Import(bufio.NewReader(f)) if err != nil { @@ -284,7 +296,7 @@ func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) { } } - return &SavedGraph[T]{Graph: g, Path: path}, nil + return &SavedGraph[K]{Graph: g, Path: path}, nil } // Save writes the graph to the file. diff --git a/encode_test.go b/encode_test.go index dead198..b19210e 100644 --- a/encode_test.go +++ b/encode_test.go @@ -2,8 +2,7 @@ package hnsw import ( "bytes" - "math/rand" - "strconv" + "cmp" "testing" "github.com/stretchr/testify/require" @@ -50,21 +49,21 @@ func Test_binaryWrite_string(t *testing.T) { require.Empty(t, buf.Bytes()) } -func verifyGraphNodes[T Embeddable](t *testing.T, g *Graph[T]) { +func verifyGraphNodes[K cmp.Ordered](t *testing.T, g *Graph[K]) { for _, layer := range g.layers { - for _, node := range layer.Nodes { + for _, node := range layer.nodes { for neighborKey, neighbor := range node.neighbors { - _, ok := layer.Nodes[neighbor.Point.ID()] + _, ok := layer.nodes[neighbor.Key] if !ok { t.Errorf( - "node %s has neighbor %s, but neighbor does not exist", - node.Point.ID(), neighbor.Point.ID(), + "node %v has neighbor %v, but neighbor does not exist", + node.Key, neighbor.Key, ) } - if neighborKey != neighbor.Point.ID() { - t.Errorf("node %s has neighbor %s, but neighbor key is %s", node.Point.ID(), - neighbor.Point.ID(), + if neighborKey != neighbor.Key { + t.Errorf("node %v has neighbor %v, but neighbor key is %v", node.Key, + neighbor.Key, neighborKey, ) } @@ -74,10 +73,10 @@ func verifyGraphNodes[T Embeddable](t *testing.T, g *Graph[T]) { } // requireGraphApproxEquals checks that two graphs are equal. -func requireGraphApproxEquals[T Embeddable](t *testing.T, g1, g2 *Graph[T]) { +func requireGraphApproxEquals[K cmp.Ordered](t *testing.T, g1, g2 *Graph[K]) { require.Equal(t, g1.Len(), g2.Len()) - a1 := Analyzer[T]{g1} - a2 := Analyzer[T]{g2} + a1 := Analyzer[K]{g1} + a2 := Analyzer[K]{g2} require.Equal( t, @@ -119,11 +118,13 @@ func requireGraphApproxEquals[T Embeddable](t *testing.T, g1, g2 *Graph[T]) { } func TestGraph_ExportImport(t *testing.T) { - rng := rand.New(rand.NewSource(0)) - - g1 := newTestGraph[Vector]() + g1 := newTestGraph[int]() for i := 0; i < 128; i++ { - g1.Add(MakeVector(strconv.Itoa(i), []float32{rng.Float32()})) + g1.Add( + Node[int]{ + i, randFloats(1), + }, + ) } buf := &bytes.Buffer{} @@ -132,7 +133,7 @@ func TestGraph_ExportImport(t *testing.T) { // Don't use newTestGraph to ensure parameters // are imported. - g2 := &Graph[Vector]{} + g2 := &Graph[int]{} err = g2.Import(buf) require.NoError(t, err) @@ -157,17 +158,21 @@ func TestGraph_ExportImport(t *testing.T) { func TestSavedGraph(t *testing.T) { dir := t.TempDir() - g1, err := LoadSavedGraph[Vector](dir + "/graph") + g1, err := LoadSavedGraph[int](dir + "/graph") require.NoError(t, err) require.Equal(t, 0, g1.Len()) for i := 0; i < 128; i++ { - g1.Add(MakeVector(strconv.Itoa(i), []float32{float32(i)})) + g1.Add( + Node[int]{ + i, randFloats(1), + }, + ) } err = g1.Save() require.NoError(t, err) - g2, err := LoadSavedGraph[Vector](dir + "/graph") + g2, err := LoadSavedGraph[int](dir + "/graph") require.NoError(t, err) requireGraphApproxEquals(t, g1.Graph, g2.Graph) @@ -177,9 +182,13 @@ const benchGraphSize = 100 func BenchmarkGraph_Import(b *testing.B) { b.ReportAllocs() - g := newTestGraph[Vector]() + g := newTestGraph[int]() for i := 0; i < benchGraphSize; i++ { - g.Add(MakeVector(strconv.Itoa(i), randFloats(100))) + g.Add( + Node[int]{ + i, randFloats(256), + }, + ) } buf := &bytes.Buffer{} @@ -192,7 +201,7 @@ func BenchmarkGraph_Import(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() rdr := bytes.NewReader(buf.Bytes()) - g := newTestGraph[Vector]() + g := newTestGraph[int]() b.StartTimer() err = g.Import(rdr) require.NoError(b, err) @@ -201,9 +210,13 @@ func BenchmarkGraph_Import(b *testing.B) { func BenchmarkGraph_Export(b *testing.B) { b.ReportAllocs() - g := newTestGraph[Vector]() + g := newTestGraph[int]() for i := 0; i < benchGraphSize; i++ { - g.Add(MakeVector(strconv.Itoa(i), randFloats(256))) + g.Add( + Node[int]{ + i, randFloats(256), + }, + ) } var buf bytes.Buffer diff --git a/example/readme/main.go b/example/readme/main.go index cc661ab..928ba9f 100644 --- a/example/readme/main.go +++ b/example/readme/main.go @@ -7,16 +7,16 @@ import ( ) func main() { - g := hnsw.NewGraph[hnsw.Vector]() + g := hnsw.NewGraph[int]() g.Add( - hnsw.MakeVector("1", []float32{1, 1, 1}), - hnsw.MakeVector("2", []float32{1, -1, 0.999}), - hnsw.MakeVector("3", []float32{1, 0, -0.5}), + hnsw.MakeNode(1, []float32{1, 1, 1}), + hnsw.MakeNode(2, []float32{1, -1, 0.999}), + hnsw.MakeNode(3, []float32{1, 0, -0.5}), ) neighbors := g.Search( []float32{0.5, 0.5, 0.5}, 1, ) - fmt.Printf("best friend: %v\n", neighbors[0].Embedding()) + fmt.Printf("best friend: %v\n", neighbors[0].Value) } diff --git a/graph.go b/graph.go index e9249a6..4cfe82a 100644 --- a/graph.go +++ b/graph.go @@ -1,6 +1,7 @@ package hnsw import ( + "cmp" "fmt" "math" "math/rand" @@ -11,34 +12,36 @@ import ( "golang.org/x/exp/maps" ) -type Embedding = []float32 +type Vector = []float32 -// Embeddable describes a type that can be embedded in a HNSW graph. -type Embeddable interface { - // ID returns a unique identifier for the object. - ID() string - // Embedding returns the embedding of the object. - // float32 is used for compatibility with OpenAI embeddings. - Embedding() Embedding +// Node is a node in the graph. +type Node[K cmp.Ordered] struct { + Key K + Value Vector +} + +func MakeNode[K cmp.Ordered](key K, vec Vector) Node[K] { + return Node[K]{Key: key, Value: vec} } // layerNode is a node in a layer of the graph. -type layerNode[T Embeddable] struct { - Point Embeddable - // neighbors is map of neighbor IDs to neighbor nodes. +type layerNode[K cmp.Ordered] struct { + Node[K] + + // neighbors is map of neighbor keys to neighbor nodes. // It is a map and not a slice to allow for efficient deletes, esp. // when M is high. - neighbors map[string]*layerNode[T] + neighbors map[K]*layerNode[K] } // addNeighbor adds a o neighbor to the node, replacing the neighbor // with the worst distance if the neighbor set is full. -func (n *layerNode[T]) addNeighbor(newNode *layerNode[T], m int, dist DistanceFunc) { +func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFunc) { if n.neighbors == nil { - n.neighbors = make(map[string]*layerNode[T], m) + n.neighbors = make(map[K]*layerNode[K], m) } - n.neighbors[newNode.Point.ID()] = newNode + n.neighbors[newNode.Key] = newNode if len(n.neighbors) <= m { return } @@ -46,10 +49,10 @@ func (n *layerNode[T]) addNeighbor(newNode *layerNode[T], m int, dist DistanceFu // Find the neighbor with the worst distance. var ( worstDist = float32(math.Inf(-1)) - worst *layerNode[T] + worst *layerNode[K] ) for _, neighbor := range n.neighbors { - d := dist(neighbor.Point.Embedding(), n.Point.Embedding()) + d := dist(neighbor.Value, n.Value) // d > worstDist may always be false if the distance function // returns NaN, e.g., when the embeddings are zero. if d > worstDist || worst == nil { @@ -58,49 +61,49 @@ func (n *layerNode[T]) addNeighbor(newNode *layerNode[T], m int, dist DistanceFu } } - delete(n.neighbors, worst.Point.ID()) + delete(n.neighbors, worst.Key) // Delete backlink from the worst neighbor. - delete(worst.neighbors, n.Point.ID()) + delete(worst.neighbors, n.Key) worst.replenish(m) } -type searchCandidate[T Embeddable] struct { - node *layerNode[T] +type searchCandidate[K cmp.Ordered] struct { + node *layerNode[K] dist float32 } -func (s searchCandidate[T]) Less(o searchCandidate[T]) bool { +func (s searchCandidate[K]) Less(o searchCandidate[K]) bool { return s.dist < o.dist } // search returns the layer node closest to the target node // within the same layer. -func (n *layerNode[T]) search( +func (n *layerNode[K]) search( // k is the number of candidates in the result set. k int, efSearch int, - target Embedding, + target Vector, distance DistanceFunc, -) []searchCandidate[T] { +) []searchCandidate[K] { // This is a basic greedy algorithm to find the entry point at the given level // that is closest to the target node. - candidates := heap.Heap[searchCandidate[T]]{} - candidates.Init(make([]searchCandidate[T], 0, efSearch)) + candidates := heap.Heap[searchCandidate[K]]{} + candidates.Init(make([]searchCandidate[K], 0, efSearch)) candidates.Push( - searchCandidate[T]{ + searchCandidate[K]{ node: n, - dist: distance(n.Point.Embedding(), target), + dist: distance(n.Value, target), }, ) var ( - result = heap.Heap[searchCandidate[T]]{} - visited = make(map[string]bool) + result = heap.Heap[searchCandidate[K]]{} + visited = make(map[K]bool) ) - result.Init(make([]searchCandidate[T], 0, k)) + result.Init(make([]searchCandidate[K], 0, k)) // Begin with the entry node in the result set. result.Push(candidates.Min()) - visited[n.Point.ID()] = true + visited[n.Key] = true for candidates.Len() > 0 { var ( @@ -110,25 +113,25 @@ func (n *layerNode[T]) search( // We iterate the map in a sorted, deterministic fashion for // tests. - neighborIDs := maps.Keys(current.neighbors) - slices.Sort(neighborIDs) - for _, neighborID := range neighborIDs { + neighborKeys := maps.Keys(current.neighbors) + slices.Sort(neighborKeys) + for _, neighborID := range neighborKeys { neighbor := current.neighbors[neighborID] if visited[neighborID] { continue } visited[neighborID] = true - dist := distance(neighbor.Point.Embedding(), target) + dist := distance(neighbor.Value, target) improved = improved || dist < result.Min().dist if result.Len() < k { - result.Push(searchCandidate[T]{node: neighbor, dist: dist}) + result.Push(searchCandidate[K]{node: neighbor, dist: dist}) } else if dist < result.Max().dist { result.PopLast() - result.Push(searchCandidate[T]{node: neighbor, dist: dist}) + result.Push(searchCandidate[K]{node: neighbor, dist: dist}) } - candidates.Push(searchCandidate[T]{node: neighbor, dist: dist}) + candidates.Push(searchCandidate[K]{node: neighbor, dist: dist}) // Always store candidates if we haven't reached the limit. if candidates.Len() > efSearch { candidates.PopLast() @@ -145,7 +148,7 @@ func (n *layerNode[T]) search( return result.Slice() } -func (n *layerNode[T]) replenish(m int) { +func (n *layerNode[K]) replenish(m int) { if len(n.neighbors) >= m { return } @@ -154,8 +157,8 @@ func (n *layerNode[T]) replenish(m int) { // This is a naive implementation that could be improved by // using a priority queue to find the best candidates. for _, neighbor := range n.neighbors { - for id, candidate := range neighbor.neighbors { - if _, ok := n.neighbors[id]; ok { + for key, candidate := range neighbor.neighbors { + if _, ok := n.neighbors[key]; ok { // do not add duplicates continue } @@ -172,46 +175,47 @@ func (n *layerNode[T]) replenish(m int) { // isolates remove the node from the graph by removing all connections // to neighbors. -func (n *layerNode[T]) isolate(m int) { +func (n *layerNode[K]) isolate(m int) { for _, neighbor := range n.neighbors { - delete(neighbor.neighbors, n.Point.ID()) + delete(neighbor.neighbors, n.Key) neighbor.replenish(m) } } -type layer[T Embeddable] struct { - // Nodes is a map of node IDs to Nodes. - // All Nodes in a higher layer are also in the lower layers, an essential +type layer[K cmp.Ordered] struct { + // nodes is a map of nodes IDs to nodes. + // All nodes in a higher layer are also in the lower layers, an essential // property of the graph. // - // Nodes is exported for interop with encoding/gob. - Nodes map[string]*layerNode[T] + // nodes is exported for interop with encoding/gob. + nodes map[K]*layerNode[K] } // entry returns the entry node of the layer. // It doesn't matter which node is returned, even that the // entry node is consistent, so we just return the first node // in the map to avoid tracking extra state. -func (l *layer[T]) entry() *layerNode[T] { +func (l *layer[K]) entry() *layerNode[K] { if l == nil { return nil } - for _, node := range l.Nodes { + for _, node := range l.nodes { return node } return nil } -func (l *layer[T]) size() int { +func (l *layer[K]) size() int { if l == nil { return 0 } - return len(l.Nodes) + return len(l.nodes) } // Graph is a Hierarchical Navigable Small World graph. // All public parameters must be set before adding nodes to the graph. -type Graph[T Embeddable] struct { +// K is cmp.Ordered instead of of comparable so that they can be sorted. +type Graph[K cmp.Ordered] struct { // Distance is the distance function used to compare embeddings. Distance DistanceFunc @@ -234,7 +238,7 @@ type Graph[T Embeddable] struct { EfSearch int // layers is a slice of layers in the graph. - layers []*layer[T] + layers []*layer[K] } func defaultRand() *rand.Rand { @@ -243,8 +247,8 @@ func defaultRand() *rand.Rand { // NewGraph returns a new graph with default parameters, roughly designed for // storing OpenAI embeddings. -func NewGraph[T Embeddable]() *Graph[T] { - return &Graph[T]{ +func NewGraph[K cmp.Ordered]() *Graph[K] { + return &Graph[K]{ M: 16, Ml: 0.25, Distance: CosineDistance, @@ -297,7 +301,7 @@ func (h *Graph[T]) randomLevel() int { return max } -func (g *Graph[T]) assertDims(n Embedding) { +func (g *Graph[T]) assertDims(n Vector) { if len(g.layers) == 0 { return } @@ -313,38 +317,48 @@ func (g *Graph[T]) Dims() int { if len(g.layers) == 0 { return 0 } - return len(g.layers[0].entry().Point.Embedding()) + return len(g.layers[0].entry().Value) +} + +func ptr[T any](v T) *T { + return &v } // Add inserts nodes into the graph. // If another node with the same ID exists, it is replaced. -func (g *Graph[T]) Add(nodes ...T) { - for _, n := range nodes { - g.assertDims(n.Embedding()) +func (g *Graph[K]) Add(nodes ...Node[K]) { + for _, node := range nodes { + key := node.Key + vec := node.Value + + g.assertDims(vec) insertLevel := g.randomLevel() // Create layers that don't exist yet. for insertLevel >= len(g.layers) { - g.layers = append(g.layers, &layer[T]{}) + g.layers = append(g.layers, &layer[K]{}) } if insertLevel < 0 { panic("invalid level") } - var elevator string + var elevator *K preLen := g.Len() // Insert node at each layer, beginning with the highest. for i := len(g.layers) - 1; i >= 0; i-- { layer := g.layers[i] - newNode := &layerNode[T]{ - Point: n, + newNode := &layerNode[K]{ + Node: Node[K]{ + Key: key, + Value: vec, + }, } // Insert the new node into the layer. if layer.entry() == nil { - layer.Nodes = map[string]*layerNode[T]{n.ID(): newNode} + layer.nodes = map[K]*layerNode[K]{key: newNode} continue } @@ -354,15 +368,15 @@ func (g *Graph[T]) Add(nodes ...T) { // On subsequent layers, we use the elevator node to enter the graph // at the best point. - if elevator != "" { - searchPoint = layer.Nodes[elevator] + if elevator != nil { + searchPoint = layer.nodes[*elevator] } if g.Distance == nil { panic("(*Graph).Distance must be set") } - neighborhood := searchPoint.search(g.M, g.EfSearch, n.Embedding(), g.Distance) + neighborhood := searchPoint.search(g.M, g.EfSearch, vec, g.Distance) if len(neighborhood) == 0 { // This should never happen because the searchPoint itself // should be in the result set. @@ -370,14 +384,14 @@ func (g *Graph[T]) Add(nodes ...T) { } // Re-set the elevator node for the next layer. - elevator = neighborhood[0].node.Point.ID() + elevator = ptr(neighborhood[0].node.Key) if insertLevel >= i { - if _, ok := layer.Nodes[n.ID()]; ok { - g.Delete(n.ID()) + if _, ok := layer.nodes[key]; ok { + g.Delete(key) } // Insert the new node into the layer. - layer.Nodes[n.ID()] = newNode + layer.nodes[key] = newNode for _, node := range neighborhood { // Create a bi-directional edge between the new node and the best node. node.node.addNeighbor(newNode, g.M, g.Distance) @@ -394,7 +408,7 @@ func (g *Graph[T]) Add(nodes ...T) { } // Search finds the k nearest neighbors from the target node. -func (h *Graph[T]) Search(near Embedding, k int) []T { +func (h *Graph[K]) Search(near Vector, k int) []Node[K] { h.assertDims(near) if len(h.layers) == 0 { return nil @@ -403,27 +417,27 @@ func (h *Graph[T]) Search(near Embedding, k int) []T { var ( efSearch = h.EfSearch - elevator string + elevator *K ) for layer := len(h.layers) - 1; layer >= 0; layer-- { searchPoint := h.layers[layer].entry() - if elevator != "" { - searchPoint = h.layers[layer].Nodes[elevator] + if elevator != nil { + searchPoint = h.layers[layer].nodes[*elevator] } // Descending hierarchies if layer > 0 { nodes := searchPoint.search(1, efSearch, near, h.Distance) - elevator = nodes[0].node.Point.ID() + elevator = ptr(nodes[0].node.Key) continue } nodes := searchPoint.search(k, efSearch, near, h.Distance) - out := make([]T, 0, len(nodes)) + out := make([]Node[K], 0, len(nodes)) for _, node := range nodes { - out = append(out, node.node.Point.(T)) + out = append(out, node.node.Node) } return out @@ -440,21 +454,21 @@ func (h *Graph[T]) Len() int { return h.layers[0].size() } -// Delete removes a node from the graph by ID. +// Delete removes a node from the graph by key. // It tries to preserve the clustering properties of the graph by // replenishing connectivity in the affected neighborhoods. -func (h *Graph[T]) Delete(id string) bool { +func (h *Graph[K]) Delete(key K) bool { if len(h.layers) == 0 { return false } var deleted bool for _, layer := range h.layers { - node, ok := layer.Nodes[id] + node, ok := layer.nodes[key] if !ok { continue } - delete(layer.Nodes, id) + delete(layer.nodes, key) node.isolate(h.M) deleted = true } @@ -462,12 +476,15 @@ func (h *Graph[T]) Delete(id string) bool { return deleted } -// Lookup returns the node with the given ID. -func (h *Graph[T]) Lookup(id string) (T, bool) { - var zero T +// Lookup returns the vector with the given key. +func (h *Graph[K]) Lookup(key K) (Vector, bool) { if len(h.layers) == 0 { - return zero, false + return nil, false } - return h.layers[0].Nodes[id].Point.(T), true + node, ok := h.layers[0].nodes[key] + if !ok { + return nil, false + } + return node.Value, ok } diff --git a/graph_test.go b/graph_test.go index cbbc584..d2a9cab 100644 --- a/graph_test.go +++ b/graph_test.go @@ -1,6 +1,7 @@ package hnsw import ( + "cmp" "math/rand" "strconv" "testing" @@ -18,34 +19,42 @@ func Test_maxLevel(t *testing.T) { require.Equal(t, 11, m) } -type basicPoint float32 - -func (n basicPoint) ID() string { - return strconv.FormatFloat(float64(n), 'f', -1, 32) -} - -func (n basicPoint) Embedding() []float32 { - return []float32{float32(n)} -} - func Test_layerNode_search(t *testing.T) { - entry := &layerNode[basicPoint]{ - Point: basicPoint(0), - neighbors: map[string]*layerNode[basicPoint]{ - "1": { - Point: basicPoint(1), + entry := &layerNode[int]{ + Node: Node[int]{ + Value: Vector{0}, + Key: 0, + }, + neighbors: map[int]*layerNode[int]{ + 1: { + Node: Node[int]{ + Value: Vector{1}, + Key: 1, + }, }, - "2": { - Point: basicPoint(2), + 2: { + Node: Node[int]{ + Value: Vector{2}, + Key: 2, + }, }, - "3": { - Point: basicPoint(3), - neighbors: map[string]*layerNode[basicPoint]{ - "3.8": { - Point: basicPoint(3.8), + 3: { + Node: Node[int]{ + Value: Vector{3}, + Key: 3, + }, + neighbors: map[int]*layerNode[int]{ + 4: { + Node: Node[int]{ + Value: Vector{4}, + Key: 5, + }, }, - "4.3": { - Point: basicPoint(4.3), + 5: { + Node: Node[int]{ + Value: Vector{5}, + Key: 5, + }, }, }, }, @@ -54,13 +63,13 @@ func Test_layerNode_search(t *testing.T) { best := entry.search(2, 4, []float32{4}, EuclideanDistance) - require.Equal(t, "3.8", best[0].node.Point.ID()) - require.Equal(t, "4.3", best[1].node.Point.ID()) + require.Equal(t, 5, best[0].node.Key) + require.Equal(t, 3, best[1].node.Key) require.Len(t, best, 2) } -func newTestGraph[T Embeddable]() *Graph[T] { - return &Graph[T]{ +func newTestGraph[K cmp.Ordered]() *Graph[K] { + return &Graph[K]{ M: 6, Distance: EuclideanDistance, Ml: 0.5, @@ -72,13 +81,18 @@ func newTestGraph[T Embeddable]() *Graph[T] { func TestGraph_AddSearch(t *testing.T) { t.Parallel() - g := newTestGraph[basicPoint]() + g := newTestGraph[int]() for i := 0; i < 128; i++ { - g.Add(basicPoint(float32(i))) + g.Add( + Node[int]{ + Key: i, + Value: Vector{float32(i)}, + }, + ) } - al := Analyzer[basicPoint]{Graph: g} + al := Analyzer[int]{Graph: g} // Layers should be approximately log2(128) = 7 // Look for an approximate doubling of the number of nodes in each layer. @@ -101,11 +115,11 @@ func TestGraph_AddSearch(t *testing.T) { require.Len(t, nearest, 4) require.EqualValues( t, - []basicPoint{ - (64), - (65), - (62), - (63), + []Node[int]{ + {64, Vector{64}}, + {65, Vector{65}}, + {62, Vector{62}}, + {63, Vector{63}}, }, nearest, ) @@ -114,19 +128,22 @@ func TestGraph_AddSearch(t *testing.T) { func TestGraph_AddDelete(t *testing.T) { t.Parallel() - g := newTestGraph[basicPoint]() + g := newTestGraph[int]() for i := 0; i < 128; i++ { - g.Add(basicPoint(i)) + g.Add(Node[int]{ + Key: i, + Value: Vector{float32(i)}, + }) } require.Equal(t, 128, g.Len()) - an := Analyzer[basicPoint]{Graph: g} + an := Analyzer[int]{Graph: g} preDeleteConnectivity := an.Connectivity() // Delete every even node. for i := 0; i < 128; i += 2 { - ok := g.Delete(basicPoint(i).ID()) + ok := g.Delete(i) require.True(t, ok) } @@ -141,7 +158,7 @@ func TestGraph_AddDelete(t *testing.T) { ) t.Run("DeleteNotFound", func(t *testing.T) { - ok := g.Delete("not found") + ok := g.Delete(-1) require.False(t, ok) }) } @@ -154,11 +171,14 @@ func Benchmark_HSNW(b *testing.B) { // Use this to ensure that complexity is O(log n) where n = h.Len(). for _, size := range sizes { b.Run(strconv.Itoa(size), func(b *testing.B) { - g := Graph[basicPoint]{} + g := Graph[int]{} g.Ml = 0.5 g.Distance = EuclideanDistance for i := 0; i < size; i++ { - g.Add(basicPoint(i)) + g.Add(Node[int]{ + Key: i, + Value: Vector{float32(i)}, + }) } b.ResetTimer() @@ -174,19 +194,6 @@ func Benchmark_HSNW(b *testing.B) { } } -type genericPoint struct { - id string - x []float32 -} - -func (n genericPoint) ID() string { - return n.id -} - -func (n genericPoint) Embedding() []float32 { - return n.x -} - func randFloats(n int) []float32 { x := make([]float32, n) for i := range x { @@ -198,11 +205,14 @@ func randFloats(n int) []float32 { func Benchmark_HNSW_1536(b *testing.B) { b.ReportAllocs() - g := newTestGraph[genericPoint]() + g := newTestGraph[int]() const size = 1000 - points := make([]genericPoint, size) + points := make([]Node[int], size) for i := 0; i < size; i++ { - points[i] = genericPoint{x: randFloats(1536), id: strconv.Itoa(i)} + points[i] = Node[int]{ + Key: i, + Value: Vector(randFloats(1536)), + } g.Add(points[i]) } b.ResetTimer() @@ -210,7 +220,7 @@ func Benchmark_HNSW_1536(b *testing.B) { b.Run("Search", func(b *testing.B) { for i := 0; i < b.N; i++ { g.Search( - points[i%size].x, + points[i%size].Value, 4, ) } @@ -218,11 +228,11 @@ func Benchmark_HNSW_1536(b *testing.B) { } func TestGraph_DefaultCosine(t *testing.T) { - g := NewGraph[Vector]() + g := NewGraph[int]() g.Add( - MakeVector("1", []float32{1, 1}), - MakeVector("2", []float32{0, 1}), - MakeVector("3", []float32{1, -1}), + Node[int]{Key: 1, Value: Vector{1, 1}}, + Node[int]{Key: 2, Value: Vector{0, 1}}, + Node[int]{Key: 3, Value: Vector{1, -1}}, ) neighbors := g.Search( @@ -232,8 +242,8 @@ func TestGraph_DefaultCosine(t *testing.T) { require.Equal( t, - []Vector{ - MakeVector("1", []float32{1, 1}), + []Node[int]{ + {1, Vector{1, 1}}, }, neighbors, ) diff --git a/vector.go b/vector.go deleted file mode 100644 index f4b8cae..0000000 --- a/vector.go +++ /dev/null @@ -1,53 +0,0 @@ -package hnsw - -import ( - "io" -) - -var _ Embeddable = Vector{} - -// Vector is a struct that holds an ID and an embedding -// and implements the Embeddable interface. -type Vector struct { - id string - embedding []float32 -} - -// MakeVector creates a new Vector with the given ID and embedding. -func MakeVector(id string, embedding []float32) Vector { - return Vector{ - id: id, - embedding: embedding, - } -} - -func (v Vector) ID() string { - return v.id -} - -func (v Vector) Embedding() []float32 { - return v.embedding -} - -func (v Vector) WriteTo(w io.Writer) (int64, error) { - n, err := multiBinaryWrite(w, v.id, len(v.embedding), v.embedding) - return int64(n), err -} - -func (v *Vector) ReadFrom(r io.Reader) (int64, error) { - var embLen int - n, err := multiBinaryRead(r, &v.id, &embLen) - if err != nil { - return int64(n), err - } - - v.embedding = make([]float32, embLen) - n, err = binaryRead(r, &v.embedding) - - return int64(n), err -} - -var ( - _ io.WriterTo = (*Vector)(nil) - _ io.ReaderFrom = (*Vector)(nil) -)