diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java new file mode 100644 index 0000000000..1a40dfedbb --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java @@ -0,0 +1,101 @@ +/* + * AbstractNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import java.util.List; + +/** + * An abstract base class implementing the {@link Node} interface. + *

+ * This class provides the fundamental structure for a node within the HNSW graph, + * managing a unique {@link Tuple} primary key and an immutable list of its neighbors. + * Subclasses are expected to provide concrete implementations, potentially adding + * more state or behavior. + * + * @param the type of the node reference used for neighbors, which must extend {@link NodeReference} + */ +abstract class AbstractNode implements Node { + @Nonnull + private final Tuple primaryKey; + + @Nonnull + private final List neighbors; + + /** + * Constructs a new {@code AbstractNode} with a specified primary key and a list of neighbors. + * + * @param primaryKey the unique identifier for this node; must not be {@code null} + * @param neighbors the list of nodes connected to this node; must not be {@code null} + */ + protected AbstractNode(@Nonnull final Tuple primaryKey, + @Nonnull final List neighbors) { + this.primaryKey = primaryKey; + this.neighbors = ImmutableList.copyOf(neighbors); + } + + /** + * Gets the primary key that uniquely identifies this object. + * @return the primary key {@link Tuple}, which will never be {@code null}. + */ + @Nonnull + @Override + public Tuple getPrimaryKey() { + return primaryKey; + } + + /** + * Gets the list of neighbors connected to this node. + *

+ * This method returns a direct reference to the internal list which is + * immutable. + * @return a non-null, possibly empty, list of neighbors. + */ + @Nonnull + @Override + public List getNeighbors() { + return neighbors; + } + + + /** + * Converts this node into its {@link CompactNode} representation. + *

+ * A {@code CompactNode} is a space-efficient implementation {@code Node}. This method provides the + * conversion logic to transform the current object into that compact form. + * + * @return a non-null {@link CompactNode} representing the current node. + */ + @Nonnull + public abstract CompactNode asCompactNode(); + + /** + * Converts this node into its {@link InliningNode} representation. + * @return this object cast to an {@link InliningNode}; never {@code null}. + * @throws ClassCastException if this object is not actually an instance of + * {@link InliningNode}. + */ + @Nonnull + public abstract InliningNode asInliningNode(); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java new file mode 100644 index 0000000000..84e7db99ab --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java @@ -0,0 +1,236 @@ +/* + * AbstractStorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.linear.AffineOperator; +import com.apple.foundationdb.linear.Quantizer; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.Tuple; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.concurrent.CompletableFuture; + +/** + * An abstract base class for {@link StorageAdapter} implementations. + *

+ * This class provides the common infrastructure for managing HNSW graph data within a {@link Subspace}. + * It handles the configuration, node creation, and listener management, while delegating the actual + * storage-specific read and write operations to concrete subclasses through the {@code fetchNodeInternal} + * and {@code writeNodeInternal} abstract methods. + * + * @param the type of {@link NodeReference} used to reference nodes in the graph + */ +abstract class AbstractStorageAdapter implements StorageAdapter { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(AbstractStorageAdapter.class); + + @Nonnull + private final Config config; + @Nonnull + private final NodeFactory nodeFactory; + @Nonnull + private final Subspace subspace; + @Nonnull + private final OnWriteListener onWriteListener; + @Nonnull + private final OnReadListener onReadListener; + + @Nonnull + private final Subspace dataSubspace; + + /** + * Constructs a new {@code AbstractStorageAdapter}. + *

+ * This constructor initializes the adapter with the necessary configuration, + * factories, and listeners for managing an HNSW graph. It also sets up a + * dedicated data subspace within the provided main subspace for storing node data. + * + * @param config the HNSW graph configuration + * @param nodeFactory the factory to create new nodes of type {@code } + * @param subspace the primary subspace for storing all graph-related data + * @param onWriteListener the listener to be called on write operations + * @param onReadListener the listener to be called on read operations + */ + protected AbstractStorageAdapter(@Nonnull final Config config, @Nonnull final NodeFactory nodeFactory, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + this.config = config; + this.nodeFactory = nodeFactory; + this.subspace = subspace; + this.onWriteListener = onWriteListener; + this.onReadListener = onReadListener; + this.dataSubspace = subspace.subspace(Tuple.from(SUBSPACE_PREFIX_DATA)); + } + + @Override + @Nonnull + public Config getConfig() { + return config; + } + + @Nonnull + @Override + public NodeFactory getNodeFactory() { + return nodeFactory; + } + + @Override + @Nonnull + public Subspace getSubspace() { + return subspace; + } + + /** + * Gets the cached subspace for the data associated with this component. + *

+ * The data subspace defines the portion of the directory space where the data + * for this component is stored. + * + * @return the non-null {@link Subspace} for the data + */ + @Override + @Nonnull + public Subspace getDataSubspace() { + return dataSubspace; + } + + @Override + @Nonnull + public OnWriteListener getOnWriteListener() { + return onWriteListener; + } + + @Override + @Nonnull + public OnReadListener getOnReadListener() { + return onReadListener; + } + + /** + * Asynchronously fetches a node from a specific layer of the HNSW. + *

+ * The node is identified by its {@code layer} and {@code primaryKey}. The entire fetch operation is + * performed within the given {@link ReadTransaction}. After the underlying + * fetch operation completes, the retrieved node is validated by the + * {@link #checkNode(Node)} method before the returned future is completed. + * + * @param readTransaction the non-null transaction to use for the read operation + * @param storageTransform an affine vector transformation operator that is used to transform the fetched vector + * into the storage space that is currently being used + * @param layer the layer of the tree from which to fetch the node + * @param primaryKey the non-null primary key that identifies the node to fetch + * + * @return a {@link CompletableFuture} that will complete with the fetched {@link AbstractNode} + * once it has been read from storage and validated + */ + @Nonnull + @Override + public CompletableFuture> fetchNode(@Nonnull final ReadTransaction readTransaction, + @Nonnull final AffineOperator storageTransform, + int layer, @Nonnull Tuple primaryKey) { + return fetchNodeInternal(readTransaction, storageTransform, layer, primaryKey).thenApply(this::checkNode); + } + + /** + * Asynchronously fetches a specific node from the data store for a given layer and primary key. + *

+ * This is an internal, abstract method that concrete subclasses must implement to define + * the storage-specific logic for retrieving a node. The operation is performed within the + * context of the provided {@link ReadTransaction}. + * + * @param readTransaction the transaction to use for the read operation; must not be {@code null} + * @param storageTransform an affine vector transformation operator that is used to transform the fetched vector + * into the storage space that is currently being used + * @param layer the layer index from which to fetch the node + * @param primaryKey the primary key that uniquely identifies the node to be fetched; must not be {@code null} + * + * @return a {@link CompletableFuture} that will be completed with the fetched {@link AbstractNode}. + * The future will complete with {@code null} if no node is found for the given key and layer. + */ + @Nonnull + protected abstract CompletableFuture> fetchNodeInternal(@Nonnull ReadTransaction readTransaction, + @Nonnull AffineOperator storageTransform, + int layer, @Nonnull Tuple primaryKey); + + /** + * Method to perform basic invariant check(s) on a newly-fetched node. + * + * @param node the node to check + * was passed in + * + * @return the node that was passed in + */ + @Nullable + private > T checkNode(@Nullable final T node) { + return node; + } + + /** + * Writes a given node and its neighbor modifications to the underlying storage. + *

+ * This operation is executed within the context of the provided {@link Transaction}. + * It handles persisting the node's data at a specific {@code layer} and applies + * the changes to its neighbors as defined in the {@link NeighborsChangeSet}. + * This method delegates the core writing logic to an internal method and provides + * debug logging upon completion. + * + * @param transaction the non-null {@link Transaction} context for this write operation + * @param quantizer the quantizer to use + * @param node the non-null {@link Node} to be written to storage + * @param layer the layer index where the node is being written + * @param changeSet the non-null {@link NeighborsChangeSet} detailing the modifications + * to the node's neighbors + */ + @Override + public void writeNode(@Nonnull final Transaction transaction, @Nonnull final Quantizer quantizer, + @Nonnull final AbstractNode node, final int layer, + @Nonnull final NeighborsChangeSet changeSet) { + writeNodeInternal(transaction, quantizer, node, layer, changeSet); + if (logger.isTraceEnabled()) { + logger.trace("written node with key={} at layer={}", node.getPrimaryKey(), layer); + } + } + + /** + * Writes a single node to the data store as part of a larger transaction. + *

+ * This is an abstract method that concrete implementations must provide. + * It is responsible for the low-level persistence of the given {@code node} at a + * specific {@code layer}. The implementation should also handle the modifications + * to the node's neighbors, as detailed in the {@code changeSet}. + * + * @param transaction the non-null transaction context for the write operation + * @param quantizer the quantizer to use + * @param node the non-null {@link Node} to write + * @param layer the layer or level of the node in the structure + * @param changeSet the non-null {@link NeighborsChangeSet} detailing additions or + * removals of neighbor links + */ + protected abstract void writeNodeInternal(@Nonnull Transaction transaction, @Nonnull Quantizer quantizer, + @Nonnull AbstractNode node, int layer, + @Nonnull NeighborsChangeSet changeSet); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AccessInfo.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AccessInfo.java new file mode 100644 index 0000000000..792012796b --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AccessInfo.java @@ -0,0 +1,111 @@ +/* + * AccessInfo.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.linear.RealVector; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.Objects; + +/** + * Class to capture the current state of this HNSW that cannot be expressed as metadata but that also is not the actual + * data that is inserted, organized and retrieved. For instance, an HNSW needs to keep track of the entry point that + * resides in the highest layer(currently). Another example is any information that pertains to coordinate system + * transformations that have to be carried out prior/posterior to inserting/retrieving an item into/from the HNSW. + */ +class AccessInfo { + /** + * The current entry point. All searches start here. + */ + @Nonnull + private final EntryNodeReference entryNodeReference; + + /** + * A seed that can be used to reconstruct a random rotator {@link com.apple.foundationdb.linear.FhtKacRotator} used + * in ({@link StorageTransform}. + */ + private final long rotatorSeed; + + /** + * The negated centroid that is usually derived as an average over some vectors seen so far. It is used to create + * the {@link StorageTransform}. The centroid is stored in its negated form (i.e. {@code centroid * (-1)}) as the + * {@link com.apple.foundationdb.linear.AffineOperator} adds its translation vector but the centroid needs to be + * subtracted. + */ + @Nullable + private final RealVector negatedCentroid; + + public AccessInfo(@Nonnull final EntryNodeReference entryNodeReference, final long rotatorSeed, + @Nullable final RealVector negatedCentroid) { + this.entryNodeReference = entryNodeReference; + this.rotatorSeed = rotatorSeed; + this.negatedCentroid = negatedCentroid; + } + + @Nonnull + public EntryNodeReference getEntryNodeReference() { + return entryNodeReference; + } + + public boolean canUseRaBitQ() { + return getNegatedCentroid() != null; + } + + public long getRotatorSeed() { + return rotatorSeed; + } + + @Nullable + public RealVector getNegatedCentroid() { + return negatedCentroid; + } + + @Nonnull + public AccessInfo withNewEntryNodeReference(@Nonnull final EntryNodeReference entryNodeReference) { + return new AccessInfo(entryNodeReference, getRotatorSeed(), getNegatedCentroid()); + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof AccessInfo)) { + return false; + } + final AccessInfo that = (AccessInfo)o; + return rotatorSeed == that.rotatorSeed && + Objects.equals(entryNodeReference, that.entryNodeReference) && + Objects.equals(negatedCentroid, that.negatedCentroid); + } + + @Override + public int hashCode() { + return Objects.hash(entryNodeReference, rotatorSeed, negatedCentroid); + } + + @Nonnull + @Override + public String toString() { + return "AccessInfo[" + + "entryNodeReference=" + entryNodeReference + + ", rotatorSeed=" + rotatorSeed + + ", centroid=" + negatedCentroid + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AggregatedVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AggregatedVector.java new file mode 100644 index 0000000000..b47dcd1131 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AggregatedVector.java @@ -0,0 +1,70 @@ +/* + * AggregatedVector.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.Transformed; + +import javax.annotation.Nonnull; +import java.util.Objects; + +/** + * A record-like class wrapping a {@link RealVector} and a count. This data structure is used to keep a running sum + * of many vectors in order to compute their centroid at a later time. + */ +class AggregatedVector { + private final int partialCount; + @Nonnull + private final Transformed partialVector; + + public AggregatedVector(final int partialCount, @Nonnull final Transformed partialVector) { + this.partialCount = partialCount; + this.partialVector = partialVector; + } + + public int getPartialCount() { + return partialCount; + } + + @Nonnull + public Transformed getPartialVector() { + return partialVector; + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof AggregatedVector)) { + return false; + } + final AggregatedVector that = (AggregatedVector)o; + return partialCount == that.partialCount && Objects.equals(partialVector, that.partialVector); + } + + @Override + public int hashCode() { + return Objects.hash(partialCount, partialVector); + } + + @Override + public String toString() { + return "AggregatedVector[" + partialCount + ", " + partialVector + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java new file mode 100644 index 0000000000..490b4bc844 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java @@ -0,0 +1,96 @@ +/* + * BaseNeighborsChangeSet.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.linear.Quantizer; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.function.Predicate; + +/** + * A base implementation of the {@link NeighborsChangeSet} interface. + *

+ * This class represents a complete, non-delta state of a node's neighbors. It holds a fixed, immutable + * list of neighbors provided at construction time. As such, it does not support parent change sets or writing deltas. + * + * @param the type of the node reference, which must extend {@link NodeReference} + */ +class BaseNeighborsChangeSet implements NeighborsChangeSet { + @Nonnull + private final List neighbors; + + /** + * Creates a new change set with the specified neighbors. + *

+ * This constructor creates an immutable copy of the provided list. + * + * @param neighbors the list of neighbors for this change set; must not be null. + */ + public BaseNeighborsChangeSet(@Nonnull final List neighbors) { + this.neighbors = ImmutableList.copyOf(neighbors); + } + + /** + * Gets the parent change set. + *

+ * This implementation always returns {@code null}, as this type of change set + * does not have a parent. + * + * @return always {@code null}. + */ + @Nullable + @Override + public BaseNeighborsChangeSet getParent() { + return null; + } + + /** + * Retrieves the list of neighbors associated with this object. + *

+ * This implementation fulfills the {@code merge} contract by simply returning the + * existing list of neighbors without performing any additional merging logic. + * @return a non-null list of neighbors. The generic type {@code N} represents + * the type of the neighboring elements. + */ + @Nonnull + @Override + public List merge() { + return neighbors; + } + + /** + * {@inheritDoc} + * + *

This implementation is a no-op and does not write any delta information, + * as indicated by the empty method body. + */ + @Override + public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, + @Nonnull final Quantizer quantizer, final int layer, @Nonnull final AbstractNode node, + @Nonnull final Predicate primaryKeyPredicate) { + // nothing to be written + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java new file mode 100644 index 0000000000..cb742506ca --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java @@ -0,0 +1,166 @@ +/* + * CompactNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; +import com.apple.foundationdb.half.Half; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.Transformed; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +/** + * Represents a compact node within a graph structure, extending {@link AbstractNode}. + *

+ * This node type is considered "compact" because it directly stores its associated + * data vector of type {@link RealVector}. It is used to represent a vector in a + * vector space and maintains references to its neighbors via {@link NodeReference} objects. + * + * @see AbstractNode + * @see NodeReference + */ +class CompactNode extends AbstractNode { + @Nonnull + private static final NodeFactory FACTORY = new NodeFactory<>() { + @SuppressWarnings("unchecked") + @Nonnull + @Override + @SpotBugsSuppressWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") + public AbstractNode create(@Nonnull final Tuple primaryKey, + @Nullable final Transformed vector, + @Nonnull final List neighbors) { + return new CompactNode(primaryKey, Objects.requireNonNull(vector), (List)neighbors); + } + + @Nonnull + @Override + public NodeKind getNodeKind() { + return NodeKind.COMPACT; + } + }; + + @Nonnull + private final Transformed vector; + + /** + * Constructs a new {@code CompactNode} instance. + *

+ * This constructor initializes the node with its primary key, a data vector, + * and a list of its neighbors. It delegates the initialization of the + * {@code primaryKey} and {@code neighbors} to the superclass constructor. + * + * @param primaryKey the primary key that uniquely identifies this node; must not be {@code null}. + * @param vector the data vector of type {@code RealVector} associated with this node; must not be {@code null}. + * @param neighbors a list of {@link NodeReference} objects representing the neighbors of this node; must not be + * {@code null}. + */ + public CompactNode(@Nonnull final Tuple primaryKey, @Nonnull final Transformed vector, + @Nonnull final List neighbors) { + super(primaryKey, neighbors); + this.vector = vector; + } + + /** + * Returns a {@link NodeReference} that uniquely identifies this node. + *

+ * This implementation creates the reference using the node's primary key, obtained via {@code getPrimaryKey()}. It + * ignores the provided {@code vector} parameter, which exists to fulfill the contract of the overridden method. + * + * @param vector the vector context, which is ignored in this implementation. + * Per the {@code @Nullable} annotation, this can be {@code null}. + * + * @return a non-null {@link NodeReference} to this node. + */ + @Nonnull + @Override + public NodeReference getSelfReference(@Nullable final Transformed vector) { + return new NodeReference(getPrimaryKey()); + } + + /** + * Gets the kind of this node. + * This implementation always returns {@link NodeKind#COMPACT}. + * @return the node kind, which is guaranteed to be {@link NodeKind#COMPACT}. + */ + @Nonnull + @Override + public NodeKind getKind() { + return NodeKind.COMPACT; + } + + /** + * Gets the vector of {@code Half} objects. + * @return the non-null vector of {@link Half} objects. + */ + @Nonnull + public Transformed getVector() { + return vector; + } + + /** + * Returns this node as a {@code CompactNode}. As this class is already a {@code CompactNode}, this method provides + * {@code this}. + * @return this object cast as a {@code CompactNode}, which is guaranteed to be non-null. + */ + @Nonnull + @Override + public CompactNode asCompactNode() { + return this; + } + + /** + * Returns this node as an {@link InliningNode}. + *

+ * This override is for node types that are not inlining nodes. As such, it + * will always fail. + * @return this node as a non-null {@link InliningNode} + * @throws IllegalStateException always, as this is not an inlining node + */ + @Nonnull + @Override + public InliningNode asInliningNode() { + throw new IllegalStateException("this is not an inlining node"); + } + + /** + * Gets the shared factory instance for creating {@link NodeReference} objects. + *

+ * This static factory method is the preferred way to obtain a {@code NodeFactory} + * for {@link NodeReference} instances, as it returns a shared, pre-configured object. + * + * @return a shared, non-null instance of {@code NodeFactory} + */ + @Nonnull + public static NodeFactory factory() { + return FACTORY; + } + + @Override + public String toString() { + return "C[primaryKey=" + getPrimaryKey() + + ";vector=" + vector + + ";neighbors=" + getNeighbors() + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java new file mode 100644 index 0000000000..d14bee5368 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java @@ -0,0 +1,292 @@ +/* + * CompactStorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.KeyValue; +import com.apple.foundationdb.Range; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.StreamingMode; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.async.AsyncIterable; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.linear.AffineOperator; +import com.apple.foundationdb.linear.Quantizer; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.Transformed; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.ByteArrayUtil; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.base.Verify; +import com.google.common.collect.Lists; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * The {@code CompactStorageAdapter} class is a concrete implementation of {@link StorageAdapter} for managing HNSW + * graph data in a compact format. + *

+ * It handles the serialization and deserialization of graph nodes to and from a persistent data store. This + * implementation is optimized for space efficiency by storing nodes with their accompanying vector data and by storing + * just neighbor primary keys. It extends {@link AbstractStorageAdapter} to inherit common storage logic. + */ +class CompactStorageAdapter extends AbstractStorageAdapter implements StorageAdapter { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(CompactStorageAdapter.class); + + /** + * Constructs a new {@code CompactStorageAdapter}. + * + * @param config the HNSW graph configuration, must not be null. See {@link Config}. + * @param nodeFactory the factory used to create new nodes of type {@link NodeReference}, must not be null. + * @param subspace the {@link Subspace} where the graph data is stored, must not be null. + * @param onWriteListener the listener to be notified of write events, must not be null. + * @param onReadListener the listener to be notified of read events, must not be null. + */ + public CompactStorageAdapter(@Nonnull final Config config, + @Nonnull final NodeFactory nodeFactory, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + super(config, nodeFactory, subspace, onWriteListener, onReadListener); + } + + /** + * Asynchronously fetches a node from the database for a given layer and primary key. + *

+ * This internal method constructs a raw byte key from the {@code layer} and {@code primaryKey} within the store's + * data subspace. It then uses the provided {@link ReadTransaction} to retrieve the raw value. If a value is found, + * it is deserialized into a {@link AbstractNode} object using the {@code nodeFromRaw} method. + * + * @param readTransaction the transaction to use for the read operation + * @param storageTransform an affine vector transformation operator that is used to transform the fetched vector + * into the current storage space + * @param layer the layer of the node to fetch + * @param primaryKey the primary key of the node to fetch + * + * @return a future that will complete with the fetched {@link AbstractNode} + * + * @throws IllegalStateException if the node cannot be found in the database for the given key + */ + @Nonnull + @Override + protected CompletableFuture> fetchNodeInternal(@Nonnull final ReadTransaction readTransaction, + @Nonnull final AffineOperator storageTransform, + final int layer, + @Nonnull final Tuple primaryKey) { + final byte[] keyBytes = getDataSubspace().pack(Tuple.from(layer, primaryKey)); + + return readTransaction.get(keyBytes) + .thenApply(valueBytes -> { + if (valueBytes == null) { + throw new IllegalStateException("cannot fetch node"); + } + return nodeFromRaw(storageTransform, layer, primaryKey, keyBytes, valueBytes); + }); + } + + /** + * Deserializes a raw key-value byte array pair into a {@code Node}. + *

+ * This method first converts the {@code valueBytes} into a {@link Tuple} and then, + * along with the {@code primaryKey}, constructs the final {@code Node} object. + * It also notifies any registered {@link OnReadListener} about the raw key-value + * read and the resulting node creation. + * + * @param storageTransform an affine vector transformation operator that is used to transform the fetched vector + * into the storage space that is currently being used + * @param layer the layer of the HNSW where this node resides + * @param primaryKey the primary key for the node + * @param keyBytes the raw byte representation of the node's key + * @param valueBytes the raw byte representation of the node's value, which will be deserialized + * + * @return a non-null, deserialized {@link AbstractNode} object + */ + @Nonnull + private AbstractNode nodeFromRaw(@Nonnull final AffineOperator storageTransform, final int layer, + final @Nonnull Tuple primaryKey, + @Nonnull final byte[] keyBytes, @Nonnull final byte[] valueBytes) { + final Tuple nodeTuple = Tuple.fromBytes(valueBytes); + final AbstractNode node = nodeFromKeyValuesTuples(storageTransform, primaryKey, nodeTuple); + final OnReadListener onReadListener = getOnReadListener(); + onReadListener.onNodeRead(layer, node); + onReadListener.onKeyValueRead(layer, keyBytes, valueBytes); + return node; + } + + /** + * Constructs a compact {@link AbstractNode} from its representation as stored key and value tuples. + *

+ * This method deserializes a node by extracting its components from the provided tuples. It verifies that the + * node is of type {@link NodeKind#COMPACT} before delegating the final construction to + * {@link #compactNodeFromTuples(AffineOperator, Tuple, Tuple, Tuple)}. The {@code valueTuple} is expected to have + * a specific structure: the serialized node kind at index 0, a nested tuple for the vector at index 1, and a nested + * tuple for the neighbors at index 2. + * + * @param storageTransform an affine vector transformation operator that is used to transform the fetched vector + * into the storage space that is currently being used + * @param primaryKey the tuple representing the primary key of the node + * @param valueTuple the tuple containing the serialized node data, including kind, vector, and neighbors + * + * @return the reconstructed compact {@link AbstractNode} + * + * @throws com.google.common.base.VerifyException if the node kind encoded in {@code valueTuple} is not + * {@link NodeKind#COMPACT} + */ + @Nonnull + private AbstractNode nodeFromKeyValuesTuples(@Nonnull final AffineOperator storageTransform, + @Nonnull final Tuple primaryKey, + @Nonnull final Tuple valueTuple) { + final NodeKind nodeKind = NodeKind.fromSerializedNodeKind((byte)valueTuple.getLong(0)); + Verify.verify(nodeKind == NodeKind.COMPACT); + + final Tuple vectorTuple; + final Tuple neighborsTuple; + + vectorTuple = valueTuple.getNestedTuple(1); + neighborsTuple = valueTuple.getNestedTuple(2); + return compactNodeFromTuples(storageTransform, primaryKey, vectorTuple, neighborsTuple); + } + + /** + * Creates a compact in-memory representation of a graph node from its constituent storage tuples. + *

+ * This method deserializes the raw data stored in {@code Tuple} objects into their + * corresponding in-memory types. It extracts the vector, constructs a list of + * {@link NodeReference} objects for the neighbors, and then uses a factory to + * assemble the final {@code Node} object. + *

+ * + * @param storageTransform an affine vector transformation operator that is used to transform the fetched vector + * into the storage space that is currently being used + * @param primaryKey the tuple representing the node's primary key + * @param vectorTuple the tuple containing the node's vector data + * @param neighborsTuple the tuple containing a list of nested tuples, where each nested tuple represents a neighbor + * + * @return a new {@code Node} instance containing the deserialized data from the input tuples + */ + @Nonnull + private AbstractNode compactNodeFromTuples(@Nonnull final AffineOperator storageTransform, + @Nonnull final Tuple primaryKey, + @Nonnull final Tuple vectorTuple, + @Nonnull final Tuple neighborsTuple) { + final Transformed vector = + storageTransform.transform(StorageAdapter.vectorFromTuple(getConfig(), vectorTuple)); + final List nodeReferences = Lists.newArrayListWithExpectedSize(neighborsTuple.size()); + + for (int i = 0; i < neighborsTuple.size(); i ++) { + final Tuple neighborTuple = neighborsTuple.getNestedTuple(i); + nodeReferences.add(new NodeReference(neighborTuple)); + } + + return getNodeFactory().create(primaryKey, vector, nodeReferences); + } + + /** + * Writes the internal representation of a compact node to the data store within a given transaction. + * This method handles the serialization of the node's vector and its final set of neighbors based on the + * provided {@code neighborsChangeSet}. + * + *

The node is stored as a {@link Tuple} with the structure {@code (NodeKind, RealVector, NeighborPrimaryKeys)}. + * The key for the storage is derived from the node's layer and its primary key. After writing, it notifies any + * registered write listeners via {@code onNodeWritten} and {@code onKeyValueWritten}. + * + * @param transaction the {@link Transaction} to use for the write operation. + * @param quantizer the quantizer to use + * @param node the {@link AbstractNode} to be serialized and written; it is processed as a {@link CompactNode}. + * @param layer the graph layer index for the node, used to construct the storage key. + * @param neighborsChangeSet a {@link NeighborsChangeSet} containing the additions and removals, which are + * merged to determine the final set of neighbors to be written. + */ + @Override + public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull final Quantizer quantizer, + @Nonnull final AbstractNode node, final int layer, + @Nonnull final NeighborsChangeSet neighborsChangeSet) { + final byte[] key = getDataSubspace().pack(Tuple.from(layer, node.getPrimaryKey())); + + final List nodeItems = Lists.newArrayListWithExpectedSize(3); + nodeItems.add(NodeKind.COMPACT.getSerialized()); + final CompactNode compactNode = node.asCompactNode(); + // getting underlying vector is okay as it is only written to the database + nodeItems.add(StorageAdapter.tupleFromVector(quantizer.encode(compactNode.getVector()))); + + final Iterable neighbors = neighborsChangeSet.merge(); + + final List neighborItems = Lists.newArrayList(); + for (final NodeReference neighborReference : neighbors) { + neighborItems.add(neighborReference.getPrimaryKey()); + } + nodeItems.add(Tuple.fromList(neighborItems)); + + final Tuple nodeTuple = Tuple.fromList(nodeItems); + + final byte[] value = nodeTuple.pack(); + transaction.set(key, value); + getOnWriteListener().onNodeWritten(layer, node); + getOnWriteListener().onKeyValueWritten(layer, key, value); + + if (logger.isTraceEnabled()) { + logger.trace("written neighbors of primaryKey={}, oldSize={}, newSize={}", node.getPrimaryKey(), + node.getNeighbors().size(), neighborItems.size()); + } + } + + /** + * Scans a given layer for nodes, returning an iterable over the results. + *

+ * This method reads a limited number of nodes from a specific layer in the underlying data store. + * The scan can be started from a specific point using the {@code lastPrimaryKey} parameter, which is + * useful for paginating through the nodes in a large layer. + * + * @param readTransaction the transaction to use for reading data; must not be {@code null} + * @param layer the layer to scan for nodes + * @param lastPrimaryKey the primary key of the last node from a previous scan. If {@code null}, + * the scan starts from the beginning of the layer. + * @param maxNumRead the maximum number of nodes to read in this scan + * + * @return an {@link Iterable} of {@link AbstractNode} objects found in the specified layer, + * limited by {@code maxNumRead} + */ + @Nonnull + @Override + public AsyncIterable> scanLayer(@Nonnull final ReadTransaction readTransaction, int layer, + @Nullable final Tuple lastPrimaryKey, int maxNumRead) { + final byte[] layerPrefix = getDataSubspace().pack(Tuple.from(layer)); + final Range range = + lastPrimaryKey == null + ? Range.startsWith(layerPrefix) + : new Range(ByteArrayUtil.strinc(getDataSubspace().pack(Tuple.from(layer, lastPrimaryKey))), + ByteArrayUtil.strinc(layerPrefix)); + final AsyncIterable itemsIterable = + readTransaction.getRange(range, maxNumRead, false, StreamingMode.ITERATOR); + + return AsyncUtil.mapIterable(itemsIterable, keyValue -> { + final byte[] key = keyValue.getKey(); + final byte[] value = keyValue.getValue(); + final Tuple primaryKey = getDataSubspace().unpack(key).getNestedTuple(1); + return nodeFromRaw(AffineOperator.identity(), layer, primaryKey, key, value); + }); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Config.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Config.java new file mode 100644 index 0000000000..eda5a0c17d --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Config.java @@ -0,0 +1,539 @@ +/* + * Config.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.linear.Metric; +import com.google.errorprone.annotations.CanIgnoreReturnValue; + +import javax.annotation.Nonnull; +import java.util.Objects; + +/** + * Configuration settings for a {@link HNSW}. + */ +@SuppressWarnings("checkstyle:MemberName") +public final class Config { + public static final long DEFAULT_RANDOM_SEED = 0L; + @Nonnull public static final Metric DEFAULT_METRIC = Metric.EUCLIDEAN_METRIC; + public static final boolean DEFAULT_USE_INLINING = false; + public static final int DEFAULT_M = 16; + public static final int DEFAULT_M_MAX_0 = 2 * DEFAULT_M; + public static final int DEFAULT_M_MAX = DEFAULT_M; + public static final int DEFAULT_EF_CONSTRUCTION = 200; + public static final boolean DEFAULT_EXTEND_CANDIDATES = false; + public static final boolean DEFAULT_KEEP_PRUNED_CONNECTIONS = false; + // stats + public static final double DEFAULT_SAMPLE_VECTOR_STATS_PROBABILITY = 0.5d; + public static final double DEFAULT_MAINTAIN_STATS_PROBABILITY = 0.05d; + public static final int DEFAULT_STATS_THRESHOLD = 1000; + // RaBitQ + public static final boolean DEFAULT_USE_RABITQ = false; + public static final int DEFAULT_RABITQ_NUM_EX_BITS = 4; + + // concurrency + public static final int DEFAULT_MAX_NUM_CONCURRENT_NODE_FETCHES = 16; + public static final int DEFAULT_MAX_NUM_CONCURRENT_NEIGHBOR_FETCHES = 16; + + /** + * The random seed that is used to probabilistically determine the highest layer of an insert. + */ + private final long randomSeed; + + /** + * The metric that is used to determine distances between vectors. + */ + @Nonnull + private final Metric metric; + + /** + * The number of dimensions used. All vectors must have exactly this number of dimensions. + */ + private final int numDimensions; + + /** + * Indicator if all layers except layer {@code 0} use inlining. If inlining is used, each node is persisted + * as a key/value pair per neighbor which includes the vectors of the neighbors but not for itself. If inlining is + * not used, each node is persisted as exactly one key/value pair per node which stores its own vector but + * specifically excludes the vectors of the neighbors. + */ + private final boolean useInlining; + + /** + * This attribute (named {@code M} by the HNSW paper) is the connectivity value for all nodes stored on any layer. + * While by no means enforced or even enforceable, we strive to create and maintain exactly {@code m} neighbors for + * a node. Due to insert/delete operations it is possible that the actual number of neighbors a node references is + * not exactly {@code m} at any given time. + */ + private final int m; + + /** + * This attribute (named {@code M_max} by the HNSW paper) is the maximum connectivity value for nodes stored on a + * layer greater than {@code 0}. We will never create more that {@code mMax} neighbors for a node. That means that + * we even prune the neighbors of a node if the actual number of neighbors would otherwise exceed {@code mMax}. + */ + private final int mMax; + + /** + * This attribute (named {@code M_max0} by the HNSW paper) is the maximum connectivity value for nodes stored on + * layer {@code 0}. We will never create more that {@code mMax0} neighbors for a node that is stored on that layer. + * That means that we even prune the neighbors of a node if the actual number of neighbors would otherwise exceed + * {@code mMax0}. + */ + private final int mMax0; + + /** + * Maximum size of the search queues (on independent queue per layer) that are used during the insertion of a new + * node. If {@code efConstruction} is set to {@code 1}, the search naturally follows a greedy approach + * (monotonous descent), whereas a high number for {@code efConstruction} allows for a more nuanced search that can + * tolerate (false) local minima. + */ + private final int efConstruction; + + /** + * Indicator to signal if, during the insertion of a node, the set of nearest neighbors of that node is to be + * extended by the actual neighbors of those neighbors to form a set of candidates that the new node may be + * connected to during the insert operation. + */ + private final boolean extendCandidates; + + /** + * Indicator to signal if, during the insertion of a node, candidates that have been discarded due to not satisfying + * the select-neighbor heuristic may get added back in to pad the set of neighbors if the new node would otherwise + * have too few neighbors (see {@link #m}). + */ + private final boolean keepPrunedConnections; + + /** + * If sampling is necessary (currently iff {@link #isUseRaBitQ()} is {@code true}), this attribute represents the + * probability of a vector being inserted to also be written into the + * {@link StorageAdapter#SUBSPACE_PREFIX_SAMPLES} subspace. The vectors in that subspace are continuously aggregated + * until a total {@link #statsThreshold} has been reached. + */ + private final double sampleVectorStatsProbability; + + /** + * If sampling is necessary (currently iff {@link #isUseRaBitQ()} is {@code true}), this attribute represents the + * probability of the {@link StorageAdapter#SUBSPACE_PREFIX_SAMPLES} subspace to be further aggregated (rolled-up) + * when a new vector is inserted. The vectors in that subspace are continuously aggregated until a total + * {@link #statsThreshold} has been reached. + */ + private final double maintainStatsProbability; + + /** + * If sampling is necessary (currently iff {@link #isUseRaBitQ()} is {@code true}), this attribute represents the + * threshold (being a number of vectors) that when reached causes the stats maintenance logic to compute the actual + * statistics (currently the centroid of the vectors that have been inserted to far). + */ + private final int statsThreshold; + + /** + * Indicator if we should RaBitQ quantization. See {@link com.apple.foundationdb.rabitq.RaBitQuantizer} for more + * details. + */ + private final boolean useRaBitQ; + + /** + * Number of bits per dimensions iff {@link #isUseRaBitQ()} is set to {@code true}, ignored otherwise. If RaBitQ + * encoding is used, a vector is stored using roughly {@code 25 + numDimensions * (numExBits + 1) / 8} bytes. + */ + private final int raBitQNumExBits; + + /** + * Maximum number of concurrent node fetches during search and modification operations. + */ + private final int maxNumConcurrentNodeFetches; + + /** + * Maximum number of concurrent neighborhood fetches during modification operations when the neighbors are pruned. + */ + private final int maxNumConcurrentNeighborhoodFetches; + + private Config(final long randomSeed, @Nonnull final Metric metric, final int numDimensions, + final boolean useInlining, final int m, final int mMax, final int mMax0, + final int efConstruction, final boolean extendCandidates, final boolean keepPrunedConnections, + final double sampleVectorStatsProbability, final double maintainStatsProbability, + final int statsThreshold, final boolean useRaBitQ, final int raBitQNumExBits, + final int maxNumConcurrentNodeFetches, final int maxNumConcurrentNeighborhoodFetches) { + this.randomSeed = randomSeed; + this.metric = metric; + this.numDimensions = numDimensions; + this.useInlining = useInlining; + this.m = m; + this.mMax = mMax; + this.mMax0 = mMax0; + this.efConstruction = efConstruction; + this.extendCandidates = extendCandidates; + this.keepPrunedConnections = keepPrunedConnections; + this.sampleVectorStatsProbability = sampleVectorStatsProbability; + this.maintainStatsProbability = maintainStatsProbability; + this.statsThreshold = statsThreshold; + this.useRaBitQ = useRaBitQ; + this.raBitQNumExBits = raBitQNumExBits; + this.maxNumConcurrentNodeFetches = maxNumConcurrentNodeFetches; + this.maxNumConcurrentNeighborhoodFetches = maxNumConcurrentNeighborhoodFetches; + } + + public long getRandomSeed() { + return randomSeed; + } + + @Nonnull + public Metric getMetric() { + return metric; + } + + public int getNumDimensions() { + return numDimensions; + } + + public boolean isUseInlining() { + return useInlining; + } + + public int getM() { + return m; + } + + public int getMMax() { + return mMax; + } + + public int getMMax0() { + return mMax0; + } + + public int getEfConstruction() { + return efConstruction; + } + + public boolean isExtendCandidates() { + return extendCandidates; + } + + public boolean isKeepPrunedConnections() { + return keepPrunedConnections; + } + + public double getSampleVectorStatsProbability() { + return sampleVectorStatsProbability; + } + + public double getMaintainStatsProbability() { + return maintainStatsProbability; + } + + public int getStatsThreshold() { + return statsThreshold; + } + + public boolean isUseRaBitQ() { + return useRaBitQ; + } + + public int getRaBitQNumExBits() { + return raBitQNumExBits; + } + + public int getMaxNumConcurrentNodeFetches() { + return maxNumConcurrentNodeFetches; + } + + public int getMaxNumConcurrentNeighborhoodFetches() { + return maxNumConcurrentNeighborhoodFetches; + } + + @Nonnull + public ConfigBuilder toBuilder() { + return new ConfigBuilder(getRandomSeed(), getMetric(), isUseInlining(), getM(), getMMax(), getMMax0(), + getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections(), + getSampleVectorStatsProbability(), getMaintainStatsProbability(), getStatsThreshold(), + isUseRaBitQ(), getRaBitQNumExBits(), getMaxNumConcurrentNodeFetches(), + getMaxNumConcurrentNeighborhoodFetches()); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Config)) { + return false; + } + final Config config = (Config)o; + return randomSeed == config.randomSeed && numDimensions == config.numDimensions && + useInlining == config.useInlining && m == config.m && mMax == config.mMax && mMax0 == config.mMax0 && + efConstruction == config.efConstruction && extendCandidates == config.extendCandidates && + keepPrunedConnections == config.keepPrunedConnections && + Double.compare(sampleVectorStatsProbability, config.sampleVectorStatsProbability) == 0 && + Double.compare(maintainStatsProbability, config.maintainStatsProbability) == 0 && + statsThreshold == config.statsThreshold && useRaBitQ == config.useRaBitQ && + raBitQNumExBits == config.raBitQNumExBits && metric == config.metric && + maxNumConcurrentNodeFetches == config.maxNumConcurrentNodeFetches && + maxNumConcurrentNeighborhoodFetches == config.maxNumConcurrentNeighborhoodFetches; + } + + @Override + public int hashCode() { + return Objects.hash(randomSeed, metric, numDimensions, useInlining, m, mMax, mMax0, efConstruction, + extendCandidates, keepPrunedConnections, sampleVectorStatsProbability, maintainStatsProbability, + statsThreshold, useRaBitQ, raBitQNumExBits, maxNumConcurrentNodeFetches, maxNumConcurrentNeighborhoodFetches); + } + + @Override + @Nonnull + public String toString() { + return "Config[randomSeed=" + getRandomSeed() + ", metric=" + getMetric() + + ", numDimensions=" + getNumDimensions() + ", isUseInlining=" + isUseInlining() + ", M=" + getM() + + ", MMax=" + getMMax() + ", MMax0=" + getMMax0() + ", efConstruction=" + getEfConstruction() + + ", isExtendCandidates=" + isExtendCandidates() + + ", isKeepPrunedConnections=" + isKeepPrunedConnections() + + ", sampleVectorStatsProbability=" + getSampleVectorStatsProbability() + + ", mainStatsProbability=" + getMaintainStatsProbability() + ", statsThreshold=" + getStatsThreshold() + + ", useRaBitQ=" + isUseRaBitQ() + ", raBitQNumExBits=" + getRaBitQNumExBits() + + ", maxNumConcurrentNodeFetches=" + getMaxNumConcurrentNodeFetches() + + ", maxNumConcurrentNeighborhoodFetches=" + getMaxNumConcurrentNeighborhoodFetches() + + "]"; + } + + /** + * Builder for {@link Config}. + * + * @see HNSW#newConfigBuilder + */ + @CanIgnoreReturnValue + @SuppressWarnings("checkstyle:MemberName") + public static class ConfigBuilder { + private long randomSeed = DEFAULT_RANDOM_SEED; + @Nonnull + private Metric metric = DEFAULT_METRIC; + private boolean useInlining = DEFAULT_USE_INLINING; + private int m = DEFAULT_M; + private int mMax = DEFAULT_M_MAX; + private int mMax0 = DEFAULT_M_MAX_0; + private int efConstruction = DEFAULT_EF_CONSTRUCTION; + private boolean extendCandidates = DEFAULT_EXTEND_CANDIDATES; + private boolean keepPrunedConnections = DEFAULT_KEEP_PRUNED_CONNECTIONS; + + private double sampleVectorStatsProbability = DEFAULT_SAMPLE_VECTOR_STATS_PROBABILITY; + private double maintainStatsProbability = DEFAULT_MAINTAIN_STATS_PROBABILITY; + private int statsThreshold = DEFAULT_STATS_THRESHOLD; + + private boolean useRaBitQ = DEFAULT_USE_RABITQ; + private int raBitQNumExBits = DEFAULT_RABITQ_NUM_EX_BITS; + + private int maxNumConcurrentNodeFetches = DEFAULT_MAX_NUM_CONCURRENT_NODE_FETCHES; + private int maxNumConcurrentNeighborhoodFetches = DEFAULT_MAX_NUM_CONCURRENT_NEIGHBOR_FETCHES; + + public ConfigBuilder() { + } + + public ConfigBuilder(final long randomSeed, @Nonnull final Metric metric, final boolean useInlining, + final int m, final int mMax, final int mMax0, final int efConstruction, + final boolean extendCandidates, final boolean keepPrunedConnections, + final double sampleVectorStatsProbability, final double maintainStatsProbability, + final int statsThreshold, final boolean useRaBitQ, final int raBitQNumExBits, + final int maxNumConcurrentNodeFetches, final int maxNumConcurrentNeighborhoodFetches) { + this.randomSeed = randomSeed; + this.metric = metric; + this.useInlining = useInlining; + this.m = m; + this.mMax = mMax; + this.mMax0 = mMax0; + this.efConstruction = efConstruction; + this.extendCandidates = extendCandidates; + this.keepPrunedConnections = keepPrunedConnections; + this.sampleVectorStatsProbability = sampleVectorStatsProbability; + this.maintainStatsProbability = maintainStatsProbability; + this.statsThreshold = statsThreshold; + this.useRaBitQ = useRaBitQ; + this.raBitQNumExBits = raBitQNumExBits; + this.maxNumConcurrentNodeFetches = maxNumConcurrentNodeFetches; + this.maxNumConcurrentNeighborhoodFetches = maxNumConcurrentNeighborhoodFetches; + } + + public long getRandomSeed() { + return randomSeed; + } + + @Nonnull + public ConfigBuilder setRandomSeed(final long randomSeed) { + this.randomSeed = randomSeed; + return this; + } + + @Nonnull + public Metric getMetric() { + return metric; + } + + @Nonnull + public ConfigBuilder setMetric(@Nonnull final Metric metric) { + this.metric = metric; + return this; + } + + public boolean isUseInlining() { + return useInlining; + } + + @Nonnull + public ConfigBuilder setUseInlining(final boolean useInlining) { + this.useInlining = useInlining; + return this; + } + + public int getM() { + return m; + } + + @Nonnull + public ConfigBuilder setM(final int m) { + this.m = m; + return this; + } + + public int getMMax() { + return mMax; + } + + @Nonnull + public ConfigBuilder setMMax(final int mMax) { + this.mMax = mMax; + return this; + } + + public int getMMax0() { + return mMax0; + } + + @Nonnull + public ConfigBuilder setMMax0(final int mMax0) { + this.mMax0 = mMax0; + return this; + } + + public int getEfConstruction() { + return efConstruction; + } + + @Nonnull + public ConfigBuilder setEfConstruction(final int efConstruction) { + this.efConstruction = efConstruction; + return this; + } + + public boolean isExtendCandidates() { + return extendCandidates; + } + + @Nonnull + public ConfigBuilder setExtendCandidates(final boolean extendCandidates) { + this.extendCandidates = extendCandidates; + return this; + } + + public boolean isKeepPrunedConnections() { + return keepPrunedConnections; + } + + @Nonnull + public ConfigBuilder setKeepPrunedConnections(final boolean keepPrunedConnections) { + this.keepPrunedConnections = keepPrunedConnections; + return this; + } + + public double getSampleVectorStatsProbability() { + return sampleVectorStatsProbability; + } + + @Nonnull + public ConfigBuilder setSampleVectorStatsProbability(final double sampleVectorStatsProbability) { + this.sampleVectorStatsProbability = sampleVectorStatsProbability; + return this; + } + + public double getMaintainStatsProbability() { + return maintainStatsProbability; + } + + @Nonnull + public ConfigBuilder setMaintainStatsProbability(final double maintainStatsProbability) { + this.maintainStatsProbability = maintainStatsProbability; + return this; + } + + public int getStatsThreshold() { + return statsThreshold; + } + + @Nonnull + public ConfigBuilder setStatsThreshold(final int statsThreshold) { + this.statsThreshold = statsThreshold; + return this; + } + + public boolean isUseRaBitQ() { + return useRaBitQ; + } + + @Nonnull + public ConfigBuilder setUseRaBitQ(final boolean useRaBitQ) { + this.useRaBitQ = useRaBitQ; + return this; + } + + public int getRaBitQNumExBits() { + return raBitQNumExBits; + } + + @Nonnull + public ConfigBuilder setRaBitQNumExBits(final int raBitQNumExBits) { + this.raBitQNumExBits = raBitQNumExBits; + return this; + } + + public int getMaxNumConcurrentNodeFetches() { + return maxNumConcurrentNodeFetches; + } + + public ConfigBuilder setMaxNumConcurrentNodeFetches(final int maxNumConcurrentNodeFetches) { + this.maxNumConcurrentNodeFetches = maxNumConcurrentNodeFetches; + return this; + } + + public int getMaxNumConcurrentNeighborhoodFetches() { + return maxNumConcurrentNeighborhoodFetches; + } + + public ConfigBuilder setMaxNumConcurrentNeighborhoodFetches(final int maxNumConcurrentNeighborhoodFetches) { + this.maxNumConcurrentNeighborhoodFetches = maxNumConcurrentNeighborhoodFetches; + return this; + } + + public Config build(final int numDimensions) { + return new Config(getRandomSeed(), getMetric(), numDimensions, isUseInlining(), getM(), getMMax(), + getMMax0(), getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections(), + getSampleVectorStatsProbability(), getMaintainStatsProbability(), getStatsThreshold(), + isUseRaBitQ(), getRaBitQNumExBits(), getMaxNumConcurrentNodeFetches(), + getMaxNumConcurrentNeighborhoodFetches()); + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java new file mode 100644 index 0000000000..1d6b5ff4a6 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java @@ -0,0 +1,141 @@ +/* + * DeleteNeighborsChangeSet.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.linear.Quantizer; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.Collection; +import java.util.Objects; +import java.util.Set; +import java.util.function.Predicate; + +/** + * A {@link NeighborsChangeSet} that represents the deletion of a set of neighbors from a parent change set. + *

+ * This class acts as a filter, wrapping a parent {@link NeighborsChangeSet} and providing a view of the neighbors + * that excludes those whose primary keys have been marked for deletion. + * + * @param the type of the node reference, which must extend {@link NodeReference} + */ +class DeleteNeighborsChangeSet implements NeighborsChangeSet { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(DeleteNeighborsChangeSet.class); + + @Nonnull + private final NeighborsChangeSet parent; + + @Nonnull + private final Set deletedNeighborsPrimaryKeys; + + /** + * Constructs a new {@code DeleteNeighborsChangeSet}. + *

+ * This object represents a set of changes where specific neighbors are marked for deletion. + * It holds a reference to a parent {@link NeighborsChangeSet} and creates an immutable copy + * of the primary keys for the neighbors to be deleted. + * + * @param parent the parent {@link NeighborsChangeSet} to which this deletion change belongs. Must not be null. + * @param deletedNeighborsPrimaryKeys a {@link Collection} of primary keys, represented as {@link Tuple}s, + * identifying the neighbors to be deleted. Must not be null. + */ + public DeleteNeighborsChangeSet(@Nonnull final NeighborsChangeSet parent, + @Nonnull final Collection deletedNeighborsPrimaryKeys) { + this.parent = parent; + this.deletedNeighborsPrimaryKeys = ImmutableSet.copyOf(deletedNeighborsPrimaryKeys); + } + + /** + * Gets the parent change set from which this change set was derived. + *

+ * In a sequence of modifications, each {@code NeighborsChangeSet} is derived from a previous state, which is + * considered its parent. This method allows traversing the history of changes backward. + * + * @return the parent {@link NeighborsChangeSet} + */ + @Nonnull + @Override + public NeighborsChangeSet getParent() { + return parent; + } + + /** + * Merges the neighbors from the parent context, filtering out any neighbors that have been marked as deleted. + *

+ * This implementation retrieves the collection of neighbors from its parent by calling + * {@code getParent().merge()}. + * It then filters this collection, removing any neighbor whose primary key is present in the + * {@code deletedNeighborsPrimaryKeys} set. + * This ensures the resulting {@link Iterable} represents a consistent view of neighbors, respecting deletions made + * in the current context. + * + * @return an {@link Iterable} of the merged neighbors, excluding those marked as deleted. This method never returns + * {@code null}. + */ + @Nonnull + @Override + public Iterable merge() { + return Iterables.filter(getParent().merge(), + current -> !deletedNeighborsPrimaryKeys.contains(Objects.requireNonNull(current).getPrimaryKey())); + } + + /** + * Writes the delta of changes for a given node to the storage layer. + *

+ * This implementation first delegates to the parent's {@code writeDelta} method to handle its changes, but modifies + * the predicate to exclude any neighbors that are marked for deletion in this delta. + *

+ * It then iterates through the set of locally deleted neighbor primary keys. For each key that matches the supplied + * {@code tuplePredicate}, it instructs the {@link InliningStorageAdapter} to delete the corresponding neighbor + * relationship for the given {@code node}. + * + * @param storageAdapter the storage adapter to which the changes are written + * @param quantizer the quantizer to use + * @param transaction the transaction context for the write operations + * @param layer the layer index where the write operations should occur + * @param node the node for which the delta is being written + * @param tuplePredicate a predicate to filter which neighbor tuples should be processed; + * only deletions matching this predicate will be written + */ + @Override + public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, + @Nonnull final Quantizer quantizer, final int layer, @Nonnull final AbstractNode node, + @Nonnull final Predicate tuplePredicate) { + getParent().writeDelta(storageAdapter, transaction, quantizer, layer, node, + tuplePredicate.and(tuple -> !deletedNeighborsPrimaryKeys.contains(tuple))); + + for (final Tuple deletedNeighborPrimaryKey : deletedNeighborsPrimaryKeys) { + if (tuplePredicate.test(deletedNeighborPrimaryKey)) { + storageAdapter.deleteNeighbor(transaction, layer, node.asInliningNode(), deletedNeighborPrimaryKey); + if (logger.isTraceEnabled()) { + logger.trace("deleted neighbor of primaryKey={} targeting primaryKey={}", node.getPrimaryKey(), + deletedNeighborPrimaryKey); + } + } + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java new file mode 100644 index 0000000000..dfb7c9082b --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java @@ -0,0 +1,99 @@ +/* + * EntryNodeReference.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.Transformed; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import java.util.Objects; + +/** + * Represents an entry reference to a node within a hierarchical graph structure. + *

+ * This class extends {@link NodeReferenceWithVector} by adding a {@code layer} + * attribute. It is used to encapsulate all the necessary information for an + * entry point into a specific layer of the graph, including its unique identifier + * (primary key), its vector representation, and its hierarchical level. + */ +class EntryNodeReference extends NodeReferenceWithVector { + private final int layer; + + /** + * Constructs a new reference to an entry node. + *

+ * This constructor initializes the node with its primary key, its associated vector, + * and the specific layer it belongs to within a hierarchical graph structure. It calls the + * superclass constructor to set the {@code primaryKey} and {@code vector}. + * + * @param primaryKey the primary key identifying the node. Must not be {@code null}. + * @param vector the vector data associated with the node. Must not be {@code null}. + * @param layer the layer number where this entry node is located. + */ + public EntryNodeReference(@Nonnull final Tuple primaryKey, @Nonnull final Transformed vector, + final int layer) { + super(primaryKey, vector); + this.layer = layer; + } + + /** + * Gets the layer value for this object. + * @return the integer representing the layer + */ + public int getLayer() { + return layer; + } + + @Nonnull + public EntryNodeReference withVector(@Nonnull final Transformed newVector) { + return new EntryNodeReference(getPrimaryKey(), newVector, getLayer()); + } + + /** + * Compares this {@code EntryNodeReference} to the specified object for equality. + *

+ * The result is {@code true} if and only if the argument is an instance of {@code EntryNodeReference}, the + * superclass's {@link #equals(Object)} method returns {@code true}, and the {@code layer} fields of both objects + * are equal. + * @param o the object to compare this {@code EntryNodeReference} against. + * @return {@code true} if the given object is equal to this one; {@code false} otherwise. + */ + @Override + public boolean equals(final Object o) { + if (!super.equals(o)) { + return false; + } + return layer == ((EntryNodeReference)o).layer; + } + + /** + * Generates a hash code for this object. + *

+ * The hash code is computed by combining the hash code of the superclass with the hash code of the {@code layer} + * field. This implementation is consistent with the contract of {@link Object#hashCode()}. + * @return a hash code value for this object. + */ + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), layer); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java new file mode 100644 index 0000000000..2e7816b530 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java @@ -0,0 +1,1549 @@ +/* + * HNSW.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Database; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.async.MoreAsyncUtil; +import com.apple.foundationdb.linear.AffineOperator; +import com.apple.foundationdb.linear.Estimator; +import com.apple.foundationdb.linear.FhtKacRotator; +import com.apple.foundationdb.linear.Metric; +import com.apple.foundationdb.linear.Quantizer; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.Transformed; +import com.apple.foundationdb.rabitq.RaBitQuantizer; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import com.google.common.collect.Streams; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; +import java.util.Queue; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.apple.foundationdb.async.MoreAsyncUtil.forEach; +import static com.apple.foundationdb.async.MoreAsyncUtil.forLoop; + +/** + * An implementation of the Hierarchical Navigable Small World (HNSW) algorithm for + * efficient approximate nearest neighbor (ANN) search. + *

+ * HNSW constructs a multi-layer graph, where each layer is a subset of the one below it. + * The top layers serve as fast entry points to navigate the graph, while the bottom layer + * contains all the data points. This structure allows for logarithmic-time complexity + * for search operations, making it suitable for large-scale, high-dimensional datasets. + *

+ * This class provides methods for building the graph ({@link #insert(Transaction, Tuple, RealVector)}) + * and performing k-NN searches ({@link #kNearestNeighborsSearch(ReadTransaction, int, int, boolean, RealVector)}). + * It is designed to be used with a transactional storage backend, managed via a {@link Subspace}. + * + * @see Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs + */ +@API(API.Status.EXPERIMENTAL) +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +public class HNSW { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(HNSW.class); + + @Nonnull + private final Random random; + @Nonnull + private final Subspace subspace; + @Nonnull + private final Executor executor; + @Nonnull + private final Config config; + @Nonnull + private final OnWriteListener onWriteListener; + @Nonnull + private final OnReadListener onReadListener; + + /** + * Start building a {@link Config}. + * @return a new {@code Config} that can be altered and then built for use with a {@link HNSW} + * @see Config.ConfigBuilder#build + */ + public static Config.ConfigBuilder newConfigBuilder() { + return new Config.ConfigBuilder(); + } + + /** + * Returns a default {@link Config}. + * @param numDimensions number of dimensions + * @return a new default {@code Config}. + * @see Config.ConfigBuilder#build + */ + @Nonnull + public static Config defaultConfig(int numDimensions) { + return new Config.ConfigBuilder().build(numDimensions); + } + + /** + * Constructs a new HNSW graph instance. + *

+ * This constructor initializes the HNSW graph with the necessary components for storage, + * execution, configuration, and event handling. All parameters are mandatory and must not be null. + * + * @param subspace the {@link Subspace} where the graph data is stored. + * @param executor the {@link Executor} service to use for concurrent operations. + * @param config the {@link Config} object containing HNSW algorithm parameters. + * @param onWriteListener a listener to be notified of write events on the graph. + * @param onReadListener a listener to be notified of read events on the graph. + * + * @throws NullPointerException if any of the parameters are {@code null}. + */ + public HNSW(@Nonnull final Subspace subspace, + @Nonnull final Executor executor, + @Nonnull final Config config, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + this.random = new Random(config.getRandomSeed()); + this.subspace = subspace; + this.executor = executor; + this.config = config; + this.onWriteListener = onWriteListener; + this.onReadListener = onReadListener; + } + + + /** + * Gets the subspace associated with this object. + * + * @return the non-null subspace + */ + @Nonnull + public Subspace getSubspace() { + return subspace; + } + + /** + * Get the executor used by this hnsw. + * @return executor used when running asynchronous tasks + */ + @Nonnull + public Executor getExecutor() { + return executor; + } + + /** + * Get this hnsw's configuration. + * @return hnsw configuration + */ + @Nonnull + public Config getConfig() { + return config; + } + + /** + * Get the on-write listener. + * @return the on-write listener + */ + @Nonnull + public OnWriteListener getOnWriteListener() { + return onWriteListener; + } + + /** + * Get the on-read listener. + * @return the on-read listener + */ + @Nonnull + public OnReadListener getOnReadListener() { + return onReadListener; + } + + @Nonnull + private AffineOperator storageTransform(@Nullable final AccessInfo accessInfo) { + if (accessInfo == null || !accessInfo.canUseRaBitQ()) { + return AffineOperator.identity(); + } + + return new StorageTransform(accessInfo.getRotatorSeed(), + getConfig().getNumDimensions(), Objects.requireNonNull(accessInfo.getNegatedCentroid())); + } + + @Nonnull + private Quantizer quantizer(@Nullable final AccessInfo accessInfo) { + if (accessInfo == null || !accessInfo.canUseRaBitQ()) { + return Quantizer.noOpQuantizer(config.getMetric()); + } + + final Config config = getConfig(); + return config.isUseRaBitQ() + ? new RaBitQuantizer(config.getMetric(), config.getRaBitQNumExBits()) + : Quantizer.noOpQuantizer(config.getMetric()); + } + + // + // Read Path + // + + /** + * Performs a k-nearest neighbors (k-NN) search for a given query vector. + *

+ * This method implements the search algorithm for an HNSW graph. The search begins at an entry point in the + * highest layer and greedily traverses down through the layers. In each layer, it finds the node closest to the + * {@code queryVector}. This node then serves as the entry point for the search in the layer below. + *

+ * Once the search reaches the base layer (layer 0), it performs a more exhaustive search starting from the + * determined entry point. It explores the graph, maintaining a dynamic list of the best candidates found so far. + * The size of this candidate list is controlled by the {@code efSearch} parameter. Finally, the method selects + * the top {@code k} nodes from the search results, sorted by their distance to the query vector. + * + * @param readTransaction the transaction to use for reading from the database + * @param k the number of nearest neighbors to return + * @param efSearch the size of the dynamic candidate list for the search. A larger value increases accuracy + * at the cost of performance. + * @param includeVectors indicator if the caller would like the search to also include vectors in the result set + * @param queryVector the vector to find the nearest neighbors for + * + * @return a {@link CompletableFuture} that will complete with a list of the {@code k} nearest neighbors, + * sorted by distance in ascending order. + */ + @SuppressWarnings("checkstyle:MethodName") // method name introduced by paper + @Nonnull + public CompletableFuture> + kNearestNeighborsSearch(@Nonnull final ReadTransaction readTransaction, + final int k, + final int efSearch, + final boolean includeVectors, + @Nonnull final RealVector queryVector) { + return StorageAdapter.fetchAccessInfo(getConfig(), readTransaction, getSubspace(), getOnReadListener()) + .thenCompose(accessInfo -> { + if (accessInfo == null) { + return CompletableFuture.completedFuture(ImmutableList.of()); // not a single node in the index + } + final EntryNodeReference entryNodeReference = accessInfo.getEntryNodeReference(); + + final AffineOperator storageTransform = storageTransform(accessInfo); + final Transformed transformedQueryVector = storageTransform.transform(queryVector); + final Quantizer quantizer = quantizer(accessInfo); + final Estimator estimator = quantizer.estimator(); + + final NodeReferenceWithDistance entryState = + new NodeReferenceWithDistance(entryNodeReference.getPrimaryKey(), + entryNodeReference.getVector(), + estimator.distance(transformedQueryVector, entryNodeReference.getVector())); + + final int entryLayer = entryNodeReference.getLayer(); + return forLoop(entryLayer, entryState, + layer -> layer > 0, + layer -> layer - 1, + (layer, previousNodeReference) -> { + final var storageAdapter = getStorageAdapterForLayer(layer); + return greedySearchLayer(storageAdapter, readTransaction, storageTransform, estimator, + previousNodeReference, layer, transformedQueryVector); + }, executor) + .thenCompose(nodeReference -> { + final var storageAdapter = getStorageAdapterForLayer(0); + + return searchFinalLayer(storageAdapter, readTransaction, storageTransform, estimator, + k, efSearch, nodeReference, includeVectors, transformedQueryVector); + }); + }); + } + + /** + * Method to search layer {@code 0} starting at a {@code nodeReference} for the {@code k} nearest neighbors of + * {@code transformedQueryVector}. The vectors that are part of the result of this search are transformed into the + * client coordinate system. + * + * @param type parameter for the type of node reference to use + * @param storageAdapter the storage adapter + * @param readTransaction the transaction to use + * @param storageTransform the storage transform needed to transform vector data back into the client coordinate + * system + * @param estimator the distance estimator in use + * @param k the number of nearest neighbors the wants us to find + * @param efSearch the search queue capacity + * @param nodeReference the entry node reference + * @param includeVectors indicator if the caller would like the search to also include vectors in the result set + * @param transformedQueryVector the transformed query vector + * + * @return a list of {@link NodeReferenceAndNode} representing the {@code k} nearest neighbors of + * {@code transformedQueryVector} + */ + @Nonnull + private CompletableFuture> + searchFinalLayer(@Nonnull final StorageAdapter storageAdapter, + final @Nonnull ReadTransaction readTransaction, + @Nonnull final AffineOperator storageTransform, + @Nonnull final Estimator estimator, + final int k, + final int efSearch, + @Nonnull final NodeReferenceWithDistance nodeReference, + final boolean includeVectors, + @Nonnull final Transformed transformedQueryVector) { + return searchLayer(storageAdapter, readTransaction, storageTransform, estimator, + ImmutableList.of(nodeReference), 0, efSearch, Maps.newConcurrentMap(), + transformedQueryVector) + .thenApply(searchResult -> + postProcessNearestNeighbors(storageTransform, k, searchResult, includeVectors)); + } + + @Nonnull + private ImmutableList + postProcessNearestNeighbors(@Nonnull final AffineOperator storageTransform, final int k, + @Nonnull final List> nearestNeighbors, + final boolean includeVectors) { + final int lastIndex = Math.max(nearestNeighbors.size() - k, 0); + + final ImmutableList.Builder resultBuilder = + ImmutableList.builder(); + + for (int i = nearestNeighbors.size() - 1; i >= lastIndex; i --) { + final var nodeReferenceAndNode = nearestNeighbors.get(i); + final var nodeReference = + Objects.requireNonNull(nodeReferenceAndNode).getNodeReferenceWithDistance(); + final AbstractNode node = nodeReferenceAndNode.getNode(); + @Nullable final RealVector reconstructedVector = + includeVectors ? storageTransform.untransform(node.asCompactNode().getVector()) : null; + + resultBuilder.add( + new ResultEntry(node.getPrimaryKey(), + reconstructedVector, nodeReference.getDistance(), + nearestNeighbors.size() - i - 1)); + } + return resultBuilder.build(); + } + + /** + * Performs a greedy search on a single layer of the HNSW graph. + *

+ * This method finds the node on the specified layer that is closest to the given query vector, + * starting the search from a designated entry point. The search is "greedy" because it aims to find + * only the single best neighbor. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the {@link StorageAdapter} for accessing the graph data + * @param readTransaction the {@link ReadTransaction} to use for the search + * @param estimator a distance estimator + * @param nodeReference the starting point for the search on this layer, which includes the node and its distance to + * the query vector + * @param layer the zero-based index of the layer to search within + * @param queryVector the query vector for which to find the nearest neighbor + * + * @return a {@link CompletableFuture} that, upon completion, will contain the closest node found on the layer, + * represented as a {@link NodeReferenceWithDistance} + */ + @Nonnull + private CompletableFuture + greedySearchLayer(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final AffineOperator storageTransform, + @Nonnull final Estimator estimator, + @Nonnull final NodeReferenceWithDistance nodeReference, + final int layer, + @Nonnull final Transformed queryVector) { + return searchLayer(storageAdapter, readTransaction, storageTransform, estimator, + ImmutableList.of(nodeReference), layer, 1, Maps.newConcurrentMap(), queryVector) + .thenApply(searchResult -> + Iterables.getOnlyElement(searchResult).getNodeReferenceWithDistance()); + } + + /** + * Searches a single layer of the graph to find the nearest neighbors to a query vector. + *

+ * This method implements the greedy search algorithm used in HNSW (Hierarchical Navigable Small World) + * graphs for a specific layer. It begins with a set of entry points and iteratively explores the graph, + * always moving towards nodes that are closer to the {@code queryVector}. + *

+ * It maintains a priority queue of candidates to visit and a result set of the nearest neighbors found so far. + * The size of the dynamic candidate list is controlled by the {@code efSearch} parameter, which balances + * search quality and performance. The entire process is asynchronous, leveraging + * {@link java.util.concurrent.CompletableFuture} + * to handle I/O operations (fetching nodes) without blocking. + * + * @param The type of the node reference, extending {@link NodeReference}. + * @param storageAdapter The storage adapter for accessing node data from the underlying storage. + * @param readTransaction The transaction context for all database read operations. + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param estimator the estimator to use + * @param nodeReferences A collection of starting node references for the search in this layer, with their distances + * to the query vector already calculated. + * @param layer The zero-based index of the layer to search. + * @param efSearch The size of the dynamic candidate list. A larger value increases recall at the + * cost of performance. + * @param nodeCache A cache of nodes that have already been fetched from storage to avoid redundant I/O. + * @param queryVector The vector for which to find the nearest neighbors. + * + * @return A {@link java.util.concurrent.CompletableFuture} that, upon completion, will contain a list of the + * best candidate nodes found in this layer, paired with their full node data. + */ + @Nonnull + private CompletableFuture>> + searchLayer(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final AffineOperator storageTransform, + @Nonnull final Estimator estimator, + @Nonnull final Collection nodeReferences, + final int layer, + final int efSearch, + @Nonnull final Map> nodeCache, + @Nonnull final Transformed queryVector) { + final Set visited = Sets.newConcurrentHashSet(NodeReference.primaryKeys(nodeReferences)); + final Queue candidates = + // This initial capacity is somewhat arbitrary as m is not necessarily a limit, + // but it gives us a number that is better than the default. + new PriorityQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + candidates.addAll(nodeReferences); + final Queue nearestNeighbors = + new PriorityQueue<>(efSearch + 1, // prevent reallocation further down + Comparator.comparing(NodeReferenceWithDistance::getDistance) + .thenComparing(NodeReferenceWithDistance::getPrimaryKey).reversed()); + nearestNeighbors.addAll(nodeReferences); + + return AsyncUtil.whileTrue(() -> { + if (candidates.isEmpty()) { + return AsyncUtil.READY_FALSE; + } + + final NodeReferenceWithDistance candidate = candidates.poll(); + final NodeReferenceWithDistance furthestNeighbor = Objects.requireNonNull(nearestNeighbors.peek()); + + if (candidate.getDistance() > furthestNeighbor.getDistance()) { + return AsyncUtil.READY_FALSE; + } + + return fetchNodeIfNotCached(storageAdapter, readTransaction, storageTransform, layer, candidate, nodeCache) + .thenApply(candidateNode -> + Iterables.filter(candidateNode.getNeighbors(), + neighbor -> !visited.contains(Objects.requireNonNull(neighbor).getPrimaryKey()))) + .thenCompose(neighborReferences -> fetchNeighborhood(storageAdapter, readTransaction, + storageTransform, layer, neighborReferences, nodeCache)) + .thenApply(neighborReferences -> { + for (final NodeReferenceWithVector current : neighborReferences) { + visited.add(current.getPrimaryKey()); + final double furthestDistance = + Objects.requireNonNull(nearestNeighbors.peek()).getDistance(); + + final double currentDistance = estimator.distance(queryVector, current.getVector()); + if (currentDistance < furthestDistance || nearestNeighbors.size() < efSearch) { + final NodeReferenceWithDistance currentWithDistance = + new NodeReferenceWithDistance(current.getPrimaryKey(), current.getVector(), + currentDistance); + candidates.add(currentWithDistance); + nearestNeighbors.add(currentWithDistance); + if (nearestNeighbors.size() > efSearch) { + nearestNeighbors.poll(); + } + } + } + return true; + }); + }) + .thenCompose(ignored -> + fetchSomeNodesIfNotCached(storageAdapter, readTransaction, storageTransform, layer, + drain(nearestNeighbors), nodeCache)) + .thenApply(searchResult -> { + if (logger.isTraceEnabled()) { + logger.trace("searched layer={} for efSearch={} with result=={}", layer, efSearch, + searchResult.stream() + .map(nodeReferenceAndNode -> + "(primaryKey=" + + nodeReferenceAndNode.getNodeReferenceWithDistance().getPrimaryKey() + + ",distance=" + + nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance() + ")") + .collect(Collectors.joining(","))); + } + return searchResult; + }); + } + + /** + * Asynchronously fetches a node if it is not already present in the cache. + *

+ * This method first attempts to retrieve the node from the provided {@code nodeCache} using the + * primary key of the {@code nodeReference}. If the node is not found in the cache, it is + * fetched from the underlying storage using the {@code storageAdapter}. Once fetched, the node + * is added to the {@code nodeCache} before the future is completed. + *

+ * This is a convenience method that delegates to + * {@link #fetchNodeIfNecessaryAndApply(StorageAdapter, ReadTransaction, AffineOperator, int, NodeReference, Function, BiFunction)}. + * + * @param the type of the node reference, which must extend {@link NodeReference} + * @param storageAdapter the storage adapter used to fetch the node from persistent storage + * @param readTransaction the transaction to use for reading from storage + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param layer the layer index where the node is located + * @param nodeReference the reference to the node to fetch + * @param nodeCache the cache to check for the node and to which the node will be added if fetched + * + * @return a {@link CompletableFuture} that will be completed with the fetched or cached {@link AbstractNode} + */ + @Nonnull + private CompletableFuture> + fetchNodeIfNotCached(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final AffineOperator storageTransform, + final int layer, + @Nonnull final NodeReference nodeReference, + @Nonnull final Map> nodeCache) { + return fetchNodeIfNecessaryAndApply(storageAdapter, readTransaction, storageTransform, layer, nodeReference, + nR -> nodeCache.get(nR.getPrimaryKey()), + (nR, node) -> { + nodeCache.put(nR.getPrimaryKey(), node); + return node; + }); + } + + /** + * Conditionally fetches a node from storage and applies a function to it. + *

+ * This method first attempts to generate a result by applying the {@code fetchBypassFunction}. + * If this function returns a non-null value, that value is returned immediately in a + * completed {@link CompletableFuture}, and no storage access occurs. This provides an + * optimization path, for example, if the required data is already available in a cache. + *

+ * If the bypass function returns {@code null}, the method proceeds to asynchronously fetch the + * node from the given {@code StorageAdapter}. Once the node is retrieved, the + * {@code biMapFunction} is applied to the original {@code nodeReference} and the fetched + * {@code Node} to produce the final result. + * + * @param The type of the input node reference. + * @param The type of the node reference used by the storage adapter. + * @param The type of the result. + * @param storageAdapter The storage adapter used to fetch the node if necessary. + * @param readTransaction The read transaction context for the storage operation. + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param layer The layer index from which to fetch the node. + * @param nodeReference The reference to the node that may need to be fetched. + * @param fetchBypassFunction A function that provides a potential shortcut. If it returns a + * non-null value, the node fetch is bypassed. + * @param biMapFunction A function to be applied after a successful node fetch, combining the + * original reference and the fetched node to produce the final result. + * + * @return A {@link CompletableFuture} that will complete with the result from either the + * {@code fetchBypassFunction} or the {@code biMapFunction}. + */ + @Nonnull + private CompletableFuture + fetchNodeIfNecessaryAndApply(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final AffineOperator storageTransform, + final int layer, + @Nonnull final R nodeReference, + @Nonnull final Function fetchBypassFunction, + @Nonnull final BiFunction, U> biMapFunction) { + final U bypass = fetchBypassFunction.apply(nodeReference); + if (bypass != null) { + return CompletableFuture.completedFuture(bypass); + } + + return onReadListener.onAsyncRead( + storageAdapter.fetchNode(readTransaction, storageTransform, layer, + nodeReference.getPrimaryKey())) + .thenApply(node -> biMapFunction.apply(nodeReference, node)); + } + + /** + * Asynchronously fetches neighborhood nodes and returns them as {@link NodeReferenceWithVector} instances, + * which include the node's vector. + *

+ * This method efficiently retrieves node data by first checking an in-memory {@code nodeCache}. If a node is not + * in the cache, it is fetched from the {@link StorageAdapter}. Fetched nodes are then added to the cache to + * optimize subsequent lookups. It also handles cases where the input {@code neighborReferences} may already + * contain {@link NodeReferenceWithVector} instances, avoiding redundant work. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the storage adapter to fetch nodes from if they are not in the cache + * @param readTransaction the transaction context for database read operations + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param layer the graph layer from which to fetch the nodes + * @param neighborReferences an iterable of references to the neighbor nodes to be fetched + * @param nodeCache a map serving as an in-memory cache for nodes. This map will be populated with any + * nodes fetched from storage. + * + * @return a {@link CompletableFuture} that, upon completion, will contain a list of + * {@link NodeReferenceWithVector} objects for the specified neighbors + */ + @Nonnull + private CompletableFuture> + fetchNeighborhood(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final AffineOperator storageTransform, + final int layer, + @Nonnull final Iterable neighborReferences, + @Nonnull final Map> nodeCache) { + return fetchSomeNodesAndApply(storageAdapter, readTransaction, storageTransform, layer, neighborReferences, + neighborReference -> { + if (neighborReference.isNodeReferenceWithVector()) { + return neighborReference.asNodeReferenceWithVector(); + } + final AbstractNode neighborNode = nodeCache.get(neighborReference.getPrimaryKey()); + if (neighborNode == null) { + return null; + } + return new NodeReferenceWithVector(neighborReference.getPrimaryKey(), + neighborNode.asCompactNode().getVector()); + }, + (neighborReference, neighborNode) -> { + // + // At this point we know that the node needed to be fetched which excludes INLINING nodes + // as they never have to be fetched. Therefore, we can safely treat the nodes as compact nodes. + // + nodeCache.put(neighborReference.getPrimaryKey(), neighborNode); + return new NodeReferenceWithVector(neighborReference.getPrimaryKey(), + neighborNode.asCompactNode().getVector()); + }); + } + + /** + * Fetches a collection of nodes, attempting to retrieve them from a cache first before + * accessing the underlying storage. + *

+ * This method iterates through the provided {@code nodeReferences}. For each reference, it + * first checks the {@code nodeCache}. If the corresponding {@link AbstractNode} is found, it is + * used directly. If not, the node is fetched from the {@link StorageAdapter}. Any nodes + * fetched from storage are then added to the {@code nodeCache} to optimize subsequent lookups. + * The entire operation is performed asynchronously. + * + * @param The type of the node reference, which must extend {@link NodeReference}. + * @param storageAdapter The storage adapter used to fetch nodes from storage if they are not in the cache. + * @param readTransaction The transaction context for the read operation. + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param layer The layer from which to fetch the nodes. + * @param nodeReferences An {@link Iterable} of {@link NodeReferenceWithDistance} objects identifying the nodes to + * be fetched. + * @param nodeCache A map used as a cache. It is checked for existing nodes and updated with any newly fetched + * nodes. + * + * @return A {@link CompletableFuture} which will complete with a {@link List} of {@link NodeReferenceAndNode} + * objects, pairing each requested reference with its corresponding node. + */ + @Nonnull + private CompletableFuture>> + fetchSomeNodesIfNotCached(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final AffineOperator storageTransform, + final int layer, + @Nonnull final Iterable nodeReferences, + @Nonnull final Map> nodeCache) { + return fetchSomeNodesAndApply(storageAdapter, readTransaction, storageTransform, layer, nodeReferences, + nodeReference -> { + final AbstractNode node = nodeCache.get(nodeReference.getPrimaryKey()); + if (node == null) { + return null; + } + return new NodeReferenceAndNode<>(nodeReference, node); + }, + (nodeReferenceWithDistance, node) -> { + nodeCache.put(nodeReferenceWithDistance.getPrimaryKey(), node); + return new NodeReferenceAndNode<>(nodeReferenceWithDistance, node); + }); + } + + /** + * Asynchronously fetches a collection of nodes from storage and applies a function to each. + *

+ * For each {@link NodeReference} in the provided iterable, this method concurrently fetches the corresponding + * {@code Node} using the given {@link StorageAdapter}. The logic delegates to + * {@code fetchNodeIfNecessaryAndApply}, which determines whether a full node fetch is required. + * If a node is fetched from storage, the {@code biMapFunction} is applied. If the fetch is bypassed + * (e.g., because the reference itself contains sufficient information), the {@code fetchBypassFunction} is used + * instead. + * + * @param The type of the node references to be processed, extending {@link NodeReference}. + * @param The type of the key references within the nodes, extending {@link NodeReference}. + * @param The type of the result after applying one of the mapping functions. + * @param storageAdapter The {@link StorageAdapter} used to fetch nodes from the underlying storage. + * @param readTransaction The {@link ReadTransaction} context for the read operations. + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param layer The layer index from which the nodes are being fetched. + * @param nodeReferences An {@link Iterable} of {@link NodeReference}s for the nodes to be fetched and processed. + * @param fetchBypassFunction The function to apply to a node reference when the actual node fetch is bypassed, + * mapping the reference directly to a result of type {@code U}. + * @param biMapFunction The function to apply when a node is successfully fetched, mapping the original + * reference and the fetched {@link AbstractNode} to a result of type {@code U}. + * + * @return A {@link CompletableFuture} that, upon completion, will hold a {@link java.util.List} of results + * of type {@code U}, corresponding to each processed node reference. + */ + @Nonnull + private CompletableFuture> + fetchSomeNodesAndApply(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final AffineOperator storageTransform, + final int layer, + @Nonnull final Iterable nodeReferences, + @Nonnull final Function fetchBypassFunction, + @Nonnull final BiFunction, U> biMapFunction) { + return forEach(nodeReferences, + currentNeighborReference -> fetchNodeIfNecessaryAndApply(storageAdapter, readTransaction, + storageTransform, layer, currentNeighborReference, fetchBypassFunction, biMapFunction), + getConfig().getMaxNumConcurrentNodeFetches(), + getExecutor()); + } + + /** + * Inserts a new vector with its associated primary key into the HNSW graph. + *

+ * The method first determines a random layer for the new node, called the {@code insertionLayer}. + * It then traverses the graph from the entry point downwards, greedily searching for the nearest + * neighbors to the {@code newVector} at each layer. This search identifies the optimal + * connection points for the new node. + *

+ * Once the nearest neighbors are found, the new node is linked into the graph structure at all + * layers up to its {@code insertionLayer}. Special handling is included for inserting the + * first-ever node into the graph or when a new node's layer is higher than any existing node, + * which updates the graph's entry point. All operations are performed asynchronously. + * + * @param transaction the {@link Transaction} context for all database operations + * @param newPrimaryKey the unique {@link Tuple} primary key for the new node being inserted + * @param newVector the {@link RealVector} data to be inserted into the graph + * + * @return a {@link CompletableFuture} that completes when the insertion operation is finished + */ + @Nonnull + public CompletableFuture insert(@Nonnull final Transaction transaction, @Nonnull final Tuple newPrimaryKey, + @Nonnull final RealVector newVector) { + final int insertionLayer = insertionLayer(); + if (logger.isTraceEnabled()) { + logger.trace("new node with key={} selected to be inserted into layer={}", newPrimaryKey, insertionLayer); + } + + return StorageAdapter.fetchAccessInfo(getConfig(), transaction, getSubspace(), getOnReadListener()) + .thenCompose(accessInfo -> { + final AccessInfo currentAccessInfo; + final AffineOperator storageTransform = storageTransform(accessInfo); + final Transformed transformedNewVector = storageTransform.transform(newVector); + final Quantizer quantizer = quantizer(accessInfo); + final Estimator estimator = quantizer.estimator(); + + if (accessInfo == null) { + // this is the first node + writeLonelyNodes(quantizer, transaction, newPrimaryKey, transformedNewVector, + insertionLayer, -1); + currentAccessInfo = new AccessInfo( + new EntryNodeReference(newPrimaryKey, transformedNewVector, insertionLayer), + -1L, null); + StorageAdapter.writeAccessInfo(transaction, getSubspace(), currentAccessInfo, + getOnWriteListener()); + if (logger.isTraceEnabled()) { + logger.trace("written initial entry node reference with key={} on layer={}", + newPrimaryKey, insertionLayer); + } + return AsyncUtil.DONE; + } else { + final EntryNodeReference entryNodeReference = accessInfo.getEntryNodeReference(); + final int lMax = entryNodeReference.getLayer(); + if (insertionLayer > lMax) { + writeLonelyNodes(quantizer, transaction, newPrimaryKey, transformedNewVector, + insertionLayer, lMax); + currentAccessInfo = accessInfo.withNewEntryNodeReference( + new EntryNodeReference(newPrimaryKey, transformedNewVector, + insertionLayer)); + StorageAdapter.writeAccessInfo(transaction, getSubspace(), currentAccessInfo, + getOnWriteListener()); + if (logger.isTraceEnabled()) { + logger.trace("written higher entry node reference with key={} on layer={}", + newPrimaryKey, insertionLayer); + } + } else { + currentAccessInfo = accessInfo; + } + } + + final EntryNodeReference entryNodeReference = accessInfo.getEntryNodeReference(); + final int lMax = entryNodeReference.getLayer(); + if (logger.isTraceEnabled()) { + logger.trace("entry node read with key {} at layer {}", entryNodeReference.getPrimaryKey(), lMax); + } + + final NodeReferenceWithDistance initialNodeReference = + new NodeReferenceWithDistance(entryNodeReference.getPrimaryKey(), + entryNodeReference.getVector(), + estimator.distance(transformedNewVector, entryNodeReference.getVector())); + return forLoop(lMax, initialNodeReference, + layer -> layer > insertionLayer, + layer -> layer - 1, + (layer, previousNodeReference) -> { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + return greedySearchLayer(storageAdapter, transaction, storageTransform, + estimator, previousNodeReference, layer, transformedNewVector); + }, executor) + .thenCompose(nodeReference -> + insertIntoLayers(transaction, storageTransform, quantizer, newPrimaryKey, + transformedNewVector, nodeReference, lMax, insertionLayer)) + .thenCompose(ignored -> + addToStatsIfNecessary(transaction, currentAccessInfo, transformedNewVector)); + }).thenCompose(ignored -> AsyncUtil.DONE); + } + + /** + * Method to keep stats if necessary. Stats need to be kept and maintained when the client would like to use + * e.g. RaBitQ as RaBitQ needs a stable somewhat correct centroid in order to function properly. + *

+ * Specifically for RaBitQ, we add vectors to a set of sampled vectors in a designated subspace of the HNSW + * structure. The parameter {@link Config#getSampleVectorStatsProbability()} governs when we do sample. Another + * parameter, {@link Config#getMaintainStatsProbability()}, determines how many times we add-up/replace (consume) + * vectors from this sampled-vector space and aggregate them in the typical running count/running sum scheme + * in order to finally compute the centroid if {@link Config#getStatsThreshold()} number of vectors have been + * sampled and aggregated. That centroid is then used to update the access info. + * + * @param transaction the transaction + * @param currentAccessInfo this current access info that was fetched as part of an insert + * @param transformedNewVector the new vector (in the transformed coordinate system) that may be added + * @return a future that returns {@code null} when completed + */ + @Nonnull + private CompletableFuture addToStatsIfNecessary(@Nonnull final Transaction transaction, + @Nonnull final AccessInfo currentAccessInfo, + @Nonnull final Transformed transformedNewVector) { + if (getConfig().isUseRaBitQ() && !currentAccessInfo.canUseRaBitQ()) { + if (shouldSampleVector()) { + StorageAdapter.appendSampledVector(transaction, getSubspace(), + 1, transformedNewVector, onWriteListener); + } + if (shouldMaintainStats()) { + return StorageAdapter.consumeSampledVectors(transaction, getSubspace(), + 50, onReadListener) + .thenApply(sampledVectors -> { + final AggregatedVector aggregatedSampledVector = + aggregateVectors(sampledVectors); + + if (aggregatedSampledVector != null) { + final int partialCount = aggregatedSampledVector.getPartialCount(); + final Transformed partialVector = aggregatedSampledVector.getPartialVector(); + StorageAdapter.appendSampledVector(transaction, getSubspace(), + partialCount, partialVector, onWriteListener); + if (logger.isTraceEnabled()) { + logger.trace("updated stats with numVectors={}, partialCount={}, partialVector={}", + sampledVectors.size(), partialCount, partialVector); + } + + if (partialCount >= getConfig().getStatsThreshold()) { + final long rotatorSeed = random.nextLong(); + final FhtKacRotator rotator = + new FhtKacRotator(rotatorSeed, getConfig().getNumDimensions(), 10); + + final Transformed centroid = + partialVector.multiply(-1.0d / partialCount); + final RealVector rotatedCentroid = + rotator.apply(centroid.getUnderlyingVector()); + final StorageTransform storageTransform = + new StorageTransform(rotator, rotatedCentroid); + + // + // The entry node reference is expressed in a transformation that has so-far been + // the identity-transformation. We now need to get the underlying identical vector + // and, for the first time, transform that vector into the new rotated and + // translated coordinate system. In this way we guarantee, that the entry node is + // always expressed in the internal system, while data vectors may be a mix of + // vectors. + // + final Transformed transformedEntryNodeVector = + storageTransform.transform(currentAccessInfo.getEntryNodeReference() + .getVector().getUnderlyingVector()); + + final AccessInfo newAccessInfo = + new AccessInfo(currentAccessInfo.getEntryNodeReference().withVector(transformedEntryNodeVector), + rotatorSeed, rotatedCentroid); + StorageAdapter.writeAccessInfo(transaction, getSubspace(), newAccessInfo, onWriteListener); + StorageAdapter.removeAllSampledVectors(transaction, getSubspace()); + if (logger.isTraceEnabled()) { + logger.trace("established rotatorSeed={}, centroid with count={}, centroid={}", + rotatorSeed, partialCount, rotatedCentroid); + } + } + } + return null; + }); + } + } + return AsyncUtil.DONE; + } + + @Nullable + private AggregatedVector aggregateVectors(@Nonnull final Iterable vectors) { + Transformed partialVector = null; + int partialCount = 0; + for (final AggregatedVector vector : vectors) { + partialVector = partialVector == null + ? vector.getPartialVector() : partialVector.add(vector.getPartialVector()); + partialCount += vector.getPartialCount(); + } + return partialCount == 0 ? null : new AggregatedVector(partialCount, partialVector); + } + + /** + * Inserts a new vector into the HNSW graph across multiple layers, starting from a given entry point. + *

+ * This method implements the second phase of the HNSW insertion algorithm. It begins at a starting layer, which is + * the minimum of the graph's maximum layer ({@code lMax}) and the new node's randomly assigned + * {@code insertionLayer}. It then iterates downwards to layer 0. In each layer, it invokes + * {@link #insertIntoLayer(StorageAdapter, Transaction, AffineOperator, Quantizer, List, int, Tuple, Transformed)} + * to perform the search and connect the new node. The set of nearest neighbors found at layer {@code L} serves as + * the entry points for the search at layer {@code L-1}. + *

+ * + * @param transaction the transaction to use for database operations + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param quantizer the quantizer to be used for this insert + * @param newPrimaryKey the primary key of the new node being inserted + * @param newVector the vector data of the new node + * @param nodeReference the initial entry point for the search, typically the nearest neighbor found in the highest + * layer + * @param lMax the maximum layer number in the HNSW graph + * @param insertionLayer the randomly determined layer for the new node. The node will be inserted into all layers + * from this layer down to 0. + * + * @return a {@link CompletableFuture} that completes when the new node has been successfully inserted into all + * its designated layers + */ + @Nonnull + private CompletableFuture insertIntoLayers(@Nonnull final Transaction transaction, + @Nonnull final AffineOperator storageTransform, + @Nonnull final Quantizer quantizer, + @Nonnull final Tuple newPrimaryKey, + @Nonnull final Transformed newVector, + @Nonnull final NodeReferenceWithDistance nodeReference, + final int lMax, + final int insertionLayer) { + if (logger.isTraceEnabled()) { + logger.trace("nearest entry point at lMax={} is at key={}", lMax, nodeReference.getPrimaryKey()); + } + return MoreAsyncUtil.>forLoop(Math.min(lMax, insertionLayer), ImmutableList.of(nodeReference), + layer -> layer >= 0, + layer -> layer - 1, + (layer, previousNodeReferences) -> { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + return insertIntoLayer(storageAdapter, transaction, storageTransform, quantizer, + previousNodeReferences, layer, newPrimaryKey, newVector); + }, executor).thenCompose(ignored -> AsyncUtil.DONE); + } + + /** + * Inserts a new node into a specified layer of the HNSW graph. + *

+ * This method orchestrates the complete insertion process for a single layer. It begins by performing a search + * within the given layer, starting from the provided {@code nearestNeighbors} as entry points, to find a set of + * candidate neighbors for the new node. From this candidate set, it selects the best connections based on the + * graph's parameters (M). + *

+ *

+ * After selecting the neighbors, it creates the new node and links it to them. It then reciprocally updates + * the selected neighbors to link back to the new node. If adding this new link causes a neighbor to exceed its + * maximum allowed connections, its connections are pruned. All changes, including the new node and the updated + * neighbors, are persisted to storage within the given transaction. + *

+ *

+ * The operation is asynchronous and returns a {@link CompletableFuture}. The future completes with the list of + * nodes found during the initial search phase, which are then used as the entry points for insertion into the + * next lower layer. + *

+ * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the storage adapter for reading from and writing to the graph + * @param transaction the transaction context for the database operations + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param quantizer the quantizer for this insert + * @param nearestNeighbors the list of nearest neighbors from the layer above, used as entry points for the search + * in this layer + * @param layer the layer number to insert the new node into + * @param newPrimaryKey the primary key of the new node to be inserted + * @param newVector the vector associated with the new node + * + * @return a {@code CompletableFuture} that completes with a list of the nearest neighbors found during the + * initial search phase. This list serves as the entry point for insertion into the next lower layer + * (i.e., {@code layer - 1}). + */ + @Nonnull + private CompletableFuture> + insertIntoLayer(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + @Nonnull final AffineOperator storageTransform, + @Nonnull final Quantizer quantizer, + @Nonnull final List nearestNeighbors, + final int layer, + @Nonnull final Tuple newPrimaryKey, + @Nonnull final Transformed newVector) { + if (logger.isTraceEnabled()) { + logger.trace("begin insert key={} at layer={}", newPrimaryKey, layer); + } + final Map> nodeCache = Maps.newConcurrentMap(); + final Estimator estimator = quantizer.estimator(); + + return searchLayer(storageAdapter, transaction, storageTransform, estimator, + nearestNeighbors, layer, config.getEfConstruction(), nodeCache, newVector) + .thenCompose(searchResult -> { + final List references = NodeReferenceAndNode.getReferences(searchResult); + + return selectNeighbors(storageAdapter, transaction, storageTransform, estimator, searchResult, + layer, getConfig().getM(), getConfig().isExtendCandidates(), nodeCache, newVector) + .thenCompose(selectedNeighbors -> { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + + final AbstractNode newNode = + nodeFactory.create(newPrimaryKey, newVector, + NodeReferenceAndNode.getReferences(selectedNeighbors)); + + final NeighborsChangeSet newNodeChangeSet = + new InsertNeighborsChangeSet<>(new BaseNeighborsChangeSet<>(ImmutableList.of()), + newNode.getNeighbors()); + + storageAdapter.writeNode(transaction, quantizer, newNode, layer, newNodeChangeSet); + + // create change sets for each selected neighbor and insert new node into them + final Map> neighborChangeSetMap = + Maps.newLinkedHashMap(); + for (final NodeReferenceAndNode selectedNeighbor : selectedNeighbors) { + final NeighborsChangeSet baseSet = + new BaseNeighborsChangeSet<>(selectedNeighbor.getNode().getNeighbors()); + final NeighborsChangeSet insertSet = + new InsertNeighborsChangeSet<>(baseSet, ImmutableList.of(newNode.getSelfReference(newVector))); + neighborChangeSetMap.put(selectedNeighbor.getNode().getPrimaryKey(), + insertSet); + } + + final int currentMMax = layer == 0 ? getConfig().getMMax0() : getConfig().getMMax(); + return forEach(selectedNeighbors, + selectedNeighbor -> { + final AbstractNode selectedNeighborNode = selectedNeighbor.getNode(); + final NeighborsChangeSet changeSet = + Objects.requireNonNull(neighborChangeSetMap.get(selectedNeighborNode.getPrimaryKey())); + return pruneNeighborsIfNecessary(storageAdapter, transaction, + storageTransform, estimator, selectedNeighbor, layer, + currentMMax, changeSet, nodeCache) + .thenApply(nodeReferencesAndNodes -> { + if (nodeReferencesAndNodes == null) { + return changeSet; + } + return resolveChangeSetFromNewNeighbors(changeSet, nodeReferencesAndNodes); + }); + }, getConfig().getMaxNumConcurrentNeighborhoodFetches(), getExecutor()) + .thenApply(changeSets -> { + for (int i = 0; i < selectedNeighbors.size(); i++) { + final NodeReferenceAndNode selectedNeighbor = selectedNeighbors.get(i); + final NeighborsChangeSet changeSet = changeSets.get(i); + storageAdapter.writeNode(transaction, quantizer, + selectedNeighbor.getNode(), layer, changeSet); + } + return ImmutableList.copyOf(references); + }); + }); + }).thenApply(nodeReferencesWithDistances -> { + if (logger.isTraceEnabled()) { + logger.trace("end insert key={} at layer={}", newPrimaryKey, layer); + } + return nodeReferencesWithDistances; + }); + } + + /** + * Calculates the delta between a current set of neighbors and a new set, producing a + * {@link NeighborsChangeSet} that represents the required insertions and deletions. + *

+ * This method compares the neighbors present in the initial {@code beforeChangeSet} with + * the provided {@code afterNeighbors}. It identifies which neighbors from the "before" state + * are missing in the "after" state (to be deleted) and which new neighbors are present in the + * "after" state but not in the "before" state (to be inserted). It then constructs a new + * {@code NeighborsChangeSet} by wrapping the original one with {@link DeleteNeighborsChangeSet} + * and {@link InsertNeighborsChangeSet} as needed. + * + * @param the type of the node reference, which must extend {@link NodeReference} + * @param beforeChangeSet the change set representing the state of neighbors before the update. + * This is used as the base for calculating changes. Must not be null. + * @param afterNeighbors an iterable collection of the desired neighbors after the update. + * Must not be null. + * + * @return a new {@code NeighborsChangeSet} that includes the necessary deletion and insertion + * operations to transform the neighbors from the "before" state to the "after" state. + */ + private NeighborsChangeSet + resolveChangeSetFromNewNeighbors(@Nonnull final NeighborsChangeSet beforeChangeSet, + @Nonnull final Iterable> afterNeighbors) { + final Map beforeNeighborsMap = Maps.newLinkedHashMap(); + for (final N n : beforeChangeSet.merge()) { + beforeNeighborsMap.put(n.getPrimaryKey(), n); + } + + final Map afterNeighborsMap = Maps.newLinkedHashMap(); + for (final NodeReferenceAndNode nodeReferenceAndNode : afterNeighbors) { + final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); + + afterNeighborsMap.put(nodeReferenceWithDistance.getPrimaryKey(), + nodeReferenceAndNode.getNode().getSelfReference(nodeReferenceWithDistance.getVector())); + } + + final ImmutableList.Builder toBeDeletedBuilder = ImmutableList.builder(); + for (final Map.Entry beforeNeighborEntry : beforeNeighborsMap.entrySet()) { + if (!afterNeighborsMap.containsKey(beforeNeighborEntry.getKey())) { + toBeDeletedBuilder.add(beforeNeighborEntry.getValue().getPrimaryKey()); + } + } + final List toBeDeleted = toBeDeletedBuilder.build(); + + final ImmutableList.Builder toBeInsertedBuilder = ImmutableList.builder(); + for (final Map.Entry afterNeighborEntry : afterNeighborsMap.entrySet()) { + if (!beforeNeighborsMap.containsKey(afterNeighborEntry.getKey())) { + toBeInsertedBuilder.add(afterNeighborEntry.getValue()); + } + } + final List toBeInserted = toBeInsertedBuilder.build(); + + NeighborsChangeSet changeSet = beforeChangeSet; + + if (!toBeDeleted.isEmpty()) { + changeSet = new DeleteNeighborsChangeSet<>(changeSet, toBeDeleted); + } + if (!toBeInserted.isEmpty()) { + changeSet = new InsertNeighborsChangeSet<>(changeSet, toBeInserted); + } + return changeSet; + } + + /** + * Prunes the neighborhood of a given node if its number of connections exceeds the maximum allowed ({@code mMax}). + *

+ * This is a maintenance operation for the HNSW graph. When new nodes are added, an existing node's neighborhood + * might temporarily grow beyond its limit. This method identifies such cases and trims the neighborhood back down + * to the {@code mMax} best connections, based on the configured distance metric. If the neighborhood size is + * already within the limit, this method does nothing. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the storage adapter to fetch nodes from the database + * @param transaction the transaction context for database operations + * @param estimator an estimator to estimate distances + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param selectedNeighbor the node whose neighborhood is being considered for pruning + * @param layer the graph layer on which the operation is performed + * @param mMax the maximum number of neighbors a node is allowed to have on this layer + * @param neighborChangeSet a set of pending changes to the neighborhood that must be included in the pruning + * calculation + * @param nodeCache a cache of nodes to avoid redundant database fetches + * + * @return a {@link CompletableFuture} which completes with a list of the newly selected neighbors for the pruned node. + * If no pruning was necessary, it completes with {@code null}. + */ + @Nonnull + private CompletableFuture>> + pruneNeighborsIfNecessary(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + @Nonnull final AffineOperator storageTransform, + @Nonnull final Estimator estimator, + @Nonnull final NodeReferenceAndNode selectedNeighbor, + final int layer, + final int mMax, + @Nonnull final NeighborsChangeSet neighborChangeSet, + @Nonnull final Map> nodeCache) { + final AbstractNode selectedNeighborNode = selectedNeighbor.getNode(); + final int numNeighbors = + Iterables.size(neighborChangeSet.merge()); // this is a view over the iterable neighbors in the set + if (numNeighbors < mMax) { + return CompletableFuture.completedFuture(null); + } else { + if (logger.isTraceEnabled()) { + logger.trace("pruning neighborhood of key={} which has numNeighbors={} out of mMax={}", + selectedNeighborNode.getPrimaryKey(), numNeighbors, mMax); + } + return fetchNeighborhood(storageAdapter, transaction, storageTransform, layer, neighborChangeSet.merge(), nodeCache) + .thenCompose(nodeReferenceWithVectors -> { + final ImmutableList.Builder nodeReferencesWithDistancesBuilder = + ImmutableList.builder(); + for (final NodeReferenceWithVector nodeReferenceWithVector : nodeReferenceWithVectors) { + final var vector = nodeReferenceWithVector.getVector(); + final double distance = + estimator.distance(vector, + selectedNeighbor.getNodeReferenceWithDistance().getVector()); + nodeReferencesWithDistancesBuilder.add( + new NodeReferenceWithDistance(nodeReferenceWithVector.getPrimaryKey(), + vector, distance)); + } + return fetchSomeNodesIfNotCached(storageAdapter, transaction, storageTransform, layer, + nodeReferencesWithDistancesBuilder.build(), nodeCache); + }) + .thenCompose(nodeReferencesAndNodes -> + selectNeighbors(storageAdapter, transaction, storageTransform, estimator, + nodeReferencesAndNodes, layer, + mMax, false, nodeCache, + selectedNeighbor.getNodeReferenceWithDistance().getVector())); + } + } + + /** + * Selects the {@code m} best neighbors for a new node from a set of candidates using the HNSW selection heuristic. + *

+ * This method implements the core logic for neighbor selection within a layer of the HNSW graph. It starts with an + * initial set of candidates ({@code nearestNeighbors}), which can be optionally extended by fetching their own + * neighbors. + * It then iteratively refines this set using a greedy best-first search. + *

+ * The selection heuristic ensures diversity among neighbors. A candidate is added to the result set only if it is + * closer to the query {@code vector} than to any node already in the result set. This prevents selecting neighbors + * that are clustered together. If the {@code keepPrunedConnections} configuration is enabled, candidates that are + * pruned by this heuristic are kept and may be added at the end if the result set is not yet full. + *

+ * The process is asynchronous and returns a {@link CompletableFuture} that will eventually contain the list of + * selected neighbors with their full node data. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the storage adapter to fetch nodes and their neighbors + * @param readTransaction the transaction for performing database reads + * @param estimator the estimator in use + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param nearestNeighbors the initial pool of candidate neighbors, typically from a search in a higher layer + * @param layer the layer in the HNSW graph where the selection is being performed + * @param m the maximum number of neighbors to select + * @param isExtendCandidates a flag indicating whether to extend the initial candidate pool by fetching the + * neighbors of the {@code nearestNeighbors} + * @param nodeCache a cache of nodes to avoid redundant storage lookups + * @param vector the query vector for which neighbors are being selected + * + * @return a {@link CompletableFuture} which will complete with a list of the selected neighbors, + * each represented as a {@link NodeReferenceAndNode} + */ + private CompletableFuture>> + selectNeighbors(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final AffineOperator storageTransform, + @Nonnull final Estimator estimator, + @Nonnull final Iterable> nearestNeighbors, + final int layer, + final int m, + final boolean isExtendCandidates, + @Nonnull final Map> nodeCache, + @Nonnull final Transformed vector) { + final Metric metric = getConfig().getMetric(); + return extendCandidatesIfNecessary(storageAdapter, readTransaction, storageTransform, estimator, + nearestNeighbors, layer, isExtendCandidates, nodeCache, vector) + .thenApply(extendedCandidates -> { + final List selected = Lists.newArrayListWithExpectedSize(m); + final Queue candidates = + new PriorityQueue<>(extendedCandidates.size(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + candidates.addAll(extendedCandidates); + final Queue discardedCandidates = + getConfig().isKeepPrunedConnections() + ? new PriorityQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)) + : null; + + while (!candidates.isEmpty() && selected.size() < m) { + final NodeReferenceWithDistance nearestCandidate = candidates.poll(); + boolean shouldSelect = true; + // if the metric does not support triangle inequality, we shold not use the heuristic + if (metric.satisfiesTriangleInequality()) { + for (final NodeReferenceWithDistance alreadySelected : selected) { + if (estimator.distance(nearestCandidate.getVector(), + alreadySelected.getVector()) < nearestCandidate.getDistance()) { + shouldSelect = false; + break; + } + } + } + if (shouldSelect) { + selected.add(nearestCandidate); + } else if (discardedCandidates != null) { + discardedCandidates.add(nearestCandidate); + } + } + + if (discardedCandidates != null) { // isKeepPrunedConnections is set to true + while (!discardedCandidates.isEmpty() && selected.size() < m) { + selected.add(discardedCandidates.poll()); + } + } + + return ImmutableList.copyOf(selected); + }).thenCompose(selectedNeighbors -> + fetchSomeNodesIfNotCached(storageAdapter, readTransaction, storageTransform, layer, + selectedNeighbors, nodeCache)) + .thenApply(selectedNeighbors -> { + if (logger.isTraceEnabled()) { + logger.trace("selected neighbors={}", + selectedNeighbors.stream() + .map(selectedNeighbor -> + "(primaryKey=" + selectedNeighbor.getNodeReferenceWithDistance().getPrimaryKey() + + ",distance=" + selectedNeighbor.getNodeReferenceWithDistance().getDistance() + ")") + .collect(Collectors.joining(","))); + } + return selectedNeighbors; + }); + } + + /** + * Conditionally extends a set of candidate nodes by fetching and evaluating their neighbors. + *

+ * If {@code isExtendCandidates} is {@code true}, this method gathers the neighbors of the provided + * {@code candidates}, fetches their full node data, and calculates their distance to the given + * {@code vector}. The resulting list will contain both the original candidates and their newly + * evaluated neighbors. + *

+ * If {@code isExtendCandidates} is {@code false}, the method simply returns a list containing + * only the original candidates. This operation is asynchronous and returns a {@link CompletableFuture}. + * + * @param the type of the {@link NodeReference} + * @param storageAdapter the {@link StorageAdapter} used to access node data from storage + * @param readTransaction the active {@link ReadTransaction} for database access + * @param estimator the estimator + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param candidates an {@link Iterable} of initial candidate nodes, which have already been evaluated + * @param layer the graph layer from which to fetch nodes + * @param isExtendCandidates a boolean flag; if {@code true}, the candidate set is extended with neighbors + * @param nodeCache a cache mapping primary keys to {@link AbstractNode} objects to avoid redundant fetches + * @param vector the query vector used to calculate distances for any new neighbor nodes + * + * @return a {@link CompletableFuture} which will complete with a list of {@link NodeReferenceWithDistance}, + * containing the original candidates and potentially their neighbors + */ + private CompletableFuture> + extendCandidatesIfNecessary(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final AffineOperator storageTransform, + @Nonnull final Estimator estimator, + @Nonnull final Iterable> candidates, + int layer, + boolean isExtendCandidates, + @Nonnull final Map> nodeCache, + @Nonnull final Transformed vector) { + if (isExtendCandidates) { + final Set candidatesSeen = Sets.newConcurrentHashSet(); + for (final NodeReferenceAndNode candidate : candidates) { + candidatesSeen.add(candidate.getNode().getPrimaryKey()); + } + + final ImmutableList.Builder neighborsOfCandidatesBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + for (final N neighbor : candidate.getNode().getNeighbors()) { + final Tuple neighborPrimaryKey = neighbor.getPrimaryKey(); + if (!candidatesSeen.contains(neighborPrimaryKey)) { + candidatesSeen.add(neighborPrimaryKey); + neighborsOfCandidatesBuilder.add(neighbor); + } + } + } + + final Iterable neighborsOfCandidates = neighborsOfCandidatesBuilder.build(); + + return fetchNeighborhood(storageAdapter, readTransaction, storageTransform, layer, + neighborsOfCandidates, nodeCache) + .thenApply(withVectors -> { + final ImmutableList.Builder extendedCandidatesBuilder = + ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + extendedCandidatesBuilder.add(candidate.getNodeReferenceWithDistance()); + } + + for (final NodeReferenceWithVector withVector : withVectors) { + final double distance = estimator.distance(vector, withVector.getVector()); + extendedCandidatesBuilder.add(new NodeReferenceWithDistance(withVector.getPrimaryKey(), + withVector.getVector(), distance)); + } + return extendedCandidatesBuilder.build(); + }); + } else { + final ImmutableList.Builder resultBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + resultBuilder.add(candidate.getNodeReferenceWithDistance()); + } + + return CompletableFuture.completedFuture(resultBuilder.build()); + } + } + + /** + * Writes lonely nodes for a given key across a specified range of layers. + *

+ * A "lonely node" is a node in the layered structure that does not have a + * sibling. This method iterates downwards from the {@code highestLayerInclusive} + * to the {@code lowestLayerExclusive}. For each layer in this range, it + * retrieves the appropriate {@link StorageAdapter} and calls + * {@link #writeLonelyNodeOnLayer} to persist the node's information. + * + * @param quantizer the quantizer + * @param transaction the transaction to use for writing to the database + * @param primaryKey the primary key of the record for which lonely nodes are being written + * @param vector the search path vector that was followed to find this key + * @param highestLayerInclusive the highest layer (inclusive) to begin writing lonely nodes on + * @param lowestLayerExclusive the lowest layer (exclusive) at which to stop writing lonely nodes + */ + private void writeLonelyNodes(@Nonnull final Quantizer quantizer, + @Nonnull final Transaction transaction, + @Nonnull final Tuple primaryKey, + @Nonnull final Transformed vector, + final int highestLayerInclusive, + final int lowestLayerExclusive) { + for (int layer = highestLayerInclusive; layer > lowestLayerExclusive; layer --) { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + writeLonelyNodeOnLayer(quantizer, storageAdapter, transaction, layer, primaryKey, vector); + } + } + + /** + * Writes a new, isolated ('lonely') node to a specified layer within the graph. + *

+ * This method uses the provided {@link StorageAdapter} to create a new node with the + * given primary key and vector but with an empty set of neighbors. The write + * operation is performed as part of the given {@link Transaction}. This is typically + * used to insert the very first node into an empty graph layer. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param quantizer the quantizer + * @param storageAdapter the {@link StorageAdapter} used to access the data store and create nodes; must not be null + * @param transaction the {@link Transaction} context for the write operation; must not be null + * @param layer the layer index where the new node will be written + * @param primaryKey the primary key for the new node; must not be null + * @param vector the vector data for the new node; must not be null + */ + private void writeLonelyNodeOnLayer(@Nonnull final Quantizer quantizer, + @Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + final int layer, + @Nonnull final Tuple primaryKey, + @Nonnull final Transformed vector) { + storageAdapter.writeNode(transaction, quantizer, + storageAdapter.getNodeFactory() + .create(primaryKey, vector, ImmutableList.of()), layer, + new BaseNeighborsChangeSet<>(ImmutableList.of())); + if (logger.isTraceEnabled()) { + logger.trace("written lonely node at key={} on layer={}", primaryKey, layer); + } + } + + /** + * Scans all nodes within a given layer of the database. + *

+ * The scan is performed transactionally in batches to avoid loading the entire layer into memory at once. Each + * discovered node is passed to the provided {@link Consumer} for processing. The operation continues fetching + * batches until all nodes in the specified layer have been processed. + * + * @param db the non-null {@link Database} instance to run the scan against. + * @param layer the specific layer index to scan. + * @param batchSize the number of nodes to retrieve and process in each batch. + * @param nodeConsumer the non-null {@link Consumer} that will accept each {@link AbstractNode} + * found in the layer. + */ + @VisibleForTesting + void scanLayer(@Nonnull final Database db, + final int layer, + final int batchSize, + @Nonnull final Consumer> nodeConsumer) { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + final AtomicReference lastPrimaryKeyAtomic = new AtomicReference<>(); + Tuple newPrimaryKey; + do { + final Tuple lastPrimaryKey = lastPrimaryKeyAtomic.get(); + lastPrimaryKeyAtomic.set(null); + newPrimaryKey = db.run(tr -> { + Streams.stream(storageAdapter.scanLayer(tr, layer, lastPrimaryKey, batchSize)) + .forEach(node -> { + nodeConsumer.accept(Objects.requireNonNull(node)); + lastPrimaryKeyAtomic.set(node.getPrimaryKey()); + }); + return lastPrimaryKeyAtomic.get(); + }, executor); + } while (newPrimaryKey != null); + } + + /** + * Gets the appropriate storage adapter for a given layer. + *

+ * This method selects a {@link StorageAdapter} implementation based on the layer number. The logic is intended to + * use an {@code InliningStorageAdapter} for layers greater than {@code 0} and a {@code CompactStorageAdapter} for + * layer 0. Note that we will only use inlining at all if the config indicates we should use inlining. + * + * @param layer the layer number for which to get the storage adapter; currently unused + * @return a non-null {@link StorageAdapter} instance, which will always be a + * {@link CompactStorageAdapter} in the current implementation + */ + @Nonnull + private StorageAdapter getStorageAdapterForLayer(final int layer) { + return config.isUseInlining() && layer > 0 + ? new InliningStorageAdapter(getConfig(), InliningNode.factory(), getSubspace(), getOnWriteListener(), + getOnReadListener()) + : new CompactStorageAdapter(getConfig(), CompactNode.factory(), getSubspace(), getOnWriteListener(), + getOnReadListener()); + } + + /** + * Calculates a random layer for a new element to be inserted. + *

+ * The layer is selected according to a logarithmic distribution, which ensures that + * the probability of choosing a higher layer decreases exponentially. This is + * achieved by applying the inverse transform sampling method. The specific formula + * is {@code floor(-ln(u) * lambda)}, where {@code u} is a uniform random + * number and {@code lambda} is a normalization factor derived from a system + * configuration parameter {@code M}. + * + * @return a non-negative integer representing the randomly selected layer. + */ + private int insertionLayer() { + double lambda = 1.0 / Math.log(getConfig().getM()); + double u = 1.0 - random.nextDouble(); // Avoid log(0) + return (int) Math.floor(-Math.log(u) * lambda); + } + + private boolean shouldSampleVector() { + return random.nextDouble() < getConfig().getSampleVectorStatsProbability(); + } + + private boolean shouldMaintainStats() { + return random.nextDouble() < getConfig().getMaintainStatsProbability(); + } + + @Nonnull + private static List drain(@Nonnull Queue queue) { + final ImmutableList.Builder resultBuilder = ImmutableList.builder(); + while (!queue.isEmpty()) { + resultBuilder.add(queue.poll()); + } + return resultBuilder.build(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java new file mode 100644 index 0000000000..7be527a73f --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java @@ -0,0 +1,148 @@ +/* + * InliningNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.Transformed; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +/** + * Represents a specific type of node within a graph structure that is used to represent nodes in an HNSW structure. + *

+ * This node extends {@link AbstractNode}, does not store its own vector and instead specifically manages neighbors + * of type {@link NodeReferenceWithVector} (which do store a vector each). + * It provides a concrete implementation for an "inlining" node, distinguishing it from other node types such as + * {@link CompactNode}. + */ +class InliningNode extends AbstractNode { + @Nonnull + private static final NodeFactory FACTORY = new NodeFactory<>() { + @SuppressWarnings("unchecked") + @Nonnull + @Override + public AbstractNode create(@Nonnull final Tuple primaryKey, + @Nullable final Transformed vector, + @Nonnull final List neighbors) { + return new InliningNode(primaryKey, (List)neighbors); + } + + @Nonnull + @Override + public NodeKind getNodeKind() { + return NodeKind.INLINING; + } + }; + + /** + * Constructs a new {@code InliningNode} with a specified primary key and a list of its neighbors. + *

+ * This constructor initializes the node by calling the constructor of its superclass, + * passing the primary key and neighbor list. + * + * @param primaryKey the non-null primary key of the node, represented by a {@link Tuple}. + * @param neighbors the non-null list of neighbors for this node, where each neighbor + * is a {@link NodeReferenceWithVector}. + */ + public InliningNode(@Nonnull final Tuple primaryKey, + @Nonnull final List neighbors) { + super(primaryKey, neighbors); + } + + /** + * Gets a reference to this node. + * + * @param vector the vector to be associated with the node reference. Despite the + * {@code @Nullable} annotation, this parameter must not be null. + * + * @return a new {@link NodeReferenceWithVector} instance containing the node's + * primary key and the provided vector; will never be null. + * + * @throws NullPointerException if the provided {@code vector} is null. + */ + @Nonnull + @Override + @SpotBugsSuppressWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") + public NodeReferenceWithVector getSelfReference(@Nullable final Transformed vector) { + return new NodeReferenceWithVector(getPrimaryKey(), Objects.requireNonNull(vector)); + } + + /** + * Gets the kind of this node. + * @return the non-null {@link NodeKind} of this node, which is always + * {@code NodeKind.INLINING}. + */ + @Nonnull + @Override + public NodeKind getKind() { + return NodeKind.INLINING; + } + + /** + * Casts this node to a {@link CompactNode}. + *

+ * This implementation always throws an exception because this specific node type + * cannot be represented as a compact node. + * @return this node as a {@link CompactNode}, never {@code null} + * @throws IllegalStateException always, as this node is not a compact node + */ + @Nonnull + @Override + public CompactNode asCompactNode() { + throw new IllegalStateException("this is not a compact node"); + } + + /** + * Returns this object as an {@link InliningNode}. + *

+ * As this class is already an instance of {@code InliningNode}, this method simply returns {@code this}. + * @return this object, which is guaranteed to be an {@code InliningNode} and never {@code null}. + */ + @Nonnull + @Override + public InliningNode asInliningNode() { + return this; + } + + /** + * Returns the singleton factory instance used to create {@link NodeReferenceWithVector} objects. + *

+ * This method provides a standard way to obtain the factory, ensuring that a single, shared instance is used + * throughout the application. + * + * @return the singleton {@link NodeFactory} instance, never {@code null}. + */ + @Nonnull + public static NodeFactory factory() { + return FACTORY; + } + + @Override + public String toString() { + return "I[primaryKey=" + getPrimaryKey() + + ";neighbors=" + getNeighbors() + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java new file mode 100644 index 0000000000..fce0fdac34 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java @@ -0,0 +1,379 @@ +/* + * InliningStorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.KeyValue; +import com.apple.foundationdb.Range; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.StreamingMode; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.async.AsyncIterable; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.linear.AffineOperator; +import com.apple.foundationdb.linear.Quantizer; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.Transformed; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.ByteArrayUtil; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * An implementation of {@link StorageAdapter} for an HNSW graph that stores node vectors "in-line" with the node's + * neighbor information. + *

+ * In this storage model, each key-value pair in the database represents a single neighbor relationship. The key + * contains the primary keys of both the source node and the neighbor node, while the value contains the neighbor's + * vector. This contrasts with a "compact" storage model where a compact node represents a vector and all of its + * neighbors. This adapter is responsible for serializing and deserializing these structures to and from the underlying + * key-value store. + * + * @see StorageAdapter + * @see HNSW + */ +class InliningStorageAdapter extends AbstractStorageAdapter implements StorageAdapter { + /** + * Constructs a new {@code InliningStorageAdapter} with the given configuration and components. + *

+ * This constructor initializes the storage adapter by passing all necessary components + * to its superclass. + * + * @param config the HNSW configuration to use for the graph + * @param nodeFactory the factory to create new {@link NodeReferenceWithVector} instances + * @param subspace the subspace where the HNSW graph data is stored + * @param onWriteListener the listener to be notified on write operations + * @param onReadListener the listener to be notified on read operations + */ + public InliningStorageAdapter(@Nonnull final Config config, + @Nonnull final NodeFactory nodeFactory, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + super(config, nodeFactory, subspace, onWriteListener, onReadListener); + } + + /** + * Asynchronously fetches a single node from a given layer by its primary key. + *

+ * This internal method constructs a prefix key based on the {@code layer} and {@code primaryKey}. + * It then performs an asynchronous range scan to retrieve all key-value pairs associated with that prefix. + * Finally, it reconstructs the complete {@link AbstractNode} object from the collected raw data using + * the {@code nodeFromRaw} method. + * + * @param readTransaction the transaction to use for reading from the database + * @param storageTransform an affine transformation operator that is used to transform the fetched vector into the + * storage space that is currently being used + * @param layer the layer of the node to fetch + * @param primaryKey the primary key of the node to fetch + * + * @return a {@link CompletableFuture} that will complete with the fetched {@link AbstractNode} containing + * {@link NodeReferenceWithVector}s + */ + @Nonnull + @Override + protected CompletableFuture> + fetchNodeInternal(@Nonnull final ReadTransaction readTransaction, + @Nonnull final AffineOperator storageTransform, + final int layer, + @Nonnull final Tuple primaryKey) { + final byte[] rangeKey = getNodeKey(layer, primaryKey); + + return AsyncUtil.collect(readTransaction.getRange(Range.startsWith(rangeKey), + ReadTransaction.ROW_LIMIT_UNLIMITED, false, StreamingMode.WANT_ALL), readTransaction.getExecutor()) + .thenApply(keyValues -> nodeFromRaw(storageTransform, layer, primaryKey, keyValues)); + } + + /** + * Constructs a {@code Node} from its raw key-value representation from storage. + *

+ * This method is responsible for deserializing a node and its neighbors. It processes a list of {@code KeyValue} + * pairs, where each pair represents a neighbor of the node being constructed. Each neighbor is converted from its + * raw form into a {@link NodeReferenceWithVector} by calling the + * {@link #neighborFromRaw(AffineOperator, int, byte[], byte[])} method. + *

+ * Once the node is created with its primary key and list of neighbors, it notifies the configured + * {@link OnReadListener} of the read operation. + * + * @param storageTransform an affine vector transformation operator that is used to transform the fetched vector + * into the storage space that is currently being used + * @param layer the layer in the graph where this node exists + * @param primaryKey the primary key that uniquely identifies the node + * @param keyValues a list of {@code KeyValue} pairs representing the raw data of the node's neighbors + * + * @return a non-null, fully constructed {@link AbstractNode} object with its neighbors + */ + @Nonnull + private AbstractNode nodeFromRaw(@Nonnull final AffineOperator storageTransform, + final int layer, + @Nonnull final Tuple primaryKey, + @Nonnull final List keyValues) { + final OnReadListener onReadListener = getOnReadListener(); + + final ImmutableList.Builder nodeReferencesWithVectorBuilder = ImmutableList.builder(); + for (final KeyValue keyValue : keyValues) { + nodeReferencesWithVectorBuilder.add(neighborFromRaw(storageTransform, layer, keyValue.getKey(), + keyValue.getValue())); + } + + final AbstractNode node = + getNodeFactory().create(primaryKey, null, nodeReferencesWithVectorBuilder.build()); + onReadListener.onNodeRead(layer, node); + return node; + } + + /** + * Constructs a {@code NodeReferenceWithVector} from raw key and value byte arrays retrieved from storage. + *

+ * This helper method deserializes a neighbor's data. It unpacks the provided {@code key} to extract the neighbor's + * primary key and unpacks the {@code value} to extract the neighbor's vector. It also notifies the configured + * {@link OnReadListener} of the read operation. + * + * @param storageTransform an affine vector transformation operator that is used to transform the fetched vector + * into the storage space that is currently being used + * @param layer the layer of the graph where the neighbor node is located. + * @param key the raw byte array key from the database, which contains the neighbor's primary key. + * @param value the raw byte array value from the database, which represents the neighbor's vector. + * @return a new {@link NodeReferenceWithVector} instance representing the deserialized neighbor. + * @throws IllegalArgumentException if the key or value byte arrays are malformed and cannot be unpacked. + */ + @Nonnull + private NodeReferenceWithVector neighborFromRaw(@Nonnull final AffineOperator storageTransform, final int layer, + @Nonnull final byte[] key, @Nonnull final byte[] value) { + final OnReadListener onReadListener = getOnReadListener(); + onReadListener.onKeyValueRead(layer, key, value); + + final Tuple neighborKeyTuple = getDataSubspace().unpack(key); + final Tuple neighborValueTuple = Tuple.fromBytes(value); + + return neighborFromTuples(storageTransform, neighborKeyTuple, neighborValueTuple); + } + + /** + * Constructs a {@code NodeReferenceWithVector} from tuples retrieved from storage. + *

+ * @param storageTransform an affine vector transformation operator that is used to transform the fetched vector + * into the storage space that is currently being used + * @param keyTuple the key tuple from the database, which contains the neighbor's primary key. + * @param valueTuple the value tuple from the database, which represents the neighbor's vector. + * @return a new {@link NodeReferenceWithVector} instance representing the deserialized neighbor. + * @throws IllegalArgumentException if the key or value byte arrays are malformed and cannot be unpacked. + */ + @Nonnull + private NodeReferenceWithVector neighborFromTuples(@Nonnull final AffineOperator storageTransform, + @Nonnull final Tuple keyTuple, @Nonnull final Tuple valueTuple) { + final Tuple neighborPrimaryKey = keyTuple.getNestedTuple(2); // neighbor primary key + // + // Transform the raw vector that was just fetched into the internal coordinate system. If we do not have + // a need to transform coordinates, this transform is the identity transformation. Vectors are always stored + // in the internal coordinate system in use at the time the vector is written. If that coordinate system changes + // afterward, for instance by enabling RaBitQ, subsequent reads of vectors that were written prior to + // the coordinate system change need to be transformed when they are read back. + // + final Transformed neighborVector = + storageTransform.transform( + StorageAdapter.vectorFromTuple(getConfig(), valueTuple)); // the entire value is the vector + return new NodeReferenceWithVector(neighborPrimaryKey, neighborVector); + } + + /** + * Writes a given node and its neighbor changes to the specified layer within a transaction. + *

+ * This implementation first converts the provided {@link AbstractNode} to an {@link InliningNode}. It then + * delegates the writing of neighbor modifications to the {@link NeighborsChangeSet#writeDelta} method. After the + * changes are written, it notifies the registered {@code OnWriteListener} that the node has been processed via + * {@code getOnWriteListener().onNodeWritten()}. + * + * @param transaction the transaction context for the write operation; must not be null + * @param quantizer the quantizer to use + * @param node the node to be written, which is expected to be an + * {@code InliningNode}; must not be null + * @param layer the layer index where the node and its neighbor changes should be written + * @param neighborsChangeSet the set of changes to the node's neighbors to be + * persisted; must not be null + */ + @Override + public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull final Quantizer quantizer, + @Nonnull final AbstractNode node, final int layer, + @Nonnull final NeighborsChangeSet neighborsChangeSet) { + final InliningNode inliningNode = node.asInliningNode(); + + neighborsChangeSet.writeDelta(this, transaction, quantizer, layer, inliningNode, t -> true); + getOnWriteListener().onNodeWritten(layer, node); + } + + /** + * Constructs the raw database key for a node based on its layer and primary key. + *

+ * This key is created by packing a tuple containing the specified {@code layer} and the node's {@code primaryKey} + * within the data subspace. The resulting byte array is suitable for use in direct database lookups and preserves + * the sort order of the components. + * + * @param layer the layer index where the node resides + * @param primaryKey the primary key that uniquely identifies the node within its layer, + * encapsulated in a {@link Tuple} + * + * @return a byte array representing the packed key for the specified node + */ + @Nonnull + private byte[] getNodeKey(final int layer, @Nonnull final Tuple primaryKey) { + return getDataSubspace().pack(Tuple.from(layer, primaryKey)); + } + + /** + * Writes a neighbor for a given node to the underlying storage within a specific transaction. + *

+ * This method serializes the neighbor's vector and constructs a unique key based on the layer, the source + * {@code node}, and the neighbor's primary key. It then persists this key-value pair using the provided + * {@link Transaction}. After a successful write, it notifies any registered listeners. + * + * @param transaction the {@link Transaction} to use for the write operation + * @param quantizer quantizer to use + * @param layer the layer index where the node and its neighbor reside + * @param node the source {@link AbstractNode} for which the neighbor is being written + * @param neighbor the {@link NodeReferenceWithVector} representing the neighbor to persist + */ + public void writeNeighbor(@Nonnull final Transaction transaction, @Nonnull final Quantizer quantizer, + final int layer, @Nonnull final AbstractNode node, + @Nonnull final NodeReferenceWithVector neighbor) { + final byte[] neighborKey = getNeighborKey(layer, node, neighbor.getPrimaryKey()); + // getting underlying vector is okay as it is only written to the database + final byte[] value = + StorageAdapter.tupleFromVector( + quantizer.encode(neighbor.getVector())).pack(); + transaction.set(neighborKey, value); + getOnWriteListener().onNeighborWritten(layer, node, neighbor); + getOnWriteListener().onKeyValueWritten(layer, neighborKey, value); + } + + /** + * Deletes a neighbor edge from a given node within a specific layer. + *

+ * This operation removes the key-value pair representing the neighbor relationship from the database within the + * given {@link Transaction}. It also notifies the {@code onWriteListener} about the deletion. + * + * @param transaction the transaction in which to perform the deletion + * @param layer the layer of the graph where the node resides + * @param node the node from which the neighbor edge is removed + * @param neighborPrimaryKey the primary key of the neighbor node to be deleted + */ + public void deleteNeighbor(@Nonnull final Transaction transaction, final int layer, + @Nonnull final AbstractNode node, + @Nonnull final Tuple neighborPrimaryKey) { + transaction.clear(getNeighborKey(layer, node, neighborPrimaryKey)); + getOnWriteListener().onNeighborDeleted(layer, node, neighborPrimaryKey); + } + + /** + * Constructs the key for a specific neighbor of a node within a given layer. + *

+ * This key is used to uniquely identify and store the neighbor relationship in the underlying data store. It is + * formed by packing a {@link Tuple} containing the {@code layer}, the primary key of the source {@code node}, and + * the {@code neighborPrimaryKey}. + * + * @param layer the layer of the graph where the node and its neighbor reside + * @param node the non-null source node for which the neighbor key is being generated + * @param neighborPrimaryKey the non-null primary key of the neighbor node + * @return a non-null byte array representing the packed key for the neighbor relationship + */ + @Nonnull + private byte[] getNeighborKey(final int layer, + @Nonnull final AbstractNode node, + @Nonnull final Tuple neighborPrimaryKey) { + return getDataSubspace().pack(Tuple.from(layer, node.getPrimaryKey(), neighborPrimaryKey)); + } + + /** + * Scans a specific layer of the graph, reconstructing nodes and their neighbors from the underlying key-value + * store. + *

+ * This method reads raw {@link com.apple.foundationdb.KeyValue} records from the database within a given layer. + * It groups adjacent records that belong to the same parent node and uses a {@link NodeFactory} to construct + * {@link AbstractNode} objects. The method supports pagination through the {@code lastPrimaryKey} parameter, + * allowing for incremental scanning of large layers. + * + * @param readTransaction the transaction to use for reading data + * @param layer the layer of the graph to scan + * @param lastPrimaryKey the primary key of the last node read in a previous scan, used for pagination. + * If {@code null}, the scan starts from the beginning of the layer. + * @param maxNumRead the maximum number of raw key-value records to read from the database + * @return an {@code Iterable} of {@link AbstractNode} objects reconstructed from the scanned layer. Each node + * contains its neighbors within that layer. + */ + @Nonnull + @Override + public Iterable> scanLayer(@Nonnull final ReadTransaction readTransaction, + int layer, + @Nullable final Tuple lastPrimaryKey, + int maxNumRead) { + final OnReadListener onReadListener = getOnReadListener(); + final byte[] layerPrefix = getDataSubspace().pack(Tuple.from(layer)); + final Range range = + lastPrimaryKey == null + ? Range.startsWith(layerPrefix) + : new Range(ByteArrayUtil.strinc(getDataSubspace().pack(Tuple.from(layer, lastPrimaryKey))), + ByteArrayUtil.strinc(layerPrefix)); + final AsyncIterable itemsIterable = + readTransaction.getRange(range, + maxNumRead, false, StreamingMode.ITERATOR); + Tuple nodePrimaryKey = null; + ImmutableList.Builder> nodeBuilder = ImmutableList.builder(); + ImmutableList.Builder neighborsBuilder = null; + + int numRead = 0; + for (final KeyValue item: itemsIterable) { + final byte[] key = item.getKey(); + final byte[] value = item.getValue(); + onReadListener.onKeyValueRead(layer, key, value); + + final Tuple neighborKeyTuple = getDataSubspace().unpack(key); + final Tuple neighborValueTuple = Tuple.fromBytes(value); + final NodeReferenceWithVector neighbor = + neighborFromTuples(AffineOperator.identity(), neighborKeyTuple, neighborValueTuple); + final Tuple nodePrimaryKeyFromNeighbor = neighborKeyTuple.getNestedTuple(1); + if (nodePrimaryKey == null || !nodePrimaryKey.equals(nodePrimaryKeyFromNeighbor)) { + if (nodePrimaryKey != null) { + nodeBuilder.add(getNodeFactory().create(nodePrimaryKey, null, neighborsBuilder.build())); + } + nodePrimaryKey = nodePrimaryKeyFromNeighbor; + neighborsBuilder = ImmutableList.builder(); + } + neighborsBuilder.add(neighbor); + numRead ++; + } + + // + // There may be a rest, deal with it here. Create a last node if we exhausted the items read from the db. + // If we didn't exhaust the dataset, but we reached maxNumRead, do not create a node and assume the caller + // will come back for more. We always assume that maxNumRead is greater than the potential numbers of neighbors + // a node can have. + // + if (numRead < maxNumRead && nodePrimaryKey != null) { + nodeBuilder.add(getNodeFactory().create(nodePrimaryKey, null, neighborsBuilder.build())); + } + + return nodeBuilder.build(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java new file mode 100644 index 0000000000..b3b5ef8a12 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java @@ -0,0 +1,134 @@ +/* + * InsertNeighborsChangeSet.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.linear.Quantizer; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +/** + * Represents an immutable change set for the neighbors of a node in the HNSW graph, specifically + * capturing the insertion of new neighbors. + *

+ * This class layers new neighbors on top of a parent {@link NeighborsChangeSet}, allowing for a + * layered representation of modifications. The changes are not applied to the database until + * {@link #writeDelta} is called. + * + * @param the type of the node reference, which must extend {@link NodeReference} + */ +class InsertNeighborsChangeSet implements NeighborsChangeSet { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(InsertNeighborsChangeSet.class); + + @Nonnull + private final NeighborsChangeSet parent; + + @Nonnull + private final Map insertedNeighborsMap; + + /** + * Creates a new {@code InsertNeighborsChangeSet}. + *

+ * This constructor initializes the change set with its parent and a list of neighbors + * to be inserted. It internally builds an immutable map of the inserted neighbors, + * keyed by their primary key for efficient lookups. + * + * @param parent the parent {@link NeighborsChangeSet} on which this insertion is based. + * @param insertedNeighbors the list of neighbors to be inserted. + */ + public InsertNeighborsChangeSet(@Nonnull final NeighborsChangeSet parent, + @Nonnull final List insertedNeighbors) { + this.parent = parent; + final ImmutableMap.Builder insertedNeighborsMapBuilder = ImmutableMap.builder(); + for (final N insertedNeighbor : insertedNeighbors) { + insertedNeighborsMapBuilder.put(insertedNeighbor.getPrimaryKey(), insertedNeighbor); + } + + this.insertedNeighborsMap = insertedNeighborsMapBuilder.build(); + } + + /** + * Gets the parent {@code NeighborsChangeSet} from which this change set was derived. + * @return the parent {@link NeighborsChangeSet}, which is never {@code null}. + */ + @Nonnull + @Override + public NeighborsChangeSet getParent() { + return parent; + } + + /** + * Merges the neighbors from this level of the hierarchy with all neighbors from parent levels. + *

+ * This is achieved by creating a combined view that includes the results of the parent's {@code #merge()} call and + * the neighbors that have been inserted at the current level. The resulting {@code Iterable} provides a complete + * set of neighbors from this node and all its ancestors. + * @return a non-null {@code Iterable} containing all neighbors from this node and its ancestors. + */ + @Nonnull + @Override + public Iterable merge() { + return Iterables.concat(getParent().merge(), insertedNeighborsMap.values()); + } + + /** + * Writes the delta of this layer to the specified storage adapter. + *

+ * This implementation first delegates to the parent to write its delta, but excludes any neighbors that have been + * newly inserted in the current context (i.e., those in {@code insertedNeighborsMap}). It then iterates through its + * own newly inserted neighbors. For each neighbor that satisfies the given {@code tuplePredicate}, it writes the + * neighbor relationship to storage via the {@link InliningStorageAdapter}. + * + * @param storageAdapter the storage adapter to write to; must not be null + * @param transaction the transaction context for the write operation; must not be null + * @param layer the layer index to write the data to + * @param node the source node for which the neighbor delta is being written; must not be null + * @param tuplePredicate a predicate to filter which neighbor tuples should be written; must not be null + */ + @Override + public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, + @Nonnull final Quantizer quantizer, final int layer, @Nonnull final AbstractNode node, + @Nonnull final Predicate tuplePredicate) { + getParent().writeDelta(storageAdapter, transaction, quantizer, layer, node, + tuplePredicate.and(tuple -> !insertedNeighborsMap.containsKey(tuple))); + + for (final Map.Entry entry : insertedNeighborsMap.entrySet()) { + final Tuple primaryKey = entry.getKey(); + if (tuplePredicate.test(primaryKey)) { + storageAdapter.writeNeighbor(transaction, quantizer, layer, node.asInliningNode(), + entry.getValue().asNodeReferenceWithVector()); + if (logger.isTraceEnabled()) { + logger.trace("inserted neighbor of primaryKey={} targeting primaryKey={}", node.getPrimaryKey(), + primaryKey); + } + } + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java new file mode 100644 index 0000000000..207c6a1f1f --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java @@ -0,0 +1,83 @@ +/* + * NeighborsChangeSet.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.linear.Quantizer; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.function.Predicate; + +/** + * Represents a set of changes to the neighbors of a node within an HNSW graph. + *

+ * Implementations of this interface manage modifications, such as additions or removals of neighbors. often in a + * layered fashion. This allows for composing changes before they are committed to storage. The {@link #getParent()} + * method returns the next element in this layered structure while {@link #merge()} consolidates changes into + * a final neighbor list. + * + * @param the type of the node reference, which must extend {@link NodeReference} + */ +interface NeighborsChangeSet { + /** + * Gets the parent change set from which this change set was derived. + *

+ * Change sets can be layered, forming a chain of modifications. + * This method allows for traversing up this tree to the preceding set of changes. + * + * @return the parent {@code NeighborsChangeSet}, or {@code null} if this change set + * is the root of the change tree and has no parent. + */ + @Nullable + NeighborsChangeSet getParent(); + + /** + * Merges multiple internal sequences into a single, consolidated iterable sequence. + *

+ * This method combines distinct internal changesets into one continuous stream of neighbors. The specific order + * of the merged elements depends on the implementation. + * + * @return a non-null {@code Iterable} containing the merged sequence of elements. + */ + @Nonnull + Iterable merge(); + + /** + * Writes the neighbor delta for a given {@link AbstractNode} to the specified storage layer. + *

+ * This method processes the provided {@code node} and writes only the records that match the given + * {@code primaryKeyPredicate} to the storage system via the {@link InliningStorageAdapter}. The entire operation + * is performed within the context of the supplied {@link Transaction}. + * + * @param storageAdapter the storage adapter to which the delta will be written; must not be null + * @param quantizer quantizer to use + * @param transaction the transaction context for the write operation; must not be null + * @param layer the specific storage layer to write the delta to + * @param node the source node containing the data to be written; must not be null + * @param primaryKeyPredicate a predicate to filter records by their primary key. Only records + * for which the predicate returns {@code true} will be written. Must not be null. + */ + void writeDelta(@Nonnull InliningStorageAdapter storageAdapter, @Nonnull Transaction transaction, + @Nonnull Quantizer quantizer, int layer, @Nonnull AbstractNode node, + @Nonnull Predicate primaryKeyPredicate); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java new file mode 100644 index 0000000000..ae1bc07bc2 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java @@ -0,0 +1,82 @@ +/* + * Node.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.Transformed; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; + +/** + * Represents a node within an HNSW (Hierarchical Navigable Small World) structure. + *

+ * A node corresponds to a data point (vector) in the structure and maintains a list of its neighbors. + * This interface defines the common contract for different node representations, such as {@link CompactNode} + * and {@link InliningNode}. + *

+ * + * @param the type of reference used to point to other nodes, which must extend {@link NodeReference} + */ +public interface Node { + /** + * Gets the primary key for this object. + *

+ * The primary key is represented as a {@link Tuple} and uniquely identifies + * the object within its storage context. This method is guaranteed to not + * return a null value. + * + * @return the primary key as a {@code Tuple}, which is never {@code null} + */ + @Nonnull + Tuple getPrimaryKey(); + + /** + * Returns a self-reference to this object, enabling fluent method chaining. This allows to create node references + * that contain a vector and are independent of the storage implementation. + * @param vector the vector of {@code Half} objects to process. This parameter + * is optional and can be {@code null}. + * + * @return a non-null reference to this object ({@code this}) for further + * method calls. + */ + @Nonnull + N getSelfReference(@Nullable Transformed vector); + + /** + * Gets the list of neighboring nodes. + *

+ * This method is guaranteed to not return {@code null}. If there are no neighbors, an empty list is returned. + * + * @return a non-null list of neighboring nodes. + */ + @Nonnull + List getNeighbors(); + + /** + * Return the kind of the node, i.e. {@link NodeKind#COMPACT} or {@link NodeKind#INLINING}. + * @return the kind of this node as a {@link NodeKind} + */ + @Nonnull + NodeKind getKind(); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java new file mode 100644 index 0000000000..584e75ccf7 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java @@ -0,0 +1,67 @@ +/* + * NodeFactory.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.Transformed; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; + +/** + * A factory interface for creating {@link AbstractNode} instances within a Hierarchical Navigable Small World (HNSW) + * graph. + *

+ * Implementations of this interface define how nodes are constructed, allowing for different node types + * or storage strategies within the HNSW structure. + * + * @param the type of {@link NodeReference} used to refer to nodes in the graph + */ +interface NodeFactory { + /** + * Creates a new node with the specified properties. + *

+ * This method is responsible for instantiating a {@code Node} object, initializing it + * with a primary key, an optional feature vector, and a list of its initial neighbors. + * + * @param primaryKey the {@link Tuple} representing the unique primary key for the new node. Must not be + * {@code null}. + * @param vector the optional feature {@link RealVector} associated with the node, which can be used for similarity + * calculations. May be {@code null} if the node does not encode a vector (see {@link CompactNode} versus + * {@link InliningNode}). + * @param neighbors the list of initial {@link NodeReference}s for the new node, + * establishing its initial connections in the graph. Must not be {@code null}. + * + * @return a new, non-null {@link AbstractNode} instance configured with the provided parameters. + */ + @Nonnull + AbstractNode create(@Nonnull Tuple primaryKey, @Nullable Transformed vector, + @Nonnull List neighbors); + + /** + * Gets the kind of this node. + * @return the kind of this node, never {@code null}. + */ + @Nonnull + NodeKind getNodeKind(); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java new file mode 100644 index 0000000000..656726a93b --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java @@ -0,0 +1,88 @@ +/* + * NodeKind.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.google.common.base.Verify; + +import javax.annotation.Nonnull; + +/** + * Represents the different kinds of nodes, each associated with a unique byte value for serialization and + * deserialization. + */ +public enum NodeKind { + /** + * Compact node. Serialization and deserialization is implemented in {@link CompactNode}. + *

+ * Compact nodes store their own vector and their neighbors-list only contain the primary key for each neighbor. + */ + COMPACT((byte)0x00), + + /** + * Inlining node. Serialization and deserialization is implemented in {@link InliningNode}. + *

+ * Inlining nodes do not store their own vector and their neighbors-list contain the both the primary key and the + * neighbor vector for each neighbor. Each neighbor is stored in its own key/value pair. + */ + INLINING((byte)0x01); + + private final byte serialized; + + /** + * Constructs a new {@code NodeKind} instance with its serialized representation. + * @param serialized the byte value used for serialization + */ + NodeKind(final byte serialized) { + this.serialized = serialized; + } + + /** + * Gets the serialized byte value. + * @return the serialized byte value + */ + public byte getSerialized() { + return serialized; + } + + /** + * Deserializes a byte into the corresponding {@link NodeKind}. + * @param serializedNodeKind the byte representation of the node kind. + * @return the corresponding {@link NodeKind}, never {@code null}. + * @throws IllegalArgumentException if the {@code serializedNodeKind} does not + * correspond to a known node kind. + */ + @Nonnull + static NodeKind fromSerializedNodeKind(byte serializedNodeKind) { + final NodeKind nodeKind; + switch (serializedNodeKind) { + case 0x00: + nodeKind = NodeKind.COMPACT; + break; + case 0x01: + nodeKind = NodeKind.INLINING; + break; + default: + throw new IllegalArgumentException("unknown node kind"); + } + Verify.verify(nodeKind.getSerialized() == serializedNodeKind); + return nodeKind; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java new file mode 100644 index 0000000000..529601b82e --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java @@ -0,0 +1,132 @@ +/* + * NodeReference.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.Streams; + +import javax.annotation.Nonnull; +import java.util.Objects; + +/** + * Represents a reference to a node, uniquely identified by its primary key. It provides fundamental operations such as + * equality comparison, hashing, and string representation based on this key. It also serves as a base class for more + * specialized node references. + */ +public class NodeReference { + @Nonnull + private final Tuple primaryKey; + + /** + * Constructs a new {@code NodeReference} with the specified primary key. + * @param primaryKey the primary key of the node to reference; must not be {@code null}. + */ + public NodeReference(@Nonnull final Tuple primaryKey) { + this.primaryKey = primaryKey; + } + + /** + * Gets the primary key for this object. + * @return the primary key as a {@code Tuple} object, which is guaranteed to be non-null. + */ + @Nonnull + public Tuple getPrimaryKey() { + return primaryKey; + } + + /** + * Method to indicate if the method {@link #asNodeReferenceWithVector()} can be safely called. + * @return {@code true} iff this instance is in fact at least a {@link NodeReferenceWithVector}. + */ + boolean isNodeReferenceWithVector() { + return false; + } + + /** + * Casts this object to a {@link NodeReferenceWithVector}. + *

+ * This method is intended to be used on subclasses that actually represent a node reference with a vector. For this + * base class or specific implementation, it is not a valid operation. + * @return this instance cast as a {@code NodeReferenceWithVector} + * @throws IllegalStateException always, to indicate that this object cannot be + * represented as a {@link NodeReferenceWithVector}. + */ + @Nonnull + public NodeReferenceWithVector asNodeReferenceWithVector() { + throw new IllegalStateException("method should not be called"); + } + + /** + * Compares this {@code NodeReference} to the specified object for equality. + *

+ * The result is {@code true} if and only if the argument is not {@code null} and is a {@code NodeReference} object + * that has the same {@code primaryKey} as this object. + * + * @param o the object to compare with this {@code NodeReference} for equality. + * @return {@code true} if the given object is equal to this one; + * {@code false} otherwise. + */ + @Override + public boolean equals(final Object o) { + if (o == null) { + return false; + } + if (this == o) { + return true; + } + if (o.getClass() != this.getClass()) { + return false; + } + final NodeReference that = (NodeReference)o; + return Objects.equals(primaryKey, that.primaryKey); + } + + /** + * Generates a hash code for this object based on the primary key. + * @return a hash code value for this object. + */ + @Override + public int hashCode() { + return Objects.hashCode(primaryKey); + } + + /** + * Returns a string representation of the object. + * @return a string representation of this object. + */ + @Override + public String toString() { + return "NR[primaryKey=" + primaryKey + "]"; + } + + /** + * Helper to extract the primary keys from a given collection of node references. + * @param neighbors an iterable of {@link NodeReference} objects from which to extract primary keys. + * @return a lazily-evaluated {@code Iterable} of {@link Tuple}s, representing the primary keys of the input nodes. + */ + @Nonnull + public static Iterable primaryKeys(@Nonnull Iterable neighbors) { + return () -> Streams.stream(neighbors) + .map(nodeReference -> + Objects.requireNonNull(nodeReference).getPrimaryKey()) + .iterator(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java new file mode 100644 index 0000000000..a6c4f33abe --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java @@ -0,0 +1,87 @@ +/* + * NodeReferenceAndNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import java.util.List; + +/** + * A container class that pairs a {@link NodeReferenceWithDistance} with its corresponding {@link AbstractNode} object. + *

+ * This is often used during graph traversal or searching, where a reference to a node (along with its distance from a + * query point) is first identified, and then the complete node data is fetched. This class holds these two related + * pieces of information together. + * @param the type of {@link NodeReference} used within the {@link AbstractNode} + */ +class NodeReferenceAndNode { + @Nonnull + private final NodeReferenceWithDistance nodeReferenceWithDistance; + @Nonnull + private final AbstractNode node; + + /** + * Constructs a new instance that pairs a node reference (with distance) with its + * corresponding {@link AbstractNode} object. + * @param nodeReferenceWithDistance the reference to a node, which also includes distance information. Must not be + * {@code null}. + * @param node the actual {@link AbstractNode} object that the reference points to. Must not be {@code null}. + */ + public NodeReferenceAndNode(@Nonnull final NodeReferenceWithDistance nodeReferenceWithDistance, + @Nonnull final AbstractNode node) { + this.nodeReferenceWithDistance = nodeReferenceWithDistance; + this.node = node; + } + + /** + * Gets the node reference and its associated distance. + * @return the non-null {@link NodeReferenceWithDistance} object. + */ + @Nonnull + public NodeReferenceWithDistance getNodeReferenceWithDistance() { + return nodeReferenceWithDistance; + } + + /** + * Gets the underlying node represented by this object. + * @return the associated {@link Node} instance, never {@code null}. + */ + @Nonnull + public AbstractNode getNode() { + return node; + } + + /** + * Helper to extract the references from a given collection of objects of this container class. + * @param referencesAndNodes an iterable of {@link NodeReferenceAndNode} objects from which to extract the + * references. + * @return a {@link List} of {@link NodeReferenceAndNode}s + */ + @Nonnull + public static List getReferences(@Nonnull List> referencesAndNodes) { + final ImmutableList.Builder referencesBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode referenceWithNode : referencesAndNodes) { + referencesBuilder.add(referenceWithNode.getNodeReferenceWithDistance()); + } + return referencesBuilder.build(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java new file mode 100644 index 0000000000..d413eb785f --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java @@ -0,0 +1,91 @@ +/* + * NodeReferenceWithDistance.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.Transformed; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import java.util.Objects; + +/** + * Represents a reference to a node that includes its vector and its distance from a query vector. + *

+ * This class extends {@link NodeReferenceWithVector} by additionally associating a distance value, typically the result + * of a distance calculation in a nearest neighbor search. Objects of this class are immutable. + */ +public class NodeReferenceWithDistance extends NodeReferenceWithVector { + private final double distance; + + /** + * Constructs a new instance of {@code NodeReferenceWithDistance}. + *

+ * This constructor initializes the reference with the node's primary key, its vector, and the calculated distance + * from some origin vector (e.g., a query vector). It calls the superclass constructor to set the {@code primaryKey} + * and {@code vector}. + * @param primaryKey the primary key of the referenced node, represented as a {@link Tuple}. Must not be null. + * @param vector the vector associated with the referenced node. Must not be null. + * @param distance the calculated distance of this node reference to some query vector or similar. + */ + public NodeReferenceWithDistance(@Nonnull final Tuple primaryKey, @Nonnull final Transformed vector, + final double distance) { + super(primaryKey, vector); + this.distance = distance; + } + + /** + * Gets the distance. + * @return the current distance value + */ + public double getDistance() { + return distance; + } + + /** + * Compares this object against the specified object for equality. + *

+ * The result is {@code true} if and only if the argument is not {@code null}, + * is a {@code NodeReferenceWithDistance} object, has the same properties as + * determined by the superclass's {@link #equals(Object)} method, and has + * the same {@code distance} value. + * @param o the object to compare with this instance for equality. + * @return {@code true} if the specified object is equal to this {@code NodeReferenceWithDistance}; + * {@code false} otherwise. + */ + @Override + public boolean equals(final Object o) { + if (!super.equals(o)) { + return false; + } + final NodeReferenceWithDistance that = (NodeReferenceWithDistance)o; + return Double.compare(distance, that.distance) == 0; + } + + /** + * Generates a hash code for this object. + * @return a hash code value for this object. + */ + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), distance); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java new file mode 100644 index 0000000000..62b71300de --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java @@ -0,0 +1,120 @@ +/* + * NodeReferenceWithVector.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.Transformed; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import java.util.Objects; + +/** + * Represents a reference to a node that includes an associated vector. + *

+ * This class extends {@link NodeReference} by adding a {@link RealVector} field. It encapsulates both the primary key + * of a node and its corresponding vector data, which is particularly useful in vector-based search and + * indexing scenarios. Primarily, node references are used to refer to {@link Node}s in a storage-independent way, i.e. + * a node reference always contains the vector of a node while the node itself (depending on the storage adapter) + * may not. + */ +public class NodeReferenceWithVector extends NodeReference { + @Nonnull + private final Transformed vector; + + /** + * Constructs a new {@code NodeReferenceWithVector} with a specified primary key and vector. + *

+ * The primary key is used to initialize the parent class via a call to {@code super()}, + * while the vector is stored as a field in this instance. Both parameters are expected + * to be non-null. + * + * @param primaryKey the primary key of the node, must not be null + * @param vector the vector associated with the node, must not be null + */ + public NodeReferenceWithVector(@Nonnull final Tuple primaryKey, @Nonnull final Transformed vector) { + super(primaryKey); + this.vector = vector; + } + + /** + * Gets the vector of {@code Half} objects. + *

+ * This method provides access to the internal vector. The returned vector is guaranteed + * not to be null, as indicated by the {@code @Nonnull} annotation. + * + * @return the vector of {@code Half} objects; will never be {@code null}. + */ + @Nonnull + public Transformed getVector() { + return vector; + } + + /** + * Override to declare that this class in fact is a {@link NodeReferenceWithVector}. + * @return {@code true} + */ + @Override + boolean isNodeReferenceWithVector() { + return true; + } + + /** + * Returns this instance cast as a {@code NodeReferenceWithVector}. + * @return this instance as a {@code NodeReferenceWithVector}, which is never {@code null}. + */ + @Nonnull + @Override + public NodeReferenceWithVector asNodeReferenceWithVector() { + return this; + } + + /** + * Compares this {@code NodeReferenceWithVector} to the specified object for equality. + * @param o the object to compare with this {@code NodeReferenceWithVector}. + * @return {@code true} if the objects are equal; {@code false} otherwise. + */ + @Override + public boolean equals(final Object o) { + if (!super.equals(o)) { + return false; + } + return Objects.equals(vector, ((NodeReferenceWithVector)o).vector); + } + + /** + * Computes the hash code for this object. + * @return a hash code value for this object. + */ + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), vector); + } + + /** + * Returns a string representation of this object. + * @return a concise string representation of this object. + */ + @Override + public String toString() { + return "NRV[primaryKey=" + getPrimaryKey() + ";vector=" + vector + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java new file mode 100644 index 0000000000..4844ddf98a --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java @@ -0,0 +1,79 @@ +/* + * OnReadListener.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.concurrent.CompletableFuture; + +/** + * Interface for call backs whenever we read node data from the database. + */ +public interface OnReadListener { + OnReadListener NOOP = new OnReadListener() { + }; + + /** + * A callback method that can be overridden to intercept the result of an asynchronous node read. + *

+ * This method provides a hook for subclasses to inspect or modify the {@code CompletableFuture} after an + * asynchronous read operation is initiated. The default implementation is a no-op that simply returns the original + * future. This method is intended to be used to measure elapsed time between the creation of a + * {@link CompletableFuture} and its completion. + * @param the type of the {@code NodeReference} + * @param future the {@code CompletableFuture} representing the pending asynchronous read operation. + * @return a {@code CompletableFuture} that will complete with the read {@code Node}. + * By default, this is the same future that was passed as an argument. + */ + @SuppressWarnings("unused") + default > CompletableFuture onAsyncRead(@Nonnull CompletableFuture future) { + return future; + } + + /** + * Callback method invoked when a node is read during a traversal process. + *

+ * This default implementation does nothing. Implementors can override this method to add custom logic that should + * be executed for each node encountered. This serves as an optional hook for processing nodes as they are read. + * @param layer the layer or depth of the node in the structure, starting from 0. + * @param node the {@link Node} that was just read (guaranteed to be non-null). + */ + @SuppressWarnings("unused") + default void onNodeRead(int layer, @Nonnull Node node) { + // nothing + } + + /** + * Callback invoked when a key-value pair is read from a specific layer. + *

+ * This method is typically called during a scan or iteration over data for each key/value pair. + * The default implementation is a no-op and does nothing. + * @param layer the layer from which the key-value pair was read. + * @param key the key that was read, guaranteed to be non-null. + * @param value the value associated with the key, can be null if the key was not found + */ + @SuppressWarnings("unused") + default void onKeyValueRead(int layer, + @Nonnull byte[] key, + @Nullable byte[] value) { + // nothing + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java new file mode 100644 index 0000000000..aacc1ca8f2 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java @@ -0,0 +1,86 @@ +/* + * OnWriteListener.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; + +/** + * Interface for call backs whenever we write data to the database. + */ +public interface OnWriteListener { + OnWriteListener NOOP = new OnWriteListener() { + }; + + /** + * Callback method invoked after a node has been successfully written to a specific layer. + *

+ * This is a default method with an empty implementation, allowing implementing classes to override it only if they + * need to react to this event. + * @param layer the index of the layer where the node was written. + * @param node the {@link Node} that was written; guaranteed to be non-null. + */ + @SuppressWarnings("unused") + default void onNodeWritten(final int layer, @Nonnull final Node node) { + // nothing + } + + /** + * Callback method invoked when a neighbor is written for a specific node. + *

+ * This method serves as a notification that a neighbor relationship has been established or updated. It is + * typically called after a write operation successfully adds a {@code neighbor} to the specified {@code node} + * within a given {@code layer}. + *

+ * As a {@code default} method, the base implementation does nothing. Implementers can override this to perform + * custom actions, such as updating caches or triggering subsequent events in response to the change. + * @param layer the index of the layer where the neighbor write operation occurred + * @param node the {@link Node} for which the neighbor was written; must not be null + * @param neighbor the {@link NodeReference} of the neighbor that was written; must not be null + */ + @SuppressWarnings("unused") + default void onNeighborWritten(final int layer, @Nonnull final Node node, + @Nonnull final NodeReference neighbor) { + // nothing + } + + /** + * Callback method invoked when a neighbor of a specific node is deleted. + *

+ * This is a default method and its base implementation is a no-op. Implementors of the interface can override this + * method to react to the deletion of a neighbor node, for example, to clean up related resources or update internal + * state. + * @param layer the layer index where the deletion occurred + * @param node the {@link Node} whose neighbor was deleted + * @param neighborPrimaryKey the primary key (as a {@link Tuple}) of the neighbor that was deleted + */ + @SuppressWarnings("unused") + default void onNeighborDeleted(final int layer, @Nonnull final Node node, + @Nonnull final Tuple neighborPrimaryKey) { + // nothing + } + + @SuppressWarnings("unused") + default void onKeyValueWritten(final int layer, @Nonnull final byte[] key, @Nonnull final byte[] value) { + // nothing + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/ResultEntry.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/ResultEntry.java new file mode 100644 index 0000000000..f76f9fa48e --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/ResultEntry.java @@ -0,0 +1,117 @@ +/* + * ResultEntry.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.Objects; + +/** + * Record-like class to wrap the results of a kNN-search. + */ +public class ResultEntry { + /** + * Primary key of the item in the HNSW. + */ + @Nonnull + private final Tuple primaryKey; + + /** + * The vector that is stored with the item in the structure. This vector is expressed in the client's coordinate + * system and should be of class {@link com.apple.foundationdb.linear.HalfRealVector}, + * {@link com.apple.foundationdb.linear.FloatRealVector}, or {@link com.apple.foundationdb.linear.DoubleRealVector}. + * This member is nullable. It is set to {@code null}, if the caller to + * {@link HNSW#kNearestNeighborsSearch(ReadTransaction, int, int, boolean, RealVector)} requested to not return + * vectors. + *

+ * The vector, if set, may or may not be exactly equal to the vector that was originally inserted in the HNSW. + * Depending on quantization settings (see {@link Config}, the vector that + * is returned may only be an approximation of the original vector. + */ + @Nullable + private final RealVector vector; + + /** + * The distance of item's vector to the query vector. + */ + private final double distance; + + /** + * The row number of the item. TODO support rank. + */ + private final int rankOrRowNumber; + + public ResultEntry(@Nonnull final Tuple primaryKey, @Nullable final RealVector vector, final double distance, + final int rankOrRowNumber) { + this.primaryKey = primaryKey; + this.vector = vector; + this.distance = distance; + this.rankOrRowNumber = rankOrRowNumber; + } + + @Nonnull + public Tuple getPrimaryKey() { + return primaryKey; + } + + @Nullable + public RealVector getVector() { + return vector; + } + + public double getDistance() { + return distance; + } + + public int getRankOrRowNumber() { + return rankOrRowNumber; + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof ResultEntry)) { + return false; + } + final ResultEntry that = (ResultEntry)o; + return Double.compare(distance, that.distance) == 0 && + rankOrRowNumber == that.rankOrRowNumber && + Objects.equals(primaryKey, that.primaryKey) && + Objects.equals(vector, that.vector); + } + + @Override + public int hashCode() { + return Objects.hash(primaryKey, vector, distance, rankOrRowNumber); + } + + @Override + public String toString() { + return "[" + + "primaryKey=" + primaryKey + + ", vector=" + vector + + ", distance=" + distance + + ", rankOrRowNumber=" + rankOrRowNumber + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java new file mode 100644 index 0000000000..9bffaacc67 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java @@ -0,0 +1,390 @@ +/* + * StorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.KeyValue; +import com.apple.foundationdb.Range; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.StreamingMode; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.async.AsyncIterable; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.linear.AffineOperator; +import com.apple.foundationdb.linear.DoubleRealVector; +import com.apple.foundationdb.linear.FloatRealVector; +import com.apple.foundationdb.linear.HalfRealVector; +import com.apple.foundationdb.linear.Quantizer; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.Transformed; +import com.apple.foundationdb.linear.VectorType; +import com.apple.foundationdb.rabitq.EncodedRealVector; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Verify; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; + +/** + * Defines the contract for storing and retrieving HNSW graph data to/from a persistent store. + *

+ * This interface provides an abstraction layer over the underlying database, handling the serialization and + * deserialization of HNSW graph components such as nodes, vectors, and their relationships. Implementations of this + * interface are responsible for managing the physical layout of data within a given {@link Subspace}. + * The generic type {@code N} represents the specific type of {@link NodeReference} that this storage adapter manages. + * + * @param the type of {@link NodeReference} this storage adapter manages + */ +interface StorageAdapter { + ImmutableList VECTOR_TYPES = ImmutableList.copyOf(VectorType.values()); + + /** + * Subspace for data. + */ + long SUBSPACE_PREFIX_DATA = 0x00; + + /** + * Subspace for the access info; contains entry nodes; these are kept separately from the data. + */ + long SUBSPACE_PREFIX_ACCESS_INFO = 0x01; + + /** + * Subspace for (mostly) statistical analysis (like finding a centroid, etc.). Contains samples of vectors. + */ + long SUBSPACE_PREFIX_SAMPLES = 0x02; + + /** + * Returns the configuration of the HNSW graph. + *

+ * This configuration object contains all the parameters used to build and search the graph, + * such as the number of neighbors to connect (M), the size of the dynamic list for + * construction (efConstruction), and the beam width for searching (ef). + * @return the {@code HNSW.Config} for this graph, never {@code null}. + */ + @Nonnull + Config getConfig(); + + /** + * Gets the factory used to create new nodes. + *

+ * This factory is responsible for instantiating new nodes of type {@code N}. + * @return the non-null factory for creating nodes. + */ + @Nonnull + NodeFactory getNodeFactory(); + + /** + * Get the subspace used to store this HNSW structure. + * @return the subspace + */ + @Nonnull + Subspace getSubspace(); + + /** + * Gets the subspace that contains the data for this object. + *

+ * This subspace represents the portion of the keyspace dedicated to storing the actual data, as opposed to metadata + * or other system-level information. + * @return the subspace containing the data, which is guaranteed to be non-null + */ + @Nonnull + Subspace getDataSubspace(); + + /** + * Get the on-write listener. + * @return the on-write listener. + */ + @Nonnull + OnWriteListener getOnWriteListener(); + + /** + * Get the on-read listener. + * @return the on-read listener. + */ + @Nonnull + OnReadListener getOnReadListener(); + + /** + * Asynchronously fetches a node from a specific layer, identified by its primary key. + *

+ * The fetch operation is performed within the scope of the provided {@link ReadTransaction}, ensuring a consistent + * view of the data. The returned {@link CompletableFuture} will be completed with the node once it has been + * retrieved from the underlying data store. + * @param readTransaction the {@link ReadTransaction} context for this read operation + * @param storageTransform an affine vector transformation operator that is used to transform the fetched vector + * into the storage space that is currently being used + * @param layer the layer from which to fetch the node + * @param primaryKey the {@link Tuple} representing the primary key of the node to retrieve + * @return a non-null {@link CompletableFuture} which will complete with the fetched {@link AbstractNode}. + */ + @Nonnull + CompletableFuture> fetchNode(@Nonnull ReadTransaction readTransaction, + @Nonnull AffineOperator storageTransform, + int layer, + @Nonnull Tuple primaryKey); + + /** + * Writes a node and its neighbor changes to the data store within a given transaction. + *

+ * This method is responsible for persisting the state of a {@link AbstractNode} and applying any modifications to its + * neighboring nodes as defined in the {@code NeighborsChangeSet}. The entire operation is performed atomically as + * part of the provided {@link Transaction}. + * @param transaction the non-null transaction context for this write operation. + * @param quantizer the quantizer to use + * @param node the non-null node to be written to the data store. + * @param layer the layer index where the node resides. + * @param changeSet the non-null set of changes describing additions or removals of + * neighbors for the given {@link AbstractNode}. + */ + void writeNode(@Nonnull Transaction transaction, @Nonnull Quantizer quantizer, @Nonnull AbstractNode node, + int layer, @Nonnull NeighborsChangeSet changeSet); + + /** + * Scans a specified layer of the structure, returning an iterable sequence of nodes. + *

+ * This method allows for paginated scanning of a layer. The scan can be started from the beginning of the layer by + * passing {@code null} for the {@code lastPrimaryKey}, or it can be resumed from a previous point by providing the + * key of the last item from the prior scan. The number of nodes returned is limited by {@code maxNumRead}. + * + * @param readTransaction the transaction to use for the read operation + * @param layer the index of the layer to scan + * @param lastPrimaryKey the primary key of the last node from a previous scan, + * or {@code null} to start from the beginning of the layer + * @param maxNumRead the maximum number of nodes to return in this scan + * @return an {@link AsyncIterable} that provides the nodes found in the specified layer range + */ + @VisibleForTesting + Iterable> scanLayer(@Nonnull ReadTransaction readTransaction, int layer, + @Nullable Tuple lastPrimaryKey, int maxNumRead); + + /** + * Creates a {@link RealVector} from a given {@link Tuple}. + *

+ * This method assumes the vector data is stored as a byte array at the first. position (index 0) of the tuple. It + * extracts this byte array and then delegates to the {@link #vectorFromBytes(Config, byte[])} method for the + * actual conversion. + * @param config an HNSW configuration + * @param vectorTuple the tuple containing the vector data as a byte array at index 0. Must not be {@code null}. + * @return a new {@link RealVector} instance created from the tuple's data. + * This method never returns {@code null}. + */ + @Nonnull + static RealVector vectorFromTuple(@Nonnull final Config config, @Nonnull final Tuple vectorTuple) { + return vectorFromBytes(config, vectorTuple.getBytes(0)); + } + + /** + * Creates a {@link RealVector} from a byte array. + *

+ * This method interprets the input byte array by interpreting the first byte of the array as the precision shift. + * The byte array must have the proper size, i.e. the invariant {@code (bytesLength - 1) % precision == 0} must + * hold. + * @param config an HNSW config + * @param vectorBytes the non-null byte array to convert. + * @return a new {@link RealVector} instance created from the byte array. + * @throws com.google.common.base.VerifyException if the length of {@code vectorBytes} does not meet the invariant + * {@code (bytesLength - 1) % precision == 0} + */ + @Nonnull + static RealVector vectorFromBytes(@Nonnull final Config config, @Nonnull final byte[] vectorBytes) { + final byte vectorTypeOrdinal = vectorBytes[0]; + switch (fromVectorTypeOrdinal(vectorTypeOrdinal)) { + case HALF: + return HalfRealVector.fromBytes(vectorBytes); + case SINGLE: + return FloatRealVector.fromBytes(vectorBytes); + case DOUBLE: + return DoubleRealVector.fromBytes(vectorBytes); + case RABITQ: + Verify.verify(config.isUseRaBitQ()); + return EncodedRealVector.fromBytes(vectorBytes, config.getNumDimensions(), + config.getRaBitQNumExBits()); + default: + throw new RuntimeException("unable to serialize vector"); + } + } + + /** + * Converts a transformed vector into a tuple. + * @param vector a transformed vector + * @return a new, non-null {@code Tuple} instance representing the contents of the underlying vector. + */ + @Nonnull + static Tuple tupleFromVector(@Nonnull final Transformed vector) { + return tupleFromVector(vector.getUnderlyingVector()); + } + + /** + * Converts a {@link RealVector} into a {@link Tuple}. + *

+ * This method first serializes the given vector into a byte array using the {@link RealVector#getRawData()} getter + * method. It then creates a {@link Tuple} from the resulting byte array. + * @param vector the {@link RealVector} to convert. Cannot be null. + * @return a new, non-null {@code Tuple} instance representing the contents of the vector. + */ + @Nonnull + @SuppressWarnings("PrimitiveArrayArgumentToVarargsMethod") + static Tuple tupleFromVector(@Nonnull final RealVector vector) { + return Tuple.from(vector.getRawData()); + } + + @Nonnull + static VectorType fromVectorTypeOrdinal(final int ordinal) { + return VECTOR_TYPES.get(ordinal); + } + + @Nonnull + static CompletableFuture fetchAccessInfo(@Nonnull final Config config, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final Subspace subspace, + @Nonnull final OnReadListener onReadListener) { + final Subspace entryNodeSubspace = accessInfoSubspace(subspace); + final byte[] key = entryNodeSubspace.pack(); + + return readTransaction.get(key) + .thenApply(valueBytes -> { + onReadListener.onKeyValueRead(-1, key, valueBytes); + if (valueBytes == null) { + return null; // not a single node in the index + } + + final Tuple entryTuple = Tuple.fromBytes(valueBytes); + final int layer = (int)entryTuple.getLong(0); + final Tuple primaryKey = entryTuple.getNestedTuple(1); + final Tuple entryVectorTuple = entryTuple.getNestedTuple(2); + final Transformed entryNodeVector = + AffineOperator.identity() + .transform(StorageAdapter.vectorFromTuple(config, entryVectorTuple)); + final EntryNodeReference entryNodeReference = + new EntryNodeReference(primaryKey, entryNodeVector, layer); + final long rotatorSeed = entryTuple.getLong(3); + final Tuple centroidVectorTuple = entryTuple.getNestedTuple(4); + return new AccessInfo(entryNodeReference, + rotatorSeed, + centroidVectorTuple == null + ? null + : StorageAdapter.vectorFromTuple(config, centroidVectorTuple)); + }); + } + + /** + * Writes an {@link AccessInfo} to the database within a given transaction and subspace. + *

+ * This method serializes the provided {@link EntryNodeReference} into a key-value pair. The key is determined by + * a dedicated subspace for entry nodes, and the value is a tuple containing the layer, primary key, and vector from + * the reference. After writing the data, it notifies the provided {@link OnWriteListener}. + * @param transaction the database transaction to use for the write operation + * @param subspace the subspace where the entry node reference will be stored + * @param accessInfo the {@link AccessInfo} object to write + * @param onWriteListener the listener to be notified after the key-value pair is written + */ + static void writeAccessInfo(@Nonnull final Transaction transaction, + @Nonnull final Subspace subspace, + @Nonnull final AccessInfo accessInfo, + @Nonnull final OnWriteListener onWriteListener) { + final Subspace entryNodeSubspace = accessInfoSubspace(subspace); + final EntryNodeReference entryNodeReference = accessInfo.getEntryNodeReference(); + final RealVector centroid = accessInfo.getNegatedCentroid(); + final byte[] key = entryNodeSubspace.pack(); + final byte[] value = Tuple.from(entryNodeReference.getLayer(), + entryNodeReference.getPrimaryKey(), + // getting underlying is okay as it is only written to the database + StorageAdapter.tupleFromVector(entryNodeReference.getVector()), + accessInfo.getRotatorSeed(), + centroid == null ? null : StorageAdapter.tupleFromVector(centroid)).pack(); + transaction.set(key, value); + onWriteListener.onKeyValueWritten(entryNodeReference.getLayer(), key, value); + } + + @Nonnull + static CompletableFuture> consumeSampledVectors(@Nonnull final Transaction transaction, + @Nonnull final Subspace subspace, + final int numMaxVectors, + @Nonnull final OnReadListener onReadListener) { + final Subspace prefixSubspace = samplesSubspace(subspace); + final byte[] prefixKey = prefixSubspace.pack(); + final ReadTransaction snapshot = transaction.snapshot(); + final Range range = Range.startsWith(prefixKey); + + return AsyncUtil.collect(snapshot.getRange(range, numMaxVectors, true, StreamingMode.ITERATOR), + snapshot.getExecutor()) + .thenApply(keyValues -> { + final ImmutableList.Builder resultBuilder = ImmutableList.builder(); + for (final KeyValue keyValue : keyValues) { + final byte[] key = keyValue.getKey(); + final byte[] value = keyValue.getValue(); + resultBuilder.add(aggregatedVectorFromRaw(prefixSubspace, key, value)); + transaction.addReadConflictKey(key); + transaction.clear(key); + onReadListener.onKeyValueRead(-1, key, value); + } + return resultBuilder.build(); + }); + } + + static void appendSampledVector(@Nonnull final Transaction transaction, + @Nonnull final Subspace subspace, + final int partialCount, + @Nonnull final Transformed vector, + @Nonnull final OnWriteListener onWriteListener) { + final Subspace prefixSubspace = samplesSubspace(subspace); + final Subspace keySubspace = prefixSubspace.subspace(Tuple.from(partialCount, UUID.randomUUID())); + final byte[] prefixKey = keySubspace.pack(); + // getting underlying is okay as it is only written to the database + final byte[] value = tupleFromVector(vector.getUnderlyingVector().toDoubleRealVector()).pack(); + transaction.set(prefixKey, value); + onWriteListener.onKeyValueWritten(-1, prefixKey, value); + } + + static void removeAllSampledVectors(@Nonnull final Transaction transaction, @Nonnull final Subspace subspace) { + final Subspace prefixSubspace = samplesSubspace(subspace); + + final byte[] prefixKey = prefixSubspace.pack(); + final Range range = Range.startsWith(prefixKey); + transaction.clear(range); + } + + @Nonnull + private static AggregatedVector aggregatedVectorFromRaw(@Nonnull final Subspace prefixSubspace, + @Nonnull final byte[] key, + @Nonnull final byte[] value) { + final Tuple keyTuple = prefixSubspace.unpack(key); + final int partialCount = Math.toIntExact(keyTuple.getLong(0)); + final RealVector vector = DoubleRealVector.fromBytes(Tuple.fromBytes(value).getBytes(0)); + + return new AggregatedVector(partialCount, AffineOperator.identity().transform(vector)); + } + + @Nonnull + static Subspace accessInfoSubspace(@Nonnull final Subspace rootSubspace) { + return rootSubspace.subspace(Tuple.from(SUBSPACE_PREFIX_ACCESS_INFO)); + } + + @Nonnull + static Subspace samplesSubspace(@Nonnull final Subspace rootSubspace) { + return rootSubspace.subspace(Tuple.from(SUBSPACE_PREFIX_SAMPLES)); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageTransform.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageTransform.java new file mode 100644 index 0000000000..3cac6f4826 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageTransform.java @@ -0,0 +1,70 @@ +/* + * StorageTransform.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.linear.AffineOperator; +import com.apple.foundationdb.linear.FhtKacRotator; +import com.apple.foundationdb.linear.LinearOperator; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.rabitq.EncodedRealVector; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * A special affine operator that uses a random rotator seeded by the current {@link AccessInfo} and a given + * (pre-rotated) centroid. This operator is used inside the HNSW to transform back and forth between the coordinate + * system of the client and the coordinate system that is currently employed in the HNSW. + */ +class StorageTransform extends AffineOperator { + public StorageTransform(final long seed, final int numDimensions, + @Nonnull final RealVector translationVector) { + this(new FhtKacRotator(seed, numDimensions, 10), translationVector); + } + + public StorageTransform(@Nullable final LinearOperator linearOperator, + @Nullable final RealVector translationVector) { + super(linearOperator, translationVector); + } + + @Nonnull + @Override + public RealVector apply(@Nonnull final RealVector vector) { + // + // Only transform the vector if it is needed. We make the decision based on whether the vector is encoded or + // not. When we switch on encoding, we apply the new coordinate system from that point onwards meaning that all + // vectors inserted before use the client coordinate system. Therefore, we must transform all regular vectors + // and ignore all encoded vectors. + // + // TODO This could be done better in the future by keeping something like a generation id with the vector + // so we would know in what coordinate system the vector is. + if (vector instanceof EncodedRealVector) { + return vector; + } + return super.apply(vector); + } + + @Nonnull + @Override + public RealVector invertedApply(@Nonnull final RealVector vector) { + return super.invertedApply(vector); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java new file mode 100644 index 0000000000..2fdee89f0b --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java @@ -0,0 +1,24 @@ +/* + * package-info.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Classes and interfaces related to the HNSW implementation as used for vector indexes. + */ +package com.apple.foundationdb.async.hnsw; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java index 2623cff1dc..5d7df33d98 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java @@ -3,7 +3,7 @@ * * This source file is part of the FoundationDB open source project * - * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/AffineOperator.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/AffineOperator.java new file mode 100644 index 0000000000..eb01db3f9d --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/AffineOperator.java @@ -0,0 +1,93 @@ +/* + * AffineOperator.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; +import com.google.common.base.Preconditions; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * Vector operator that applies/unapplies a linear operator and an addition to a vector. + */ +@SpotBugsSuppressWarnings(value = "SING_SINGLETON_HAS_NONPRIVATE_CONSTRUCTOR", justification = "Singleton designation is a false positive") +public class AffineOperator implements VectorOperator { + private static final AffineOperator IDENTITY_OPERATOR = new AffineOperator(null, null); + + @Nullable + private final LinearOperator linearOperator; + @Nullable + private final RealVector translationVector; + + public AffineOperator(@Nullable final LinearOperator linearOperator, @Nullable final RealVector translationVector) { + Preconditions.checkArgument(linearOperator == null || translationVector == null || + linearOperator.getNumColumnDimensions() == translationVector.getNumDimensions()); + this.linearOperator = linearOperator; + this.translationVector = translationVector; + } + + @Override + public int getNumDimensions() { + return linearOperator != null + ? linearOperator.getNumDimensions() + : (translationVector != null + ? translationVector.getNumDimensions() + : -1); + } + + @Nonnull + @Override + public RealVector apply(@Nonnull final RealVector vector) { + RealVector result = vector; + + if (linearOperator != null) { + result = linearOperator.apply(result); + } + + if (translationVector != null) { + result = result.add(translationVector); + } + + return result; + } + + @Nonnull + @Override + public RealVector invertedApply(@Nonnull final RealVector vector) { + RealVector result = vector; + + if (translationVector != null) { + result = result.subtract(translationVector); + } + + if (linearOperator != null) { + result = linearOperator.transposedApply(result); + } + + return result; + } + + @Nonnull + public static AffineOperator identity() { + return IDENTITY_OPERATOR; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/ColumnMajorRealMatrix.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/ColumnMajorRealMatrix.java index a6e58ea05d..ca844e172a 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/ColumnMajorRealMatrix.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/ColumnMajorRealMatrix.java @@ -41,18 +41,17 @@ public ColumnMajorRealMatrix(@Nonnull final double[][] data) { } @Nonnull - @Override - public double[][] getData() { + private double[][] getData() { return data; } @Override - public int getRowDimension() { + public int getNumRowDimensions() { return data[0].length; } @Override - public int getColumnDimension() { + public int getNumColumnDimensions() { return data.length; } @@ -68,9 +67,9 @@ public double[] getColumn(final int column) { @Nonnull @Override - public RealMatrix transpose() { - int n = getRowDimension(); - int m = getColumnDimension(); + public ColumnMajorRealMatrix transpose() { + int n = getNumRowDimensions(); + int m = getNumColumnDimensions(); double[][] result = new double[n][m]; for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { @@ -82,11 +81,11 @@ public RealMatrix transpose() { @Nonnull @Override - public RealMatrix multiply(@Nonnull final RealMatrix otherMatrix) { - Preconditions.checkArgument(getColumnDimension() == otherMatrix.getRowDimension()); - int n = getRowDimension(); - int m = otherMatrix.getColumnDimension(); - int common = getColumnDimension(); + public ColumnMajorRealMatrix multiply(@Nonnull final RealMatrix otherMatrix) { + Preconditions.checkArgument(getNumColumnDimensions() == otherMatrix.getNumRowDimensions()); + int n = getNumRowDimensions(); + int m = otherMatrix.getNumColumnDimensions(); + int common = getNumColumnDimensions(); double[][] result = new double[m][n]; for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { @@ -100,7 +99,8 @@ public RealMatrix multiply(@Nonnull final RealMatrix otherMatrix) { @Nonnull @Override - public RealMatrix subMatrix(final int startRow, final int lengthRow, final int startColumn, final int lengthColumn) { + public ColumnMajorRealMatrix subMatrix(final int startRow, final int lengthRow, + final int startColumn, final int lengthColumn) { final double[][] subData = new double[lengthColumn][lengthRow]; for (int j = startColumn; j < startColumn + lengthColumn; j ++) { @@ -113,7 +113,13 @@ public RealMatrix subMatrix(final int startRow, final int lengthRow, final int s @Nonnull @Override public RowMajorRealMatrix toRowMajor() { - return new RowMajorRealMatrix(transpose().getData()); + return new RowMajorRealMatrix(getRowMajorData()); + } + + @Nonnull + @Override + public double[][] getRowMajorData() { + return transpose().getData(); } @Nonnull @@ -124,12 +130,30 @@ public ColumnMajorRealMatrix toColumnMajor() { @Nonnull @Override - public RealMatrix quickTranspose() { - return new RowMajorRealMatrix(data); + public double[][] getColumnMajorData() { + return getData(); + } + + @Nonnull + @Override + public RowMajorRealMatrix quickTranspose() { + return new RowMajorRealMatrix(getColumnMajorData()); + } + + @Nonnull + @Override + public RowMajorRealMatrix flipMajor() { + return (RowMajorRealMatrix)RealMatrix.super.flipMajor(); } @Override public final boolean equals(final Object o) { + if (o == null) { + return false; + } + if (this == o) { + return true; + } if (o instanceof ColumnMajorRealMatrix) { final ColumnMajorRealMatrix that = (ColumnMajorRealMatrix)o; return Arrays.deepEquals(data, that.data); diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Estimator.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Estimator.java index b11377a688..4c140d43fc 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Estimator.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Estimator.java @@ -30,6 +30,11 @@ * vector (the query) is compared against many stored vectors. */ public interface Estimator { + default double distance(@Nonnull final Transformed query, + @Nonnull final Transformed storedVector) { + return distance(query.getUnderlyingVector(), storedVector.getUnderlyingVector()); + } + /** * Calculates the distance between a pre-rotated and translated query vector and a stored vector. *

@@ -41,6 +46,6 @@ public interface Estimator { * @param storedVector the stored vector to which the distance is calculated, cannot be null. * @return a non-negative {@code double} representing the distance between the two vectors. */ - double distance(@Nonnull RealVector query, // pre-rotated query q + double distance(@Nonnull RealVector query, @Nonnull RealVector storedVector); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/FhtKacRotator.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/FhtKacRotator.java index 9ff2f776f5..4aabdd6087 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/FhtKacRotator.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/FhtKacRotator.java @@ -95,12 +95,12 @@ public FhtKacRotator(final long seed, final int numDimensions, final int rounds) } @Override - public int getRowDimension() { + public int getNumRowDimensions() { return numDimensions; } @Override - public int getColumnDimension() { + public int getNumColumnDimensions() { return numDimensions; } @@ -111,7 +111,7 @@ public boolean isTransposable() { @Nonnull @Override - public RealVector operate(@Nonnull final RealVector x) { + public RealVector apply(@Nonnull final RealVector x) { return new DoubleRealVector(operate(x.getData())); } @@ -142,12 +142,12 @@ private double[] operate(@Nonnull final double[] x) { @Nonnull @Override - public RealVector operateTranspose(@Nonnull final RealVector x) { + public RealVector transposedApply(@Nonnull final RealVector x) { return new DoubleRealVector(operateTranspose(x.getData())); } @Nonnull - public double[] operateTranspose(@Nonnull final double[] x) { + private double[] operateTranspose(@Nonnull final double[] x) { if (x.length != numDimensions) { throw new IllegalArgumentException("dimensionality of x != n"); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/LinearOperator.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/LinearOperator.java index f19f02f50b..aedec85ffb 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/LinearOperator.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/LinearOperator.java @@ -22,20 +22,28 @@ import javax.annotation.Nonnull; -public interface LinearOperator { - int getRowDimension(); +public interface LinearOperator extends VectorOperator { + int getNumRowDimensions(); - int getColumnDimension(); + @Override + default int getNumDimensions() { + return getNumColumnDimensions(); + } + + int getNumColumnDimensions(); default boolean isSquare() { - return getRowDimension() == getColumnDimension(); + return getNumRowDimensions() == getNumColumnDimensions(); } boolean isTransposable(); @Nonnull - RealVector operate(@Nonnull RealVector vector); + @Override + default RealVector invertedApply(@Nonnull RealVector vector) { + return transposedApply(vector); + } @Nonnull - RealVector operateTranspose(@Nonnull RealVector vector); + RealVector transposedApply(@Nonnull RealVector vector); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/QRDecomposition.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/QRDecomposition.java index da7038e578..1f6787a93b 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/QRDecomposition.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/QRDecomposition.java @@ -62,10 +62,10 @@ private QRDecomposition() { public static Result decomposeMatrix(@Nonnull final RealMatrix matrix) { Preconditions.checkArgument(matrix.isSquare()); - final double[] rDiagonal = new double[matrix.getRowDimension()]; - final double[][] qrt = matrix.toRowMajor().transpose().getData(); + final double[] rDiagonal = new double[matrix.getNumRowDimensions()]; + final double[][] qrt = matrix.transpose().getRowMajorData(); - for (int minor = 0; minor < matrix.getRowDimension(); minor++) { + for (int minor = 0; minor < matrix.getNumRowDimensions(); minor++) { performHouseholderReflection(minor, qrt, rDiagonal); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Quantizer.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Quantizer.java index b8018a7320..6e851f2b9a 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Quantizer.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Quantizer.java @@ -41,19 +41,24 @@ public interface Quantizer { @Nonnull Estimator estimator(); + @Nonnull + default Transformed encode(@Nonnull final Transformed vector) { + return new Transformed<>(encode(vector.getUnderlyingVector())); + } + /** * Encodes the given data vector into another vector representation. *

* This method transforms the raw input data into a different, quantized format, which is often a vector more * suitable for processing/storing the data. The specifics of the encoding depend on the implementation of the class. * - * @param data the input {@link RealVector} to be encoded. Must not be {@code null} and is assumed to have been + * @param vector the input {@link RealVector} to be encoded. Must not be {@code null} and is assumed to have been * preprocessed, such as by rotation and/or translation. The preprocessing has to align with the requirements * of the specific quantizer. * @return the encoded vector representation of the input data, guaranteed to be non-null. */ @Nonnull - RealVector encode(@Nonnull RealVector data); + RealVector encode(@Nonnull RealVector vector); /** * Creates a no-op {@code Quantizer} that does not perform any data transformation. @@ -79,8 +84,8 @@ public Estimator estimator() { @Nonnull @Override - public RealVector encode(@Nonnull final RealVector data) { - return data; + public RealVector encode(@Nonnull final RealVector vector) { + return vector; } }; } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealMatrix.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealMatrix.java index 4d9e4638f3..60fc8cff00 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealMatrix.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealMatrix.java @@ -26,9 +26,6 @@ import javax.annotation.Nullable; public interface RealMatrix extends LinearOperator { - @Nonnull - double[][] getData(); - double getEntry(int row, int column); @Override @@ -41,12 +38,12 @@ default boolean isTransposable() { @Nonnull @Override - default RealVector operate(@Nonnull final RealVector vector) { - Verify.verify(getColumnDimension() == vector.getNumDimensions()); - final double[] result = new double[getRowDimension()]; - for (int i = 0; i < getRowDimension(); i ++) { + default RealVector apply(@Nonnull final RealVector vector) { + Verify.verify(getNumColumnDimensions() == vector.getNumDimensions()); + final double[] result = new double[getNumRowDimensions()]; + for (int i = 0; i < getNumRowDimensions(); i ++) { double sum = 0.0d; - for (int j = 0; j < getColumnDimension(); j ++) { + for (int j = 0; j < getNumColumnDimensions(); j ++) { sum += getEntry(i, j) * vector.getComponent(j); } result[i] = sum; @@ -56,12 +53,12 @@ default RealVector operate(@Nonnull final RealVector vector) { @Nonnull @Override - default RealVector operateTranspose(@Nonnull final RealVector vector) { - Verify.verify(getRowDimension() == vector.getNumDimensions()); - final double[] result = new double[getColumnDimension()]; - for (int j = 0; j < getColumnDimension(); j ++) { + default RealVector transposedApply(@Nonnull final RealVector vector) { + Verify.verify(getNumRowDimensions() == vector.getNumDimensions()); + final double[] result = new double[getNumColumnDimensions()]; + for (int j = 0; j < getNumColumnDimensions(); j ++) { double sum = 0.0d; - for (int i = 0; i < getRowDimension(); i ++) { + for (int i = 0; i < getNumRowDimensions(); i ++) { sum += getEntry(i, j) * vector.getComponent(i); } result[j] = sum; @@ -78,25 +75,36 @@ default RealVector operateTranspose(@Nonnull final RealVector vector) { @Nonnull RowMajorRealMatrix toRowMajor(); + @Nonnull + double[][] getRowMajorData(); + @Nonnull ColumnMajorRealMatrix toColumnMajor(); + @Nonnull + double[][] getColumnMajorData(); + @Nonnull RealMatrix quickTranspose(); + @Nonnull + default RealMatrix flipMajor() { + return transpose().quickTranspose(); + } + default boolean valueEquals(@Nullable final Object o) { if (!(o instanceof RealMatrix)) { return false; } final RealMatrix that = (RealMatrix)o; - if (getRowDimension() != that.getRowDimension() || - getColumnDimension() != that.getColumnDimension()) { + if (getNumRowDimensions() != that.getNumRowDimensions() || + getNumColumnDimensions() != that.getNumColumnDimensions()) { return false; } - for (int i = 0; i < getRowDimension(); i ++) { - for (int j = 0; j < getColumnDimension(); j ++) { + for (int i = 0; i < getNumRowDimensions(); i ++) { + for (int j = 0; j < getNumColumnDimensions(); j ++) { if (getEntry(i, j) != that.getEntry(i, j)) { return false; } @@ -107,8 +115,8 @@ default boolean valueEquals(@Nullable final Object o) { default int valueBasedHashCode() { int hashCode = 0; - for (int i = 0; i < getRowDimension(); i ++) { - for (int j = 0; j < getColumnDimension(); j ++) { + for (int i = 0; i < getNumRowDimensions(); i ++) { + for (int j = 0; j < getNumColumnDimensions(); j ++) { hashCode += 31 * Double.hashCode(getEntry(i, j)); } } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealVector.java index 44ff3c826d..b0c79513f2 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealVector.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealVector.java @@ -182,10 +182,10 @@ default RealVector subtract(final double scalar) { } @Nonnull - default RealVector multiply(final double scalar) { + default RealVector multiply(final double scalarFactor) { final double[] result = new double[getNumDimensions()]; for (int i = 0; i < getNumDimensions(); i ++) { - result[i] = getComponent(i) * scalar; + result[i] = getComponent(i) * scalarFactor; } return withData(result); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RowMajorRealMatrix.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RowMajorRealMatrix.java index 502c13d081..8a197c68a1 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RowMajorRealMatrix.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/RowMajorRealMatrix.java @@ -41,18 +41,17 @@ public RowMajorRealMatrix(@Nonnull final double[][] data) { } @Nonnull - @Override - public double[][] getData() { + private double[][] getData() { return data; } @Override - public int getRowDimension() { + public int getNumRowDimensions() { return data.length; } @Override - public int getColumnDimension() { + public int getNumColumnDimensions() { return data[0].length; } @@ -68,9 +67,9 @@ public double[] getRow(final int row) { @Nonnull @Override - public RealMatrix transpose() { - int n = getRowDimension(); - int m = getColumnDimension(); + public RowMajorRealMatrix transpose() { + int n = getNumRowDimensions(); + int m = getNumColumnDimensions(); double[][] result = new double[m][n]; for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { @@ -82,11 +81,11 @@ public RealMatrix transpose() { @Nonnull @Override - public RealMatrix multiply(@Nonnull final RealMatrix otherMatrix) { - Preconditions.checkArgument(getColumnDimension() == otherMatrix.getRowDimension()); - final int n = getRowDimension(); - final int m = otherMatrix.getColumnDimension(); - final int common = getColumnDimension(); + public RowMajorRealMatrix multiply(@Nonnull final RealMatrix otherMatrix) { + Preconditions.checkArgument(getNumColumnDimensions() == otherMatrix.getNumRowDimensions()); + final int n = getNumRowDimensions(); + final int m = otherMatrix.getNumColumnDimensions(); + final int common = getNumColumnDimensions(); double[][] result = new double[n][m]; for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { @@ -100,7 +99,8 @@ public RealMatrix multiply(@Nonnull final RealMatrix otherMatrix) { @Nonnull @Override - public RealMatrix subMatrix(final int startRow, final int lengthRow, final int startColumn, final int lengthColumn) { + public RowMajorRealMatrix subMatrix(final int startRow, final int lengthRow, + final int startColumn, final int lengthColumn) { final double[][] subData = new double[lengthRow][lengthColumn]; for (int i = startRow; i < startRow + lengthRow; i ++) { @@ -116,20 +116,44 @@ public RowMajorRealMatrix toRowMajor() { return this; } + @Nonnull + @Override + public double[][] getRowMajorData() { + return getData(); + } + @Nonnull @Override public ColumnMajorRealMatrix toColumnMajor() { - return new ColumnMajorRealMatrix(transpose().getData()); + return new ColumnMajorRealMatrix(getColumnMajorData()); } @Nonnull @Override - public RealMatrix quickTranspose() { - return new ColumnMajorRealMatrix(data); + public double[][] getColumnMajorData() { + return transpose().getData(); + } + + @Nonnull + @Override + public ColumnMajorRealMatrix quickTranspose() { + return new ColumnMajorRealMatrix(getRowMajorData()); + } + + @Nonnull + @Override + public ColumnMajorRealMatrix flipMajor() { + return (ColumnMajorRealMatrix)RealMatrix.super.flipMajor(); } @Override public final boolean equals(final Object o) { + if (o == null) { + return false; + } + if (this == o) { + return true; + } if (o instanceof RowMajorRealMatrix) { final RowMajorRealMatrix that = (RowMajorRealMatrix)o; return Arrays.deepEquals(data, that.data); diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Transformed.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Transformed.java new file mode 100644 index 0000000000..ad22d79f26 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/Transformed.java @@ -0,0 +1,111 @@ +/* + * Transformed.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import javax.annotation.Nonnull; +import java.util.Objects; + +/** + * This class aims to reduce potential logic problems with respect to coordinate transformations by soliciting help from + * Java's type system. + *

+ * While implementing complex algorithms that required coordinate transformations, some problems seemed to occur + * repeatedly and the following observations were made: + *

    + *
  • A few algorithms use an API that passes vectors back and forth in a coordinate system given by the user. + * Internally, however, the same algorithms transform these vectors into some other coordinate system that is more + * advantageous to the algorithm in some way. Therefore, vectors are constantly transformed back and forth between + * the respective coordinate systems.
  • + *
  • We observed cases where there are mixtures of vectors handled withing the same methods, i.e. some vectors + * were expressed using the internal and some vectors were expressed using the external coordinate system. + * Problems occur when these vectors are intermingled and the coordinate system mappings of the individual vectors + * are lost. + *
  • + *
  • + * We observed cases where a vector is transformed from one coordinate system to another one and then erroneously + * transformed a second time. + *
  • + *
+ *

+ * The following approach only makes sense for scenarios that deal with exactly two coordinate systems. + *

+ * We would like to express vectors in one system by {@link RealVector} whereas the vectors in the secondary system + * are expressed using {@link Transformed} of {@link RealVector}. The hope is that Java's compiler can assist in + * avoiding using the wrong sort of vector in the wrong situation. While it is possible to circumvent these best-effort + * type system-imposed restrictions, this class is meant to be utilized in a more pragmatic way. + *

+ * Objects of this class wrap some vector of type {@code V} creating a transformed vector. The visibility of + * this class' constructor is package-private by design. Only operators implementing {@link VectorOperator} can + * transform an instance of type {@code V} extends {@link RealVector} into a {@code Transformed} object. The same is + * true for inverse transformations: only operators can transform a {@code Transformed} vector back to the original + * vector. + *

+ * In other places where {@code Transformed}s are created (and destructed) users should be aware of exactly what happens + * and why. We tried to restrict visibilities of constructors and accessors, but due to Java's lack in expressiveness + * when it comes to type system finesse, this is a best-effort approach. If a {@code Transformed} is + * deconstructed using {@link #getUnderlyingVector()}, the user should ensure that the resulting vector is not + * further transformed by e.g. another affine operator. + * In short, we want to avoid users to write code similar to + * {@code someNewOperator.transform(oldTransformed.getUnderlyingVector()} as the result would be a transformed vector + * that is in fact transformed a second time. Note that this can make sense in some cases, however, in the described + * use case it mostly does not. + * @param the wrapped kind of {@link RealVector} + */ +public final class Transformed { + @Nonnull + private final V transformedVector; + + Transformed(@Nonnull final V transformedVector) { + this.transformedVector = transformedVector; + } + + @Nonnull + public V getUnderlyingVector() { + return transformedVector; + } + + public Transformed add(@Nonnull Transformed other) { + return new Transformed<>(transformedVector.add(other.transformedVector)); + } + + public Transformed multiply(double scalarFactor) { + return new Transformed<>(transformedVector.multiply(scalarFactor)); + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof Transformed)) { + return false; + } + final Transformed that = (Transformed)o; + return Objects.equals(transformedVector, that.transformedVector); + } + + @Override + public int hashCode() { + return Objects.hashCode(transformedVector); + } + + @Override + public String toString() { + return transformedVector.toString(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/linear/VectorOperator.java b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/VectorOperator.java new file mode 100644 index 0000000000..883c5b7fb3 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/linear/VectorOperator.java @@ -0,0 +1,78 @@ +/* + * VectorOperator.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import javax.annotation.Nonnull; + +/** + * Interface that represents the root of all linear and affine operators including matrices. A vector operator can + * be applied to a vector. Mathematically, there is another operator that if applied to the resulting vector + * results in the original vector. Instead of modeling this duality of operators as two distinct operators with an + * {@code apply()}, we instead only use one operator object that has the ability to both {@code apply()} and + * to {@code invertedApply()}. The invariants {@code apply(invertedApply(v)) == v} and + * {@code invertedApply(apply(v)) == v} + * both hold. + */ +public interface VectorOperator { + /** + * Returns the numbers of dimensions a vector must have to be able to be applied or apply-inverted. + * @return the numbers of dimensions this vector operator supports; can be {@code -1} if any number of dimensions + * is supported. + */ + int getNumDimensions(); + + /** + * Apply this operator to the vector passed in. + * @param vector the vector + * @return a new vector + */ + @Nonnull + RealVector apply(@Nonnull RealVector vector); + + /** + * Apply the inverted operator to the vector passed in. {@code applyInverted(apply(v)) == v} should hold. + * @param vector the vector + * @return a new vector + */ + @Nonnull + RealVector invertedApply(@Nonnull RealVector vector); + + /** + * Applies the operator to the vector that is passed in and creates a `Transformed` wrapper wrapping the result. + * @param vector the vector + * @return a {@link Transformed}-wrapped result + */ + @Nonnull + default Transformed transform(@Nonnull final RealVector vector) { + return new Transformed<>(apply(vector)); + } + + /** + * Inverted-applies the operator to a transformed vector that is passed in and returns a naked (unwrapped) + * {@link RealVector}. + * @param vector the vector + * @return a {@link Transformed}-wrapped result + */ + @Nonnull + default RealVector untransform(@Nonnull final Transformed vector) { + return invertedApply(vector.getUnderlyingVector()); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitEstimator.java b/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitEstimator.java index 2cc299e3b1..2eb01e7ae8 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitEstimator.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitEstimator.java @@ -48,8 +48,7 @@ public int getNumExBits() { } @Override - public double distance(@Nonnull final RealVector query, - @Nonnull final RealVector storedVector) { + public double distance(@Nonnull final RealVector query, @Nonnull final RealVector storedVector) { if (!(query instanceof EncodedRealVector) && storedVector instanceof EncodedRealVector) { // only use the estimator if the first (by convention) vector is not encoded, but the second is return distance(query, (EncodedRealVector)storedVector); @@ -61,13 +60,12 @@ public double distance(@Nonnull final RealVector query, return metric.distance(query, storedVector); } - private double distance(@Nonnull final RealVector query, // pre-rotated query q - @Nonnull final EncodedRealVector encodedVector) { + private double distance(@Nonnull final RealVector query, @Nonnull final EncodedRealVector encodedVector) { return estimateDistanceAndErrorBound(query, encodedVector).getDistance(); } @Nonnull - public Result estimateDistanceAndErrorBound(@Nonnull final RealVector query, // pre-rotated query q + public Result estimateDistanceAndErrorBound(@Nonnull final RealVector query, @Nonnull final EncodedRealVector encodedVector) { final double cb = (1 << numExBits) - 0.5; final double gAdd = query.dot(query); diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitQuantizer.java b/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitQuantizer.java index 6204d2b909..f589cd0081 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitQuantizer.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/rabitq/RaBitQuantizer.java @@ -106,15 +106,17 @@ public RaBitEstimator estimator() { * core encoding logic to an internal helper method and returns the final * {@link EncodedRealVector}. * - * @param data the {@link RealVector} to be encoded; must not be null. The vector must be pre-rotated and - * translated. + * @param vector the {@link RealVector} to be encoded; must not be null. * * @return the resulting {@link EncodedRealVector}, guaranteed to be non-null. */ @Nonnull @Override - public EncodedRealVector encode(@Nonnull final RealVector data) { - return encodeInternal(data).getEncodedVector(); + public EncodedRealVector encode(@Nonnull final RealVector vector) { + if (vector instanceof EncodedRealVector) { + return (EncodedRealVector)vector; + } + return encodeInternal(vector).getEncodedVector(); } /** diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/ConfigTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/ConfigTest.java new file mode 100644 index 0000000000..c3a5c69117 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/ConfigTest.java @@ -0,0 +1,128 @@ +/* + * ConfigTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.linear.Metric; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +class ConfigTest { + @Test + void testConfig() { + final Config defaultConfig = HNSW.defaultConfig(768); + + Assertions.assertThat(HNSW.newConfigBuilder().build(768)).isEqualTo(defaultConfig); + Assertions.assertThat(defaultConfig.toBuilder().build(768)).isEqualTo(defaultConfig); + + final long randomSeed = 1L; + final Metric metric = Metric.COSINE_METRIC; + final boolean useInlining = true; + final int m = Config.DEFAULT_M + 1; + final int mMax = Config.DEFAULT_M_MAX + 1; + final int mMax0 = Config.DEFAULT_M_MAX_0 + 1; + final int efConstruction = Config.DEFAULT_EF_CONSTRUCTION + 1; + final boolean extendCandidates = true; + final boolean keepPrunedConnections = true; + final int statsThreshold = 1; + final double sampleVectorStatsProbability = 0.000001d; + final double maintainStatsProbability = 0.000002d; + + final boolean useRaBitQ = true; + final int raBitQNumExBits = Config.DEFAULT_RABITQ_NUM_EX_BITS + 1; + + final int maxNumConcurrentNodeFetches = 1; + final int maxNumConcurrentNeighborhoodFetches = 2; + + Assertions.assertThat(defaultConfig.getRandomSeed()).isNotEqualTo(randomSeed); + Assertions.assertThat(defaultConfig.getMetric()).isNotSameAs(metric); + Assertions.assertThat(defaultConfig.isUseInlining()).isNotEqualTo(useInlining); + Assertions.assertThat(defaultConfig.getM()).isNotEqualTo(m); + Assertions.assertThat(defaultConfig.getMMax()).isNotEqualTo(mMax); + Assertions.assertThat(defaultConfig.getMMax0()).isNotEqualTo(mMax0); + Assertions.assertThat(defaultConfig.getEfConstruction()).isNotEqualTo(efConstruction); + Assertions.assertThat(defaultConfig.isExtendCandidates()).isNotEqualTo(extendCandidates); + Assertions.assertThat(defaultConfig.isKeepPrunedConnections()).isNotEqualTo(keepPrunedConnections); + + Assertions.assertThat(defaultConfig.getSampleVectorStatsProbability()).isNotEqualTo(sampleVectorStatsProbability); + Assertions.assertThat(defaultConfig.getMaintainStatsProbability()).isNotEqualTo(maintainStatsProbability); + Assertions.assertThat(defaultConfig.getStatsThreshold()).isNotEqualTo(statsThreshold); + + Assertions.assertThat(defaultConfig.isUseRaBitQ()).isNotEqualTo(useRaBitQ); + Assertions.assertThat(defaultConfig.getRaBitQNumExBits()).isNotEqualTo(raBitQNumExBits); + + Assertions.assertThat(defaultConfig.getMaxNumConcurrentNodeFetches()).isNotEqualTo(maxNumConcurrentNodeFetches); + Assertions.assertThat(defaultConfig.getMaxNumConcurrentNeighborhoodFetches()).isNotEqualTo(maxNumConcurrentNeighborhoodFetches); + + final Config newConfig = + defaultConfig.toBuilder() + .setRandomSeed(randomSeed) + .setMetric(metric) + .setUseInlining(useInlining) + .setM(m) + .setMMax(mMax) + .setMMax0(mMax0) + .setEfConstruction(efConstruction) + .setExtendCandidates(extendCandidates) + .setKeepPrunedConnections(keepPrunedConnections) + .setSampleVectorStatsProbability(sampleVectorStatsProbability) + .setMaintainStatsProbability(maintainStatsProbability) + .setStatsThreshold(statsThreshold) + .setUseRaBitQ(useRaBitQ) + .setRaBitQNumExBits(raBitQNumExBits) + .setMaxNumConcurrentNodeFetches(maxNumConcurrentNodeFetches) + .setMaxNumConcurrentNeighborhoodFetches(maxNumConcurrentNeighborhoodFetches) + .build(768); + + Assertions.assertThat(newConfig.getRandomSeed()).isEqualTo(randomSeed); + Assertions.assertThat(newConfig.getMetric()).isSameAs(metric); + Assertions.assertThat(newConfig.isUseInlining()).isEqualTo(useInlining); + Assertions.assertThat(newConfig.getM()).isEqualTo(m); + Assertions.assertThat(newConfig.getMMax()).isEqualTo(mMax); + Assertions.assertThat(newConfig.getMMax0()).isEqualTo(mMax0); + Assertions.assertThat(newConfig.getEfConstruction()).isEqualTo(efConstruction); + Assertions.assertThat(newConfig.isExtendCandidates()).isEqualTo(extendCandidates); + Assertions.assertThat(newConfig.isKeepPrunedConnections()).isEqualTo(keepPrunedConnections); + + Assertions.assertThat(newConfig.getSampleVectorStatsProbability()).isEqualTo(sampleVectorStatsProbability); + Assertions.assertThat(newConfig.getMaintainStatsProbability()).isEqualTo(maintainStatsProbability); + Assertions.assertThat(newConfig.getStatsThreshold()).isEqualTo(statsThreshold); + + Assertions.assertThat(newConfig.isUseRaBitQ()).isEqualTo(useRaBitQ); + Assertions.assertThat(newConfig.getRaBitQNumExBits()).isEqualTo(raBitQNumExBits); + + Assertions.assertThat(newConfig.getMaxNumConcurrentNodeFetches()).isEqualTo(maxNumConcurrentNodeFetches); + Assertions.assertThat(newConfig.getMaxNumConcurrentNeighborhoodFetches()).isEqualTo(maxNumConcurrentNeighborhoodFetches); + } + + @Test + void testEqualsHashCodeAndToString() { + final Config config1 = HNSW.newConfigBuilder().build(768); + final Config config2 = HNSW.newConfigBuilder().build(768); + final Config config3 = HNSW.newConfigBuilder().setM(1).build(768); + + Assertions.assertThat(config1.hashCode()).isEqualTo(config2.hashCode()); + Assertions.assertThat(config1).isEqualTo(config2); + Assertions.assertThat(config3).isNotEqualTo(config1); + + Assertions.assertThat(config1.toString()).isEqualTo(config2.toString()); + Assertions.assertThat(config1.toString()).isNotEqualTo(config3.toString()); + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/DataRecordsTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/DataRecordsTest.java new file mode 100644 index 0000000000..62fcd89076 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/DataRecordsTest.java @@ -0,0 +1,219 @@ +/* + * DataRecordsTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.linear.AffineOperator; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.RealVectorTest; +import com.apple.foundationdb.linear.Transformed; +import com.apple.foundationdb.tuple.Tuple; +import com.apple.test.RandomSeedSource; +import com.google.common.collect.ImmutableList; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; + +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Random; +import java.util.function.Function; + +class DataRecordsTest { + @ParameterizedTest + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) + void testAccessInfo(final long randomSeed) { + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::accessInfo); + } + + @ParameterizedTest + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) + void testAggregatedVector(final long randomSeed) { + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::aggregatedVector); + } + + @ParameterizedTest + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) + void testCompactNode(final long randomSeed) { + final Random random = new Random(randomSeed); + final long dependentRandomSeed = random.nextLong(); + + final CompactNode compactNode1 = compactNode(new Random(dependentRandomSeed)); + final CompactNode compactNode1Clone = compactNode(new Random(dependentRandomSeed)); + Assertions.assertThat(compactNode1).hasToString(compactNode1Clone.toString()); + + final CompactNode compactNode2 = compactNode(random); + Assertions.assertThat(compactNode1).doesNotHaveToString(compactNode2.toString()); + + Assertions.assertThatThrownBy(compactNode1::asInliningNode).isInstanceOf(IllegalStateException.class); + } + + @ParameterizedTest + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) + void testInliningNode(final long randomSeed) { + final Random random = new Random(randomSeed); + final long dependentRandomSeed = random.nextLong(); + + final InliningNode inliningNode1 = inliningNode(new Random(dependentRandomSeed)); + final InliningNode inliningNode1Clone = inliningNode(new Random(dependentRandomSeed)); + Assertions.assertThat(inliningNode1).hasToString(inliningNode1Clone.toString()); + + final InliningNode inliningNode2 = inliningNode(random); + Assertions.assertThat(inliningNode1).doesNotHaveToString(inliningNode2.toString()); + + Assertions.assertThatThrownBy(inliningNode1::asCompactNode).isInstanceOf(IllegalStateException.class); + } + + @ParameterizedTest + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) + void testEntryNodeReference(final long randomSeed) { + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::entryNodeReference); + } + + @ParameterizedTest + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) + void testNodeReference(final long randomSeed) { + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::nodeReference); + final NodeReference nodeReference = nodeReference(new Random(randomSeed)); + Assertions.assertThat(nodeReference.isNodeReferenceWithVector()).isFalse(); + Assertions.assertThatThrownBy(nodeReference::asNodeReferenceWithVector).isInstanceOf(IllegalStateException.class); + } + + @ParameterizedTest + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) + void testNodeReferenceWithVector(final long randomSeed) { + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::nodeReferenceWithVector); + final NodeReferenceWithVector nodeReference = nodeReferenceWithVector(new Random(randomSeed)); + Assertions.assertThat(nodeReference.isNodeReferenceWithVector()).isTrue(); + Assertions.assertThat(nodeReference.asNodeReferenceWithVector()).isInstanceOf(NodeReferenceWithVector.class); + } + + @ParameterizedTest + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) + void testNodeReferenceWithDistance(final long randomSeed) { + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::nodeReferenceWithDistance); + final NodeReferenceWithDistance nodeReference = nodeReferenceWithDistance(new Random(randomSeed)); + Assertions.assertThat(nodeReference.isNodeReferenceWithVector()).isTrue(); + Assertions.assertThat(nodeReference.asNodeReferenceWithVector()).isInstanceOf(NodeReferenceWithDistance.class); + } + + @ParameterizedTest + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) + void testResultEntry(final long randomSeed) { + assertHashCodeEqualsToString(randomSeed, DataRecordsTest::resultEntry); + } + + private static void assertHashCodeEqualsToString(final long randomSeed, final Function createFunction) { + final Random random = new Random(randomSeed); + final long dependentRandomSeed = random.nextLong(); + final T t1 = createFunction.apply(new Random(dependentRandomSeed)); + final T t1Clone = createFunction.apply(new Random(dependentRandomSeed)); + Assertions.assertThat(t1.hashCode()).isEqualTo(t1Clone.hashCode()); + Assertions.assertThat(t1).isEqualTo(t1Clone); + Assertions.assertThat(t1).hasToString(t1Clone.toString()); + + final T t2 = createFunction.apply(random); + Assertions.assertThat(t1).isNotEqualTo(t2); + Assertions.assertThat(t1).doesNotHaveToString(t2.toString()); + } + + @Nonnull + private static ResultEntry resultEntry(@Nonnull final Random random) { + return new ResultEntry(primaryKey(random), rawVector(random), random.nextDouble(), random.nextInt(100)); + } + + @Nonnull + private static CompactNode compactNode(@Nonnull final Random random) { + return CompactNode.factory() + .create(primaryKey(random), vector(random), nodeReferences(random)) + .asCompactNode(); + } + + @Nonnull + private static InliningNode inliningNode(@Nonnull final Random random) { + return InliningNode.factory() + .create(primaryKey(random), vector(random), nodeReferenceWithVectors(random)) + .asInliningNode(); + } + + @Nonnull + private static NodeReferenceWithDistance nodeReferenceWithDistance(@Nonnull final Random random) { + return new NodeReferenceWithDistance(primaryKey(random), vector(random), random.nextDouble()); + } + + @Nonnull + private static List nodeReferenceWithVectors(@Nonnull final Random random) { + int size = random.nextInt(20); + final ImmutableList.Builder resultBuilder = ImmutableList.builder(); + for (int i = 0; i < size; i ++) { + resultBuilder.add(nodeReferenceWithVector(random)); + } + return resultBuilder.build(); + } + + @Nonnull + private static NodeReferenceWithVector nodeReferenceWithVector(@Nonnull final Random random) { + return new NodeReferenceWithVector(primaryKey(random), vector(random)); + } + + @Nonnull + private static List nodeReferences(@Nonnull final Random random) { + int size = random.nextInt(20); + final ImmutableList.Builder resultBuilder = ImmutableList.builder(); + for (int i = 0; i < size; i ++) { + resultBuilder.add(nodeReference(random)); + } + return resultBuilder.build(); + } + + @Nonnull + private static NodeReference nodeReference(@Nonnull final Random random) { + return new NodeReference(primaryKey(random)); + } + + @Nonnull + private static AggregatedVector aggregatedVector(@Nonnull final Random random) { + return new AggregatedVector(random.nextInt(100), vector(random)); + } + + @Nonnull + private static AccessInfo accessInfo(@Nonnull final Random random) { + return new AccessInfo(entryNodeReference(random), random.nextLong(), rawVector(random)); + } + + @Nonnull + private static EntryNodeReference entryNodeReference(@Nonnull final Random random) { + return new EntryNodeReference(primaryKey(random), vector(random), random.nextInt(10)); + } + + @Nonnull + private static Tuple primaryKey(@Nonnull final Random random) { + return Tuple.from(random.nextInt(100)); + } + + @Nonnull + private static Transformed vector(@Nonnull final Random random) { + return AffineOperator.identity().transform(rawVector(random)); + } + + @Nonnull + private static RealVector rawVector(@Nonnull final Random random) { + return RealVectorTest.createRandomDoubleVector(random, 768); + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java new file mode 100644 index 0000000000..94dc7a9804 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java @@ -0,0 +1,661 @@ +/* + * HNSWTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Database; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.async.rtree.RTree; +import com.apple.foundationdb.linear.AffineOperator; +import com.apple.foundationdb.linear.DoubleRealVector; +import com.apple.foundationdb.linear.HalfRealVector; +import com.apple.foundationdb.linear.Metric; +import com.apple.foundationdb.linear.Quantizer; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.StoredVecsIterator; +import com.apple.foundationdb.rabitq.EncodedRealVector; +import com.apple.foundationdb.test.TestDatabaseExtension; +import com.apple.foundationdb.test.TestExecutors; +import com.apple.foundationdb.test.TestSubspaceExtension; +import com.apple.foundationdb.tuple.Tuple; +import com.apple.test.RandomSeedSource; +import com.apple.test.RandomizedTestUtils; +import com.apple.test.SuperSlow; +import com.apple.test.Tags; +import com.google.common.base.Verify; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; +import com.google.common.collect.ObjectArrays; +import com.google.common.collect.Sets; +import org.assertj.core.api.Assertions; +import org.assertj.core.util.Lists; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.TreeSet; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.LongStream; +import java.util.stream.Stream; + +import static com.apple.foundationdb.linear.RealVectorTest.createRandomDoubleVector; +import static com.apple.foundationdb.linear.RealVectorTest.createRandomHalfVector; +import static org.assertj.core.api.Assertions.within; + +/** + * Tests testing insert/update/deletes of data into/in/from {@link RTree}s. + */ +@Execution(ExecutionMode.CONCURRENT) +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +@Tag(Tags.RequiresFDB) +@Tag(Tags.Slow) +class HNSWTest { + private static final Logger logger = LoggerFactory.getLogger(HNSWTest.class); + + @RegisterExtension + static final TestDatabaseExtension dbExtension = new TestDatabaseExtension(); + @RegisterExtension + TestSubspaceExtension rtSubspace = new TestSubspaceExtension(dbExtension); + @RegisterExtension + TestSubspaceExtension rtSecondarySubspace = new TestSubspaceExtension(dbExtension); + + private Database db; + + @BeforeEach + public void setUpDb() { + db = dbExtension.getDatabase(); + } + + @ParameterizedTest + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) + void testCompactSerialization(final long seed) { + final Random random = new Random(seed); + final int numDimensions = 768; + final CompactStorageAdapter storageAdapter = + new CompactStorageAdapter(HNSW.newConfigBuilder().build(numDimensions), CompactNode.factory(), + rtSubspace.getSubspace(), OnWriteListener.NOOP, OnReadListener.NOOP); + final AbstractNode originalNode = + db.run(tr -> { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + + final AbstractNode randomCompactNode = + createRandomCompactNode(random, nodeFactory, numDimensions, 16); + + writeNode(tr, storageAdapter, randomCompactNode, 0); + return randomCompactNode; + }); + + db.run(tr -> storageAdapter.fetchNode(tr, AffineOperator.identity(), 0, + originalNode.getPrimaryKey()) + .thenAccept(node -> + Assertions.assertThat(node).satisfies( + n -> Assertions.assertThat(n).isInstanceOf(CompactNode.class), + n -> Assertions.assertThat(n.getKind()).isSameAs(NodeKind.COMPACT), + n -> Assertions.assertThat((Object)n.getPrimaryKey()).isEqualTo(originalNode.getPrimaryKey()), + n -> Assertions.assertThat(n.asCompactNode().getVector()) + .isEqualTo(originalNode.asCompactNode().getVector()), + n -> { + final ArrayList neighbors = + Lists.newArrayList(node.getNeighbors()); + neighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); + final ArrayList originalNeighbors = + Lists.newArrayList(originalNode.getNeighbors()); + originalNeighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); + Assertions.assertThat(neighbors).isEqualTo(originalNeighbors); + } + )).join()); + } + + @ParameterizedTest + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) + void testInliningSerialization(final long seed) { + final Random random = new Random(seed); + final int numDimensions = 768; + final InliningStorageAdapter storageAdapter = + new InliningStorageAdapter(HNSW.newConfigBuilder().build(numDimensions), + InliningNode.factory(), rtSubspace.getSubspace(), + OnWriteListener.NOOP, OnReadListener.NOOP); + final Node originalNode = + db.run(tr -> { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + + final AbstractNode randomInliningNode = + createRandomInliningNode(random, nodeFactory, numDimensions, 16); + + writeNode(tr, storageAdapter, randomInliningNode, 0); + return randomInliningNode; + }); + + db.run(tr -> storageAdapter.fetchNode(tr, AffineOperator.identity(), 0, + originalNode.getPrimaryKey()) + .thenAccept(node -> + Assertions.assertThat(node).satisfies( + n -> Assertions.assertThat(n).isInstanceOf(InliningNode.class), + n -> Assertions.assertThat(n.getKind()).isSameAs(NodeKind.INLINING), + n -> Assertions.assertThat((Object)node.getPrimaryKey()).isEqualTo(originalNode.getPrimaryKey()), + n -> { + final ArrayList neighbors = + Lists.newArrayList(node.getNeighbors()); + neighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); // should not be necessary the way it is stored + final ArrayList originalNeighbors = + Lists.newArrayList(originalNode.getNeighbors()); + originalNeighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); + Assertions.assertThat(neighbors).isEqualTo(originalNeighbors); + } + )).join()); + } + + static Stream randomSeedsWithOptions() { + return RandomizedTestUtils.randomSeeds(0xdeadc0deL) + .flatMap(seed -> Sets.cartesianProduct(ImmutableSet.of(true, false), + ImmutableSet.of(true, false), + ImmutableSet.of(true, false), + ImmutableSet.of(true, false)).stream() + .map(arguments -> Arguments.of(ObjectArrays.concat(seed, arguments.toArray())))); + } + + @ParameterizedTest(name = "seed={0} useInlining={1} extendCandidates={2} keepPrunedConnections={3} useRaBitQ={4}") + @MethodSource("randomSeedsWithOptions") + void testBasicInsert(final long seed, final boolean useInlining, final boolean extendCandidates, + final boolean keepPrunedConnections, final boolean useRaBitQ) { + final Random random = new Random(seed); + final Metric metric = Metric.EUCLIDEAN_METRIC; + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final int numDimensions = 128; + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.newConfigBuilder().setMetric(metric) + .setUseInlining(useInlining).setExtendCandidates(extendCandidates) + .setKeepPrunedConnections(keepPrunedConnections) + .setUseRaBitQ(useRaBitQ) + .setRaBitQNumExBits(5) + .setSampleVectorStatsProbability(1.0d) + .setMaintainStatsProbability(0.1d) + .setStatsThreshold(100) + .setM(32).setMMax(32).setMMax0(64).build(numDimensions), + OnWriteListener.NOOP, onReadListener); + + final int k = 50; + final HalfRealVector queryVector = createRandomHalfVector(random, numDimensions); + final TreeSet recordsOrderedByDistance = + new TreeSet<>(Comparator.comparing(PrimaryKeyVectorAndDistance::getDistance)); + + for (int i = 0; i < 1000;) { + i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, + tr -> { + final var primaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final HalfRealVector dataVector = createRandomHalfVector(random, numDimensions); + final double distance = metric.distance(dataVector, queryVector); + final PrimaryKeyVectorAndDistance record = + new PrimaryKeyVectorAndDistance(primaryKey, dataVector, distance); + recordsOrderedByDistance.add(record); + if (recordsOrderedByDistance.size() > k) { + recordsOrderedByDistance.pollLast(); + } + return record; + }); + } + + onReadListener.reset(); + final long beginTs = System.nanoTime(); + final List results = + db.run(tr -> + hnsw.kNearestNeighborsSearch(tr, k, 100, true, queryVector).join()); + final long endTs = System.nanoTime(); + + final ImmutableSet trueNN = + recordsOrderedByDistance.stream() + .map(PrimaryKeyVectorAndDistance::getPrimaryKey) + .collect(ImmutableSet.toImmutableSet()); + + int recallCount = 0; + for (ResultEntry resultEntry : results) { + logger.info("nodeId ={} at distance={}", resultEntry.getPrimaryKey().getLong(0), + resultEntry.getDistance()); + if (trueNN.contains(resultEntry.getPrimaryKey())) { + recallCount ++; + } + } + final double recall = (double)recallCount / (double)k; + logger.info("search transaction took elapsedTime={}ms; read nodes={}, read bytes={}, recall={}", + TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), + onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer(), + String.format(Locale.ROOT, "%.2f", recall * 100.0d)); + Assertions.assertThat(recall).isGreaterThan(0.9); + + final Set insertedIds = + LongStream.range(0, 1000) + .boxed() + .collect(Collectors.toSet()); + + final Set readIds = Sets.newHashSet(); + hnsw.scanLayer(db, 0, 100, + node -> Assertions.assertThat(readIds.add(node.getPrimaryKey().getLong(0))).isTrue()); + Assertions.assertThat(readIds).isEqualTo(insertedIds); + + readIds.clear(); + hnsw.scanLayer(db, 1, 100, + node -> Assertions.assertThat(readIds.add(node.getPrimaryKey().getLong(0))).isTrue()); + Assertions.assertThat(readIds.size()).isBetween(10, 50); + } + + @ParameterizedTest() + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) + void testBasicInsertWithRaBitQEncodings(final long seed) { + final Random random = new Random(seed); + final Metric metric = Metric.EUCLIDEAN_METRIC; + + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + final int numDimensions = 128; + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.newConfigBuilder().setMetric(metric) + .setUseRaBitQ(true) + .setRaBitQNumExBits(5) + .setSampleVectorStatsProbability(1.0d) // every vector is sampled + .setMaintainStatsProbability(1.0d) // for every vector we maintain the stats + .setStatsThreshold(950) // after 950 vectors we enable RaBitQ + .setM(32).setMMax(32).setMMax0(64).build(numDimensions), + OnWriteListener.NOOP, OnReadListener.NOOP); + + final int k = 499; + final DoubleRealVector queryVector = createRandomDoubleVector(random, numDimensions); + final Map dataMap = Maps.newHashMap(); + final TreeSet recordsOrderedByDistance = + new TreeSet<>(Comparator.comparing(PrimaryKeyVectorAndDistance::getDistance)); + + for (int i = 0; i < 1000;) { + i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, new TestOnReadListener(), + tr -> { + final var primaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final DoubleRealVector dataVector = createRandomDoubleVector(random, numDimensions); + final double distance = metric.distance(dataVector, queryVector); + dataMap.put(primaryKey, dataVector); + + final PrimaryKeyVectorAndDistance record = + new PrimaryKeyVectorAndDistance(primaryKey, dataVector, distance); + recordsOrderedByDistance.add(record); + if (recordsOrderedByDistance.size() > k) { + recordsOrderedByDistance.pollLast(); + } + return record; + }); + } + + // + // If we fetch the current state back from the db some vectors are regular vectors and some vectors are + // RaBitQ encoded. Since that information is not surfaced through the API, we need to scan layer 0, get + // all vectors directly from disk (encoded/not-encoded, transformed/not-transformed) in order to check + // that transformations/reconstructions are applied properly. + // + final Map fromDBMap = Maps.newHashMap(); + hnsw.scanLayer(db, 0, 100, + node -> fromDBMap.put(node.getPrimaryKey(), + node.asCompactNode().getVector().getUnderlyingVector())); + + // + // Still run a kNN search to make sure that recall is satisfactory. + // + final List results = + db.run(tr -> + hnsw.kNearestNeighborsSearch(tr, k, 500, true, queryVector).join()); + + final ImmutableSet trueNN = + recordsOrderedByDistance.stream() + .map(PrimaryKeyAndVector::getPrimaryKey) + .collect(ImmutableSet.toImmutableSet()); + + int recallCount = 0; + int exactVectorCount = 0; + int encodedVectorCount = 0; + for (final ResultEntry resultEntry : results) { + if (trueNN.contains(resultEntry.getPrimaryKey())) { + recallCount ++; + } + + final RealVector originalVector = dataMap.get(resultEntry.getPrimaryKey()); + Assertions.assertThat(originalVector).isNotNull(); + final RealVector fromDBVector = fromDBMap.get(resultEntry.getPrimaryKey()); + Assertions.assertThat(fromDBVector).isNotNull(); + if (!(fromDBVector instanceof EncodedRealVector)) { + Assertions.assertThat(originalVector).isEqualTo(fromDBVector); + exactVectorCount ++; + final double distance = metric.distance(originalVector, + Objects.requireNonNull(resultEntry.getVector())); + Assertions.assertThat(distance).isCloseTo(0.0d, within(2E-12)); + } else { + encodedVectorCount ++; + final double distance = metric.distance(originalVector, + Objects.requireNonNull(resultEntry.getVector()).toDoubleRealVector()); + Assertions.assertThat(distance).isCloseTo(0.0d, within(20.0d)); + } + } + final double recall = (double)recallCount / (double)k; + Assertions.assertThat(recall).isGreaterThan(0.9); + // must have both kinds + Assertions.assertThat(exactVectorCount).isGreaterThan(0); + Assertions.assertThat(encodedVectorCount).isGreaterThan(0); + } + + private int basicInsertBatch(final HNSW hnsw, final int batchSize, + @Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener, + @Nonnull final Function insertFunction) { + return db.run(tr -> { + onReadListener.reset(); + final long nextNodeId = nextNodeIdAtomic.get(); + final long beginTs = System.nanoTime(); + for (int i = 0; i < batchSize; i ++) { + final var record = insertFunction.apply(tr); + if (record == null) { + return i; + } + hnsw.insert(tr, record.getPrimaryKey(), record.getVector()).join(); + } + final long endTs = System.nanoTime(); + logger.info("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, readBytes={}", + batchSize, nextNodeId, TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), + onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer()); + return batchSize; + }); + } + + @Test + @SuperSlow + void testSIFTInsertSmall() throws Exception { + final Metric metric = Metric.EUCLIDEAN_METRIC; + final int k = 100; + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.newConfigBuilder().setUseRaBitQ(true).setRaBitQNumExBits(5) + .setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(128), + OnWriteListener.NOOP, onReadListener); + + final Path siftSmallPath = Paths.get(".out/extracted/siftsmall/siftsmall_base.fvecs"); + + final Map dataMap = Maps.newHashMap(); + + try (final var fileChannel = FileChannel.open(siftSmallPath, StandardOpenOption.READ)) { + final Iterator vectorIterator = new StoredVecsIterator.StoredFVecsIterator(fileChannel); + + int i = 0; + final AtomicReference sumReference = new AtomicReference<>(null); + while (vectorIterator.hasNext()) { + i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, + tr -> { + if (!vectorIterator.hasNext()) { + return null; + } + final DoubleRealVector doubleVector = vectorIterator.next(); + final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final HalfRealVector currentVector = doubleVector.toHalfRealVector(); + + if (sumReference.get() == null) { + sumReference.set(currentVector); + } else { + sumReference.set(sumReference.get().add(currentVector)); + } + + dataMap.put(Math.toIntExact(currentPrimaryKey.getLong(0)), currentVector); + return new PrimaryKeyAndVector(currentPrimaryKey, currentVector); + }); + } + Assertions.assertThat(i).isEqualTo(10000); + } + + validateSIFTSmall(hnsw, dataMap, k); + } + + private void validateSIFTSmall(@Nonnull final HNSW hnsw, @Nonnull final Map dataMap, final int k) throws IOException { + final Metric metric = hnsw.getConfig().getMetric(); + final Path siftSmallGroundTruthPath = Paths.get(".out/extracted/siftsmall/siftsmall_groundtruth.ivecs"); + final Path siftSmallQueryPath = Paths.get(".out/extracted/siftsmall/siftsmall_query.fvecs"); + + final TestOnReadListener onReadListener = (TestOnReadListener)hnsw.getOnReadListener(); + + try (final var queryChannel = FileChannel.open(siftSmallQueryPath, StandardOpenOption.READ); + final var groundTruthChannel = FileChannel.open(siftSmallGroundTruthPath, StandardOpenOption.READ)) { + final Iterator queryIterator = new StoredVecsIterator.StoredFVecsIterator(queryChannel); + final Iterator> groundTruthIterator = new StoredVecsIterator.StoredIVecsIterator(groundTruthChannel); + + Verify.verify(queryIterator.hasNext() == groundTruthIterator.hasNext()); + + while (queryIterator.hasNext()) { + final HalfRealVector queryVector = queryIterator.next().toHalfRealVector(); + final Set groundTruthIndices = ImmutableSet.copyOf(groundTruthIterator.next()); + onReadListener.reset(); + final long beginTs = System.nanoTime(); + final List results = + db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, + true, queryVector).join()); + final long endTs = System.nanoTime(); + logger.info("retrieved result in elapsedTimeMs={}, reading numNodes={}, readBytes={}", + TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), + onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer()); + + int recallCount = 0; + for (final ResultEntry resultEntry : results) { + final int primaryKeyIndex = (int)resultEntry.getPrimaryKey().getLong(0); + + // + // Assert that the original vector and the reconstructed vector are the same-ish vector + // (minus reconstruction errors). The closeness value is dependent on the encoding quality settings, + // the dimensionality, and the metric in use. For now, we just set it to 20.0 as that should be + // fairly safe with respect to not giving us false-positives and also tripping for actual logic + // errors as the expected random distance is far larger. + // + final RealVector originalVector = dataMap.get(primaryKeyIndex); + Assertions.assertThat(originalVector).isNotNull(); + final double distance = metric.distance(originalVector, + Objects.requireNonNull(resultEntry.getVector()).toDoubleRealVector()); + Assertions.assertThat(distance).isCloseTo(0.0d, within(20.0d)); + + logger.trace("retrieved result nodeId = {} at distance = {} ", + primaryKeyIndex, resultEntry.getDistance()); + if (groundTruthIndices.contains(primaryKeyIndex)) { + recallCount ++; + } + } + + final double recall = (double)recallCount / k; + Assertions.assertThat(recall).isGreaterThan(0.93); + + logger.info("query returned results recall={}", String.format(Locale.ROOT, "%.2f", recall * 100.0d)); + } + } + } + + private void writeNode(@Nonnull final Transaction transaction, + @Nonnull final StorageAdapter storageAdapter, + @Nonnull final AbstractNode node, + final int layer) { + final NeighborsChangeSet insertChangeSet = + new InsertNeighborsChangeSet<>(new BaseNeighborsChangeSet<>(ImmutableList.of()), + node.getNeighbors()); + storageAdapter.writeNode(transaction, Quantizer.noOpQuantizer(Metric.EUCLIDEAN_METRIC), node, layer, + insertChangeSet); + } + + @Nonnull + private AbstractNode createRandomCompactNode(@Nonnull final Random random, + @Nonnull final NodeFactory nodeFactory, + final int numDimensions, + final int numberOfNeighbors) { + final Tuple primaryKey = createRandomPrimaryKey(random); + final ImmutableList.Builder neighborsBuilder = ImmutableList.builder(); + for (int i = 0; i < numberOfNeighbors; i ++) { + neighborsBuilder.add(createRandomNodeReference(random)); + } + + return nodeFactory.create(primaryKey, + AffineOperator.identity().transform(createRandomHalfVector(random, numDimensions)), + neighborsBuilder.build()); + } + + @Nonnull + private AbstractNode createRandomInliningNode(@Nonnull final Random random, + @Nonnull final NodeFactory nodeFactory, + final int numDimensions, + final int numberOfNeighbors) { + final Tuple primaryKey = createRandomPrimaryKey(random); + final ImmutableList.Builder neighborsBuilder = ImmutableList.builder(); + for (int i = 0; i < numberOfNeighbors; i ++) { + neighborsBuilder.add(createRandomNodeReferenceWithVector(random, numDimensions)); + } + + return nodeFactory.create(primaryKey, + AffineOperator.identity().transform(createRandomHalfVector(random, numDimensions)), + neighborsBuilder.build()); + } + + @Nonnull + private NodeReference createRandomNodeReference(@Nonnull final Random random) { + return new NodeReference(createRandomPrimaryKey(random)); + } + + @Nonnull + private NodeReferenceWithVector createRandomNodeReferenceWithVector(@Nonnull final Random random, + final int dimensionality) { + return new NodeReferenceWithVector(createRandomPrimaryKey(random), + AffineOperator.identity().transform(createRandomHalfVector(random, dimensionality))); + } + + @Nonnull + private static Tuple createRandomPrimaryKey(final @Nonnull Random random) { + return Tuple.from(random.nextLong()); + } + + @Nonnull + private static Tuple createNextPrimaryKey(@Nonnull final AtomicLong nextIdAtomic) { + return Tuple.from(nextIdAtomic.getAndIncrement()); + } + + private static class TestOnReadListener implements OnReadListener { + final Map nodeCountByLayer; + final Map sumMByLayer; + final Map bytesReadByLayer; + + public TestOnReadListener() { + this.nodeCountByLayer = Maps.newConcurrentMap(); + this.sumMByLayer = Maps.newConcurrentMap(); + this.bytesReadByLayer = Maps.newConcurrentMap(); + } + + public Map getNodeCountByLayer() { + return nodeCountByLayer; + } + + public Map getBytesReadByLayer() { + return bytesReadByLayer; + } + + public Map getSumMByLayer() { + return sumMByLayer; + } + + public void reset() { + nodeCountByLayer.clear(); + bytesReadByLayer.clear(); + sumMByLayer.clear(); + } + + @Override + public void onNodeRead(final int layer, @Nonnull final Node node) { + nodeCountByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + 1L); + sumMByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + node.getNeighbors().size()); + } + + @Override + public void onKeyValueRead(final int layer, @Nonnull final byte[] key, @Nullable final byte[] value) { + bytesReadByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + + key.length + (value == null ? 0 : value.length)); + } + } + + private static class PrimaryKeyAndVector { + @Nonnull + private final Tuple primaryKey; + @Nonnull + private final RealVector vector; + + public PrimaryKeyAndVector(@Nonnull final Tuple primaryKey, + @Nonnull final RealVector vector) { + this.primaryKey = primaryKey; + this.vector = vector; + } + + @Nonnull + public Tuple getPrimaryKey() { + return primaryKey; + } + + @Nonnull + public RealVector getVector() { + return vector; + } + } + + private static class PrimaryKeyVectorAndDistance extends PrimaryKeyAndVector { + private final double distance; + + public PrimaryKeyVectorAndDistance(@Nonnull final Tuple primaryKey, + @Nonnull final RealVector vector, + final double distance) { + super(primaryKey, vector); + this.distance = distance; + } + + public double getDistance() { + return distance; + } + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/RealVectorSerializationTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/RealVectorSerializationTest.java new file mode 100644 index 0000000000..6b56283ad3 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/RealVectorSerializationTest.java @@ -0,0 +1,79 @@ +/* + * RealVectorSerializationTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.linear.DoubleRealVector; +import com.apple.foundationdb.linear.FloatRealVector; +import com.apple.foundationdb.linear.HalfRealVector; +import com.apple.foundationdb.linear.RealVector; +import com.apple.foundationdb.linear.RealVectorTest; +import com.apple.test.RandomizedTestUtils; +import com.google.common.collect.ImmutableSet; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.annotation.Nonnull; +import java.util.Random; +import java.util.stream.Stream; + +class RealVectorSerializationTest { + @Nonnull + private static Stream randomSeedsWithNumDimensions() { + return RandomizedTestUtils.randomSeeds(0xdeadc0deL, 0xfdb5ca1eL, 0xf005ba1L) + .flatMap(seed -> ImmutableSet.of(3, 5, 10, 128, 768, 1000).stream() + .map(numDimensions -> Arguments.of(seed, numDimensions))); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testSerializationDeserializationHalfVector(final long seed, final int numDimensions) { + final Random random = new Random(seed); + final HalfRealVector randomVector = RealVectorTest.createRandomHalfVector(random, numDimensions); + final RealVector deserializedVector = + StorageAdapter.vectorFromBytes(HNSW.newConfigBuilder().build(numDimensions), randomVector.getRawData()); + Assertions.assertThat(deserializedVector).isInstanceOf(HalfRealVector.class); + Assertions.assertThat(deserializedVector).isEqualTo(randomVector); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testSerializationDeserializationFloatVector(final long seed, final int numDimensions) { + final Random random = new Random(seed); + final FloatRealVector randomVector = RealVectorTest.createRandomFloatVector(random, numDimensions); + final RealVector deserializedVector = + StorageAdapter.vectorFromBytes(HNSW.newConfigBuilder().build(numDimensions), randomVector.getRawData()); + Assertions.assertThat(deserializedVector).isInstanceOf(FloatRealVector.class); + Assertions.assertThat(deserializedVector).isEqualTo(randomVector); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testSerializationDeserializationDoubleVector(final long seed, final int numDimensions) { + final Random random = new Random(seed); + final DoubleRealVector randomVector = RealVectorTest.createRandomDoubleVector(random, numDimensions); + final RealVector deserializedVector = + StorageAdapter.vectorFromBytes(HNSW.newConfigBuilder().build(numDimensions), randomVector.getRawData()); + Assertions.assertThat(deserializedVector).isInstanceOf(DoubleRealVector.class); + Assertions.assertThat(deserializedVector).isEqualTo(randomVector); + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/AffineOperatorTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/AffineOperatorTest.java new file mode 100644 index 0000000000..f2da1d53a2 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/AffineOperatorTest.java @@ -0,0 +1,59 @@ +/* + * AffineOperatorTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.linear; + +import com.apple.test.RandomizedTestUtils; +import com.google.common.collect.ImmutableSet; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.annotation.Nonnull; +import java.util.Random; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.within; + +class AffineOperatorTest { + @Nonnull + private static Stream randomSeedsWithNumDimensions() { + return RandomizedTestUtils.randomSeeds(0xdeadc0deL, 0xfdb5ca1eL, 0xf005ba1L) + .flatMap(seed -> ImmutableSet.of(3, 5, 10, 128, 768, 1000).stream() + .map(numDimensions -> Arguments.of(seed, numDimensions))); + } + + @ParameterizedTest + @MethodSource("randomSeedsWithNumDimensions") + void testSimpleRotationAndBack(final long seed, final int numDimensions) { + final FhtKacRotator rotator = new FhtKacRotator(seed, numDimensions, 10); + final Random random = new Random(seed); + final RealVector translation = RealVectorTest.createRandomDoubleVector(random, numDimensions); + final AffineOperator affineOperator = new AffineOperator(rotator, translation); + Assertions.assertThat(affineOperator.getNumDimensions()).isEqualTo(numDimensions); + + final RealVector x = RealVectorTest.createRandomDoubleVector(random, numDimensions); + final RealVector y = affineOperator.apply(x); + final RealVector z = affineOperator.invertedApply(y); + + Assertions.assertThat(Metric.EUCLIDEAN_METRIC.distance(x, z)).isCloseTo(0, within(2E-10)); + } +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/FhtKacRotatorTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/FhtKacRotatorTest.java index 9b44987174..9fb8e7466d 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/FhtKacRotatorTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/FhtKacRotatorTest.java @@ -33,7 +33,7 @@ import static org.assertj.core.api.Assertions.within; -public class FhtKacRotatorTest { +class FhtKacRotatorTest { @Nonnull private static Stream randomSeedsWithNumDimensions() { return RandomizedTestUtils.randomSeeds(0xdeadc0deL, 0xfdb5ca1eL, 0xf005ba1L) @@ -48,8 +48,8 @@ void testSimpleRotationAndBack(final long seed, final int numDimensions) { final Random random = new Random(seed); final RealVector x = RealVectorTest.createRandomDoubleVector(random, numDimensions); - final RealVector y = rotator.operate(x); - final RealVector z = rotator.operateTranspose(y); + final RealVector y = rotator.apply(x); + final RealVector z = rotator.invertedApply(y); Assertions.assertThat(Metric.EUCLIDEAN_METRIC.distance(x, z)).isCloseTo(0, within(2E-10)); } @@ -64,8 +64,8 @@ void testRotationIsStable(final long seed, final int numDimensions) { final Random random = new Random(seed); final RealVector x = RealVectorTest.createRandomDoubleVector(random, numDimensions); - final RealVector x_ = rotator1.operate(x); - final RealVector x__ = rotator2.operate(x); + final RealVector x_ = rotator1.apply(x); + final RealVector x__ = rotator2.apply(x); Assertions.assertThat(x_).isEqualTo(x__); } @@ -74,10 +74,10 @@ void testRotationIsStable(final long seed, final int numDimensions) { @MethodSource("randomSeedsWithNumDimensions") void testOrthogonality(final long seed, final int numDimensions) { final FhtKacRotator rotator = new FhtKacRotator(seed, numDimensions, 10); - final ColumnMajorRealMatrix p = new ColumnMajorRealMatrix(rotator.computeP().transpose().getData()); + final ColumnMajorRealMatrix p = rotator.computeP().transpose().quickTranspose(); for (int j = 0; j < numDimensions; j ++) { - final RealVector rotated = rotator.operateTranspose(new DoubleRealVector(p.getColumn(j))); + final RealVector rotated = rotator.invertedApply(new DoubleRealVector(p.getColumn(j))); for (int i = 0; i < numDimensions; i++) { double expected = (i == j) ? 1.0 : 0.0; Assertions.assertThat(Math.abs(rotated.getComponent(i) - expected)) diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/MetricTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/MetricTest.java index 39dcfe668a..42468ec623 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/MetricTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/MetricTest.java @@ -40,7 +40,7 @@ import static com.apple.foundationdb.linear.Metric.EUCLIDEAN_SQUARE_METRIC; import static com.apple.foundationdb.linear.Metric.MANHATTAN_METRIC; -public class MetricTest { +class MetricTest { static Stream metricAndExpectedDistance() { // Distance between (1.0, 2.0) and (4.0, 6.0) final RealVector v1 = v(1.0, 2.0); diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/QRDecompositionTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/QRDecompositionTest.java index 3dfe46f2d1..e633fab347 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/QRDecompositionTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/QRDecompositionTest.java @@ -35,7 +35,7 @@ import static org.assertj.core.api.Assertions.within; @SuppressWarnings("checkstyle:AbbreviationAsWordInName") -public class QRDecompositionTest { +class QRDecompositionTest { @Nonnull private static Stream randomSeedsWithNumDimensions() { return RandomizedTestUtils.randomSeeds(0xdeadc0deL, 0xfdb5ca1eL, 0xf005ba1L) @@ -50,8 +50,8 @@ void testQREqualsM(final long seed, final int numDimensions) { final RealMatrix m = MatrixHelpers.randomOrthogonalMatrix(random, numDimensions); final QRDecomposition.Result result = QRDecomposition.decomposeMatrix(m); final RealMatrix product = result.getQ().multiply(result.getR()); - for (int i = 0; i < product.getRowDimension(); i++) { - for (int j = 0; j < product.getColumnDimension(); j++) { + for (int i = 0; i < product.getNumRowDimensions(); i++) { + for (int j = 0; j < product.getNumColumnDimensions(); j++) { assertThat(product.getEntry(i, j)).isCloseTo(m.getEntry(i, j), within(2E-14)); } } @@ -66,8 +66,8 @@ void testRepeatedQR(final long seed, final int numDimensions) { final QRDecomposition.Result secondResult = QRDecomposition.decomposeMatrix(firstResult.getQ()); final RealMatrix r = secondResult.getR(); - for (int i = 0; i < r.getRowDimension(); i++) { - for (int j = 0; j < r.getColumnDimension(); j++) { + for (int i = 0; i < r.getNumRowDimensions(); i++) { + for (int j = 0; j < r.getNumColumnDimensions(); j++) { assertThat(Math.abs(r.getEntry(i, j))).isCloseTo((i == j) ? 1.0d : 0.0d, within(2E-14)); } } @@ -91,8 +91,8 @@ void testZeroes() { final QRDecomposition.Result result = QRDecomposition.decomposeMatrix(m); final RealMatrix product = result.getQ().multiply(result.getR()); - for (int i = 0; i < product.getRowDimension(); i++) { - for (int j = 0; j < product.getColumnDimension(); j++) { + for (int i = 0; i < product.getNumRowDimensions(); i++) { + for (int j = 0; j < product.getNumColumnDimensions(); j++) { assertThat(product.getEntry(i, j)).isCloseTo(m.getEntry(i, j), within(2E-14)); } } diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/RealMatrixTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/RealMatrixTest.java index 439f41ffcf..b210500f2a 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/RealMatrixTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/RealMatrixTest.java @@ -22,6 +22,7 @@ import com.apple.test.RandomizedTestUtils; import com.google.common.collect.ImmutableSet; +import org.assertj.core.api.Assertions; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -33,7 +34,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.within; -public class RealMatrixTest { +class RealMatrixTest { @Nonnull private static Stream randomSeedsWithNumDimensions() { return RandomizedTestUtils.randomSeeds(0xdeadc0deL, 0xfdb5ca1eL, 0xf005ba1L) @@ -48,9 +49,9 @@ void testTranspose(final long seed, final int numDimensions) { final int numRows = random.nextInt(numDimensions) + 1; final int numColumns = random.nextInt(numDimensions) + 1; final RealMatrix matrix = MatrixHelpers.randomGaussianMatrix(random, numRows, numColumns); - final RealMatrix otherMatrix = flip(matrix); + final RealMatrix otherMatrix = matrix.flipMajor(); assertThat(otherMatrix).isEqualTo(matrix); - final RealMatrix anotherMatrix = flip(otherMatrix); + final RealMatrix anotherMatrix = otherMatrix.flipMajor(); assertThat(anotherMatrix).isEqualTo(otherMatrix); assertThat(anotherMatrix).isEqualTo(matrix); assertThat(anotherMatrix.getClass()).isSameAs(matrix.getClass()); @@ -87,29 +88,15 @@ void testDifferentMajor(final long seed, final int numDimensions) { assertThat(anotherMatrix).isEqualTo(matrix); } - - @Nonnull - private static RealMatrix flip(@Nonnull final RealMatrix matrix) { - assertThat(matrix) - .satisfiesAnyOf(m -> assertThat(m).isInstanceOf(RowMajorRealMatrix.class), - m -> assertThat(m).isInstanceOf(ColumnMajorRealMatrix.class)); - final double[][] data = matrix.transpose().getData(); - if (matrix instanceof RowMajorRealMatrix) { - return new ColumnMajorRealMatrix(data); - } else { - return new RowMajorRealMatrix(data); - } - } - @ParameterizedTest @MethodSource("randomSeedsWithNumDimensions") - void testOperateAndBack(final long seed, final int numDimensions) { + void testApplyAndBack(final long seed, final int numDimensions) { final Random random = new Random(seed); final RealMatrix matrix = MatrixHelpers.randomOrthogonalMatrix(random, numDimensions); assertThat(matrix.isTransposable()).isTrue(); final RealVector x = RealVectorTest.createRandomDoubleVector(random, numDimensions); - final RealVector y = matrix.operate(x); - final RealVector z = matrix.operateTranspose(y); + final RealVector y = matrix.apply(x); + final RealVector z = matrix.transposedApply(y); assertThat(Metric.EUCLIDEAN_METRIC.distance(x, z)).isCloseTo(0, within(2E-10)); } @@ -141,7 +128,7 @@ void testMultiplyRowMajorMatrix(final long seed, final int d) { @MethodSource("randomSeedsWithNumDimensions") void testMultiplyColumnMajorMatrix(final long seed, final int d) { final Random random = new Random(seed); - final RealMatrix r = flip(MatrixHelpers.randomOrthogonalMatrix(random, d)); + final RealMatrix r = MatrixHelpers.randomOrthogonalMatrix(random, d).flipMajor(); assertMultiplyMxMT(d, random, r); } @@ -158,11 +145,11 @@ private static void assertMultiplyMxMT(final int d, @Nonnull final Random random final RealMatrix product = m.multiply(mT); assertThat(product) - .satisfies(p -> assertThat(p.getRowDimension()).isEqualTo(numResultRows), - p -> assertThat(p.getColumnDimension()).isEqualTo(numResultColumns)); + .satisfies(p -> assertThat(p.getNumRowDimensions()).isEqualTo(numResultRows), + p -> assertThat(p.getNumColumnDimensions()).isEqualTo(numResultColumns)); - for (int i = 0; i < product.getRowDimension(); i++) { - for (int j = 0; j < product.getColumnDimension(); j++) { + for (int i = 0; i < product.getNumRowDimensions(); i++) { + for (int j = 0; j < product.getNumColumnDimensions(); j++) { double expected = (i == j) ? 1.0 : 0.0; assertThat(Math.abs(product.getEntry(i, j) - expected)) .isCloseTo(0, within(2E-14)); @@ -182,12 +169,15 @@ void testMultiplyMatrix2(final long seed, final int d) { final RealMatrix product = m1.multiply(m2); - for (int i = 0; i < product.getRowDimension(); i++) { - for (int j = 0; j < product.getColumnDimension(); j++) { + for (int i = 0; i < product.getNumRowDimensions(); i++) { + for (int j = 0; j < product.getNumColumnDimensions(); j++) { final double expected = new DoubleRealVector(m1.getRow(i)).dot(new DoubleRealVector(m2.getColumn(j))); assertThat(Math.abs(product.getEntry(i, j) - expected)) .isCloseTo(0, within(2E-14)); } } + + Assertions.assertThat(m1.toRowMajor()).isSameAs(m1); + Assertions.assertThat(m2.toColumnMajor()).isSameAs(m2); } } diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/StoredVecsIteratorTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/StoredVecsIteratorTest.java index cd9c18684f..f1814e2ec1 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/linear/StoredVecsIteratorTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/linear/StoredVecsIteratorTest.java @@ -33,7 +33,7 @@ import java.util.List; import java.util.Set; -public class StoredVecsIteratorTest { +class StoredVecsIteratorTest { @SuppressWarnings("checkstyle:AbbreviationAsWordInName") @Test void readSIFT() throws IOException { diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/rabitq/RaBitQuantizerTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/rabitq/RaBitQuantizerTest.java index 5785c30fc9..7c1b78f738 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/rabitq/RaBitQuantizerTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/rabitq/RaBitQuantizerTest.java @@ -20,11 +20,13 @@ package com.apple.foundationdb.rabitq; +import com.apple.foundationdb.linear.AffineOperator; import com.apple.foundationdb.linear.DoubleRealVector; import com.apple.foundationdb.linear.FhtKacRotator; import com.apple.foundationdb.linear.Metric; import com.apple.foundationdb.linear.RealVector; import com.apple.foundationdb.linear.RealVectorTest; +import com.apple.foundationdb.linear.Transformed; import com.apple.test.RandomizedTestUtils; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; @@ -112,20 +114,26 @@ private static Stream estimationArgs() { @MethodSource("estimationArgs") void basicEncodeWithEstimationTestSpecialValues(final double[] centroidData, final double[] vData, final double[] qData, final double expectedDistance) { - final RealVector v = new DoubleRealVector(vData); - final RealVector q = new DoubleRealVector(qData); + final RealVector centroid = new DoubleRealVector(centroidData); + final AffineOperator operator = new AffineOperator(null, centroid.multiply(-1.0d)); + final Transformed v = operator.transform(new DoubleRealVector(vData)); + final Transformed q = operator.transform(new DoubleRealVector(qData)); final RaBitQuantizer quantizer = new RaBitQuantizer(Metric.EUCLIDEAN_SQUARE_METRIC, 7); - final EncodedRealVector encodedVector = quantizer.encode(v); + final Transformed encodedVector = quantizer.encode(v); final RaBitEstimator estimator = quantizer.estimator(); - final RaBitEstimator.Result estimatedDistanceResult = estimator.estimateDistanceAndErrorBound(q, encodedVector); + final RaBitEstimator.Result estimatedDistanceResult = + estimator.estimateDistanceAndErrorBound(q.getUnderlyingVector(), + (EncodedRealVector)encodedVector.getUnderlyingVector()); logger.info("estimated distance result = {}", estimatedDistanceResult); Assertions.assertThat(estimatedDistanceResult.getDistance()) .isCloseTo(expectedDistance, Offset.offset(0.01d)); - final EncodedRealVector encodedVector2 = quantizer.encode(v); + final Transformed encodedVector2 = quantizer.encode(v); Assertions.assertThat(encodedVector2.hashCode()).isEqualTo(encodedVector.hashCode()); Assertions.assertThat(encodedVector2).isEqualTo(encodedVector); + Assertions.assertThat(encodedVector.hashCode()).isEqualTo(encodedVector.getUnderlyingVector().hashCode()); + Assertions.assertThat(encodedVector.toString()).isEqualTo(encodedVector.getUnderlyingVector().toString()); } @ParameterizedTest @@ -165,28 +173,28 @@ void encodeManyWithEstimationsTest(final long seed, final int numDimensions, fin logger.trace("v = {}", v); logger.trace("centroid = {}", centroid); - final RealVector centroidRot = rotator.operateTranspose(centroid); - final RealVector qTrans = rotator.operateTranspose(q).subtract(centroidRot); - final RealVector vTrans = rotator.operateTranspose(v).subtract(centroidRot); + final RealVector centroidRot = rotator.apply(centroid); + final AffineOperator operator = new AffineOperator(rotator, centroidRot.multiply(-1.0d)); + final Transformed qTrans = operator.transform(q); + final Transformed vTrans = operator.transform(v); logger.trace("qTrans = {}", qTrans); logger.trace("vTrans = {}", vTrans); logger.trace("centroidRot = {}", centroidRot); final RaBitQuantizer quantizer = new RaBitQuantizer(Metric.EUCLIDEAN_SQUARE_METRIC, numExBits); - final RaBitQuantizer.Result resultV = quantizer.encodeInternal(vTrans); - final EncodedRealVector encodedV = resultV.encodedVector; - logger.trace("fAddEx vor v = {}", encodedV.getAddEx()); - logger.trace("fRescaleEx vor v = {}", encodedV.getRescaleEx()); - logger.trace("fErrorEx vor v = {}", encodedV.getErrorEx()); - - final EncodedRealVector encodedQ = quantizer.encode(qTrans); + final Transformed encodedV = quantizer.encode(vTrans); + final Transformed encodedQ = quantizer.encode(qTrans); final RaBitEstimator estimator = quantizer.estimator(); - final RealVector reconstructedQ = rotator.operate(encodedQ.add(centroidRot)); - final RealVector reconstructedV = rotator.operate(encodedV.add(centroidRot)); - final RaBitEstimator.Result estimatedDistance = estimator.estimateDistanceAndErrorBound(qTrans, encodedV); + final RealVector reconstructedQ = operator.untransform(encodedQ); + final RealVector reconstructedV = operator.untransform(encodedV); + final RaBitEstimator.Result estimatedDistance = + estimator.estimateDistanceAndErrorBound(qTrans.getUnderlyingVector(), + (EncodedRealVector)encodedV.getUnderlyingVector()); logger.trace("estimated ||qRot - vRot||^2 = {}", estimatedDistance); - final double trueDistance = Metric.EUCLIDEAN_SQUARE_METRIC.distance(vTrans, qTrans); + final double trueDistance = + Metric.EUCLIDEAN_SQUARE_METRIC.distance(vTrans.getUnderlyingVector(), + qTrans.getUnderlyingVector()); logger.trace("true ||qRot - vRot||^2 = {}", trueDistance); if (trueDistance >= estimatedDistance.getDistance() - estimatedDistance.getErr() && trueDistance < estimatedDistance.getDistance() + estimatedDistance.getErr()) { @@ -207,7 +215,7 @@ void encodeManyWithEstimationsTest(final long seed, final int numDimensions, fin logger.info("estimator better than reconstructed distance = {}%", String.format(Locale.ROOT, "%.2f", (double)numEstimationBetter * 100.0d / numRounds)); logger.info("relative error = {}%", String.format(Locale.ROOT, "%.2f", sumRelativeError * 100.0d / numRounds)); - Assertions.assertThat((double)numEstimationWithinBounds / numRounds).isGreaterThan(0.9); + Assertions.assertThat((double)numEstimationWithinBounds / numRounds).isGreaterThan(0.8); Assertions.assertThat((double)numEstimationBetter / numRounds).isBetween(0.3, 0.7); Assertions.assertThat(sumRelativeError / numRounds).isLessThan(0.1d); }