From d6543ed597552805806ff77e9547877d73804664 Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Tue, 28 May 2024 11:28:45 -0500 Subject: [PATCH 1/7] WIP: replace `string` key with `cmp.Ordered` generic --- graph.go | 49 +++++++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/graph.go b/graph.go index e9249a6..92c0ffa 100644 --- a/graph.go +++ b/graph.go @@ -1,6 +1,7 @@ package hnsw import ( + "cmp" "fmt" "math" "math/rand" @@ -14,28 +15,28 @@ import ( type Embedding = []float32 // Embeddable describes a type that can be embedded in a HNSW graph. -type Embeddable interface { +type Embeddable[K cmp.Ordered] interface { // ID returns a unique identifier for the object. - ID() string + ID() K // Embedding returns the embedding of the object. // float32 is used for compatibility with OpenAI embeddings. Embedding() Embedding } // layerNode is a node in a layer of the graph. -type layerNode[T Embeddable] struct { - Point Embeddable +type layerNode[K cmp.Ordered, V Embeddable[K]] struct { + Point Embeddable[K] // neighbors is map of neighbor IDs to neighbor nodes. // It is a map and not a slice to allow for efficient deletes, esp. // when M is high. - neighbors map[string]*layerNode[T] + neighbors map[K]*layerNode[K, V] } // addNeighbor adds a o neighbor to the node, replacing the neighbor // with the worst distance if the neighbor set is full. -func (n *layerNode[T]) addNeighbor(newNode *layerNode[T], m int, dist DistanceFunc) { +func (n *layerNode[K, V]) addNeighbor(newNode *layerNode[K, V], m int, dist DistanceFunc) { if n.neighbors == nil { - n.neighbors = make(map[string]*layerNode[T], m) + n.neighbors = make(map[K]*layerNode[K, V], m) } n.neighbors[newNode.Point.ID()] = newNode @@ -46,7 +47,7 @@ func (n *layerNode[T]) addNeighbor(newNode *layerNode[T], m int, dist DistanceFu // Find the neighbor with the worst distance. var ( worstDist = float32(math.Inf(-1)) - worst *layerNode[T] + worst *layerNode[K, V] ) for _, neighbor := range n.neighbors { d := dist(neighbor.Point.Embedding(), n.Point.Embedding()) @@ -64,39 +65,39 @@ func (n *layerNode[T]) addNeighbor(newNode *layerNode[T], m int, dist DistanceFu worst.replenish(m) } -type searchCandidate[T Embeddable] struct { - node *layerNode[T] +type searchCandidate[K cmp.Ordered, V Embeddable[K]] struct { + node *layerNode[K, V] dist float32 } -func (s searchCandidate[T]) Less(o searchCandidate[T]) bool { +func (s searchCandidate[K, V]) Less(o searchCandidate[K, V]) bool { return s.dist < o.dist } // search returns the layer node closest to the target node // within the same layer. -func (n *layerNode[T]) search( +func (n *layerNode[K, V]) search( // k is the number of candidates in the result set. k int, efSearch int, target Embedding, distance DistanceFunc, -) []searchCandidate[T] { +) []searchCandidate[K, V] { // This is a basic greedy algorithm to find the entry point at the given level // that is closest to the target node. - candidates := heap.Heap[searchCandidate[T]]{} - candidates.Init(make([]searchCandidate[T], 0, efSearch)) + candidates := heap.Heap[searchCandidate[K, V]]{} + candidates.Init(make([]searchCandidate[K, V], 0, efSearch)) candidates.Push( - searchCandidate[T]{ + searchCandidate[K, V]{ node: n, dist: distance(n.Point.Embedding(), target), }, ) var ( - result = heap.Heap[searchCandidate[T]]{} - visited = make(map[string]bool) + result = heap.Heap[searchCandidate[K, V]]{} + visited = make(map[K]bool) ) - result.Init(make([]searchCandidate[T], 0, k)) + result.Init(make([]searchCandidate[K, V], 0, k)) // Begin with the entry node in the result set. result.Push(candidates.Min()) @@ -122,13 +123,13 @@ func (n *layerNode[T]) search( dist := distance(neighbor.Point.Embedding(), target) improved = improved || dist < result.Min().dist if result.Len() < k { - result.Push(searchCandidate[T]{node: neighbor, dist: dist}) + result.Push(searchCandidate[K, V]{node: neighbor, dist: dist}) } else if dist < result.Max().dist { result.PopLast() - result.Push(searchCandidate[T]{node: neighbor, dist: dist}) + result.Push(searchCandidate[K, V]{node: neighbor, dist: dist}) } - candidates.Push(searchCandidate[T]{node: neighbor, dist: dist}) + candidates.Push(searchCandidate[K, V]{node: neighbor, dist: dist}) // Always store candidates if we haven't reached the limit. if candidates.Len() > efSearch { candidates.PopLast() @@ -145,7 +146,7 @@ func (n *layerNode[T]) search( return result.Slice() } -func (n *layerNode[T]) replenish(m int) { +func (n *layerNode[K, V]) replenish(m int) { if len(n.neighbors) >= m { return } @@ -172,7 +173,7 @@ func (n *layerNode[T]) replenish(m int) { // isolates remove the node from the graph by removing all connections // to neighbors. -func (n *layerNode[T]) isolate(m int) { +func (n *layerNode[K, V]) isolate(m int) { for _, neighbor := range n.neighbors { delete(neighbor.neighbors, n.Point.ID()) neighbor.replenish(m) From 241409c4a9918dcf516952dd21a395e8c165f5a4 Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Fri, 31 May 2024 11:57:34 -0500 Subject: [PATCH 2/7] WIP: rm Embeddable for real types --- encode.go | 4 +- graph.go | 105 +++++++++++++++++++++++--------------------------- graph_test.go | 12 +++--- vector.go | 53 ------------------------- 4 files changed, 57 insertions(+), 117 deletions(-) delete mode 100644 vector.go diff --git a/encode.go b/encode.go index b161de1..8966880 100644 --- a/encode.go +++ b/encode.go @@ -228,7 +228,7 @@ func (h *Graph[T]) Import(r io.Reader) error { } node := &layerNode[T]{ - Point: point, + vec: point, neighbors: make(map[string]*layerNode[T]), } @@ -243,7 +243,7 @@ func (h *Graph[T]) Import(r io.Reader) error { node.neighbors[id] = nodes[id] } } - h.layers[i] = &layer[T]{Nodes: nodes} + h.layers[i] = &layer[T]{nodes: nodes} } return nil diff --git a/graph.go b/graph.go index 92c0ffa..d7f0a29 100644 --- a/graph.go +++ b/graph.go @@ -12,34 +12,27 @@ import ( "golang.org/x/exp/maps" ) -type Embedding = []float32 - -// Embeddable describes a type that can be embedded in a HNSW graph. -type Embeddable[K cmp.Ordered] interface { - // ID returns a unique identifier for the object. - ID() K - // Embedding returns the embedding of the object. - // float32 is used for compatibility with OpenAI embeddings. - Embedding() Embedding -} +type Vector = []float32 // layerNode is a node in a layer of the graph. -type layerNode[K cmp.Ordered, V Embeddable[K]] struct { - Point Embeddable[K] +type layerNode[K cmp.Ordered] struct { + id K + vec Vector + // neighbors is map of neighbor IDs to neighbor nodes. // It is a map and not a slice to allow for efficient deletes, esp. // when M is high. - neighbors map[K]*layerNode[K, V] + neighbors map[K]*layerNode[K] } // addNeighbor adds a o neighbor to the node, replacing the neighbor // with the worst distance if the neighbor set is full. -func (n *layerNode[K, V]) addNeighbor(newNode *layerNode[K, V], m int, dist DistanceFunc) { +func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFunc) { if n.neighbors == nil { - n.neighbors = make(map[K]*layerNode[K, V], m) + n.neighbors = make(map[K]*layerNode[K], m) } - n.neighbors[newNode.Point.ID()] = newNode + n.neighbors[newNode.id] = newNode if len(n.neighbors) <= m { return } @@ -47,10 +40,10 @@ func (n *layerNode[K, V]) addNeighbor(newNode *layerNode[K, V], m int, dist Dist // Find the neighbor with the worst distance. var ( worstDist = float32(math.Inf(-1)) - worst *layerNode[K, V] + worst *layerNode[K] ) for _, neighbor := range n.neighbors { - d := dist(neighbor.Point.Embedding(), n.Point.Embedding()) + d := dist(neighbor.vec, n.vec) // d > worstDist may always be false if the distance function // returns NaN, e.g., when the embeddings are zero. if d > worstDist || worst == nil { @@ -59,49 +52,49 @@ func (n *layerNode[K, V]) addNeighbor(newNode *layerNode[K, V], m int, dist Dist } } - delete(n.neighbors, worst.Point.ID()) + delete(n.neighbors, worst.id) // Delete backlink from the worst neighbor. - delete(worst.neighbors, n.Point.ID()) + delete(worst.neighbors, n.id) worst.replenish(m) } -type searchCandidate[K cmp.Ordered, V Embeddable[K]] struct { - node *layerNode[K, V] +type searchCandidate[K cmp.Ordered] struct { + node *layerNode[K] dist float32 } -func (s searchCandidate[K, V]) Less(o searchCandidate[K, V]) bool { +func (s searchCandidate[K]) Less(o searchCandidate[K]) bool { return s.dist < o.dist } // search returns the layer node closest to the target node // within the same layer. -func (n *layerNode[K, V]) search( +func (n *layerNode[K]) search( // k is the number of candidates in the result set. k int, efSearch int, - target Embedding, + target Vector, distance DistanceFunc, -) []searchCandidate[K, V] { +) []searchCandidate[K] { // This is a basic greedy algorithm to find the entry point at the given level // that is closest to the target node. - candidates := heap.Heap[searchCandidate[K, V]]{} - candidates.Init(make([]searchCandidate[K, V], 0, efSearch)) + candidates := heap.Heap[searchCandidate[K]]{} + candidates.Init(make([]searchCandidate[K], 0, efSearch)) candidates.Push( - searchCandidate[K, V]{ + searchCandidate[K]{ node: n, - dist: distance(n.Point.Embedding(), target), + dist: distance(n.vec, target), }, ) var ( - result = heap.Heap[searchCandidate[K, V]]{} + result = heap.Heap[searchCandidate[K]]{} visited = make(map[K]bool) ) - result.Init(make([]searchCandidate[K, V], 0, k)) + result.Init(make([]searchCandidate[K], 0, k)) // Begin with the entry node in the result set. result.Push(candidates.Min()) - visited[n.Point.ID()] = true + visited[n.id] = true for candidates.Len() > 0 { var ( @@ -120,16 +113,16 @@ func (n *layerNode[K, V]) search( } visited[neighborID] = true - dist := distance(neighbor.Point.Embedding(), target) + dist := distance(neighbor.vec, target) improved = improved || dist < result.Min().dist if result.Len() < k { - result.Push(searchCandidate[K, V]{node: neighbor, dist: dist}) + result.Push(searchCandidate[K]{node: neighbor, dist: dist}) } else if dist < result.Max().dist { result.PopLast() - result.Push(searchCandidate[K, V]{node: neighbor, dist: dist}) + result.Push(searchCandidate[K]{node: neighbor, dist: dist}) } - candidates.Push(searchCandidate[K, V]{node: neighbor, dist: dist}) + candidates.Push(searchCandidate[K]{node: neighbor, dist: dist}) // Always store candidates if we haven't reached the limit. if candidates.Len() > efSearch { candidates.PopLast() @@ -146,7 +139,7 @@ func (n *layerNode[K, V]) search( return result.Slice() } -func (n *layerNode[K, V]) replenish(m int) { +func (n *layerNode[K]) replenish(m int) { if len(n.neighbors) >= m { return } @@ -173,46 +166,46 @@ func (n *layerNode[K, V]) replenish(m int) { // isolates remove the node from the graph by removing all connections // to neighbors. -func (n *layerNode[K, V]) isolate(m int) { +func (n *layerNode[K]) isolate(m int) { for _, neighbor := range n.neighbors { - delete(neighbor.neighbors, n.Point.ID()) + delete(neighbor.neighbors, n.id) neighbor.replenish(m) } } -type layer[T Embeddable] struct { - // Nodes is a map of node IDs to Nodes. - // All Nodes in a higher layer are also in the lower layers, an essential +type layer[K cmp.Ordered] struct { + // nodes is a map of nodes IDs to nodes. + // All nodes in a higher layer are also in the lower layers, an essential // property of the graph. // - // Nodes is exported for interop with encoding/gob. - Nodes map[string]*layerNode[T] + // nodes is exported for interop with encoding/gob. + nodes map[string]*layerNode[K] } // entry returns the entry node of the layer. // It doesn't matter which node is returned, even that the // entry node is consistent, so we just return the first node // in the map to avoid tracking extra state. -func (l *layer[T]) entry() *layerNode[T] { +func (l *layer[K]) entry() *layerNode[K] { if l == nil { return nil } - for _, node := range l.Nodes { + for _, node := range l.nodes { return node } return nil } -func (l *layer[T]) size() int { +func (l *layer[K]) size() int { if l == nil { return 0 } - return len(l.Nodes) + return len(l.nodes) } // Graph is a Hierarchical Navigable Small World graph. // All public parameters must be set before adding nodes to the graph. -type Graph[T Embeddable] struct { +type Graph[K cmp.Ordered] struct { // Distance is the distance function used to compare embeddings. Distance DistanceFunc @@ -235,7 +228,7 @@ type Graph[T Embeddable] struct { EfSearch int // layers is a slice of layers in the graph. - layers []*layer[T] + layers []*layer[K] } func defaultRand() *rand.Rand { @@ -244,8 +237,8 @@ func defaultRand() *rand.Rand { // NewGraph returns a new graph with default parameters, roughly designed for // storing OpenAI embeddings. -func NewGraph[T Embeddable]() *Graph[T] { - return &Graph[T]{ +func NewGraph[K cmp.Ordered, V Embeddable[K]]() *Graph[K, V] { + return &Graph[K, V]{ M: 16, Ml: 0.25, Distance: CosineDistance, @@ -298,7 +291,7 @@ func (h *Graph[T]) randomLevel() int { return max } -func (g *Graph[T]) assertDims(n Embedding) { +func (g *Graph[T]) assertDims(n Vector) { if len(g.layers) == 0 { return } @@ -340,7 +333,7 @@ func (g *Graph[T]) Add(nodes ...T) { for i := len(g.layers) - 1; i >= 0; i-- { layer := g.layers[i] newNode := &layerNode[T]{ - Point: n, + vec: n, } // Insert the new node into the layer. @@ -395,7 +388,7 @@ func (g *Graph[T]) Add(nodes ...T) { } // Search finds the k nearest neighbors from the target node. -func (h *Graph[T]) Search(near Embedding, k int) []T { +func (h *Graph[T]) Search(near Vector, k int) []T { h.assertDims(near) if len(h.layers) == 0 { return nil diff --git a/graph_test.go b/graph_test.go index cbbc584..688e052 100644 --- a/graph_test.go +++ b/graph_test.go @@ -30,22 +30,22 @@ func (n basicPoint) Embedding() []float32 { func Test_layerNode_search(t *testing.T) { entry := &layerNode[basicPoint]{ - Point: basicPoint(0), + vec: basicPoint(0), neighbors: map[string]*layerNode[basicPoint]{ "1": { - Point: basicPoint(1), + vec: basicPoint(1), }, "2": { - Point: basicPoint(2), + vec: basicPoint(2), }, "3": { - Point: basicPoint(3), + vec: basicPoint(3), neighbors: map[string]*layerNode[basicPoint]{ "3.8": { - Point: basicPoint(3.8), + vec: basicPoint(3.8), }, "4.3": { - Point: basicPoint(4.3), + vec: basicPoint(4.3), }, }, }, diff --git a/vector.go b/vector.go deleted file mode 100644 index f4b8cae..0000000 --- a/vector.go +++ /dev/null @@ -1,53 +0,0 @@ -package hnsw - -import ( - "io" -) - -var _ Embeddable = Vector{} - -// Vector is a struct that holds an ID and an embedding -// and implements the Embeddable interface. -type Vector struct { - id string - embedding []float32 -} - -// MakeVector creates a new Vector with the given ID and embedding. -func MakeVector(id string, embedding []float32) Vector { - return Vector{ - id: id, - embedding: embedding, - } -} - -func (v Vector) ID() string { - return v.id -} - -func (v Vector) Embedding() []float32 { - return v.embedding -} - -func (v Vector) WriteTo(w io.Writer) (int64, error) { - n, err := multiBinaryWrite(w, v.id, len(v.embedding), v.embedding) - return int64(n), err -} - -func (v *Vector) ReadFrom(r io.Reader) (int64, error) { - var embLen int - n, err := multiBinaryRead(r, &v.id, &embLen) - if err != nil { - return int64(n), err - } - - v.embedding = make([]float32, embLen) - n, err = binaryRead(r, &v.embedding) - - return int64(n), err -} - -var ( - _ io.WriterTo = (*Vector)(nil) - _ io.ReaderFrom = (*Vector)(nil) -) From 4875617857535f53328b02e06e11189044c15477 Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Fri, 31 May 2024 13:24:22 -0500 Subject: [PATCH 3/7] WIP: graph.go compiles! --- graph.go | 106 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 62 insertions(+), 44 deletions(-) diff --git a/graph.go b/graph.go index d7f0a29..2c9604a 100644 --- a/graph.go +++ b/graph.go @@ -14,10 +14,15 @@ import ( type Vector = []float32 +// Node is a node in the graph. +type Node[K cmp.Ordered] struct { + ID K + Vec Vector +} + // layerNode is a node in a layer of the graph. type layerNode[K cmp.Ordered] struct { - id K - vec Vector + Node[K] // neighbors is map of neighbor IDs to neighbor nodes. // It is a map and not a slice to allow for efficient deletes, esp. @@ -32,7 +37,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu n.neighbors = make(map[K]*layerNode[K], m) } - n.neighbors[newNode.id] = newNode + n.neighbors[newNode.ID] = newNode if len(n.neighbors) <= m { return } @@ -43,7 +48,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu worst *layerNode[K] ) for _, neighbor := range n.neighbors { - d := dist(neighbor.vec, n.vec) + d := dist(neighbor.Vec, n.Vec) // d > worstDist may always be false if the distance function // returns NaN, e.g., when the embeddings are zero. if d > worstDist || worst == nil { @@ -52,9 +57,9 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu } } - delete(n.neighbors, worst.id) + delete(n.neighbors, worst.ID) // Delete backlink from the worst neighbor. - delete(worst.neighbors, n.id) + delete(worst.neighbors, n.ID) worst.replenish(m) } @@ -83,7 +88,7 @@ func (n *layerNode[K]) search( candidates.Push( searchCandidate[K]{ node: n, - dist: distance(n.vec, target), + dist: distance(n.Vec, target), }, ) var ( @@ -94,7 +99,7 @@ func (n *layerNode[K]) search( // Begin with the entry node in the result set. result.Push(candidates.Min()) - visited[n.id] = true + visited[n.ID] = true for candidates.Len() > 0 { var ( @@ -113,7 +118,7 @@ func (n *layerNode[K]) search( } visited[neighborID] = true - dist := distance(neighbor.vec, target) + dist := distance(neighbor.Vec, target) improved = improved || dist < result.Min().dist if result.Len() < k { result.Push(searchCandidate[K]{node: neighbor, dist: dist}) @@ -168,7 +173,7 @@ func (n *layerNode[K]) replenish(m int) { // to neighbors. func (n *layerNode[K]) isolate(m int) { for _, neighbor := range n.neighbors { - delete(neighbor.neighbors, n.id) + delete(neighbor.neighbors, n.ID) neighbor.replenish(m) } } @@ -179,7 +184,7 @@ type layer[K cmp.Ordered] struct { // property of the graph. // // nodes is exported for interop with encoding/gob. - nodes map[string]*layerNode[K] + nodes map[K]*layerNode[K] } // entry returns the entry node of the layer. @@ -237,8 +242,8 @@ func defaultRand() *rand.Rand { // NewGraph returns a new graph with default parameters, roughly designed for // storing OpenAI embeddings. -func NewGraph[K cmp.Ordered, V Embeddable[K]]() *Graph[K, V] { - return &Graph[K, V]{ +func NewGraph[K cmp.Ordered]() *Graph[K] { + return &Graph[K]{ M: 16, Ml: 0.25, Distance: CosineDistance, @@ -307,38 +312,48 @@ func (g *Graph[T]) Dims() int { if len(g.layers) == 0 { return 0 } - return len(g.layers[0].entry().Point.Embedding()) + return len(g.layers[0].entry().Vec) +} + +func ptr[T any](v T) *T { + return &v } // Add inserts nodes into the graph. // If another node with the same ID exists, it is replaced. -func (g *Graph[T]) Add(nodes ...T) { - for _, n := range nodes { - g.assertDims(n.Embedding()) +func (g *Graph[K]) Add(nodes ...Node[K]) { + for _, node := range nodes { + id := node.ID + vec := node.Vec + + g.assertDims(vec) insertLevel := g.randomLevel() // Create layers that don't exist yet. for insertLevel >= len(g.layers) { - g.layers = append(g.layers, &layer[T]{}) + g.layers = append(g.layers, &layer[K]{}) } if insertLevel < 0 { panic("invalid level") } - var elevator string + var elevator *K preLen := g.Len() // Insert node at each layer, beginning with the highest. for i := len(g.layers) - 1; i >= 0; i-- { layer := g.layers[i] - newNode := &layerNode[T]{ - vec: n, + newNode := &layerNode[K]{ + Node: Node[K]{ + ID: id, + Vec: vec, + }, } // Insert the new node into the layer. if layer.entry() == nil { - layer.Nodes = map[string]*layerNode[T]{n.ID(): newNode} + layer.nodes = map[K]*layerNode[K]{id: newNode} continue } @@ -348,15 +363,15 @@ func (g *Graph[T]) Add(nodes ...T) { // On subsequent layers, we use the elevator node to enter the graph // at the best point. - if elevator != "" { - searchPoint = layer.Nodes[elevator] + if elevator != nil { + searchPoint = layer.nodes[*elevator] } if g.Distance == nil { panic("(*Graph).Distance must be set") } - neighborhood := searchPoint.search(g.M, g.EfSearch, n.Embedding(), g.Distance) + neighborhood := searchPoint.search(g.M, g.EfSearch, vec, g.Distance) if len(neighborhood) == 0 { // This should never happen because the searchPoint itself // should be in the result set. @@ -364,14 +379,14 @@ func (g *Graph[T]) Add(nodes ...T) { } // Re-set the elevator node for the next layer. - elevator = neighborhood[0].node.Point.ID() + elevator = ptr(neighborhood[0].node.ID) if insertLevel >= i { - if _, ok := layer.Nodes[n.ID()]; ok { - g.Delete(n.ID()) + if _, ok := layer.nodes[id]; ok { + g.Delete(id) } // Insert the new node into the layer. - layer.Nodes[n.ID()] = newNode + layer.nodes[id] = newNode for _, node := range neighborhood { // Create a bi-directional edge between the new node and the best node. node.node.addNeighbor(newNode, g.M, g.Distance) @@ -388,7 +403,7 @@ func (g *Graph[T]) Add(nodes ...T) { } // Search finds the k nearest neighbors from the target node. -func (h *Graph[T]) Search(near Vector, k int) []T { +func (h *Graph[K]) Search(near Vector, k int) []Node[K] { h.assertDims(near) if len(h.layers) == 0 { return nil @@ -397,27 +412,27 @@ func (h *Graph[T]) Search(near Vector, k int) []T { var ( efSearch = h.EfSearch - elevator string + elevator *K ) for layer := len(h.layers) - 1; layer >= 0; layer-- { searchPoint := h.layers[layer].entry() - if elevator != "" { - searchPoint = h.layers[layer].Nodes[elevator] + if elevator != nil { + searchPoint = h.layers[layer].nodes[*elevator] } // Descending hierarchies if layer > 0 { nodes := searchPoint.search(1, efSearch, near, h.Distance) - elevator = nodes[0].node.Point.ID() + elevator = ptr(nodes[0].node.ID) continue } nodes := searchPoint.search(k, efSearch, near, h.Distance) - out := make([]T, 0, len(nodes)) + out := make([]Node[K], 0, len(nodes)) for _, node := range nodes { - out = append(out, node.node.Point.(T)) + out = append(out, node.node.Node) } return out @@ -437,18 +452,18 @@ func (h *Graph[T]) Len() int { // Delete removes a node from the graph by ID. // It tries to preserve the clustering properties of the graph by // replenishing connectivity in the affected neighborhoods. -func (h *Graph[T]) Delete(id string) bool { +func (h *Graph[K]) Delete(id K) bool { if len(h.layers) == 0 { return false } var deleted bool for _, layer := range h.layers { - node, ok := layer.Nodes[id] + node, ok := layer.nodes[id] if !ok { continue } - delete(layer.Nodes, id) + delete(layer.nodes, id) node.isolate(h.M) deleted = true } @@ -456,12 +471,15 @@ func (h *Graph[T]) Delete(id string) bool { return deleted } -// Lookup returns the node with the given ID. -func (h *Graph[T]) Lookup(id string) (T, bool) { - var zero T +// Lookup returns the vector with the given ID. +func (h *Graph[K]) Lookup(id K) (Vector, bool) { if len(h.layers) == 0 { - return zero, false + return nil, false } - return h.layers[0].Nodes[id].Point.(T), true + node, ok := h.layers[0].nodes[id] + if !ok { + return nil, false + } + return node.Vec, ok } From 52ae936a25467f3a24771f650ff36d44d9997d03 Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Fri, 31 May 2024 13:37:27 -0500 Subject: [PATCH 4/7] WIP: graph_test.go compiles! --- analyzer.go | 14 ++--- graph_test.go | 142 +++++++++++++++++++++++++++----------------------- 2 files changed, 84 insertions(+), 72 deletions(-) diff --git a/analyzer.go b/analyzer.go index c527739..c57bb18 100644 --- a/analyzer.go +++ b/analyzer.go @@ -1,11 +1,13 @@ package hnsw +import "cmp" + // Analyzer is a struct that holds a graph and provides // methods for analyzing it. It offers no compatibility guarantee // as the methods of measuring the graph's health with change // with the implementation. -type Analyzer[T Embeddable] struct { - Graph *Graph[T] +type Analyzer[K cmp.Ordered] struct { + Graph *Graph[K] } func (a *Analyzer[T]) Height() int { @@ -17,16 +19,16 @@ func (a *Analyzer[T]) Height() int { func (a *Analyzer[T]) Connectivity() []float64 { var layerConnectivity []float64 for _, layer := range a.Graph.layers { - if len(layer.Nodes) == 0 { + if len(layer.nodes) == 0 { continue } var sum float64 - for _, node := range layer.Nodes { + for _, node := range layer.nodes { sum += float64(len(node.neighbors)) } - layerConnectivity = append(layerConnectivity, sum/float64(len(layer.Nodes))) + layerConnectivity = append(layerConnectivity, sum/float64(len(layer.nodes))) } return layerConnectivity @@ -36,7 +38,7 @@ func (a *Analyzer[T]) Connectivity() []float64 { func (a *Analyzer[T]) Topography() []int { var topography []int for _, layer := range a.Graph.layers { - topography = append(topography, len(layer.Nodes)) + topography = append(topography, len(layer.nodes)) } return topography } diff --git a/graph_test.go b/graph_test.go index 688e052..c14c586 100644 --- a/graph_test.go +++ b/graph_test.go @@ -1,6 +1,7 @@ package hnsw import ( + "cmp" "math/rand" "strconv" "testing" @@ -18,34 +19,42 @@ func Test_maxLevel(t *testing.T) { require.Equal(t, 11, m) } -type basicPoint float32 - -func (n basicPoint) ID() string { - return strconv.FormatFloat(float64(n), 'f', -1, 32) -} - -func (n basicPoint) Embedding() []float32 { - return []float32{float32(n)} -} - func Test_layerNode_search(t *testing.T) { - entry := &layerNode[basicPoint]{ - vec: basicPoint(0), - neighbors: map[string]*layerNode[basicPoint]{ - "1": { - vec: basicPoint(1), + entry := &layerNode[int]{ + Node: Node[int]{ + Vec: Vector{0}, + ID: 0, + }, + neighbors: map[int]*layerNode[int]{ + 1: { + Node: Node[int]{ + Vec: Vector{1}, + ID: 1, + }, }, - "2": { - vec: basicPoint(2), + 2: { + Node: Node[int]{ + Vec: Vector{2}, + ID: 2, + }, }, - "3": { - vec: basicPoint(3), - neighbors: map[string]*layerNode[basicPoint]{ - "3.8": { - vec: basicPoint(3.8), + 3: { + Node: Node[int]{ + Vec: Vector{3}, + ID: 3, + }, + neighbors: map[int]*layerNode[int]{ + 4: { + Node: Node[int]{ + Vec: Vector{4}, + ID: 5, + }, }, - "4.3": { - vec: basicPoint(4.3), + 5: { + Node: Node[int]{ + Vec: Vector{5}, + ID: 5, + }, }, }, }, @@ -54,13 +63,13 @@ func Test_layerNode_search(t *testing.T) { best := entry.search(2, 4, []float32{4}, EuclideanDistance) - require.Equal(t, "3.8", best[0].node.Point.ID()) - require.Equal(t, "4.3", best[1].node.Point.ID()) + require.Equal(t, 4, best[0].node.ID) + require.Equal(t, 3, best[1].node.ID) require.Len(t, best, 2) } -func newTestGraph[T Embeddable]() *Graph[T] { - return &Graph[T]{ +func newTestGraph[K cmp.Ordered]() *Graph[K] { + return &Graph[K]{ M: 6, Distance: EuclideanDistance, Ml: 0.5, @@ -72,13 +81,18 @@ func newTestGraph[T Embeddable]() *Graph[T] { func TestGraph_AddSearch(t *testing.T) { t.Parallel() - g := newTestGraph[basicPoint]() + g := newTestGraph[int]() for i := 0; i < 128; i++ { - g.Add(basicPoint(float32(i))) + g.Add( + Node[int]{ + ID: i, + Vec: Vector{float32(i)}, + }, + ) } - al := Analyzer[basicPoint]{Graph: g} + al := Analyzer[int]{Graph: g} // Layers should be approximately log2(128) = 7 // Look for an approximate doubling of the number of nodes in each layer. @@ -101,11 +115,11 @@ func TestGraph_AddSearch(t *testing.T) { require.Len(t, nearest, 4) require.EqualValues( t, - []basicPoint{ - (64), - (65), - (62), - (63), + []Node[int]{ + {64, Vector{64}}, + {65, Vector{65}}, + {62, Vector{62}}, + {63, Vector{63}}, }, nearest, ) @@ -114,19 +128,22 @@ func TestGraph_AddSearch(t *testing.T) { func TestGraph_AddDelete(t *testing.T) { t.Parallel() - g := newTestGraph[basicPoint]() + g := newTestGraph[int]() for i := 0; i < 128; i++ { - g.Add(basicPoint(i)) + g.Add(Node[int]{ + ID: i, + Vec: Vector{float32(i)}, + }) } require.Equal(t, 128, g.Len()) - an := Analyzer[basicPoint]{Graph: g} + an := Analyzer[int]{Graph: g} preDeleteConnectivity := an.Connectivity() // Delete every even node. for i := 0; i < 128; i += 2 { - ok := g.Delete(basicPoint(i).ID()) + ok := g.Delete(i) require.True(t, ok) } @@ -141,7 +158,7 @@ func TestGraph_AddDelete(t *testing.T) { ) t.Run("DeleteNotFound", func(t *testing.T) { - ok := g.Delete("not found") + ok := g.Delete(-1) require.False(t, ok) }) } @@ -154,11 +171,14 @@ func Benchmark_HSNW(b *testing.B) { // Use this to ensure that complexity is O(log n) where n = h.Len(). for _, size := range sizes { b.Run(strconv.Itoa(size), func(b *testing.B) { - g := Graph[basicPoint]{} + g := Graph[int]{} g.Ml = 0.5 g.Distance = EuclideanDistance for i := 0; i < size; i++ { - g.Add(basicPoint(i)) + g.Add(Node[int]{ + ID: i, + Vec: Vector{float32(i)}, + }) } b.ResetTimer() @@ -174,19 +194,6 @@ func Benchmark_HSNW(b *testing.B) { } } -type genericPoint struct { - id string - x []float32 -} - -func (n genericPoint) ID() string { - return n.id -} - -func (n genericPoint) Embedding() []float32 { - return n.x -} - func randFloats(n int) []float32 { x := make([]float32, n) for i := range x { @@ -198,11 +205,14 @@ func randFloats(n int) []float32 { func Benchmark_HNSW_1536(b *testing.B) { b.ReportAllocs() - g := newTestGraph[genericPoint]() + g := newTestGraph[int]() const size = 1000 - points := make([]genericPoint, size) + points := make([]Node[int], size) for i := 0; i < size; i++ { - points[i] = genericPoint{x: randFloats(1536), id: strconv.Itoa(i)} + points[i] = Node[int]{ + ID: i, + Vec: Vector(randFloats(1536)), + } g.Add(points[i]) } b.ResetTimer() @@ -210,7 +220,7 @@ func Benchmark_HNSW_1536(b *testing.B) { b.Run("Search", func(b *testing.B) { for i := 0; i < b.N; i++ { g.Search( - points[i%size].x, + points[i%size].Vec, 4, ) } @@ -218,11 +228,11 @@ func Benchmark_HNSW_1536(b *testing.B) { } func TestGraph_DefaultCosine(t *testing.T) { - g := NewGraph[Vector]() + g := NewGraph[int]() g.Add( - MakeVector("1", []float32{1, 1}), - MakeVector("2", []float32{0, 1}), - MakeVector("3", []float32{1, -1}), + Node[int]{ID: 1, Vec: Vector{1, 1}}, + Node[int]{ID: 2, Vec: Vector{0, 1}}, + Node[int]{ID: 3, Vec: Vector{1, -1}}, ) neighbors := g.Search( @@ -232,8 +242,8 @@ func TestGraph_DefaultCosine(t *testing.T) { require.Equal( t, - []Vector{ - MakeVector("1", []float32{1, 1}), + []Node[int]{ + {1, Vector{1, 1}}, }, neighbors, ) From 0236acbac9eed247734db68f360eb12d1dc6b719 Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Fri, 31 May 2024 13:55:29 -0500 Subject: [PATCH 5/7] Tests pass! --- encode.go | 78 ++++++++++++++++++++++++------------------ encode_test.go | 65 +++++++++++++++++++++-------------- example/readme/main.go | 10 +++--- graph.go | 4 +++ graph_test.go | 2 +- 5 files changed, 94 insertions(+), 65 deletions(-) diff --git a/encode.go b/encode.go index 8966880..08c43d7 100644 --- a/encode.go +++ b/encode.go @@ -2,6 +2,7 @@ package hnsw import ( "bufio" + "cmp" "encoding/binary" "fmt" "io" @@ -43,6 +44,16 @@ func binaryRead(r io.Reader, data interface{}) (int, error) { *v = string(s) return len(s), err + case *[]float32: + var ln int + _, err := binaryRead(r, &ln) + if err != nil { + return 0, err + } + + *v = make([]float32, ln) + return binary.Size(*v), binary.Read(r, byteOrder, *v) + case io.ReaderFrom: n, err := v.ReadFrom(r) return int(n), err @@ -73,6 +84,12 @@ func binaryWrite(w io.Writer, data any) (int, error) { } return n + n2, nil + case []float32: + n, err := binaryWrite(w, len(v)) + if err != nil { + return n, err + } + return n + binary.Size(v), binary.Write(w, byteOrder, v) default: sz := binary.Size(data) @@ -113,7 +130,7 @@ const encodingVersion = 1 // Export writes the graph to a writer. // // T must implement io.WriterTo. -func (h *Graph[T]) Export(w io.Writer) error { +func (h *Graph[K]) Export(w io.Writer) error { distFuncName, ok := distanceFuncToName(h.Distance) if !ok { return fmt.Errorf("distance function %v must be registered with RegisterDistanceFunc", h.Distance) @@ -134,24 +151,20 @@ func (h *Graph[T]) Export(w io.Writer) error { return fmt.Errorf("encode number of layers: %w", err) } for _, layer := range h.layers { - _, err = binaryWrite(w, len(layer.Nodes)) + _, err = binaryWrite(w, len(layer.nodes)) if err != nil { return fmt.Errorf("encode number of nodes: %w", err) } - for _, node := range layer.Nodes { - _, err = binaryWrite(w, node.Point) + for _, node := range layer.nodes { + _, err = multiBinaryWrite(w, node.ID, node.Vec, len(node.neighbors)) if err != nil { - return fmt.Errorf("encode node point: %w", err) - } - - if _, err = binaryWrite(w, len(node.neighbors)); err != nil { - return fmt.Errorf("encode number of neighbors: %w", err) + return fmt.Errorf("encode node data: %w", err) } for neighbor := range node.neighbors { _, err = binaryWrite(w, neighbor) if err != nil { - return fmt.Errorf("encode neighbor %q: %w", neighbor, err) + return fmt.Errorf("encode neighbor %v: %w", neighbor, err) } } } @@ -164,7 +177,7 @@ func (h *Graph[T]) Export(w io.Writer) error { // T must implement io.ReaderFrom. // The imported graph does not have to match the exported graph's parameters (except for // dimensionality). The graph will converge onto the new parameters. -func (h *Graph[T]) Import(r io.Reader) error { +func (h *Graph[K]) Import(r io.Reader) error { var ( version int dist string @@ -195,7 +208,7 @@ func (h *Graph[T]) Import(r io.Reader) error { return err } - h.layers = make([]*layer[T], nLayers) + h.layers = make([]*layer[K], nLayers) for i := 0; i < nLayers; i++ { var nNodes int _, err = binaryRead(r, &nNodes) @@ -203,23 +216,19 @@ func (h *Graph[T]) Import(r io.Reader) error { return err } - nodes := make(map[string]*layerNode[T], nNodes) + nodes := make(map[K]*layerNode[K], nNodes) for j := 0; j < nNodes; j++ { - var point T - _, err = binaryRead(r, &point) - if err != nil { - return fmt.Errorf("decoding node %d: %w", j, err) - } - + var id K + var vec Vector var nNeighbors int - _, err = binaryRead(r, &nNeighbors) + _, err = multiBinaryRead(r, &id, &vec, &nNeighbors) if err != nil { - return fmt.Errorf("decoding number of neighbors for node %d: %w", j, err) + return fmt.Errorf("decoding node %d: %w", j, err) } - neighbors := make([]string, nNeighbors) + neighbors := make([]K, nNeighbors) for k := 0; k < nNeighbors; k++ { - var neighbor string + var neighbor K _, err = binaryRead(r, &neighbor) if err != nil { return fmt.Errorf("decoding neighbor %d for node %d: %w", k, j, err) @@ -227,12 +236,15 @@ func (h *Graph[T]) Import(r io.Reader) error { neighbors[k] = neighbor } - node := &layerNode[T]{ - vec: point, - neighbors: make(map[string]*layerNode[T]), + node := &layerNode[K]{ + Node: Node[K]{ + ID: id, + Vec: vec, + }, + neighbors: make(map[K]*layerNode[K]), } - nodes[point.ID()] = node + nodes[id] = node for _, neighbor := range neighbors { node.neighbors[neighbor] = nil } @@ -243,7 +255,7 @@ func (h *Graph[T]) Import(r io.Reader) error { node.neighbors[id] = nodes[id] } } - h.layers[i] = &layer[T]{nodes: nodes} + h.layers[i] = &layer[K]{nodes: nodes} } return nil @@ -253,8 +265,8 @@ func (h *Graph[T]) Import(r io.Reader) error { // changes to a file upon calls to Save. It is more convenient // but less powerful than calling Graph.Export and Graph.Import // directly. -type SavedGraph[T Embeddable] struct { - *Graph[T] +type SavedGraph[K cmp.Ordered] struct { + *Graph[K] Path string } @@ -265,7 +277,7 @@ type SavedGraph[T Embeddable] struct { // // It does not hold open a file descriptor, so SavedGraph can be forgotten // without ever calling Save. -func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) { +func LoadSavedGraph[K cmp.Ordered](path string) (*SavedGraph[K], error) { f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) if err != nil { return nil, err @@ -276,7 +288,7 @@ func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) { return nil, err } - g := NewGraph[T]() + g := NewGraph[K]() if info.Size() > 0 { err = g.Import(bufio.NewReader(f)) if err != nil { @@ -284,7 +296,7 @@ func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) { } } - return &SavedGraph[T]{Graph: g, Path: path}, nil + return &SavedGraph[K]{Graph: g, Path: path}, nil } // Save writes the graph to the file. diff --git a/encode_test.go b/encode_test.go index dead198..cb5e38e 100644 --- a/encode_test.go +++ b/encode_test.go @@ -2,8 +2,7 @@ package hnsw import ( "bytes" - "math/rand" - "strconv" + "cmp" "testing" "github.com/stretchr/testify/require" @@ -50,21 +49,21 @@ func Test_binaryWrite_string(t *testing.T) { require.Empty(t, buf.Bytes()) } -func verifyGraphNodes[T Embeddable](t *testing.T, g *Graph[T]) { +func verifyGraphNodes[K cmp.Ordered](t *testing.T, g *Graph[K]) { for _, layer := range g.layers { - for _, node := range layer.Nodes { + for _, node := range layer.nodes { for neighborKey, neighbor := range node.neighbors { - _, ok := layer.Nodes[neighbor.Point.ID()] + _, ok := layer.nodes[neighbor.ID] if !ok { t.Errorf( - "node %s has neighbor %s, but neighbor does not exist", - node.Point.ID(), neighbor.Point.ID(), + "node %v has neighbor %v, but neighbor does not exist", + node.ID, neighbor.ID, ) } - if neighborKey != neighbor.Point.ID() { - t.Errorf("node %s has neighbor %s, but neighbor key is %s", node.Point.ID(), - neighbor.Point.ID(), + if neighborKey != neighbor.ID { + t.Errorf("node %v has neighbor %v, but neighbor key is %v", node.ID, + neighbor.ID, neighborKey, ) } @@ -74,10 +73,10 @@ func verifyGraphNodes[T Embeddable](t *testing.T, g *Graph[T]) { } // requireGraphApproxEquals checks that two graphs are equal. -func requireGraphApproxEquals[T Embeddable](t *testing.T, g1, g2 *Graph[T]) { +func requireGraphApproxEquals[K cmp.Ordered](t *testing.T, g1, g2 *Graph[K]) { require.Equal(t, g1.Len(), g2.Len()) - a1 := Analyzer[T]{g1} - a2 := Analyzer[T]{g2} + a1 := Analyzer[K]{g1} + a2 := Analyzer[K]{g2} require.Equal( t, @@ -119,11 +118,13 @@ func requireGraphApproxEquals[T Embeddable](t *testing.T, g1, g2 *Graph[T]) { } func TestGraph_ExportImport(t *testing.T) { - rng := rand.New(rand.NewSource(0)) - - g1 := newTestGraph[Vector]() + g1 := newTestGraph[int]() for i := 0; i < 128; i++ { - g1.Add(MakeVector(strconv.Itoa(i), []float32{rng.Float32()})) + g1.Add( + Node[int]{ + i, randFloats(1), + }, + ) } buf := &bytes.Buffer{} @@ -132,7 +133,7 @@ func TestGraph_ExportImport(t *testing.T) { // Don't use newTestGraph to ensure parameters // are imported. - g2 := &Graph[Vector]{} + g2 := &Graph[int]{} err = g2.Import(buf) require.NoError(t, err) @@ -157,17 +158,21 @@ func TestGraph_ExportImport(t *testing.T) { func TestSavedGraph(t *testing.T) { dir := t.TempDir() - g1, err := LoadSavedGraph[Vector](dir + "/graph") + g1, err := LoadSavedGraph[int](dir + "/graph") require.NoError(t, err) require.Equal(t, 0, g1.Len()) for i := 0; i < 128; i++ { - g1.Add(MakeVector(strconv.Itoa(i), []float32{float32(i)})) + g1.Add( + Node[int]{ + i, randFloats(1), + }, + ) } err = g1.Save() require.NoError(t, err) - g2, err := LoadSavedGraph[Vector](dir + "/graph") + g2, err := LoadSavedGraph[int](dir + "/graph") require.NoError(t, err) requireGraphApproxEquals(t, g1.Graph, g2.Graph) @@ -177,9 +182,13 @@ const benchGraphSize = 100 func BenchmarkGraph_Import(b *testing.B) { b.ReportAllocs() - g := newTestGraph[Vector]() + g := newTestGraph[int]() for i := 0; i < benchGraphSize; i++ { - g.Add(MakeVector(strconv.Itoa(i), randFloats(100))) + g.Add( + Node[int]{ + i, randFloats(256), + }, + ) } buf := &bytes.Buffer{} @@ -192,7 +201,7 @@ func BenchmarkGraph_Import(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() rdr := bytes.NewReader(buf.Bytes()) - g := newTestGraph[Vector]() + g := newTestGraph[int]() b.StartTimer() err = g.Import(rdr) require.NoError(b, err) @@ -201,9 +210,13 @@ func BenchmarkGraph_Import(b *testing.B) { func BenchmarkGraph_Export(b *testing.B) { b.ReportAllocs() - g := newTestGraph[Vector]() + g := newTestGraph[int]() for i := 0; i < benchGraphSize; i++ { - g.Add(MakeVector(strconv.Itoa(i), randFloats(256))) + g.Add( + Node[int]{ + i, randFloats(256), + }, + ) } var buf bytes.Buffer diff --git a/example/readme/main.go b/example/readme/main.go index cc661ab..6207ec7 100644 --- a/example/readme/main.go +++ b/example/readme/main.go @@ -7,16 +7,16 @@ import ( ) func main() { - g := hnsw.NewGraph[hnsw.Vector]() + g := hnsw.NewGraph[int]() g.Add( - hnsw.MakeVector("1", []float32{1, 1, 1}), - hnsw.MakeVector("2", []float32{1, -1, 0.999}), - hnsw.MakeVector("3", []float32{1, 0, -0.5}), + hnsw.MakeNode(1, []float32{1, 1, 1}), + hnsw.MakeNode(2, []float32{1, -1, 0.999}), + hnsw.MakeNode(3, []float32{1, 0, -0.5}), ) neighbors := g.Search( []float32{0.5, 0.5, 0.5}, 1, ) - fmt.Printf("best friend: %v\n", neighbors[0].Embedding()) + fmt.Printf("best friend: %v\n", neighbors[0].Vec) } diff --git a/graph.go b/graph.go index 2c9604a..3057c37 100644 --- a/graph.go +++ b/graph.go @@ -20,6 +20,10 @@ type Node[K cmp.Ordered] struct { Vec Vector } +func MakeNode[K cmp.Ordered](id K, vec Vector) Node[K] { + return Node[K]{ID: id, Vec: vec} +} + // layerNode is a node in a layer of the graph. type layerNode[K cmp.Ordered] struct { Node[K] diff --git a/graph_test.go b/graph_test.go index c14c586..01ace54 100644 --- a/graph_test.go +++ b/graph_test.go @@ -63,7 +63,7 @@ func Test_layerNode_search(t *testing.T) { best := entry.search(2, 4, []float32{4}, EuclideanDistance) - require.Equal(t, 4, best[0].node.ID) + require.Equal(t, 5, best[0].node.ID) require.Equal(t, 3, best[1].node.ID) require.Len(t, best, 2) } From f00e907804c11f1c0938fd8043af2b6218561a86 Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Fri, 31 May 2024 13:58:19 -0500 Subject: [PATCH 6/7] Update README for new API --- README.md | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index ed00d93..d664a28 100644 --- a/README.md +++ b/README.md @@ -32,18 +32,18 @@ go get github.com/coder/hnsw@main ``` ```go -g := hnsw.NewGraph[hnsw.Vector]() +g := hnsw.NewGraph[int]() g.Add( - hnsw.MakeVector("1", []float32{1, 1, 1}), - hnsw.MakeVector("2", []float32{1, -1, 0.999}), - hnsw.MakeVector("3", []float32{1, 0, -0.5}), + hnsw.MakeNode(1, []float32{1, 1, 1}), + hnsw.MakeNode(2, []float32{1, -1, 0.999}), + hnsw.MakeNode(3, []float32{1, 0, -0.5}), ) neighbors := g.Search( []float32{0.5, 0.5, 0.5}, 1, ) -fmt.Printf("best friend: %v\n", neighbors[0].Embedding()) +fmt.Printf("best friend: %v\n", neighbors[0].Vec) // Output: best friend: [1 1 1] ``` @@ -59,13 +59,13 @@ If you're using a single file as the backend, hnsw provides a convenient `SavedG ```go path := "some.graph" -g1, err := LoadSavedGraph[hnsw.Vector](path) +g1, err := LoadSavedGraph[int](path) if err != nil { panic(err) } // Insert some vectors for i := 0; i < 128; i++ { - g1.Add(MakeVector(strconv.Itoa(i), []float32{float32(i)})) + g1.Add(hnsw.MakeNode(i, []float32{float32(i)})) } // Save to disk @@ -76,7 +76,7 @@ if err != nil { // Later... // g2 is a copy of g1 -g2, err := LoadSavedGraph[Vector](path) +g2, err := LoadSavedGraph[int](path) if err != nil { panic(err) } @@ -94,10 +94,10 @@ nearly at disk speed. On my M3 Macbook I get these benchmark results: goos: darwin goarch: arm64 pkg: github.com/coder/hnsw -BenchmarkGraph_Import-16 2733 369803 ns/op 228.65 MB/s 352041 B/op 9880 allocs/op -BenchmarkGraph_Export-16 6046 194441 ns/op 1076.65 MB/s 261854 B/op 3760 allocs/op +BenchmarkGraph_Import-16 4029 259927 ns/op 796.85 MB/s 496022 B/op 3212 allocs/op +BenchmarkGraph_Export-16 7042 168028 ns/op 1232.49 MB/s 239886 B/op 2388 allocs/op PASS -ok github.com/coder/hnsw 2.530s +ok github.com/coder/hnsw 2.624s ``` when saving/loading a graph of 100 vectors with 256 dimensions. From ac9deaf635ca5d466b6a0c75a22afb2b4e5bc28a Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Fri, 14 Jun 2024 11:23:23 -0500 Subject: [PATCH 7/7] Update variable name ID to key --- README.md | 10 +++--- encode.go | 16 ++++----- encode_test.go | 10 +++--- example/readme/main.go | 2 +- graph.go | 75 +++++++++++++++++++++--------------------- graph_test.go | 52 ++++++++++++++--------------- 6 files changed, 83 insertions(+), 82 deletions(-) diff --git a/README.md b/README.md index d664a28..4023467 100644 --- a/README.md +++ b/README.md @@ -130,18 +130,18 @@ $$ where: * $n$ is the number of vectors in the graph -* $\text{size(id)}$ is the average size of the ID in bytes +* $\text{size(key)}$ is the average size of the key in bytes * $M$ is the maximum number of neighbors each node can have * $d$ is the dimensionality of the vectors * $mem_{graph}$ is the memory used by the graph structure across all layers * $mem_{base}$ is the memory used by the vectors themselves in the base or 0th layer You can infer that: -* Connectivity ($M$) is very expensive if IDs are large -* If $d \cdot 4$ is far larger than $M \cdot \text{size(id)}$, you should expect linear memory usage spent on representing vector data -* If $d \cdot 4$ is far smaller than $M \cdot \text{size(id)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure +* Connectivity ($M$) is very expensive if keys are large +* If $d \cdot 4$ is far larger than $M \cdot \text{size(key)}$, you should expect linear memory usage spent on representing vector data +* If $d \cdot 4$ is far smaller than $M \cdot \text{size(key)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure -In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte IDs, you would see that each vector takes: +In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte keys, you would see that each vector takes: * $256 \cdot 4 = 1024$ data bytes * $16 \cdot 8 = 128$ metadata bytes diff --git a/encode.go b/encode.go index 08c43d7..e4aaf9b 100644 --- a/encode.go +++ b/encode.go @@ -156,7 +156,7 @@ func (h *Graph[K]) Export(w io.Writer) error { return fmt.Errorf("encode number of nodes: %w", err) } for _, node := range layer.nodes { - _, err = multiBinaryWrite(w, node.ID, node.Vec, len(node.neighbors)) + _, err = multiBinaryWrite(w, node.Key, node.Value, len(node.neighbors)) if err != nil { return fmt.Errorf("encode node data: %w", err) } @@ -218,10 +218,10 @@ func (h *Graph[K]) Import(r io.Reader) error { nodes := make(map[K]*layerNode[K], nNodes) for j := 0; j < nNodes; j++ { - var id K + var key K var vec Vector var nNeighbors int - _, err = multiBinaryRead(r, &id, &vec, &nNeighbors) + _, err = multiBinaryRead(r, &key, &vec, &nNeighbors) if err != nil { return fmt.Errorf("decoding node %d: %w", j, err) } @@ -238,21 +238,21 @@ func (h *Graph[K]) Import(r io.Reader) error { node := &layerNode[K]{ Node: Node[K]{ - ID: id, - Vec: vec, + Key: key, + Value: vec, }, neighbors: make(map[K]*layerNode[K]), } - nodes[id] = node + nodes[key] = node for _, neighbor := range neighbors { node.neighbors[neighbor] = nil } } // Fill in neighbor pointers for _, node := range nodes { - for id := range node.neighbors { - node.neighbors[id] = nodes[id] + for key := range node.neighbors { + node.neighbors[key] = nodes[key] } } h.layers[i] = &layer[K]{nodes: nodes} diff --git a/encode_test.go b/encode_test.go index cb5e38e..b19210e 100644 --- a/encode_test.go +++ b/encode_test.go @@ -53,17 +53,17 @@ func verifyGraphNodes[K cmp.Ordered](t *testing.T, g *Graph[K]) { for _, layer := range g.layers { for _, node := range layer.nodes { for neighborKey, neighbor := range node.neighbors { - _, ok := layer.nodes[neighbor.ID] + _, ok := layer.nodes[neighbor.Key] if !ok { t.Errorf( "node %v has neighbor %v, but neighbor does not exist", - node.ID, neighbor.ID, + node.Key, neighbor.Key, ) } - if neighborKey != neighbor.ID { - t.Errorf("node %v has neighbor %v, but neighbor key is %v", node.ID, - neighbor.ID, + if neighborKey != neighbor.Key { + t.Errorf("node %v has neighbor %v, but neighbor key is %v", node.Key, + neighbor.Key, neighborKey, ) } diff --git a/example/readme/main.go b/example/readme/main.go index 6207ec7..928ba9f 100644 --- a/example/readme/main.go +++ b/example/readme/main.go @@ -18,5 +18,5 @@ func main() { []float32{0.5, 0.5, 0.5}, 1, ) - fmt.Printf("best friend: %v\n", neighbors[0].Vec) + fmt.Printf("best friend: %v\n", neighbors[0].Value) } diff --git a/graph.go b/graph.go index 3057c37..4cfe82a 100644 --- a/graph.go +++ b/graph.go @@ -16,19 +16,19 @@ type Vector = []float32 // Node is a node in the graph. type Node[K cmp.Ordered] struct { - ID K - Vec Vector + Key K + Value Vector } -func MakeNode[K cmp.Ordered](id K, vec Vector) Node[K] { - return Node[K]{ID: id, Vec: vec} +func MakeNode[K cmp.Ordered](key K, vec Vector) Node[K] { + return Node[K]{Key: key, Value: vec} } // layerNode is a node in a layer of the graph. type layerNode[K cmp.Ordered] struct { Node[K] - // neighbors is map of neighbor IDs to neighbor nodes. + // neighbors is map of neighbor keys to neighbor nodes. // It is a map and not a slice to allow for efficient deletes, esp. // when M is high. neighbors map[K]*layerNode[K] @@ -41,7 +41,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu n.neighbors = make(map[K]*layerNode[K], m) } - n.neighbors[newNode.ID] = newNode + n.neighbors[newNode.Key] = newNode if len(n.neighbors) <= m { return } @@ -52,7 +52,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu worst *layerNode[K] ) for _, neighbor := range n.neighbors { - d := dist(neighbor.Vec, n.Vec) + d := dist(neighbor.Value, n.Value) // d > worstDist may always be false if the distance function // returns NaN, e.g., when the embeddings are zero. if d > worstDist || worst == nil { @@ -61,9 +61,9 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu } } - delete(n.neighbors, worst.ID) + delete(n.neighbors, worst.Key) // Delete backlink from the worst neighbor. - delete(worst.neighbors, n.ID) + delete(worst.neighbors, n.Key) worst.replenish(m) } @@ -92,7 +92,7 @@ func (n *layerNode[K]) search( candidates.Push( searchCandidate[K]{ node: n, - dist: distance(n.Vec, target), + dist: distance(n.Value, target), }, ) var ( @@ -103,7 +103,7 @@ func (n *layerNode[K]) search( // Begin with the entry node in the result set. result.Push(candidates.Min()) - visited[n.ID] = true + visited[n.Key] = true for candidates.Len() > 0 { var ( @@ -113,16 +113,16 @@ func (n *layerNode[K]) search( // We iterate the map in a sorted, deterministic fashion for // tests. - neighborIDs := maps.Keys(current.neighbors) - slices.Sort(neighborIDs) - for _, neighborID := range neighborIDs { + neighborKeys := maps.Keys(current.neighbors) + slices.Sort(neighborKeys) + for _, neighborID := range neighborKeys { neighbor := current.neighbors[neighborID] if visited[neighborID] { continue } visited[neighborID] = true - dist := distance(neighbor.Vec, target) + dist := distance(neighbor.Value, target) improved = improved || dist < result.Min().dist if result.Len() < k { result.Push(searchCandidate[K]{node: neighbor, dist: dist}) @@ -157,8 +157,8 @@ func (n *layerNode[K]) replenish(m int) { // This is a naive implementation that could be improved by // using a priority queue to find the best candidates. for _, neighbor := range n.neighbors { - for id, candidate := range neighbor.neighbors { - if _, ok := n.neighbors[id]; ok { + for key, candidate := range neighbor.neighbors { + if _, ok := n.neighbors[key]; ok { // do not add duplicates continue } @@ -177,7 +177,7 @@ func (n *layerNode[K]) replenish(m int) { // to neighbors. func (n *layerNode[K]) isolate(m int) { for _, neighbor := range n.neighbors { - delete(neighbor.neighbors, n.ID) + delete(neighbor.neighbors, n.Key) neighbor.replenish(m) } } @@ -214,6 +214,7 @@ func (l *layer[K]) size() int { // Graph is a Hierarchical Navigable Small World graph. // All public parameters must be set before adding nodes to the graph. +// K is cmp.Ordered instead of of comparable so that they can be sorted. type Graph[K cmp.Ordered] struct { // Distance is the distance function used to compare embeddings. Distance DistanceFunc @@ -316,7 +317,7 @@ func (g *Graph[T]) Dims() int { if len(g.layers) == 0 { return 0 } - return len(g.layers[0].entry().Vec) + return len(g.layers[0].entry().Value) } func ptr[T any](v T) *T { @@ -327,8 +328,8 @@ func ptr[T any](v T) *T { // If another node with the same ID exists, it is replaced. func (g *Graph[K]) Add(nodes ...Node[K]) { for _, node := range nodes { - id := node.ID - vec := node.Vec + key := node.Key + vec := node.Value g.assertDims(vec) insertLevel := g.randomLevel() @@ -350,14 +351,14 @@ func (g *Graph[K]) Add(nodes ...Node[K]) { layer := g.layers[i] newNode := &layerNode[K]{ Node: Node[K]{ - ID: id, - Vec: vec, + Key: key, + Value: vec, }, } // Insert the new node into the layer. if layer.entry() == nil { - layer.nodes = map[K]*layerNode[K]{id: newNode} + layer.nodes = map[K]*layerNode[K]{key: newNode} continue } @@ -383,14 +384,14 @@ func (g *Graph[K]) Add(nodes ...Node[K]) { } // Re-set the elevator node for the next layer. - elevator = ptr(neighborhood[0].node.ID) + elevator = ptr(neighborhood[0].node.Key) if insertLevel >= i { - if _, ok := layer.nodes[id]; ok { - g.Delete(id) + if _, ok := layer.nodes[key]; ok { + g.Delete(key) } // Insert the new node into the layer. - layer.nodes[id] = newNode + layer.nodes[key] = newNode for _, node := range neighborhood { // Create a bi-directional edge between the new node and the best node. node.node.addNeighbor(newNode, g.M, g.Distance) @@ -428,7 +429,7 @@ func (h *Graph[K]) Search(near Vector, k int) []Node[K] { // Descending hierarchies if layer > 0 { nodes := searchPoint.search(1, efSearch, near, h.Distance) - elevator = ptr(nodes[0].node.ID) + elevator = ptr(nodes[0].node.Key) continue } @@ -453,21 +454,21 @@ func (h *Graph[T]) Len() int { return h.layers[0].size() } -// Delete removes a node from the graph by ID. +// Delete removes a node from the graph by key. // It tries to preserve the clustering properties of the graph by // replenishing connectivity in the affected neighborhoods. -func (h *Graph[K]) Delete(id K) bool { +func (h *Graph[K]) Delete(key K) bool { if len(h.layers) == 0 { return false } var deleted bool for _, layer := range h.layers { - node, ok := layer.nodes[id] + node, ok := layer.nodes[key] if !ok { continue } - delete(layer.nodes, id) + delete(layer.nodes, key) node.isolate(h.M) deleted = true } @@ -475,15 +476,15 @@ func (h *Graph[K]) Delete(id K) bool { return deleted } -// Lookup returns the vector with the given ID. -func (h *Graph[K]) Lookup(id K) (Vector, bool) { +// Lookup returns the vector with the given key. +func (h *Graph[K]) Lookup(key K) (Vector, bool) { if len(h.layers) == 0 { return nil, false } - node, ok := h.layers[0].nodes[id] + node, ok := h.layers[0].nodes[key] if !ok { return nil, false } - return node.Vec, ok + return node.Value, ok } diff --git a/graph_test.go b/graph_test.go index 01ace54..d2a9cab 100644 --- a/graph_test.go +++ b/graph_test.go @@ -22,38 +22,38 @@ func Test_maxLevel(t *testing.T) { func Test_layerNode_search(t *testing.T) { entry := &layerNode[int]{ Node: Node[int]{ - Vec: Vector{0}, - ID: 0, + Value: Vector{0}, + Key: 0, }, neighbors: map[int]*layerNode[int]{ 1: { Node: Node[int]{ - Vec: Vector{1}, - ID: 1, + Value: Vector{1}, + Key: 1, }, }, 2: { Node: Node[int]{ - Vec: Vector{2}, - ID: 2, + Value: Vector{2}, + Key: 2, }, }, 3: { Node: Node[int]{ - Vec: Vector{3}, - ID: 3, + Value: Vector{3}, + Key: 3, }, neighbors: map[int]*layerNode[int]{ 4: { Node: Node[int]{ - Vec: Vector{4}, - ID: 5, + Value: Vector{4}, + Key: 5, }, }, 5: { Node: Node[int]{ - Vec: Vector{5}, - ID: 5, + Value: Vector{5}, + Key: 5, }, }, }, @@ -63,8 +63,8 @@ func Test_layerNode_search(t *testing.T) { best := entry.search(2, 4, []float32{4}, EuclideanDistance) - require.Equal(t, 5, best[0].node.ID) - require.Equal(t, 3, best[1].node.ID) + require.Equal(t, 5, best[0].node.Key) + require.Equal(t, 3, best[1].node.Key) require.Len(t, best, 2) } @@ -86,8 +86,8 @@ func TestGraph_AddSearch(t *testing.T) { for i := 0; i < 128; i++ { g.Add( Node[int]{ - ID: i, - Vec: Vector{float32(i)}, + Key: i, + Value: Vector{float32(i)}, }, ) } @@ -131,8 +131,8 @@ func TestGraph_AddDelete(t *testing.T) { g := newTestGraph[int]() for i := 0; i < 128; i++ { g.Add(Node[int]{ - ID: i, - Vec: Vector{float32(i)}, + Key: i, + Value: Vector{float32(i)}, }) } @@ -176,8 +176,8 @@ func Benchmark_HSNW(b *testing.B) { g.Distance = EuclideanDistance for i := 0; i < size; i++ { g.Add(Node[int]{ - ID: i, - Vec: Vector{float32(i)}, + Key: i, + Value: Vector{float32(i)}, }) } b.ResetTimer() @@ -210,8 +210,8 @@ func Benchmark_HNSW_1536(b *testing.B) { points := make([]Node[int], size) for i := 0; i < size; i++ { points[i] = Node[int]{ - ID: i, - Vec: Vector(randFloats(1536)), + Key: i, + Value: Vector(randFloats(1536)), } g.Add(points[i]) } @@ -220,7 +220,7 @@ func Benchmark_HNSW_1536(b *testing.B) { b.Run("Search", func(b *testing.B) { for i := 0; i < b.N; i++ { g.Search( - points[i%size].Vec, + points[i%size].Value, 4, ) } @@ -230,9 +230,9 @@ func Benchmark_HNSW_1536(b *testing.B) { func TestGraph_DefaultCosine(t *testing.T) { g := NewGraph[int]() g.Add( - Node[int]{ID: 1, Vec: Vector{1, 1}}, - Node[int]{ID: 2, Vec: Vector{0, 1}}, - Node[int]{ID: 3, Vec: Vector{1, -1}}, + Node[int]{Key: 1, Value: Vector{1, 1}}, + Node[int]{Key: 2, Value: Vector{0, 1}}, + Node[int]{Key: 3, Value: Vector{1, -1}}, ) neighbors := g.Search(