From 0c89250a955370ab45c98261f5b1932786fa09a1 Mon Sep 17 00:00:00 2001 From: Normen Seemann Date: Tue, 16 Sep 2025 09:38:55 +0200 Subject: [PATCH 01/10] initial code drop from hnsw-poc --- fdb-extensions/fdb-extensions.gradle | 1 + .../foundationdb/async/MoreAsyncUtil.java | 64 + .../foundationdb/async/hnsw/AbstractNode.java | 63 + .../async/hnsw/AbstractStorageAdapter.java | 144 ++ .../async/hnsw/BaseNeighborsChangeSet.java | 61 + .../foundationdb/async/hnsw/CompactNode.java | 103 ++ .../async/hnsw/CompactStorageAdapter.java | 177 +++ .../async/hnsw/DeleteNeighborsChangeSet.java | 83 ++ .../async/hnsw/EntryNodeReference.java | 56 + .../apple/foundationdb/async/hnsw/HNSW.java | 1246 +++++++++++++++++ .../foundationdb/async/hnsw/HNSWHelpers.java | 63 + .../foundationdb/async/hnsw/InliningNode.java | 94 ++ .../async/hnsw/InliningStorageAdapter.java | 181 +++ .../async/hnsw/InsertNeighborsChangeSet.java | 89 ++ .../apple/foundationdb/async/hnsw/Metric.java | 161 +++ .../foundationdb/async/hnsw/Metrics.java | 43 + .../async/hnsw/NeighborsChangeSet.java | 42 + .../apple/foundationdb/async/hnsw/Node.java | 59 + .../foundationdb/async/hnsw/NodeFactory.java | 37 + .../foundationdb/async/hnsw/NodeKind.java | 60 + .../async/hnsw/NodeReference.java | 72 + .../async/hnsw/NodeReferenceAndNode.java | 57 + .../async/hnsw/NodeReferenceWithDistance.java | 58 + .../async/hnsw/NodeReferenceWithVector.java | 76 + .../async/hnsw/OnReadListener.java | 46 + .../async/hnsw/OnWriteListener.java | 49 + .../async/hnsw/StorageAdapter.java | 184 +++ .../apple/foundationdb/async/hnsw/Vector.java | 224 +++ .../foundationdb/async/hnsw/package-info.java | 24 + .../foundationdb/async/rtree/NodeHelpers.java | 2 +- .../async/rtree/StorageAdapter.java | 1 - .../async/hnsw/HNSWModificationTest.java | 666 +++++++++ gradle/codequality/pmd-rules.xml | 1 + gradle/libs.versions.toml | 2 + gradle/scripts/log4j-test.properties | 2 +- 35 files changed, 4288 insertions(+), 3 deletions(-) create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java create mode 100644 fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java create mode 100644 fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java diff --git a/fdb-extensions/fdb-extensions.gradle b/fdb-extensions/fdb-extensions.gradle index 6c77f13db9..3601bd1e4a 100644 --- a/fdb-extensions/fdb-extensions.gradle +++ b/fdb-extensions/fdb-extensions.gradle @@ -27,6 +27,7 @@ dependencies { } api(libs.fdbJava) implementation(libs.guava) + implementation(libs.half4j) implementation(libs.slf4j.api) compileOnly(libs.jsr305) diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java index 563dec11a6..64e6d6b732 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java @@ -23,12 +23,14 @@ import com.apple.foundationdb.annotation.API; import com.apple.foundationdb.util.LoggableException; import com.google.common.base.Suppliers; +import com.google.common.collect.Lists; import com.google.common.util.concurrent.ThreadFactoryBuilder; import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -42,9 +44,13 @@ import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.IntPredicate; +import java.util.function.IntUnaryOperator; import java.util.function.Predicate; import java.util.function.Supplier; @@ -1051,6 +1057,64 @@ public static CompletableFuture swallowException(@Nonnull CompletableFutur return result; } + @Nonnull + public static CompletableFuture forLoop(final int startI, @Nullable final U startU, + @Nonnull final IntPredicate conditionPredicate, + @Nonnull final IntUnaryOperator stepFunction, + @Nonnull final BiFunction> body, + @Nonnull final Executor executor) { + final AtomicInteger loopVariableAtomic = new AtomicInteger(startI); + final AtomicReference lastResultAtomic = new AtomicReference<>(startU); + return whileTrue(() -> { + final int loopVariable = loopVariableAtomic.get(); + if (!conditionPredicate.test(loopVariable)) { + return AsyncUtil.READY_FALSE; + } + return body.apply(loopVariable, lastResultAtomic.get()) + .thenApply(result -> { + loopVariableAtomic.set(stepFunction.applyAsInt(loopVariable)); + lastResultAtomic.set(result); + return true; + }); + }, executor).thenApply(ignored -> lastResultAtomic.get()); + } + + @SuppressWarnings("unchecked") + public static CompletableFuture> forEach(@Nonnull final Iterable items, + @Nonnull final Function> body, + final int parallelism, + @Nonnull final Executor executor) { + // this deque is only modified by once upon creation + final ArrayDeque toBeProcessed = new ArrayDeque<>(); + for (final T item : items) { + toBeProcessed.addLast(item); + } + + final List> working = Lists.newArrayList(); + final AtomicInteger indexAtomic = new AtomicInteger(0); + final Object[] resultArray = new Object[toBeProcessed.size()]; + + return whileTrue(() -> { + working.removeIf(CompletableFuture::isDone); + + while (working.size() <= parallelism) { + final T currentItem = toBeProcessed.pollFirst(); + if (currentItem == null) { + break; + } + + final int index = indexAtomic.getAndIncrement(); + working.add(body.apply(currentItem) + .thenAccept(result -> resultArray[index] = result)); + } + + if (working.isEmpty()) { + return AsyncUtil.READY_FALSE; + } + return whenAny(working).thenApply(ignored -> true); + }, executor).thenApply(ignored -> Arrays.asList((U[])resultArray)); + } + /** * A {@code Boolean} function that is always true. * @param the type of the (ignored) argument to the function diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java new file mode 100644 index 0000000000..aa062e8700 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java @@ -0,0 +1,63 @@ +/* + * AbstractNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import java.util.List; + +/** + * TODO. + * @param node type class. + */ +abstract class AbstractNode implements Node { + @Nonnull + private final Tuple primaryKey; + + @Nonnull + private final List neighbors; + + protected AbstractNode(@Nonnull final Tuple primaryKey, + @Nonnull final List neighbors) { + this.primaryKey = primaryKey; + this.neighbors = ImmutableList.copyOf(neighbors); + } + + @Nonnull + @Override + public Tuple getPrimaryKey() { + return primaryKey; + } + + @Nonnull + @Override + public List getNeighbors() { + return neighbors; + } + + @Nonnull + @Override + public N getNeighbor(final int index) { + return neighbors.get(index); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java new file mode 100644 index 0000000000..e3d0c943fc --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java @@ -0,0 +1,144 @@ +/* + * AbstractStorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.Tuple; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.concurrent.CompletableFuture; + +/** + * Implementations and attributes common to all concrete implementations of {@link StorageAdapter}. + */ +abstract class AbstractStorageAdapter implements StorageAdapter { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(AbstractStorageAdapter.class); + + @Nonnull + private final HNSW.Config config; + @Nonnull + private final NodeFactory nodeFactory; + @Nonnull + private final Subspace subspace; + @Nonnull + private final OnWriteListener onWriteListener; + @Nonnull + private final OnReadListener onReadListener; + + private final Subspace dataSubspace; + + protected AbstractStorageAdapter(@Nonnull final HNSW.Config config, @Nonnull final NodeFactory nodeFactory, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + this.config = config; + this.nodeFactory = nodeFactory; + this.subspace = subspace; + this.onWriteListener = onWriteListener; + this.onReadListener = onReadListener; + this.dataSubspace = subspace.subspace(Tuple.from(SUBSPACE_PREFIX_DATA)); + } + + @Override + @Nonnull + public HNSW.Config getConfig() { + return config; + } + + @Nonnull + @Override + public NodeFactory getNodeFactory() { + return nodeFactory; + } + + @Nonnull + @Override + public NodeKind getNodeKind() { + return getNodeFactory().getNodeKind(); + } + + @Override + @Nonnull + public Subspace getSubspace() { + return subspace; + } + + @Override + @Nonnull + public Subspace getDataSubspace() { + return dataSubspace; + } + + @Override + @Nonnull + public OnWriteListener getOnWriteListener() { + return onWriteListener; + } + + @Override + @Nonnull + public OnReadListener getOnReadListener() { + return onReadListener; + } + + @Nonnull + @Override + public CompletableFuture> fetchNode(@Nonnull final ReadTransaction readTransaction, + int layer, @Nonnull Tuple primaryKey) { + return fetchNodeInternal(readTransaction, layer, primaryKey).thenApply(this::checkNode); + } + + @Nonnull + protected abstract CompletableFuture> fetchNodeInternal(@Nonnull ReadTransaction readTransaction, + int layer, @Nonnull Tuple primaryKey); + + /** + * Method to perform basic invariant check(s) on a newly-fetched node. + * + * @param node the node to check + * was passed in + * + * @return the node that was passed in + */ + @Nullable + private Node checkNode(@Nullable final Node node) { + return node; + } + + @Override + public void writeNode(@Nonnull Transaction transaction, @Nonnull Node node, int layer, + @Nonnull NeighborsChangeSet changeSet) { + writeNodeInternal(transaction, node, layer, changeSet); + if (logger.isDebugEnabled()) { + logger.debug("written node with key={} at layer={}", node.getPrimaryKey(), layer); + } + } + + protected abstract void writeNodeInternal(@Nonnull Transaction transaction, @Nonnull Node node, int layer, + @Nonnull NeighborsChangeSet changeSet); + +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java new file mode 100644 index 0000000000..bb8271af39 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java @@ -0,0 +1,61 @@ +/* + * InliningNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.function.Predicate; + +/** + * TODO. + */ +class BaseNeighborsChangeSet implements NeighborsChangeSet { + @Nonnull + private final List neighbors; + + public BaseNeighborsChangeSet(@Nonnull final List neighbors) { + this.neighbors = ImmutableList.copyOf(neighbors); + } + + @Nullable + @Override + public BaseNeighborsChangeSet getParent() { + return null; + } + + @Nonnull + @Override + public List merge() { + return neighbors; + } + + @Override + public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, + final int layer, @Nonnull final Node node, + @Nonnull final Predicate primaryKeyPredicate) { + // nothing to be written + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java new file mode 100644 index 0000000000..a6a28e778d --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java @@ -0,0 +1,103 @@ +/* + * CompactNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +/** + * TODO. + */ +public class CompactNode extends AbstractNode { + @Nonnull + private static final NodeFactory FACTORY = new NodeFactory<>() { + @SuppressWarnings("unchecked") + @Nonnull + @Override + @SpotBugsSuppressWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") + public Node create(@Nonnull final Tuple primaryKey, @Nullable final Vector vector, + @Nonnull final List neighbors) { + return new CompactNode(primaryKey, Objects.requireNonNull(vector), (List)neighbors); + } + + @Nonnull + @Override + public NodeKind getNodeKind() { + return NodeKind.COMPACT; + } + }; + + @Nonnull + private final Vector vector; + + public CompactNode(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, + @Nonnull final List neighbors) { + super(primaryKey, neighbors); + this.vector = vector; + } + + @Nonnull + @Override + public NodeReference getSelfReference(@Nullable final Vector vector) { + return new NodeReference(getPrimaryKey()); + } + + @Nonnull + @Override + public NodeKind getKind() { + return NodeKind.COMPACT; + } + + @Nonnull + public Vector getVector() { + return vector; + } + + @Nonnull + @Override + public CompactNode asCompactNode() { + return this; + } + + @Nonnull + @Override + public InliningNode asInliningNode() { + throw new IllegalStateException("this is not an inlining node"); + } + + @Nonnull + public static NodeFactory factory() { + return FACTORY; + } + + @Override + public String toString() { + return "C[primaryKey=" + getPrimaryKey() + + ";vector=" + vector + + ";neighbors=" + getNeighbors() + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java new file mode 100644 index 0000000000..c3a04f86a2 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java @@ -0,0 +1,177 @@ +/* + * CompactStorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.KeyValue; +import com.apple.foundationdb.Range; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.StreamingMode; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.async.AsyncIterable; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.ByteArrayUtil; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; +import com.google.common.base.Verify; +import com.google.common.collect.Lists; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * TODO. + */ +class CompactStorageAdapter extends AbstractStorageAdapter implements StorageAdapter { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(CompactStorageAdapter.class); + + public CompactStorageAdapter(@Nonnull final HNSW.Config config, @Nonnull final NodeFactory nodeFactory, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + super(config, nodeFactory, subspace, onWriteListener, onReadListener); + } + + @Nonnull + @Override + public StorageAdapter asCompactStorageAdapter() { + return this; + } + + @Nonnull + @Override + public StorageAdapter asInliningStorageAdapter() { + throw new IllegalStateException("cannot call this method on a compact storage adapter"); + } + + @Nonnull + @Override + protected CompletableFuture> fetchNodeInternal(@Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Tuple primaryKey) { + final byte[] keyBytes = getDataSubspace().pack(Tuple.from(layer, primaryKey)); + + return readTransaction.get(keyBytes) + .thenApply(valueBytes -> { + if (valueBytes == null) { + throw new IllegalStateException("cannot fetch node"); + } + return nodeFromRaw(layer, primaryKey, keyBytes, valueBytes); + }); + } + + @Nonnull + private Node nodeFromRaw(final int layer, final @Nonnull Tuple primaryKey, + @Nonnull final byte[] keyBytes, @Nonnull final byte[] valueBytes) { + final Tuple nodeTuple = Tuple.fromBytes(valueBytes); + final Node node = nodeFromTuples(primaryKey, nodeTuple); + final OnReadListener onReadListener = getOnReadListener(); + onReadListener.onNodeRead(layer, node); + onReadListener.onKeyValueRead(layer, keyBytes, valueBytes); + return node; + } + + @Nonnull + private Node nodeFromTuples(@Nonnull final Tuple primaryKey, + @Nonnull final Tuple valueTuple) { + final NodeKind nodeKind = NodeKind.fromSerializedNodeKind((byte)valueTuple.getLong(0)); + Verify.verify(nodeKind == NodeKind.COMPACT); + + final Tuple vectorTuple; + final Tuple neighborsTuple; + + vectorTuple = valueTuple.getNestedTuple(1); + neighborsTuple = valueTuple.getNestedTuple(2); + return compactNodeFromTuples(primaryKey, vectorTuple, neighborsTuple); + } + + @Nonnull + private Node compactNodeFromTuples(@Nonnull final Tuple primaryKey, + @Nonnull final Tuple vectorTuple, + @Nonnull final Tuple neighborsTuple) { + final Vector vector = StorageAdapter.vectorFromTuple(vectorTuple); + final List nodeReferences = Lists.newArrayListWithExpectedSize(neighborsTuple.size()); + + for (int i = 0; i < neighborsTuple.size(); i ++) { + final Tuple neighborTuple = neighborsTuple.getNestedTuple(i); + nodeReferences.add(new NodeReference(neighborTuple)); + } + + return getNodeFactory().create(primaryKey, vector, nodeReferences); + } + + @Override + public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull final Node node, + final int layer, @Nonnull final NeighborsChangeSet neighborsChangeSet) { + final byte[] key = getDataSubspace().pack(Tuple.from(layer, node.getPrimaryKey())); + + final List nodeItems = Lists.newArrayListWithExpectedSize(3); + nodeItems.add(NodeKind.COMPACT.getSerialized()); + final CompactNode compactNode = node.asCompactNode(); + nodeItems.add(StorageAdapter.tupleFromVector(compactNode.getVector())); + + final Iterable neighbors = neighborsChangeSet.merge(); + + final List neighborItems = Lists.newArrayList(); + for (final NodeReference neighborReference : neighbors) { + neighborItems.add(neighborReference.getPrimaryKey()); + } + nodeItems.add(Tuple.fromList(neighborItems)); + + final Tuple nodeTuple = Tuple.fromList(nodeItems); + + final byte[] value = nodeTuple.pack(); + transaction.set(key, value); + getOnWriteListener().onNodeWritten(layer, node); + getOnWriteListener().onKeyValueWritten(layer, key, value); + + if (logger.isDebugEnabled()) { + logger.debug("written neighbors of primaryKey={}, oldSize={}, newSize={}", node.getPrimaryKey(), + node.getNeighbors().size(), neighborItems.size()); + } + } + + @Nonnull + @Override + public Iterable> scanLayer(@Nonnull final ReadTransaction readTransaction, int layer, + @Nullable final Tuple lastPrimaryKey, int maxNumRead) { + final byte[] layerPrefix = getDataSubspace().pack(Tuple.from(layer)); + final Range range = + lastPrimaryKey == null + ? Range.startsWith(layerPrefix) + : new Range(ByteArrayUtil.strinc(getDataSubspace().pack(Tuple.from(layer, lastPrimaryKey))), + ByteArrayUtil.strinc(layerPrefix)); + final AsyncIterable itemsIterable = + readTransaction.getRange(range, maxNumRead, false, StreamingMode.ITERATOR); + + return AsyncUtil.mapIterable(itemsIterable, keyValue -> { + final byte[] key = keyValue.getKey(); + final byte[] value = keyValue.getValue(); + final Tuple primaryKey = getDataSubspace().unpack(key).getNestedTuple(1); + return nodeFromRaw(layer, primaryKey, key, value); + }); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java new file mode 100644 index 0000000000..e431561119 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java @@ -0,0 +1,83 @@ +/* + * InliningNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.Collection; +import java.util.Set; +import java.util.function.Predicate; + +/** + * TODO. + */ +class DeleteNeighborsChangeSet implements NeighborsChangeSet { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(DeleteNeighborsChangeSet.class); + + @Nonnull + private final NeighborsChangeSet parent; + + @Nonnull + private final Set deletedNeighborsPrimaryKeys; + + public DeleteNeighborsChangeSet(@Nonnull final NeighborsChangeSet parent, + @Nonnull final Collection deletedNeighborsPrimaryKeys) { + this.parent = parent; + this.deletedNeighborsPrimaryKeys = ImmutableSet.copyOf(deletedNeighborsPrimaryKeys); + } + + @Nonnull + @Override + public NeighborsChangeSet getParent() { + return parent; + } + + @Nonnull + @Override + public Iterable merge() { + return Iterables.filter(getParent().merge(), + current -> !deletedNeighborsPrimaryKeys.contains(current.getPrimaryKey())); + } + + @Override + public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, + final int layer, @Nonnull final Node node, @Nonnull final Predicate tuplePredicate) { + getParent().writeDelta(storageAdapter, transaction, layer, node, + tuplePredicate.and(tuple -> !deletedNeighborsPrimaryKeys.contains(tuple))); + + for (final Tuple deletedNeighborPrimaryKey : deletedNeighborsPrimaryKeys) { + if (tuplePredicate.test(deletedNeighborPrimaryKey)) { + storageAdapter.deleteNeighbor(transaction, layer, node.asInliningNode(), deletedNeighborPrimaryKey); + if (logger.isDebugEnabled()) { + logger.debug("deleted neighbor of primaryKey={} targeting primaryKey={}", node.getPrimaryKey(), + deletedNeighborPrimaryKey); + } + } + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java new file mode 100644 index 0000000000..db81252e17 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java @@ -0,0 +1,56 @@ +/* + * NodeWithLayer.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; +import java.util.Objects; + +class EntryNodeReference extends NodeReferenceWithVector { + private final int layer; + + public EntryNodeReference(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, final int layer) { + super(primaryKey, vector); + this.layer = layer; + } + + public int getLayer() { + return layer; + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof EntryNodeReference)) { + return false; + } + if (!super.equals(o)) { + return false; + } + return layer == ((EntryNodeReference)o).layer; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), layer); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java new file mode 100644 index 0000000000..fb177c9d77 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java @@ -0,0 +1,1246 @@ +/* + * HNSW.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Database; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.async.MoreAsyncUtil; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; +import com.google.common.base.Verify; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import com.google.common.collect.Streams; +import com.google.common.collect.TreeMultimap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Queue; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.PriorityBlockingQueue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.apple.foundationdb.async.MoreAsyncUtil.forEach; +import static com.apple.foundationdb.async.MoreAsyncUtil.forLoop; + +/** + * TODO. + */ +@API(API.Status.EXPERIMENTAL) +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +public class HNSW { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(HNSW.class); + + public static final int MAX_CONCURRENT_NODE_READS = 16; + public static final int MAX_CONCURRENT_NEIGHBOR_FETCHES = 3; + public static final int MAX_CONCURRENT_SEARCHES = 10; + @Nonnull public static final Random DEFAULT_RANDOM = new Random(0L); + @Nonnull public static final Metric DEFAULT_METRIC = new Metric.EuclideanMetric(); + public static final int DEFAULT_M = 16; + public static final int DEFAULT_M_MAX = DEFAULT_M; + public static final int DEFAULT_M_MAX_0 = 2 * DEFAULT_M; + public static final int DEFAULT_EF_SEARCH = 100; + public static final int DEFAULT_EF_CONSTRUCTION = 200; + public static final boolean DEFAULT_EXTEND_CANDIDATES = false; + public static final boolean DEFAULT_KEEP_PRUNED_CONNECTIONS = false; + + @Nonnull + public static final Config DEFAULT_CONFIG = new Config(); + + @Nonnull + private final Subspace subspace; + @Nonnull + private final Executor executor; + @Nonnull + private final Config config; + @Nonnull + private final OnWriteListener onWriteListener; + @Nonnull + private final OnReadListener onReadListener; + + /** + * Configuration settings for a {@link HNSW}. + */ + @SuppressWarnings("checkstyle:MemberName") + public static class Config { + @Nonnull + private final Random random; + @Nonnull + private final Metric metric; + private final int m; + private final int mMax; + private final int mMax0; + private final int efSearch; + private final int efConstruction; + private final boolean extendCandidates; + private final boolean keepPrunedConnections; + + protected Config() { + this.random = DEFAULT_RANDOM; + this.metric = DEFAULT_METRIC; + this.m = DEFAULT_M; + this.mMax = DEFAULT_M_MAX; + this.mMax0 = DEFAULT_M_MAX_0; + this.efSearch = DEFAULT_EF_SEARCH; + this.efConstruction = DEFAULT_EF_CONSTRUCTION; + this.extendCandidates = DEFAULT_EXTEND_CANDIDATES; + this.keepPrunedConnections = DEFAULT_KEEP_PRUNED_CONNECTIONS; + } + + protected Config(@Nonnull final Random random, @Nonnull final Metric metric, final int m, final int mMax, + final int mMax0, final int efSearch, final int efConstruction, final boolean extendCandidates, + final boolean keepPrunedConnections) { + this.random = random; + this.metric = metric; + this.m = m; + this.mMax = mMax; + this.mMax0 = mMax0; + this.efSearch = efSearch; + this.efConstruction = efConstruction; + this.extendCandidates = extendCandidates; + this.keepPrunedConnections = keepPrunedConnections; + } + + @Nonnull + public Random getRandom() { + return random; + } + + @Nonnull + public Metric getMetric() { + return metric; + } + + public int getM() { + return m; + } + + public int getMMax() { + return mMax; + } + + public int getMMax0() { + return mMax0; + } + + public int getEfSearch() { + return efSearch; + } + + public int getEfConstruction() { + return efConstruction; + } + + public boolean isExtendCandidates() { + return extendCandidates; + } + + public boolean isKeepPrunedConnections() { + return keepPrunedConnections; + } + + @Nonnull + public ConfigBuilder toBuilder() { + return new ConfigBuilder(getRandom(), getMetric(), getM(), getMMax(), getMMax0(), getEfSearch(), + getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections()); + } + + @Override + @Nonnull + public String toString() { + return "Config[metric=" + getMetric() + "M=" + getM() + " , MMax=" + getMMax() + " , MMax0=" + getMMax0() + + ", efSearch=" + getEfSearch() + ", efConstruction=" + getEfConstruction() + + ", isExtendCandidates=" + isExtendCandidates() + + ", isKeepPrunedConnections=" + isKeepPrunedConnections() + "]"; + } + } + + /** + * Builder for {@link Config}. + * + * @see #newConfigBuilder + */ + @CanIgnoreReturnValue + @SuppressWarnings("checkstyle:MemberName") + public static class ConfigBuilder { + @Nonnull + private Random random = DEFAULT_RANDOM; + @Nonnull + private Metric metric = DEFAULT_METRIC; + private int m = DEFAULT_M; + private int mMax = DEFAULT_M_MAX; + private int mMax0 = DEFAULT_M_MAX_0; + private int efSearch = DEFAULT_EF_SEARCH; + private int efConstruction = DEFAULT_EF_CONSTRUCTION; + private boolean extendCandidates = DEFAULT_EXTEND_CANDIDATES; + private boolean keepPrunedConnections = DEFAULT_KEEP_PRUNED_CONNECTIONS; + + public ConfigBuilder() { + } + + public ConfigBuilder(@Nonnull Random random, @Nonnull final Metric metric, final int m, final int mMax, + final int mMax0, final int efSearch, final int efConstruction, + final boolean extendCandidates, final boolean keepPrunedConnections) { + this.random = random; + this.metric = metric; + this.m = m; + this.mMax = mMax; + this.mMax0 = mMax0; + this.efSearch = efSearch; + this.efConstruction = efConstruction; + this.extendCandidates = extendCandidates; + this.keepPrunedConnections = keepPrunedConnections; + } + + @Nonnull + public Random getRandom() { + return random; + } + + @Nonnull + public ConfigBuilder setRandom(@Nonnull final Random random) { + this.random = random; + return this; + } + + @Nonnull + public Metric getMetric() { + return metric; + } + + @Nonnull + public ConfigBuilder setMetric(@Nonnull final Metric metric) { + this.metric = metric; + return this; + } + + public int getM() { + return m; + } + + @Nonnull + public ConfigBuilder setM(final int m) { + this.m = m; + return this; + } + + public int getMMax() { + return mMax; + } + + @Nonnull + public ConfigBuilder setMMax(final int mMax) { + this.mMax = mMax; + return this; + } + + public int getMMax0() { + return mMax0; + } + + @Nonnull + public ConfigBuilder setMMax0(final int mMax0) { + this.mMax0 = mMax0; + return this; + } + + public int getEfSearch() { + return efSearch; + } + + public ConfigBuilder setEfSearch(final int efSearch) { + this.efSearch = efSearch; + return this; + } + + public int getEfConstruction() { + return efConstruction; + } + + public ConfigBuilder setEfConstruction(final int efConstruction) { + this.efConstruction = efConstruction; + return this; + } + + public boolean isExtendCandidates() { + return extendCandidates; + } + + public ConfigBuilder setExtendCandidates(final boolean extendCandidates) { + this.extendCandidates = extendCandidates; + return this; + } + + public boolean isKeepPrunedConnections() { + return keepPrunedConnections; + } + + public ConfigBuilder setKeepPrunedConnections(final boolean keepPrunedConnections) { + this.keepPrunedConnections = keepPrunedConnections; + return this; + } + + public Config build() { + return new Config(getRandom(), getMetric(), getM(), getMMax(), getMMax0(), getEfSearch(), + getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections()); + } + } + + /** + * Start building a {@link Config}. + * @return a new {@code Config} that can be altered and then built for use with a {@link HNSW} + * @see ConfigBuilder#build + */ + public static ConfigBuilder newConfigBuilder() { + return new ConfigBuilder(); + } + + /** + * TODO. + */ + public HNSW(@Nonnull final Subspace subspace, @Nonnull final Executor executor) { + this(subspace, executor, DEFAULT_CONFIG, OnWriteListener.NOOP, OnReadListener.NOOP); + } + + /** + * TODO. + */ + public HNSW(@Nonnull final Subspace subspace, + @Nonnull final Executor executor, @Nonnull final Config config, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + this.subspace = subspace; + this.executor = executor; + this.config = config; + this.onWriteListener = onWriteListener; + this.onReadListener = onReadListener; + } + + + @Nonnull + public Subspace getSubspace() { + return subspace; + } + + /** + * Get the executer used by this r-tree. + * @return executor used when running asynchronous tasks + */ + @Nonnull + public Executor getExecutor() { + return executor; + } + + /** + * Get this r-tree's configuration. + * @return r-tree configuration + */ + @Nonnull + public Config getConfig() { + return config; + } + + /** + * Get the on-write listener. + * @return the on-write listener + */ + @Nonnull + public OnWriteListener getOnWriteListener() { + return onWriteListener; + } + + /** + * Get the on-read listener. + * @return the on-read listener + */ + @Nonnull + public OnReadListener getOnReadListener() { + return onReadListener; + } + + // + // Read Path + // + + /** + * TODO. + */ + @SuppressWarnings("checkstyle:MethodName") // method name introduced by paper + @Nonnull + public CompletableFuture>> kNearestNeighborsSearch(@Nonnull final ReadTransaction readTransaction, + final int k, + final int efSearch, + @Nonnull final Vector queryVector) { + return StorageAdapter.fetchEntryNodeReference(readTransaction, getSubspace(), getOnReadListener()) + .thenCompose(entryPointAndLayer -> { + if (entryPointAndLayer == null) { + return CompletableFuture.completedFuture(null); // not a single node in the index + } + + final Metric metric = getConfig().getMetric(); + + final NodeReferenceWithDistance entryState = + new NodeReferenceWithDistance(entryPointAndLayer.getPrimaryKey(), + entryPointAndLayer.getVector(), + Vector.comparativeDistance(metric, entryPointAndLayer.getVector(), queryVector)); + + final var entryLayer = entryPointAndLayer.getLayer(); + if (entryLayer == 0) { + // entry data points to a node in layer 0 directly + return CompletableFuture.completedFuture(entryState); + } + + return forLoop(entryLayer, entryState, + layer -> layer > 0, + layer -> layer - 1, + (layer, previousNodeReference) -> { + final var storageAdapter = getStorageAdapterForLayer(layer); + return greedySearchLayer(storageAdapter, readTransaction, previousNodeReference, + layer, queryVector); + }, executor); + }).thenCompose(nodeReference -> { + if (nodeReference == null) { + return CompletableFuture.completedFuture(null); + } + + final var storageAdapter = getStorageAdapterForLayer(0); + + return searchLayer(storageAdapter, readTransaction, + ImmutableList.of(nodeReference), 0, efSearch, + Maps.newConcurrentMap(), queryVector) + .thenApply(searchResult -> { + // reverse the original queue + final TreeMultimap> sortedTopK = + TreeMultimap.create(Comparator.naturalOrder(), + Comparator.comparing(nodeReferenceAndNode -> nodeReferenceAndNode.getNode().getPrimaryKey())); + + for (final NodeReferenceAndNode nodeReferenceAndNode : searchResult) { + if (sortedTopK.size() < k || sortedTopK.keySet().last() > + nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance()) { + sortedTopK.put(nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance(), + nodeReferenceAndNode); + } + + if (sortedTopK.size() > k) { + final Double lastKey = sortedTopK.keySet().last(); + final NodeReferenceAndNode lastNode = sortedTopK.get(lastKey).last(); + sortedTopK.remove(lastKey, lastNode); + } + } + + return ImmutableList.copyOf(sortedTopK.values()); + }); + }); + } + + @Nonnull + private CompletableFuture greedySearchLayer(@Nonnull StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final NodeReferenceWithDistance entryNeighbor, + final int layer, + @Nonnull final Vector queryVector) { + if (storageAdapter.getNodeKind() == NodeKind.INLINING) { + return greedySearchInliningLayer(storageAdapter.asInliningStorageAdapter(), readTransaction, entryNeighbor, layer, queryVector); + } else { + return searchLayer(storageAdapter, readTransaction, ImmutableList.of(entryNeighbor), layer, 1, Maps.newConcurrentMap(), queryVector) + .thenApply(searchResult -> Iterables.getOnlyElement(searchResult).getNodeReferenceWithDistance()); + } + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture greedySearchInliningLayer(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final NodeReferenceWithDistance entryNeighbor, + final int layer, + @Nonnull final Vector queryVector) { + Verify.verify(layer > 0); + final Metric metric = getConfig().getMetric(); + final AtomicReference currentNodeReferenceAtomic = + new AtomicReference<>(entryNeighbor); + + return AsyncUtil.whileTrue(() -> onReadListener.onAsyncRead( + storageAdapter.fetchNode(readTransaction, layer, currentNodeReferenceAtomic.get().getPrimaryKey())) + .thenApply(node -> { + if (node == null) { + throw new IllegalStateException("unable to fetch node"); + } + final InliningNode inliningNode = node.asInliningNode(); + final List neighbors = inliningNode.getNeighbors(); + + final NodeReferenceWithDistance currentNodeReference = currentNodeReferenceAtomic.get(); + double minDistance = currentNodeReference.getDistance(); + + NodeReferenceWithVector nearestNeighbor = null; + for (final NodeReferenceWithVector neighbor : neighbors) { + final double distance = + Vector.comparativeDistance(metric, neighbor.getVector(), queryVector); + if (distance < minDistance) { + minDistance = distance; + nearestNeighbor = neighbor; + } + } + + if (nearestNeighbor == null) { + return false; + } + + currentNodeReferenceAtomic.set( + new NodeReferenceWithDistance(nearestNeighbor.getPrimaryKey(), nearestNeighbor.getVector(), + minDistance)); + return true; + }), executor).thenApply(ignored -> currentNodeReferenceAtomic.get()); + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture>> searchLayer(@Nonnull StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final Collection entryNeighbors, + final int layer, + final int efSearch, + @Nonnull final Map> nodeCache, + @Nonnull final Vector queryVector) { + final Set visited = Sets.newConcurrentHashSet(NodeReference.primaryKeys(entryNeighbors)); + final Queue candidates = + new PriorityBlockingQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + candidates.addAll(entryNeighbors); + final Queue nearestNeighbors = + new PriorityBlockingQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance).reversed()); + nearestNeighbors.addAll(entryNeighbors); + final Metric metric = getConfig().getMetric(); + + return AsyncUtil.whileTrue(() -> { + if (candidates.isEmpty()) { + return AsyncUtil.READY_FALSE; + } + + final NodeReferenceWithDistance candidate = candidates.poll(); + final NodeReferenceWithDistance furthestNeighbor = Objects.requireNonNull(nearestNeighbors.peek()); + + if (candidate.getDistance() > furthestNeighbor.getDistance()) { + return AsyncUtil.READY_FALSE; + } + + return fetchNodeIfNotCached(storageAdapter, readTransaction, layer, candidate, nodeCache) + .thenApply(candidateNode -> + Iterables.filter(candidateNode.getNeighbors(), + neighbor -> !visited.contains(neighbor.getPrimaryKey()))) + .thenCompose(neighborReferences -> fetchNeighborhood(storageAdapter, readTransaction, + layer, neighborReferences, nodeCache)) + .thenApply(neighborReferences -> { + for (final NodeReferenceWithVector current : neighborReferences) { + visited.add(current.getPrimaryKey()); + final double furthestDistance = + Objects.requireNonNull(nearestNeighbors.peek()).getDistance(); + + final double currentDistance = + Vector.comparativeDistance(metric, current.getVector(), queryVector); + if (currentDistance < furthestDistance || nearestNeighbors.size() < efSearch) { + final NodeReferenceWithDistance currentWithDistance = + new NodeReferenceWithDistance(current.getPrimaryKey(), current.getVector(), + currentDistance); + candidates.add(currentWithDistance); + nearestNeighbors.add(currentWithDistance); + if (nearestNeighbors.size() > efSearch) { + nearestNeighbors.poll(); + } + } + } + return true; + }); + }).thenCompose(ignored -> + fetchSomeNodesIfNotCached(storageAdapter, readTransaction, layer, nearestNeighbors, nodeCache)) + .thenApply(searchResult -> { + debug(l -> l.debug("searched layer={} for efSearch={} with result=={}", layer, efSearch, + searchResult.stream() + .map(nodeReferenceAndNode -> + "(primaryKey=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getPrimaryKey() + + ",distance=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance() + ")") + .collect(Collectors.joining(",")))); + return searchResult; + }); + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture> fetchNodeIfNotCached(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final NodeReference nodeReference, + @Nonnull final Map> nodeCache) { + return fetchNodeIfNecessaryAndApply(storageAdapter, readTransaction, layer, nodeReference, + nR -> nodeCache.get(nR.getPrimaryKey()), + (nR, node) -> { + nodeCache.put(nR.getPrimaryKey(), node); + return node; + }); + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture fetchNodeIfNecessaryAndApply(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final R nodeReference, + @Nonnull final Function fetchBypassFunction, + @Nonnull final BiFunction, U> biMapFunction) { + final U bypass = fetchBypassFunction.apply(nodeReference); + if (bypass != null) { + return CompletableFuture.completedFuture(bypass); + } + + return onReadListener.onAsyncRead( + storageAdapter.fetchNode(readTransaction, layer, nodeReference.getPrimaryKey())) + .thenApply(node -> biMapFunction.apply(nodeReference, node)); + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture> fetchNeighborhood(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Iterable neighborReferences, + @Nonnull final Map> nodeCache) { + return fetchSomeNodesAndApply(storageAdapter, readTransaction, layer, neighborReferences, + neighborReference -> { + if (neighborReference instanceof NodeReferenceWithVector) { + return (NodeReferenceWithVector)neighborReference; + } + final Node neighborNode = nodeCache.get(neighborReference.getPrimaryKey()); + if (neighborNode == null) { + return null; + } + return new NodeReferenceWithVector(neighborReference.getPrimaryKey(), neighborNode.asCompactNode().getVector()); + }, + (neighborReference, neighborNode) -> { + nodeCache.put(neighborReference.getPrimaryKey(), neighborNode); + return new NodeReferenceWithVector(neighborReference.getPrimaryKey(), neighborNode.asCompactNode().getVector()); + }); + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture>> fetchSomeNodesIfNotCached(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Iterable nodeReferences, + @Nonnull final Map> nodeCache) { + return fetchSomeNodesAndApply(storageAdapter, readTransaction, layer, nodeReferences, + nodeReference -> { + final Node node = nodeCache.get(nodeReference.getPrimaryKey()); + if (node == null) { + return null; + } + return new NodeReferenceAndNode<>(nodeReference, node); + }, + (nodeReferenceWithDistance, node) -> { + nodeCache.put(nodeReferenceWithDistance.getPrimaryKey(), node); + return new NodeReferenceAndNode<>(nodeReferenceWithDistance, node); + }); + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture> fetchSomeNodesAndApply(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Iterable nodeReferences, + @Nonnull final Function fetchBypassFunction, + @Nonnull final BiFunction, U> biMapFunction) { + return forEach(nodeReferences, + currentNeighborReference -> fetchNodeIfNecessaryAndApply(storageAdapter, readTransaction, layer, + currentNeighborReference, fetchBypassFunction, biMapFunction), MAX_CONCURRENT_NODE_READS, + getExecutor()); + } + + @Nonnull + public CompletableFuture insert(@Nonnull final Transaction transaction, @Nonnull final NodeReferenceWithVector nodeReferenceWithVector) { + return insert(transaction, nodeReferenceWithVector.getPrimaryKey(), nodeReferenceWithVector.getVector()); + } + + @Nonnull + public CompletableFuture insert(@Nonnull final Transaction transaction, @Nonnull final Tuple newPrimaryKey, + @Nonnull final Vector newVector) { + final Metric metric = getConfig().getMetric(); + + final int insertionLayer = insertionLayer(getConfig().getRandom()); + debug(l -> l.debug("new node with key={} selected to be inserted into layer={}", newPrimaryKey, insertionLayer)); + + return StorageAdapter.fetchEntryNodeReference(transaction, getSubspace(), getOnReadListener()) + .thenApply(entryNodeReference -> { + if (entryNodeReference == null) { + // this is the first node + writeLonelyNodes(transaction, newPrimaryKey, newVector, insertionLayer, -1); + StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), + new EntryNodeReference(newPrimaryKey, newVector, insertionLayer), getOnWriteListener()); + debug(l -> l.debug("written entry node reference with key={} on layer={}", newPrimaryKey, insertionLayer)); + } else { + final int lMax = entryNodeReference.getLayer(); + if (insertionLayer > lMax) { + writeLonelyNodes(transaction, newPrimaryKey, newVector, insertionLayer, lMax); + StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), + new EntryNodeReference(newPrimaryKey, newVector, insertionLayer), getOnWriteListener()); + debug(l -> l.debug("written entry node reference with key={} on layer={}", newPrimaryKey, insertionLayer)); + } + } + return entryNodeReference; + }).thenCompose(entryNodeReference -> { + if (entryNodeReference == null) { + return AsyncUtil.DONE; + } + + final int lMax = entryNodeReference.getLayer(); + debug(l -> l.debug("entry node with key {} at layer {}", entryNodeReference.getPrimaryKey(), + lMax)); + + final NodeReferenceWithDistance initialNodeReference = + new NodeReferenceWithDistance(entryNodeReference.getPrimaryKey(), + entryNodeReference.getVector(), + Vector.comparativeDistance(metric, entryNodeReference.getVector(), newVector)); + return forLoop(lMax, initialNodeReference, + layer -> layer > insertionLayer, + layer -> layer - 1, + (layer, previousNodeReference) -> { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + return greedySearchLayer(storageAdapter, transaction, + previousNodeReference, layer, newVector); + }, executor) + .thenCompose(nodeReference -> + insertIntoLayers(transaction, newPrimaryKey, newVector, nodeReference, + lMax, insertionLayer)); + }).thenCompose(ignored -> AsyncUtil.DONE); + } + + @Nonnull + public CompletableFuture insertBatch(@Nonnull final Transaction transaction, + @Nonnull List batch) { + final Metric metric = getConfig().getMetric(); + + // determine the layer each item should be inserted at + final Random random = getConfig().getRandom(); + final List batchWithLayers = Lists.newArrayListWithCapacity(batch.size()); + for (final NodeReferenceWithVector current : batch) { + batchWithLayers.add(new NodeReferenceWithLayer(current.getPrimaryKey(), current.getVector(), + insertionLayer(random))); + } + // sort the layers in reverse order + batchWithLayers.sort(Comparator.comparing(NodeReferenceWithLayer::getLayer).reversed()); + + return StorageAdapter.fetchEntryNodeReference(transaction, getSubspace(), getOnReadListener()) + .thenCompose(entryNodeReference -> { + final int lMax = entryNodeReference == null ? -1 : entryNodeReference.getLayer(); + + return forEach(batchWithLayers, + item -> { + if (lMax == -1) { + return CompletableFuture.completedFuture(null); + } + + final Vector itemVector = item.getVector(); + final int itemL = item.getLayer(); + + final NodeReferenceWithDistance initialNodeReference = + new NodeReferenceWithDistance(entryNodeReference.getPrimaryKey(), + entryNodeReference.getVector(), + Vector.comparativeDistance(metric, entryNodeReference.getVector(), itemVector)); + + return forLoop(lMax, initialNodeReference, + layer -> layer > itemL, + layer -> layer - 1, + (layer, previousNodeReference) -> { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + return greedySearchLayer(storageAdapter, transaction, + previousNodeReference, layer, itemVector); + }, executor); + }, MAX_CONCURRENT_SEARCHES, getExecutor()) + .thenCompose(searchEntryReferences -> + forLoop(0, entryNodeReference, + index -> index < batchWithLayers.size(), + index -> index + 1, + (index, currentEntryNodeReference) -> { + final NodeReferenceWithLayer item = batchWithLayers.get(index); + final Tuple itemPrimaryKey = item.getPrimaryKey(); + final Vector itemVector = item.getVector(); + final int itemL = item.getLayer(); + + final EntryNodeReference newEntryNodeReference; + final int currentLMax; + + if (entryNodeReference == null) { + // this is the first node + writeLonelyNodes(transaction, itemPrimaryKey, itemVector, itemL, -1); + newEntryNodeReference = + new EntryNodeReference(itemPrimaryKey, itemVector, itemL); + StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), + newEntryNodeReference, getOnWriteListener()); + debug(l -> l.debug("written entry node reference with key={} on layer={}", itemPrimaryKey, itemL)); + + return CompletableFuture.completedFuture(newEntryNodeReference); + } else { + currentLMax = currentEntryNodeReference.getLayer(); + if (itemL > currentLMax) { + writeLonelyNodes(transaction, itemPrimaryKey, itemVector, itemL, lMax); + newEntryNodeReference = + new EntryNodeReference(itemPrimaryKey, itemVector, itemL); + StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), + newEntryNodeReference, getOnWriteListener()); + debug(l -> l.debug("written entry node reference with key={} on layer={}", itemPrimaryKey, itemL)); + } else { + newEntryNodeReference = entryNodeReference; + } + } + + debug(l -> l.debug("entry node with key {} at layer {}", + currentEntryNodeReference.getPrimaryKey(), currentLMax)); + + final var currentSearchEntry = + searchEntryReferences.get(index); + + return insertIntoLayers(transaction, itemPrimaryKey, itemVector, currentSearchEntry, + lMax, itemL).thenApply(ignored -> newEntryNodeReference); + }, getExecutor())); + }).thenCompose(ignored -> AsyncUtil.DONE); + } + + @Nonnull + private CompletableFuture insertIntoLayers(@Nonnull final Transaction transaction, + @Nonnull final Tuple newPrimaryKey, + @Nonnull final Vector newVector, + @Nonnull final NodeReferenceWithDistance nodeReference, + final int lMax, + final int insertionLayer) { + debug(l -> l.debug("nearest entry point at lMax={} is at key={}", lMax, nodeReference.getPrimaryKey())); + return MoreAsyncUtil.>forLoop(Math.min(lMax, insertionLayer), ImmutableList.of(nodeReference), + layer -> layer >= 0, + layer -> layer - 1, + (layer, previousNodeReferences) -> { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + return insertIntoLayer(storageAdapter, transaction, + previousNodeReferences, layer, newPrimaryKey, newVector); + }, executor).thenCompose(ignored -> AsyncUtil.DONE); + } + + @Nonnull + private CompletableFuture> insertIntoLayer(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + @Nonnull final List nearestNeighbors, + int layer, + @Nonnull final Tuple newPrimaryKey, + @Nonnull final Vector newVector) { + debug(l -> l.debug("begin insert key={} at layer={}", newPrimaryKey, layer)); + final Map> nodeCache = Maps.newConcurrentMap(); + + return searchLayer(storageAdapter, transaction, + nearestNeighbors, layer, config.getEfConstruction(), nodeCache, newVector) + .thenCompose(searchResult -> { + final List references = NodeReferenceAndNode.getReferences(searchResult); + + return selectNeighbors(storageAdapter, transaction, searchResult, layer, getConfig().getM(), + getConfig().isExtendCandidates(), nodeCache, newVector) + .thenCompose(selectedNeighbors -> { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + + final Node newNode = + nodeFactory.create(newPrimaryKey, newVector, + NodeReferenceAndNode.getReferences(selectedNeighbors)); + + final NeighborsChangeSet newNodeChangeSet = + new InsertNeighborsChangeSet<>(new BaseNeighborsChangeSet<>(ImmutableList.of()), + newNode.getNeighbors()); + + storageAdapter.writeNode(transaction, newNode, layer, newNodeChangeSet); + + // create change sets for each selected neighbor and insert new node into them + final Map> neighborChangeSetMap = + Maps.newLinkedHashMap(); + for (final NodeReferenceAndNode selectedNeighbor : selectedNeighbors) { + final NeighborsChangeSet baseSet = + new BaseNeighborsChangeSet<>(selectedNeighbor.getNode().getNeighbors()); + final NeighborsChangeSet insertSet = + new InsertNeighborsChangeSet<>(baseSet, ImmutableList.of(newNode.getSelfReference(newVector))); + neighborChangeSetMap.put(selectedNeighbor.getNode().getPrimaryKey(), + insertSet); + } + + final int currentMMax = layer == 0 ? getConfig().getMMax0() : getConfig().getMMax(); + return forEach(selectedNeighbors, + selectedNeighbor -> { + final Node selectedNeighborNode = selectedNeighbor.getNode(); + final NeighborsChangeSet changeSet = + Objects.requireNonNull(neighborChangeSetMap.get(selectedNeighborNode.getPrimaryKey())); + return pruneNeighborsIfNecessary(storageAdapter, transaction, + selectedNeighbor, layer, currentMMax, changeSet, nodeCache) + .thenApply(nodeReferencesAndNodes -> { + if (nodeReferencesAndNodes == null) { + return changeSet; + } + return resolveChangeSetFromNewNeighbors(changeSet, nodeReferencesAndNodes); + }); + }, MAX_CONCURRENT_NEIGHBOR_FETCHES, getExecutor()) + .thenApply(changeSets -> { + for (int i = 0; i < selectedNeighbors.size(); i++) { + final NodeReferenceAndNode selectedNeighbor = selectedNeighbors.get(i); + final NeighborsChangeSet changeSet = changeSets.get(i); + storageAdapter.writeNode(transaction, selectedNeighbor.getNode(), + layer, changeSet); + } + return ImmutableList.copyOf(references); + }); + }); + }).thenApply(nodeReferencesWithDistances -> { + debug(l -> l.debug("end insert key={} at layer={}", newPrimaryKey, layer)); + return nodeReferencesWithDistances; + }); + } + + private NeighborsChangeSet resolveChangeSetFromNewNeighbors(@Nonnull final NeighborsChangeSet beforeChangeSet, + @Nonnull final Iterable> afterNeighbors) { + final Map beforeNeighborsMap = Maps.newLinkedHashMap(); + for (final N n : beforeChangeSet.merge()) { + beforeNeighborsMap.put(n.getPrimaryKey(), n); + } + + final Map afterNeighborsMap = Maps.newLinkedHashMap(); + for (final NodeReferenceAndNode nodeReferenceAndNode : afterNeighbors) { + final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); + + afterNeighborsMap.put(nodeReferenceWithDistance.getPrimaryKey(), + nodeReferenceAndNode.getNode().getSelfReference(nodeReferenceWithDistance.getVector())); + } + + final ImmutableList.Builder toBeDeletedBuilder = ImmutableList.builder(); + for (final Map.Entry beforeNeighborEntry : beforeNeighborsMap.entrySet()) { + if (!afterNeighborsMap.containsKey(beforeNeighborEntry.getKey())) { + toBeDeletedBuilder.add(beforeNeighborEntry.getValue().getPrimaryKey()); + } + } + final List toBeDeleted = toBeDeletedBuilder.build(); + + final ImmutableList.Builder toBeInsertedBuilder = ImmutableList.builder(); + for (final Map.Entry afterNeighborEntry : afterNeighborsMap.entrySet()) { + if (!beforeNeighborsMap.containsKey(afterNeighborEntry.getKey())) { + toBeInsertedBuilder.add(afterNeighborEntry.getValue()); + } + } + final List toBeInserted = toBeInsertedBuilder.build(); + + NeighborsChangeSet changeSet = beforeChangeSet; + + if (!toBeDeleted.isEmpty()) { + changeSet = new DeleteNeighborsChangeSet<>(changeSet, toBeDeleted); + } + if (!toBeInserted.isEmpty()) { + changeSet = new InsertNeighborsChangeSet<>(changeSet, toBeInserted); + } + return changeSet; + } + + @Nonnull + private CompletableFuture>> pruneNeighborsIfNecessary(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + @Nonnull final NodeReferenceAndNode selectedNeighbor, + int layer, + int mMax, + @Nonnull final NeighborsChangeSet neighborChangeSet, + @Nonnull final Map> nodeCache) { + final Metric metric = getConfig().getMetric(); + final Node selectedNeighborNode = selectedNeighbor.getNode(); + if (selectedNeighborNode.getNeighbors().size() < mMax) { + return CompletableFuture.completedFuture(null); + } else { + debug(l -> l.debug("pruning neighborhood of key={} which has numNeighbors={} out of mMax={}", + selectedNeighborNode.getPrimaryKey(), selectedNeighborNode.getNeighbors().size(), mMax)); + return fetchNeighborhood(storageAdapter, transaction, layer, neighborChangeSet.merge(), nodeCache) + .thenCompose(nodeReferenceWithVectors -> { + final ImmutableList.Builder nodeReferencesWithDistancesBuilder = + ImmutableList.builder(); + for (final NodeReferenceWithVector nodeReferenceWithVector : nodeReferenceWithVectors) { + final var vector = nodeReferenceWithVector.getVector(); + final double distance = + Vector.comparativeDistance(metric, vector, + selectedNeighbor.getNodeReferenceWithDistance().getVector()); + nodeReferencesWithDistancesBuilder.add( + new NodeReferenceWithDistance(nodeReferenceWithVector.getPrimaryKey(), + vector, distance)); + } + return fetchSomeNodesIfNotCached(storageAdapter, transaction, layer, + nodeReferencesWithDistancesBuilder.build(), nodeCache); + }) + .thenCompose(nodeReferencesAndNodes -> + selectNeighbors(storageAdapter, transaction, + nodeReferencesAndNodes, layer, + mMax, false, nodeCache, + selectedNeighbor.getNodeReferenceWithDistance().getVector())); + } + } + + private CompletableFuture>> selectNeighbors(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final Iterable> nearestNeighbors, + final int layer, + final int m, + final boolean isExtendCandidates, + @Nonnull final Map> nodeCache, + @Nonnull final Vector vector) { + return extendCandidatesIfNecessary(storageAdapter, readTransaction, nearestNeighbors, layer, isExtendCandidates, nodeCache, vector) + .thenApply(extendedCandidates -> { + final List selected = Lists.newArrayListWithExpectedSize(m); + final Queue candidates = + new PriorityBlockingQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + candidates.addAll(extendedCandidates); + final Queue discardedCandidates = + getConfig().isKeepPrunedConnections() + ? new PriorityBlockingQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)) + : null; + + final Metric metric = getConfig().getMetric(); + + while (!candidates.isEmpty() && selected.size() < m) { + final NodeReferenceWithDistance nearestCandidate = candidates.poll(); + boolean shouldSelect = true; + for (final NodeReferenceWithDistance alreadySelected : selected) { + if (Vector.comparativeDistance(metric, nearestCandidate.getVector(), + alreadySelected.getVector()) < nearestCandidate.getDistance()) { + shouldSelect = false; + break; + } + } + if (shouldSelect) { + selected.add(nearestCandidate); + } else if (discardedCandidates != null) { + discardedCandidates.add(nearestCandidate); + } + } + + if (discardedCandidates != null) { // isKeepPrunedConnections is set to true + while (!discardedCandidates.isEmpty() && selected.size() < m) { + selected.add(discardedCandidates.poll()); + } + } + + return ImmutableList.copyOf(selected); + }).thenCompose(selectedNeighbors -> + fetchSomeNodesIfNotCached(storageAdapter, readTransaction, layer, selectedNeighbors, nodeCache)) + .thenApply(selectedNeighbors -> { + debug(l -> + l.debug("selected neighbors={}", + selectedNeighbors.stream() + .map(selectedNeighbor -> + "(primaryKey=" + selectedNeighbor.getNodeReferenceWithDistance().getPrimaryKey() + + ",distance=" + selectedNeighbor.getNodeReferenceWithDistance().getDistance() + ")") + .collect(Collectors.joining(",")))); + return selectedNeighbors; + }); + } + + private CompletableFuture> extendCandidatesIfNecessary(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final Iterable> candidates, + int layer, + boolean isExtendCandidates, + @Nonnull final Map> nodeCache, + @Nonnull final Vector vector) { + if (isExtendCandidates) { + final Metric metric = getConfig().getMetric(); + + final Set candidatesSeen = Sets.newConcurrentHashSet(); + for (final NodeReferenceAndNode candidate : candidates) { + candidatesSeen.add(candidate.getNode().getPrimaryKey()); + } + + final ImmutableList.Builder neighborsOfCandidatesBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + for (final N neighbor : candidate.getNode().getNeighbors()) { + final Tuple neighborPrimaryKey = neighbor.getPrimaryKey(); + if (!candidatesSeen.contains(neighborPrimaryKey)) { + candidatesSeen.add(neighborPrimaryKey); + neighborsOfCandidatesBuilder.add(neighbor); + } + } + } + + final Iterable neighborsOfCandidates = neighborsOfCandidatesBuilder.build(); + + return fetchNeighborhood(storageAdapter, readTransaction, layer, neighborsOfCandidates, nodeCache) + .thenApply(withVectors -> { + final ImmutableList.Builder extendedCandidatesBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + extendedCandidatesBuilder.add(candidate.getNodeReferenceWithDistance()); + } + + for (final NodeReferenceWithVector withVector : withVectors) { + final double distance = Vector.comparativeDistance(metric, withVector.getVector(), vector); + extendedCandidatesBuilder.add(new NodeReferenceWithDistance(withVector.getPrimaryKey(), + withVector.getVector(), distance)); + } + return extendedCandidatesBuilder.build(); + }); + } else { + final ImmutableList.Builder resultBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + resultBuilder.add(candidate.getNodeReferenceWithDistance()); + } + + return CompletableFuture.completedFuture(resultBuilder.build()); + } + } + + private void writeLonelyNodes(@Nonnull final Transaction transaction, + @Nonnull final Tuple primaryKey, + @Nonnull final Vector vector, + final int highestLayerInclusive, + final int lowestLayerExclusive) { + for (int layer = highestLayerInclusive; layer > lowestLayerExclusive; layer --) { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + writeLonelyNodeOnLayer(storageAdapter, transaction, layer, primaryKey, vector); + } + } + + private void writeLonelyNodeOnLayer(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + final int layer, + @Nonnull final Tuple primaryKey, + @Nonnull final Vector vector) { + storageAdapter.writeNode(transaction, + storageAdapter.getNodeFactory() + .create(primaryKey, vector, ImmutableList.of()), layer, + new BaseNeighborsChangeSet<>(ImmutableList.of())); + debug(l -> l.debug("written lonely node at key={} on layer={}", primaryKey, layer)); + } + + public void scanLayer(@Nonnull final Database db, + final int layer, + final int batchSize, + @Nonnull final Consumer> nodeConsumer) { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + final AtomicReference lastPrimaryKeyAtomic = new AtomicReference<>(); + Tuple newPrimaryKey; + do { + final Tuple lastPrimaryKey = lastPrimaryKeyAtomic.get(); + lastPrimaryKeyAtomic.set(null); + newPrimaryKey = db.run(tr -> { + Streams.stream(storageAdapter.scanLayer(tr, layer, lastPrimaryKey, batchSize)) + .forEach(node -> { + nodeConsumer.accept(node); + lastPrimaryKeyAtomic.set(node.getPrimaryKey()); + }); + return lastPrimaryKeyAtomic.get(); + }, executor); + } while (newPrimaryKey != null); + } + + @Nonnull + private StorageAdapter getStorageAdapterForLayer(final int layer) { + return false && layer > 0 + ? new InliningStorageAdapter(getConfig(), InliningNode.factory(), getSubspace(), getOnWriteListener(), getOnReadListener()) + : new CompactStorageAdapter(getConfig(), CompactNode.factory(), getSubspace(), getOnWriteListener(), getOnReadListener()); + } + + private int insertionLayer(@Nonnull final Random random) { + double lambda = 1.0 / Math.log(getConfig().getM()); + double u = 1.0 - random.nextDouble(); // Avoid log(0) + return (int) Math.floor(-Math.log(u) * lambda); + } + + @SuppressWarnings("PMD.UnusedPrivateMethod") + private void info(@Nonnull final Consumer loggerConsumer) { + if (logger.isInfoEnabled()) { + loggerConsumer.accept(logger); + } + } + + private void debug(@Nonnull final Consumer loggerConsumer) { + if (logger.isDebugEnabled()) { + loggerConsumer.accept(logger); + } + } + + private static class NodeReferenceWithLayer extends NodeReferenceWithVector { + private final int layer; + + public NodeReferenceWithLayer(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, + final int layer) { + super(primaryKey, vector); + this.layer = layer; + } + + public int getLayer() { + return layer; + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof NodeReferenceWithLayer)) { + return false; + } + if (!super.equals(o)) { + return false; + } + return layer == ((NodeReferenceWithLayer)o).layer; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), layer); + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java new file mode 100644 index 0000000000..322b4f85b0 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java @@ -0,0 +1,63 @@ +/* + * HNSWHelpers.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; + +/** + * Some helper methods for {@link Node}s. + */ +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +public class HNSWHelpers { + private static final char[] hexArray = "0123456789ABCDEF".toCharArray(); + + private HNSWHelpers() { + // nothing + } + + /** + * Helper method to format bytes as hex strings for logging and debugging. + * @param bytes an array of bytes + * @return a {@link String} containing the hexadecimal representation of the byte array passed in + */ + @Nonnull + public static String bytesToHex(byte[] bytes) { + char[] hexChars = new char[bytes.length * 2]; + for (int j = 0; j < bytes.length; j++) { + int v = bytes[j] & 0xFF; + hexChars[j * 2] = hexArray[v >>> 4]; + hexChars[j * 2 + 1] = hexArray[v & 0x0F]; + } + return "0x" + new String(hexChars).replaceFirst("^0+(?!$)", ""); + } + + @Nonnull + public static Half halfValueOf(final double d) { + return Half.shortBitsToHalf(Half.halfToShortBits(Half.valueOf(d))); + } + + @Nonnull + public static Half halfValueOf(final float f) { + return Half.shortBitsToHalf(Half.halfToShortBits(Half.valueOf(f))); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java new file mode 100644 index 0000000000..48e2398950 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java @@ -0,0 +1,94 @@ +/* + * InliningNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +/** + * TODO. + */ +class InliningNode extends AbstractNode { + @Nonnull + private static final NodeFactory FACTORY = new NodeFactory<>() { + @SuppressWarnings("unchecked") + @Nonnull + @Override + public Node create(@Nonnull final Tuple primaryKey, + @Nullable final Vector vector, + @Nonnull final List neighbors) { + return new InliningNode(primaryKey, (List)neighbors); + } + + @Nonnull + @Override + public NodeKind getNodeKind() { + return NodeKind.INLINING; + } + }; + + public InliningNode(@Nonnull final Tuple primaryKey, + @Nonnull final List neighbors) { + super(primaryKey, neighbors); + } + + @Nonnull + @Override + @SpotBugsSuppressWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") + public NodeReferenceWithVector getSelfReference(@Nullable final Vector vector) { + return new NodeReferenceWithVector(getPrimaryKey(), Objects.requireNonNull(vector)); + } + + @Nonnull + @Override + public NodeKind getKind() { + return NodeKind.INLINING; + } + + @Nonnull + @Override + public CompactNode asCompactNode() { + throw new IllegalStateException("this is not a compact node"); + } + + @Nonnull + @Override + public InliningNode asInliningNode() { + return this; + } + + @Nonnull + public static NodeFactory factory() { + return FACTORY; + } + + @Override + public String toString() { + return "I[primaryKey=" + getPrimaryKey() + + ";neighbors=" + getNeighbors() + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java new file mode 100644 index 0000000000..ebbfd4d698 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java @@ -0,0 +1,181 @@ +/* + * CompactStorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.KeyValue; +import com.apple.foundationdb.Range; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.StreamingMode; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.async.AsyncIterable; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.ByteArrayUtil; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * TODO. + */ +class InliningStorageAdapter extends AbstractStorageAdapter implements StorageAdapter { + public InliningStorageAdapter(@Nonnull final HNSW.Config config, + @Nonnull final NodeFactory nodeFactory, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + super(config, nodeFactory, subspace, onWriteListener, onReadListener); + } + + @Nonnull + @Override + public StorageAdapter asCompactStorageAdapter() { + throw new IllegalStateException("cannot call this method on an inlining storage adapter"); + } + + @Nonnull + @Override + public StorageAdapter asInliningStorageAdapter() { + return this; + } + + @Nonnull + @Override + protected CompletableFuture> fetchNodeInternal(@Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Tuple primaryKey) { + final byte[] rangeKey = getNodeKey(layer, primaryKey); + + return AsyncUtil.collect(readTransaction.getRange(Range.startsWith(rangeKey), + ReadTransaction.ROW_LIMIT_UNLIMITED, false, StreamingMode.WANT_ALL), readTransaction.getExecutor()) + .thenApply(keyValues -> nodeFromRaw(layer, primaryKey, keyValues)); + } + + @Nonnull + private Node nodeFromRaw(final int layer, final @Nonnull Tuple primaryKey, final List keyValues) { + final OnReadListener onReadListener = getOnReadListener(); + + final ImmutableList.Builder nodeReferencesWithVectorBuilder = ImmutableList.builder(); + for (final KeyValue keyValue : keyValues) { + nodeReferencesWithVectorBuilder.add(neighborFromRaw(layer, keyValue.getKey(), keyValue.getValue())); + } + + final Node node = + getNodeFactory().create(primaryKey, null, nodeReferencesWithVectorBuilder.build()); + onReadListener.onNodeRead(layer, node); + return node; + } + + @Nonnull + private NodeReferenceWithVector neighborFromRaw(final int layer, final @Nonnull byte[] key, final byte[] value) { + final OnReadListener onReadListener = getOnReadListener(); + + onReadListener.onKeyValueRead(layer, key, value); + final Tuple neighborKeyTuple = getDataSubspace().unpack(key); + final Tuple neighborValueTuple = Tuple.fromBytes(value); + + final Tuple neighborPrimaryKey = neighborKeyTuple.getNestedTuple(2); // neighbor primary key + final Vector neighborVector = StorageAdapter.vectorFromTuple(neighborValueTuple); // the entire value is the vector + return new NodeReferenceWithVector(neighborPrimaryKey, neighborVector); + } + + @Override + public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull final Node node, + final int layer, @Nonnull final NeighborsChangeSet neighborsChangeSet) { + final InliningNode inliningNode = node.asInliningNode(); + + neighborsChangeSet.writeDelta(this, transaction, layer, inliningNode, t -> true); + getOnWriteListener().onNodeWritten(layer, node); + } + + @Nonnull + private byte[] getNodeKey(final int layer, @Nonnull final Tuple primaryKey) { + return getDataSubspace().pack(Tuple.from(layer, primaryKey)); + } + + public void writeNeighbor(@Nonnull final Transaction transaction, final int layer, + @Nonnull final Node node, @Nonnull final NodeReferenceWithVector neighbor) { + final byte[] neighborKey = getNeighborKey(layer, node, neighbor.getPrimaryKey()); + final byte[] value = StorageAdapter.tupleFromVector(neighbor.getVector()).pack(); + transaction.set(neighborKey, + value); + getOnWriteListener().onNeighborWritten(layer, node, neighbor); + getOnWriteListener().onKeyValueWritten(layer, neighborKey, value); + } + + public void deleteNeighbor(@Nonnull final Transaction transaction, final int layer, + @Nonnull final Node node, @Nonnull final Tuple neighborPrimaryKey) { + transaction.clear(getNeighborKey(layer, node, neighborPrimaryKey)); + getOnWriteListener().onNeighborDeleted(layer, node, neighborPrimaryKey); + } + + @Nonnull + private byte[] getNeighborKey(final int layer, + @Nonnull final Node node, + @Nonnull final Tuple neighborPrimaryKey) { + return getDataSubspace().pack(Tuple.from(layer, node.getPrimaryKey(), neighborPrimaryKey)); + } + + @Nonnull + @Override + public Iterable> scanLayer(@Nonnull final ReadTransaction readTransaction, int layer, + @Nullable final Tuple lastPrimaryKey, int maxNumRead) { + final byte[] layerPrefix = getDataSubspace().pack(Tuple.from(layer)); + final Range range = + lastPrimaryKey == null + ? Range.startsWith(layerPrefix) + : new Range(ByteArrayUtil.strinc(getDataSubspace().pack(Tuple.from(layer, lastPrimaryKey))), + ByteArrayUtil.strinc(layerPrefix)); + final AsyncIterable itemsIterable = + readTransaction.getRange(range, + maxNumRead, false, StreamingMode.ITERATOR); + int numRead = 0; + Tuple nodePrimaryKey = null; + ImmutableList.Builder> nodeBuilder = ImmutableList.builder(); + ImmutableList.Builder neighborsBuilder = ImmutableList.builder(); + for (final KeyValue item: itemsIterable) { + final NodeReferenceWithVector neighbor = + neighborFromRaw(layer, item.getKey(), item.getValue()); + final Tuple primaryKeyFromNodeReference = neighbor.getPrimaryKey(); + if (nodePrimaryKey == null) { + nodePrimaryKey = primaryKeyFromNodeReference; + } else { + if (!nodePrimaryKey.equals(primaryKeyFromNodeReference)) { + nodeBuilder.add(getNodeFactory().create(nodePrimaryKey, null, neighborsBuilder.build())); + } + } + neighborsBuilder.add(neighbor); + numRead ++; + } + + // there may be a rest + if (numRead > 0 && numRead < maxNumRead) { + nodeBuilder.add(getNodeFactory().create(nodePrimaryKey, null, neighborsBuilder.build())); + } + + return nodeBuilder.build(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java new file mode 100644 index 0000000000..d68d3ae933 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java @@ -0,0 +1,89 @@ +/* + * InliningNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +/** + * TODO. + */ +class InsertNeighborsChangeSet implements NeighborsChangeSet { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(InsertNeighborsChangeSet.class); + + @Nonnull + private final NeighborsChangeSet parent; + + @Nonnull + private final Map insertedNeighborsMap; + + public InsertNeighborsChangeSet(@Nonnull final NeighborsChangeSet parent, + @Nonnull final List insertedNeighbors) { + this.parent = parent; + final ImmutableMap.Builder insertedNeighborsMapBuilder = ImmutableMap.builder(); + for (final N insertedNeighbor : insertedNeighbors) { + insertedNeighborsMapBuilder.put(insertedNeighbor.getPrimaryKey(), insertedNeighbor); + } + + this.insertedNeighborsMap = insertedNeighborsMapBuilder.build(); + } + + @Nonnull + @Override + public NeighborsChangeSet getParent() { + return parent; + } + + @Nonnull + @Override + public Iterable merge() { + return Iterables.concat(getParent().merge(), insertedNeighborsMap.values()); + } + + @Override + public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, + final int layer, @Nonnull final Node node, @Nonnull final Predicate tuplePredicate) { + getParent().writeDelta(storageAdapter, transaction, layer, node, + tuplePredicate.and(tuple -> !insertedNeighborsMap.containsKey(tuple))); + + for (final Map.Entry entry : insertedNeighborsMap.entrySet()) { + final Tuple primaryKey = entry.getKey(); + if (tuplePredicate.test(primaryKey)) { + storageAdapter.writeNeighbor(transaction, layer, node.asInliningNode(), + entry.getValue().asNodeReferenceWithVector()); + if (logger.isDebugEnabled()) { + logger.debug("inserted neighbor of primaryKey={} targeting primaryKey={}", node.getPrimaryKey(), + primaryKey); + } + } + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java new file mode 100644 index 0000000000..6e236a5d10 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java @@ -0,0 +1,161 @@ +/* + * Metric.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import javax.annotation.Nonnull; + +public interface Metric { + double distance(Double[] vector1, Double[] vector2); + + default double comparativeDistance(Double[] vector1, Double[] vector2) { + return distance(vector1, vector2); + } + + /** + * A helper method to validate that vectors can be compared. + * @param vector1 The first vector. + * @param vector2 The second vector. + */ + private static void validate(Double[] vector1, Double[] vector2) { + if (vector1 == null || vector2 == null) { + throw new IllegalArgumentException("Vectors cannot be null"); + } + if (vector1.length != vector2.length) { + throw new IllegalArgumentException( + "Vectors must have the same dimensionality. Got " + vector1.length + " and " + vector2.length + ); + } + if (vector1.length == 0) { + throw new IllegalArgumentException("Vectors cannot be empty."); + } + } + + class ManhattanMetric implements Metric { + @Override + public double distance(final Double[] vector1, final Double[] vector2) { + Metric.validate(vector1, vector2); + + double sumOfAbsDiffs = 0.0; + for (int i = 0; i < vector1.length; i++) { + sumOfAbsDiffs += Math.abs(vector1[i] - vector2[i]); + } + return sumOfAbsDiffs; + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } + + class EuclideanMetric implements Metric { + @Override + public double distance(final Double[] vector1, final Double[] vector2) { + Metric.validate(vector1, vector2); + + return Math.sqrt(EuclideanSquareMetric.distanceInternal(vector1, vector2)); + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } + + class EuclideanSquareMetric implements Metric { + @Override + public double distance(final Double[] vector1, final Double[] vector2) { + Metric.validate(vector1, vector2); + return distanceInternal(vector1, vector2); + } + + private static double distanceInternal(final Double[] vector1, final Double[] vector2) { + double sumOfSquares = 0.0d; + for (int i = 0; i < vector1.length; i++) { + double diff = vector1[i] - vector2[i]; + sumOfSquares += diff * diff; + } + return sumOfSquares; + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } + + class CosineMetric implements Metric { + @Override + public double distance(final Double[] vector1, final Double[] vector2) { + Metric.validate(vector1, vector2); + + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (int i = 0; i < vector1.length; i++) { + dotProduct += vector1[i] * vector2[i]; + normA += vector1[i] * vector1[i]; + normB += vector2[i] * vector2[i]; + } + + // Handle the case of zero-vectors to avoid division by zero + if (normA == 0.0 || normB == 0.0) { + return Double.POSITIVE_INFINITY; + } + + return 1.0d - dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } + + class DotProductMetric implements Metric { + @Override + public double distance(final Double[] vector1, final Double[] vector2) { + throw new UnsupportedOperationException("dot product metric is not a true metric and can only be used for ranking"); + } + + @Override + public double comparativeDistance(final Double[] vector1, final Double[] vector2) { + Metric.validate(vector1, vector2); + + double product = 0.0d; + for (int i = 0; i < vector1.length; i++) { + product += vector1[i] * vector2[i]; + } + return -product; + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java new file mode 100644 index 0000000000..8c30faf852 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java @@ -0,0 +1,43 @@ +/* + * Metric.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import javax.annotation.Nonnull; + +public enum Metrics { + MANHATTAN_METRIC(new Metric.ManhattanMetric()), + EUCLIDEAN_METRIC(new Metric.EuclideanMetric()), + EUCLIDEAN_SQUARE_METRIC(new Metric.EuclideanSquareMetric()), + COSINE_METRIC(new Metric.CosineMetric()), + DOT_PRODUCT_METRIC(new Metric.DotProductMetric()); + + @Nonnull + private final Metric metric; + + Metrics(@Nonnull final Metric metric) { + this.metric = metric; + } + + @Nonnull + public Metric getMetric() { + return metric; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java new file mode 100644 index 0000000000..b7f38ef1a7 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java @@ -0,0 +1,42 @@ +/* + * InliningNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.function.Predicate; + +/** + * TODO. + */ +interface NeighborsChangeSet { + @Nullable + NeighborsChangeSet getParent(); + + @Nonnull + Iterable merge(); + + void writeDelta(@Nonnull InliningStorageAdapter storageAdapter, @Nonnull Transaction transaction, int layer, + @Nonnull Node node, @Nonnull Predicate primaryKeyPredicate); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java new file mode 100644 index 0000000000..f2c623f882 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java @@ -0,0 +1,59 @@ +/* + * Node.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; + +/** + * TODO. + * @param neighbor type + */ +public interface Node { + @Nonnull + Tuple getPrimaryKey(); + + @Nonnull + N getSelfReference(@Nullable Vector vector); + + @Nonnull + List getNeighbors(); + + @Nonnull + N getNeighbor(int index); + + /** + * Return the kind of the node, i.e. {@link NodeKind#COMPACT} or {@link NodeKind#INLINING}. + * @return the kind of this node as a {@link NodeKind} + */ + @Nonnull + NodeKind getKind(); + + @Nonnull + CompactNode asCompactNode(); + + @Nonnull + InliningNode asInliningNode(); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java new file mode 100644 index 0000000000..321e3f53d8 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java @@ -0,0 +1,37 @@ +/* + * NodeFactory.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; + +public interface NodeFactory { + @Nonnull + Node create(@Nonnull Tuple primaryKey, @Nullable Vector vector, + @Nonnull List neighbors); + + @Nonnull + NodeKind getNodeKind(); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java new file mode 100644 index 0000000000..13d71a1b9b --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java @@ -0,0 +1,60 @@ +/* + * NodeKind.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.google.common.base.Verify; + +import javax.annotation.Nonnull; + +/** + * Enum to capture the kind of node. + */ +public enum NodeKind { + COMPACT((byte)0x00), + INLINING((byte)0x01); + + private final byte serialized; + + NodeKind(final byte serialized) { + this.serialized = serialized; + } + + public byte getSerialized() { + return serialized; + } + + @Nonnull + static NodeKind fromSerializedNodeKind(byte serializedNodeKind) { + final NodeKind nodeKind; + switch (serializedNodeKind) { + case 0x00: + nodeKind = NodeKind.COMPACT; + break; + case 0x01: + nodeKind = NodeKind.INLINING; + break; + default: + throw new IllegalArgumentException("unknown node kind"); + } + Verify.verify(nodeKind.getSerialized() == serializedNodeKind); + return nodeKind; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java new file mode 100644 index 0000000000..59b831d04d --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java @@ -0,0 +1,72 @@ +/* + * NodeReference.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.Streams; + +import javax.annotation.Nonnull; +import java.util.Objects; + +public class NodeReference { + @Nonnull + private final Tuple primaryKey; + + public NodeReference(@Nonnull final Tuple primaryKey) { + this.primaryKey = primaryKey; + } + + @Nonnull + public Tuple getPrimaryKey() { + return primaryKey; + } + + @Nonnull + public NodeReferenceWithVector asNodeReferenceWithVector() { + throw new IllegalStateException("method should not be called"); + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof NodeReference)) { + return false; + } + final NodeReference that = (NodeReference)o; + return Objects.equals(primaryKey, that.primaryKey); + } + + @Override + public int hashCode() { + return Objects.hashCode(primaryKey); + } + + @Override + public String toString() { + return "NR[primaryKey=" + primaryKey + "]"; + } + + @Nonnull + public static Iterable primaryKeys(@Nonnull Iterable neighbors) { + return () -> Streams.stream(neighbors) + .map(NodeReference::getPrimaryKey) + .iterator(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java new file mode 100644 index 0000000000..bbf74e864a --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java @@ -0,0 +1,57 @@ +/* + * NodeReferenceAndNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import java.util.List; + +public class NodeReferenceAndNode { + @Nonnull + private final NodeReferenceWithDistance nodeReferenceWithDistance; + @Nonnull + private final Node node; + + public NodeReferenceAndNode(@Nonnull final NodeReferenceWithDistance nodeReferenceWithDistance, @Nonnull final Node node) { + this.nodeReferenceWithDistance = nodeReferenceWithDistance; + this.node = node; + } + + @Nonnull + public NodeReferenceWithDistance getNodeReferenceWithDistance() { + return nodeReferenceWithDistance; + } + + @Nonnull + public Node getNode() { + return node; + } + + @Nonnull + public static List getReferences(@Nonnull List> referencesAndNodes) { + final ImmutableList.Builder referencesBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode referenceWithNode : referencesAndNodes) { + referencesBuilder.add(referenceWithNode.getNodeReferenceWithDistance()); + } + return referencesBuilder.build(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java new file mode 100644 index 0000000000..bc9470735c --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java @@ -0,0 +1,58 @@ +/* + * NodeReferenceWithDistance.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; +import java.util.Objects; + +public class NodeReferenceWithDistance extends NodeReferenceWithVector { + private final double distance; + + public NodeReferenceWithDistance(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, + final double distance) { + super(primaryKey, vector); + this.distance = distance; + } + + public double getDistance() { + return distance; + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof NodeReferenceWithDistance)) { + return false; + } + if (!super.equals(o)) { + return false; + } + final NodeReferenceWithDistance that = (NodeReferenceWithDistance)o; + return Double.compare(distance, that.distance) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), distance); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java new file mode 100644 index 0000000000..e21b221622 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java @@ -0,0 +1,76 @@ +/* + * NodeReferenceWithVector.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; +import com.google.common.base.Objects; + +import javax.annotation.Nonnull; + +public class NodeReferenceWithVector extends NodeReference { + @Nonnull + private final Vector vector; + + public NodeReferenceWithVector(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector) { + super(primaryKey); + this.vector = vector; + } + + @Nonnull + public Vector getVector() { + return vector; + } + + @Nonnull + public Vector getDoubleVector() { + return vector.toDoubleVector(); + } + + @Nonnull + @Override + public NodeReferenceWithVector asNodeReferenceWithVector() { + return this; + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof NodeReferenceWithVector)) { + return false; + } + if (!super.equals(o)) { + return false; + } + return Objects.equal(vector, ((NodeReferenceWithVector)o).vector); + } + + @Override + public int hashCode() { + return Objects.hashCode(super.hashCode(), vector); + } + + @Override + public String toString() { + return "NRV[primaryKey=" + getPrimaryKey() + + ";vector=" + vector.toString(3) + + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java new file mode 100644 index 0000000000..753648cf77 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java @@ -0,0 +1,46 @@ +/* + * OnReadListener.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import javax.annotation.Nonnull; +import java.util.concurrent.CompletableFuture; + +/** + * Function interface for a call back whenever we read the slots for a node. + */ +public interface OnReadListener { + OnReadListener NOOP = new OnReadListener() { + }; + + default CompletableFuture> onAsyncRead(@Nonnull CompletableFuture> future) { + return future; + } + + default void onNodeRead(int layer, @Nonnull Node node) { + // nothing + } + + default void onKeyValueRead(int layer, + @Nonnull byte[] key, + @Nonnull byte[] value) { + // nothing + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java new file mode 100644 index 0000000000..fd4a096208 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java @@ -0,0 +1,49 @@ +/* + * OnWriteListener.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; + +/** + * Function interface for a call back whenever we read the slots for a node. + */ +public interface OnWriteListener { + OnWriteListener NOOP = new OnWriteListener() { + }; + + default void onNodeWritten(final int layer, @Nonnull final Node node) { + // nothing + } + + default void onNeighborWritten(final int layer, @Nonnull final Node node, final NodeReference neighbor) { + // nothing + } + + default void onNeighborDeleted(final int layer, @Nonnull final Node node, @Nonnull Tuple neighborPrimaryKey) { + // nothing + } + + default void onKeyValueWritten(final int layer, @Nonnull byte[] key, @Nonnull byte[] value) { + // nothing + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java new file mode 100644 index 0000000000..82bd281c62 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java @@ -0,0 +1,184 @@ +/* + * StorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; +import com.google.common.base.Verify; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.concurrent.CompletableFuture; + +/** + * Storage adapter used for serialization and deserialization of nodes. + */ +interface StorageAdapter { + byte SUBSPACE_PREFIX_ENTRY_NODE = 0x01; + byte SUBSPACE_PREFIX_DATA = 0x02; + + /** + * Get the {@link HNSW.Config} associated with this storage adapter. + * @return the configuration used by this storage adapter + */ + @Nonnull + HNSW.Config getConfig(); + + @Nonnull + NodeFactory getNodeFactory(); + + @Nonnull + NodeKind getNodeKind(); + + @Nonnull + StorageAdapter asCompactStorageAdapter(); + + @Nonnull + StorageAdapter asInliningStorageAdapter(); + + /** + * Get the subspace used to store this r-tree. + * + * @return r-tree subspace + */ + @Nonnull + Subspace getSubspace(); + + @Nonnull + Subspace getDataSubspace(); + + /** + * Get the on-write listener. + * + * @return the on-write listener. + */ + @Nonnull + OnWriteListener getOnWriteListener(); + + /** + * Get the on-read listener. + * + * @return the on-read listener. + */ + @Nonnull + OnReadListener getOnReadListener(); + + @Nonnull + CompletableFuture> fetchNode(@Nonnull ReadTransaction readTransaction, + int layer, + @Nonnull Tuple primaryKey); + + void writeNode(@Nonnull Transaction transaction, @Nonnull Node node, int layer, + @Nonnull NeighborsChangeSet changeSet); + + Iterable> scanLayer(@Nonnull ReadTransaction readTransaction, int layer, @Nullable Tuple lastPrimaryKey, + int maxNumRead); + + @Nonnull + static CompletableFuture fetchEntryNodeReference(@Nonnull final ReadTransaction readTransaction, + @Nonnull final Subspace subspace, + @Nonnull final OnReadListener onReadListener) { + final Subspace entryNodeSubspace = subspace.subspace(Tuple.from(SUBSPACE_PREFIX_ENTRY_NODE)); + final byte[] key = entryNodeSubspace.pack(); + + return readTransaction.get(key) + .thenApply(valueBytes -> { + if (valueBytes == null) { + return null; // not a single node in the index + } + onReadListener.onKeyValueRead(-1, key, valueBytes); + + final Tuple entryTuple = Tuple.fromBytes(valueBytes); + final int lMax = (int)entryTuple.getLong(0); + final Tuple primaryKey = entryTuple.getNestedTuple(1); + final Tuple vectorTuple = entryTuple.getNestedTuple(2); + return new EntryNodeReference(primaryKey, StorageAdapter.vectorFromTuple(vectorTuple), lMax); + }); + } + + static void writeEntryNodeReference(@Nonnull final Transaction transaction, + @Nonnull final Subspace subspace, + @Nonnull final EntryNodeReference entryNodeReference, + @Nonnull final OnWriteListener onWriteListener) { + final Subspace entryNodeSubspace = subspace.subspace(Tuple.from(SUBSPACE_PREFIX_ENTRY_NODE)); + final byte[] key = entryNodeSubspace.pack(); + final byte[] value = Tuple.from(entryNodeReference.getLayer(), + entryNodeReference.getPrimaryKey(), + StorageAdapter.tupleFromVector(entryNodeReference.getVector())).pack(); + transaction.set(key, + value); + onWriteListener.onKeyValueWritten(entryNodeReference.getLayer(), key, value); + } + + @Nonnull + static Vector.HalfVector vectorFromTuple(final Tuple vectorTuple) { + return vectorFromBytes(vectorTuple.getBytes(0)); + } + + @Nonnull + static Vector.HalfVector vectorFromBytes(final byte[] vectorBytes) { + final int bytesLength = vectorBytes.length; + Verify.verify(bytesLength % 2 == 0); + final int componentSize = bytesLength >>> 1; + final Half[] vectorHalfs = new Half[componentSize]; + for (int i = 0; i < componentSize; i ++) { + vectorHalfs[i] = Half.shortBitsToHalf(shortFromBytes(vectorBytes, i << 1)); + } + return new Vector.HalfVector(vectorHalfs); + } + + + @Nonnull + @SuppressWarnings("PrimitiveArrayArgumentToVarargsMethod") + static Tuple tupleFromVector(final Vector vector) { + return Tuple.from(bytesFromVector(vector)); + } + + @Nonnull + static byte[] bytesFromVector(final Vector vector) { + final byte[] vectorBytes = new byte[2 * vector.size()]; + for (int i = 0; i < vector.size(); i ++) { + final byte[] componentBytes = bytesFromShort(Half.halfToShortBits(vector.getComponent(i))); + final int indexTimesTwo = i << 1; + vectorBytes[indexTimesTwo] = componentBytes[0]; + vectorBytes[indexTimesTwo + 1] = componentBytes[1]; + } + return vectorBytes; + } + + static short shortFromBytes(final byte[] bytes, final int offset) { + Verify.verify(offset % 2 == 0); + int high = bytes[offset] & 0xFF; // Convert to unsigned int + int low = bytes[offset + 1] & 0xFF; + + return (short) ((high << 8) | low); + } + + static byte[] bytesFromShort(final short value) { + byte[] result = new byte[2]; + result[0] = (byte) ((value >> 8) & 0xFF); // high byte first + result[1] = (byte) (value & 0xFF); // low byte second + return result; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java new file mode 100644 index 0000000000..e1c7e34e10 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java @@ -0,0 +1,224 @@ +/* + * HNSWHelpers.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.christianheina.langx.half4j.Half; +import com.google.common.base.Suppliers; + +import javax.annotation.Nonnull; +import java.util.Arrays; +import java.util.Objects; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +/** + * TODO. + * @param representation type + */ +public abstract class Vector { + @Nonnull + protected R[] data; + @Nonnull + protected Supplier hashCodeSupplier; + + public Vector(@Nonnull final R[] data) { + this.data = data; + this.hashCodeSupplier = Suppliers.memoize(this::computeHashCode); + } + + public int size() { + return data.length; + } + + @Nonnull + R getComponent(int dimension) { + return data[dimension]; + } + + @Nonnull + public R[] getData() { + return data; + } + + @Nonnull + public abstract byte[] getRawData(); + + @Nonnull + public abstract Vector toHalfVector(); + + @Nonnull + public abstract DoubleVector toDoubleVector(); + + public abstract int precision(); + + @Override + public boolean equals(final Object o) { + if (!(o instanceof Vector)) { + return false; + } + final Vector vector = (Vector)o; + return Objects.deepEquals(data, vector.data); + } + + @Override + public int hashCode() { + return hashCodeSupplier.get(); + } + + private int computeHashCode() { + return Arrays.hashCode(data); + } + + @Override + public String toString() { + return toString(3); + } + + public String toString(final int limitDimensions) { + if (limitDimensions < data.length) { + return "[" + Arrays.stream(Arrays.copyOfRange(data, 0, limitDimensions)) + .map(String::valueOf) + .collect(Collectors.joining(",")) + ", ...]"; + } else { + return "[" + Arrays.stream(data) + .map(String::valueOf) + .collect(Collectors.joining(",")) + "]"; + } + } + + public static class HalfVector extends Vector { + @Nonnull + private final Supplier toDoubleVectorSupplier; + @Nonnull + private final Supplier toRawDataSupplier; + + public HalfVector(@Nonnull final Half[] data) { + super(data); + this.toDoubleVectorSupplier = Suppliers.memoize(this::computeDoubleVector); + this.toRawDataSupplier = Suppliers.memoize(this::computeRawData); + } + + @Nonnull + @Override + public Vector toHalfVector() { + return this; + } + + @Nonnull + @Override + public DoubleVector toDoubleVector() { + return toDoubleVectorSupplier.get(); + } + + @Override + public int precision() { + return 16; + } + + @Nonnull + public DoubleVector computeDoubleVector() { + Double[] result = new Double[data.length]; + for (int i = 0; i < data.length; i ++) { + result[i] = data[i].doubleValue(); + } + return new DoubleVector(result); + } + + @Nonnull + @Override + public byte[] getRawData() { + return toRawDataSupplier.get(); + } + + @Nonnull + private byte[] computeRawData() { + return StorageAdapter.bytesFromVector(this); + } + + @Nonnull + public static HalfVector halfVectorFromBytes(@Nonnull final byte[] vectorBytes) { + return StorageAdapter.vectorFromBytes(vectorBytes); + } + } + + public static class DoubleVector extends Vector { + @Nonnull + private final Supplier toHalfVectorSupplier; + + public DoubleVector(@Nonnull final Double[] data) { + super(data); + this.toHalfVectorSupplier = Suppliers.memoize(this::computeHalfVector); + } + + @Nonnull + @Override + public HalfVector toHalfVector() { + return toHalfVectorSupplier.get(); + } + + @Nonnull + public HalfVector computeHalfVector() { + Half[] result = new Half[data.length]; + for (int i = 0; i < data.length; i ++) { + result[i] = Half.valueOf(data[i]); + } + return new HalfVector(result); + } + + @Nonnull + @Override + public DoubleVector toDoubleVector() { + return this; + } + + @Override + public int precision() { + return 64; + } + + @Nonnull + @Override + public byte[] getRawData() { + // TODO + throw new UnsupportedOperationException("not implemented yet"); + } + } + + public static double distance(@Nonnull Metric metric, + @Nonnull final Vector vector1, + @Nonnull final Vector vector2) { + return metric.distance(vector1.toDoubleVector().getData(), vector2.toDoubleVector().getData()); + } + + static double comparativeDistance(@Nonnull Metric metric, + @Nonnull final Vector vector1, + @Nonnull final Vector vector2) { + return metric.comparativeDistance(vector1.toDoubleVector().getData(), vector2.toDoubleVector().getData()); + } + + public static Vector fromBytes(@Nonnull final byte[] bytes, int precision) { + if (precision == 16) { + return HalfVector.halfVectorFromBytes(bytes); + } + // TODO + throw new UnsupportedOperationException("not implemented yet"); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java new file mode 100644 index 0000000000..5565b7f9f6 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java @@ -0,0 +1,24 @@ +/* + * package-info.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Classes and interfaces related to the Hilbert R-tree implementation. + */ +package com.apple.foundationdb.async.hnsw; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/NodeHelpers.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/NodeHelpers.java index db4e4cf636..a11ac8b462 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/NodeHelpers.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/NodeHelpers.java @@ -1,5 +1,5 @@ /* - * NodeHelpers.java + * HNSWHelpers.java * * This source file is part of the FoundationDB open source project * diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java index f60c17da63..2623cff1dc 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java @@ -36,7 +36,6 @@ * Storage adapter used for serialization and deserialization of nodes. */ interface StorageAdapter { - /** * Get the {@link RTree.Config} associated with this storage adapter. * @return the configuration used by this storage adapter diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java new file mode 100644 index 0000000000..dc070c2066 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java @@ -0,0 +1,666 @@ +/* + * HNSWModificationTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Database; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.async.hnsw.Vector.HalfVector; +import com.apple.foundationdb.async.rtree.RTree; +import com.apple.foundationdb.test.TestDatabaseExtension; +import com.apple.foundationdb.test.TestExecutors; +import com.apple.foundationdb.test.TestSubspaceExtension; +import com.apple.foundationdb.tuple.Tuple; +import com.apple.test.Tags; +import com.christianheina.langx.half4j.Half; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; +import org.assertj.core.util.Lists; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.NavigableSet; +import java.util.Objects; +import java.util.Random; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +/** + * Tests testing insert/update/deletes of data into/in/from {@link RTree}s. + */ +@Execution(ExecutionMode.CONCURRENT) +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +@Tag(Tags.RequiresFDB) +@Tag(Tags.Slow) +public class HNSWModificationTest { + private static final Logger logger = LoggerFactory.getLogger(HNSWModificationTest.class); + private static final int NUM_TEST_RUNS = 5; + private static final int NUM_SAMPLES = 10_000; + + @RegisterExtension + static final TestDatabaseExtension dbExtension = new TestDatabaseExtension(); + @RegisterExtension + TestSubspaceExtension rtSubspace = new TestSubspaceExtension(dbExtension); + @RegisterExtension + TestSubspaceExtension rtSecondarySubspace = new TestSubspaceExtension(dbExtension); + + private Database db; + + @BeforeEach + public void setUpDb() { + db = dbExtension.getDatabase(); + } + + @Test + public void testCompactSerialization() { + final Random random = new Random(0); + final CompactStorageAdapter storageAdapter = + new CompactStorageAdapter(HNSW.DEFAULT_CONFIG, CompactNode.factory(), rtSubspace.getSubspace(), + OnWriteListener.NOOP, OnReadListener.NOOP); + final Node originalNode = + db.run(tr -> { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + + final Node randomCompactNode = + createRandomCompactNode(random, nodeFactory, 768, 16); + + writeNode(tr, storageAdapter, randomCompactNode, 0); + return randomCompactNode; + }); + + db.run(tr -> storageAdapter.fetchNode(tr, 0, originalNode.getPrimaryKey()) + .thenAccept(node -> { + Assertions.assertAll( + () -> Assertions.assertInstanceOf(CompactNode.class, node), + () -> Assertions.assertEquals(NodeKind.COMPACT, node.getKind()), + () -> Assertions.assertEquals(node.getPrimaryKey(), originalNode.getPrimaryKey()), + () -> Assertions.assertEquals(node.asCompactNode().getVector(), + originalNode.asCompactNode().getVector()), + () -> { + final ArrayList neighbors = + Lists.newArrayList(node.getNeighbors()); + neighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); + final ArrayList originalNeighbors = + Lists.newArrayList(originalNode.getNeighbors()); + originalNeighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); + Assertions.assertEquals(neighbors, originalNeighbors); + } + ); + }).join()); + } + + @Test + public void testInliningSerialization() { + final Random random = new Random(0); + final InliningStorageAdapter storageAdapter = + new InliningStorageAdapter(HNSW.DEFAULT_CONFIG, InliningNode.factory(), rtSubspace.getSubspace(), + OnWriteListener.NOOP, OnReadListener.NOOP); + final Node originalNode = + db.run(tr -> { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + + final Node randomInliningNode = + createRandomInliningNode(random, nodeFactory, 768, 16); + + writeNode(tr, storageAdapter, randomInliningNode, 0); + return randomInliningNode; + }); + + db.run(tr -> storageAdapter.fetchNode(tr, 0, originalNode.getPrimaryKey()) + .thenAccept(node -> Assertions.assertAll( + () -> Assertions.assertInstanceOf(InliningNode.class, node), + () -> Assertions.assertEquals(NodeKind.INLINING, node.getKind()), + () -> Assertions.assertEquals(node.getPrimaryKey(), originalNode.getPrimaryKey()), + () -> { + final ArrayList neighbors = + Lists.newArrayList(node.getNeighbors()); + neighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); // should not be necessary the way it is stored + final ArrayList originalNeighbors = + Lists.newArrayList(originalNode.getNeighbors()); + originalNeighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); + Assertions.assertEquals(neighbors, originalNeighbors); + } + )).join()); + } + + @Test + public void testBasicInsert() { + final Random random = new Random(0); + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final int dimensions = 128; + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.DEFAULT_CONFIG.toBuilder().setMetric(Metrics.EUCLIDEAN_METRIC.getMetric()) + .setM(32).setMMax(32).setMMax0(64).build(), + OnWriteListener.NOOP, onReadListener); + + for (int i = 0; i < 1000;) { + i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, + tr -> new NodeReferenceWithVector(createNextPrimaryKey(nextNodeIdAtomic), createRandomVector(random, dimensions))); + } + + onReadListener.reset(); + final long beginTs = System.nanoTime(); + final List> result = + db.run(tr -> hnsw.kNearestNeighborsSearch(tr, 10, 100, createRandomVector(random, dimensions)).join()); + final long endTs = System.nanoTime(); + + for (NodeReferenceAndNode nodeReferenceAndNode : result) { + final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); + logger.info("nodeId ={} at distance={}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), + nodeReferenceWithDistance.getDistance()); + } + System.out.println(onReadListener.getNodeCountByLayer()); + System.out.println(onReadListener.getBytesReadByLayer()); + + logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); + } + + private int basicInsertBatch(final HNSW hnsw, final int batchSize, + @Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener, + @Nonnull final Function insertFunction) { + return db.run(tr -> { + onReadListener.reset(); + final long nextNodeId = nextNodeIdAtomic.get(); + final long beginTs = System.nanoTime(); + for (int i = 0; i < batchSize; i ++) { + final var newNodeReference = insertFunction.apply(tr); + if (newNodeReference != null) { + hnsw.insert(tr, newNodeReference).join(); + } + } + final long endTs = System.nanoTime(); + logger.info("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", batchSize, nextNodeId, + TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), onReadListener.getNodeCountByLayer(), onReadListener.getSumMByLayer()); + return batchSize; + }); + } + + private int insertBatch(final HNSW hnsw, final int batchSize, + @Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener, + @Nonnull final Function insertFunction) { + return db.run(tr -> { + onReadListener.reset(); + final long nextNodeId = nextNodeIdAtomic.get(); + final long beginTs = System.nanoTime(); + final ImmutableList.Builder nodeReferenceWithVectorBuilder = + ImmutableList.builder(); + for (int i = 0; i < batchSize; i ++) { + final var newNodeReference = insertFunction.apply(tr); + if (newNodeReference != null) { + nodeReferenceWithVectorBuilder.add(newNodeReference); + } + } + hnsw.insertBatch(tr, nodeReferenceWithVectorBuilder.build()).join(); + final long endTs = System.nanoTime(); + logger.info("inserted batch batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", batchSize, nextNodeId, + TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), onReadListener.getNodeCountByLayer(), onReadListener.getSumMByLayer()); + return batchSize; + }); + } + + @Test + @Timeout(value = 150, unit = TimeUnit.MINUTES) + public void testSIFTInsert10k() throws Exception { + final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); + final int k = 10; + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.DEFAULT_CONFIG.toBuilder().setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(), + OnWriteListener.NOOP, onReadListener); + + final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv"; + final int dimensions = 128; + + final AtomicReference queryVectorAtomic = new AtomicReference<>(); + final NavigableSet trueResults = new ConcurrentSkipListSet<>( + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + + try (BufferedReader br = new BufferedReader(new FileReader(tsvFile))) { + for (int i = 0; i < 10000;) { + i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, + tr -> { + final String line; + try { + line = br.readLine(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + final String[] values = Objects.requireNonNull(line).split("\t"); + Assertions.assertEquals(dimensions, values.length); + final Half[] halfs = new Half[dimensions]; + + for (int c = 0; c < values.length; c++) { + final String value = values[c]; + halfs[c] = HNSWHelpers.halfValueOf(Double.parseDouble(value)); + } + final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final HalfVector currentVector = new HalfVector(halfs); + final HalfVector queryVector = queryVectorAtomic.get(); + if (queryVector == null) { + queryVectorAtomic.set(currentVector); + return null; + } else { + final double currentDistance = + Vector.comparativeDistance(metric, currentVector, queryVector); + if (trueResults.size() < k || trueResults.last().getDistance() > currentDistance) { + trueResults.add( + new NodeReferenceWithDistance(currentPrimaryKey, currentVector, + Vector.comparativeDistance(metric, currentVector, queryVector))); + } + if (trueResults.size() > k) { + trueResults.remove(trueResults.last()); + } + return new NodeReferenceWithVector(currentPrimaryKey, currentVector); + } + }); + } + } + + onReadListener.reset(); + final long beginTs = System.nanoTime(); + final List> results = + db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVectorAtomic.get()).join()); + final long endTs = System.nanoTime(); + + for (NodeReferenceAndNode nodeReferenceAndNode : results) { + final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); + logger.info("retrieved result nodeId = {} at distance= {}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), + nodeReferenceWithDistance.getDistance()); + } + + for (final NodeReferenceWithDistance nodeReferenceWithDistance : trueResults) { + logger.info("true result nodeId ={} at distance={}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), + nodeReferenceWithDistance.getDistance()); + } + + System.out.println(onReadListener.getNodeCountByLayer()); + System.out.println(onReadListener.getBytesReadByLayer()); + + logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); + } + + @Test + @Timeout(value = 150, unit = TimeUnit.MINUTES) + public void testSIFTInsert10kWithBatchInsert() throws Exception { + final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); + final int k = 10; + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.DEFAULT_CONFIG.toBuilder().setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(), + OnWriteListener.NOOP, onReadListener); + + final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv"; + final int dimensions = 128; + + final AtomicReference queryVectorAtomic = new AtomicReference<>(); + final NavigableSet trueResults = new ConcurrentSkipListSet<>( + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + + try (BufferedReader br = new BufferedReader(new FileReader(tsvFile))) { + for (int i = 0; i < 10000;) { + i += insertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, + tr -> { + final String line; + try { + line = br.readLine(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + final String[] values = Objects.requireNonNull(line).split("\t"); + Assertions.assertEquals(dimensions, values.length); + final Half[] halfs = new Half[dimensions]; + + for (int c = 0; c < values.length; c++) { + final String value = values[c]; + halfs[c] = HNSWHelpers.halfValueOf(Double.parseDouble(value)); + } + final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final HalfVector currentVector = new HalfVector(halfs); + final HalfVector queryVector = queryVectorAtomic.get(); + if (queryVector == null) { + queryVectorAtomic.set(currentVector); + return null; + } else { + final double currentDistance = + Vector.comparativeDistance(metric, currentVector, queryVector); + if (trueResults.size() < k || trueResults.last().getDistance() > currentDistance) { + trueResults.add( + new NodeReferenceWithDistance(currentPrimaryKey, currentVector, + Vector.comparativeDistance(metric, currentVector, queryVector))); + } + if (trueResults.size() > k) { + trueResults.remove(trueResults.last()); + } + return new NodeReferenceWithVector(currentPrimaryKey, currentVector); + } + }); + } + } + + onReadListener.reset(); + final long beginTs = System.nanoTime(); + final List> results = + db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVectorAtomic.get()).join()); + final long endTs = System.nanoTime(); + + for (NodeReferenceAndNode nodeReferenceAndNode : results) { + final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); + logger.info("retrieved result nodeId = {} at distance= {}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), + nodeReferenceWithDistance.getDistance()); + } + + for (final NodeReferenceWithDistance nodeReferenceWithDistance : trueResults) { + logger.info("true result nodeId ={} at distance={}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), + nodeReferenceWithDistance.getDistance()); + } + + System.out.println(onReadListener.getNodeCountByLayer()); + System.out.println(onReadListener.getBytesReadByLayer()); + + logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); + } + + @Test + public void testBasicInsertAndScanLayer() throws Exception { + final Random random = new Random(0); + final AtomicLong nextNodeId = new AtomicLong(0L); + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.DEFAULT_CONFIG.toBuilder().setM(4).setMMax(4).setMMax0(4).build(), + OnWriteListener.NOOP, OnReadListener.NOOP); + + db.run(tr -> { + for (int i = 0; i < 100; i ++) { + hnsw.insert(tr, createNextPrimaryKey(nextNodeId), createRandomVector(random, 2)).join(); + } + return null; + }); + + int layer = 0; + while (true) { + if (!dumpLayer(hnsw, layer++)) { + break; + } + } + } + + @Test + public void testManyRandomVectors() { + final Random random = new Random(); + for (long l = 0L; l < 3000000; l ++) { + final HalfVector randomVector = createRandomVector(random, 768); + final Tuple vectorTuple = StorageAdapter.tupleFromVector(randomVector); + final Vector roundTripVector = StorageAdapter.vectorFromTuple(vectorTuple); + Vector.comparativeDistance(Metrics.EUCLIDEAN_METRIC.getMetric(), randomVector, roundTripVector); + Assertions.assertEquals(randomVector, roundTripVector); + } + } + + @Test + @Timeout(value = 150, unit = TimeUnit.MINUTES) + public void testSIFTVectors() throws Exception { + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.DEFAULT_CONFIG.toBuilder().setMetric(Metrics.EUCLIDEAN_METRIC.getMetric()) + .setM(32).setMMax(32).setMMax0(64).build(), + OnWriteListener.NOOP, onReadListener); + + + final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv"; + final int dimensions = 128; + final var referenceVector = createRandomVector(new Random(0), dimensions); + long count = 0L; + double mean = 0.0d; + double mean2 = 0.0d; + + try (BufferedReader br = new BufferedReader(new FileReader(tsvFile))) { + for (int i = 0; i < 100_000; i ++) { + final String line; + try { + line = br.readLine(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + final String[] values = Objects.requireNonNull(line).split("\t"); + Assertions.assertEquals(dimensions, values.length); + final Half[] halfs = new Half[dimensions]; + for (int c = 0; c < values.length; c++) { + final String value = values[c]; + halfs[c] = HNSWHelpers.halfValueOf(Double.parseDouble(value)); + } + final HalfVector newVector = new HalfVector(halfs); + final double distance = Vector.comparativeDistance(Metrics.EUCLIDEAN_METRIC.getMetric(), + referenceVector, newVector); + count++; + final double delta = distance - mean; + mean += delta / count; + final double delta2 = distance - mean; + mean2 += delta * delta2; + } + } + final double sampleVariance = mean2 / (count - 1); + final double standardDeviation = Math.sqrt(sampleVariance); + logger.info("mean={}, sample_variance={}, stddeviation={}, cv={}", mean, sampleVariance, standardDeviation, + standardDeviation / mean); + } + + + @ParameterizedTest + @ValueSource(ints = {2, 3, 10, 100, 768}) + public void testManyVectorsStandardDeviation(final int dimensionality) { + final Random random = new Random(); + final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); + long count = 0L; + double mean = 0.0d; + double mean2 = 0.0d; + for (long i = 0L; i < 100000; i ++) { + final HalfVector vector1 = createRandomVector(random, dimensionality); + final HalfVector vector2 = createRandomVector(random, dimensionality); + final double distance = Vector.comparativeDistance(metric, vector1, vector2); + count = i + 1; + final double delta = distance - mean; + mean += delta / count; + final double delta2 = distance - mean; + mean2 += delta * delta2; + } + final double sampleVariance = mean2 / (count - 1); + final double standardDeviation = Math.sqrt(sampleVariance); + logger.info("mean={}, sample_variance={}, stddeviation={}, cv={}", mean, sampleVariance, standardDeviation, + standardDeviation / mean); + } + + private boolean dumpLayer(final HNSW hnsw, final int layer) throws IOException { + final String verticesFileName = "/Users/nseemann/Downloads/vertices-" + layer + ".csv"; + final String edgesFileName = "/Users/nseemann/Downloads/edges-" + layer + ".csv"; + + final AtomicLong numReadAtomic = new AtomicLong(0L); + try (final BufferedWriter verticesWriter = new BufferedWriter(new FileWriter(verticesFileName)); + final BufferedWriter edgesWriter = new BufferedWriter(new FileWriter(edgesFileName))) { + hnsw.scanLayer(db, layer, 100, node -> { + final CompactNode compactNode = node.asCompactNode(); + final Vector vector = compactNode.getVector(); + try { + verticesWriter.write(compactNode.getPrimaryKey().getLong(0) + "," + + vector.getComponent(0) + "," + + vector.getComponent(1)); + verticesWriter.newLine(); + + for (final var neighbor : compactNode.getNeighbors()) { + edgesWriter.write(compactNode.getPrimaryKey().getLong(0) + "," + + neighbor.getPrimaryKey().getLong(0)); + edgesWriter.newLine(); + } + numReadAtomic.getAndIncrement(); + } catch (final IOException e) { + throw new RuntimeException("unable to write to file", e); + } + }); + } + return numReadAtomic.get() != 0; + } + + private void writeNode(@Nonnull final Transaction transaction, + @Nonnull final StorageAdapter storageAdapter, + @Nonnull final Node node, + final int layer) { + final NeighborsChangeSet insertChangeSet = + new InsertNeighborsChangeSet<>(new BaseNeighborsChangeSet<>(ImmutableList.of()), + node.getNeighbors()); + storageAdapter.writeNode(transaction, node, layer, insertChangeSet); + } + + @Nonnull + private Node createRandomCompactNode(@Nonnull final Random random, + @Nonnull final NodeFactory nodeFactory, + final int dimensionality, + final int numberOfNeighbors) { + final Tuple primaryKey = createRandomPrimaryKey(random); + final ImmutableList.Builder neighborsBuilder = ImmutableList.builder(); + for (int i = 0; i < numberOfNeighbors; i ++) { + neighborsBuilder.add(createRandomNodeReference(random)); + } + + return nodeFactory.create(primaryKey, createRandomVector(random, dimensionality), neighborsBuilder.build()); + } + + @Nonnull + private Node createRandomInliningNode(@Nonnull final Random random, + @Nonnull final NodeFactory nodeFactory, + final int dimensionality, + final int numberOfNeighbors) { + final Tuple primaryKey = createRandomPrimaryKey(random); + final ImmutableList.Builder neighborsBuilder = ImmutableList.builder(); + for (int i = 0; i < numberOfNeighbors; i ++) { + neighborsBuilder.add(createRandomNodeReferenceWithVector(random, dimensionality)); + } + + return nodeFactory.create(primaryKey, createRandomVector(random, dimensionality), neighborsBuilder.build()); + } + + @Nonnull + private NodeReference createRandomNodeReference(@Nonnull final Random random) { + return new NodeReference(createRandomPrimaryKey(random)); + } + + @Nonnull + private NodeReferenceWithVector createRandomNodeReferenceWithVector(@Nonnull final Random random, final int dimensionality) { + return new NodeReferenceWithVector(createRandomPrimaryKey(random), createRandomVector(random, dimensionality)); + } + + @Nonnull + private static Tuple createRandomPrimaryKey(final @Nonnull Random random) { + return Tuple.from(random.nextLong()); + } + + @Nonnull + private static Tuple createNextPrimaryKey(@Nonnull final AtomicLong nextIdAtomic) { + return Tuple.from(nextIdAtomic.getAndIncrement()); + } + + @Nonnull + private HalfVector createRandomVector(@Nonnull final Random random, final int dimensionality) { + final Half[] components = new Half[dimensionality]; + for (int d = 0; d < dimensionality; d ++) { + // don't ask + components[d] = HNSWHelpers.halfValueOf(random.nextDouble()); + } + return new HalfVector(components); + } + + private static class TestOnReadListener implements OnReadListener { + final Map nodeCountByLayer; + final Map sumMByLayer; + final Map bytesReadByLayer; + + public TestOnReadListener() { + this.nodeCountByLayer = Maps.newConcurrentMap(); + this.sumMByLayer = Maps.newConcurrentMap(); + this.bytesReadByLayer = Maps.newConcurrentMap(); + } + + public Map getNodeCountByLayer() { + return nodeCountByLayer; + } + + public Map getBytesReadByLayer() { + return bytesReadByLayer; + } + + public Map getSumMByLayer() { + return sumMByLayer; + } + + public void reset() { + nodeCountByLayer.clear(); + bytesReadByLayer.clear(); + sumMByLayer.clear(); + } + + @Override + public void onNodeRead(final int layer, @Nonnull final Node node) { + nodeCountByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + 1L); + sumMByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + node.getNeighbors().size()); + } + + @Override + public void onKeyValueRead(final int layer, @Nonnull final byte[] key, @Nonnull final byte[] value) { + bytesReadByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + + key.length + value.length); + } + } +} diff --git a/gradle/codequality/pmd-rules.xml b/gradle/codequality/pmd-rules.xml index 500ef17c69..4d8745d875 100644 --- a/gradle/codequality/pmd-rules.xml +++ b/gradle/codequality/pmd-rules.xml @@ -16,6 +16,7 @@ + diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index c4e6482b97..419df00cd0 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -37,6 +37,7 @@ generatedAnnotation = "1.3.2" grpc = "1.64.1" grpc-commonProtos = "2.37.0" guava = "33.3.1-jre" +half4j = "0.0.2" h2 = "1.3.148" icu = "69.1" lucene = "8.11.1" @@ -95,6 +96,7 @@ grpc-services = { module = "io.grpc:grpc-services", version.ref = "grpc" } grpc-stub = { module = "io.grpc:grpc-stub", version.ref = "grpc" } grpc-util = { module = "io.grpc:grpc-util", version.ref = "grpc" } guava = { module = "com.google.guava:guava", version.ref = "guava" } +half4j = { module = "com.christianheina.langx:half4j", version.ref = "half4j"} icu = { module = "com.ibm.icu:icu4j", version.ref = "icu" } javaPoet = { module = "com.squareup:javapoet", version.ref = "javaPoet" } jsr305 = { module = "com.google.code.findbugs:jsr305", version.ref = "jsr305" } diff --git a/gradle/scripts/log4j-test.properties b/gradle/scripts/log4j-test.properties index 447ee2f55a..1ae7583751 100644 --- a/gradle/scripts/log4j-test.properties +++ b/gradle/scripts/log4j-test.properties @@ -26,7 +26,7 @@ appender.console.name = STDOUT appender.console.layout.type = PatternLayout appender.console.layout.pattern = %d [%level] %logger{1.} - %m %X%n%ex{full} -rootLogger.level = debug +rootLogger.level = info rootLogger.appenderRefs = stdout rootLogger.appenderRef.stdout.ref = STDOUT From 75f270824c5be92664cdf4abb83c89666fad6a60 Mon Sep 17 00:00:00 2001 From: Normen Seemann Date: Tue, 16 Sep 2025 20:17:02 +0200 Subject: [PATCH 02/10] adding tests --- fdb-extensions/fdb-extensions.gradle | 32 +++++ .../apple/foundationdb/async/hnsw/Vector.java | 114 ++++++++++++++++++ .../foundationdb/async/rtree/NodeHelpers.java | 2 +- .../async/hnsw/HNSWModificationTest.java | 102 +++++++--------- 4 files changed, 193 insertions(+), 57 deletions(-) diff --git a/fdb-extensions/fdb-extensions.gradle b/fdb-extensions/fdb-extensions.gradle index 3601bd1e4a..200324d730 100644 --- a/fdb-extensions/fdb-extensions.gradle +++ b/fdb-extensions/fdb-extensions.gradle @@ -42,6 +42,38 @@ dependencies { testFixturesAnnotationProcessor(libs.autoService) } +def siftSmallFile = layout.buildDirectory.file('downloads/siftsmall.tar.gz') +def extractDir = layout.buildDirectory.dir("extracted") + +// Task that downloads the CSV exactly once unless it changed +tasks.register('downloadSiftSmall', de.undercouch.gradle.tasks.download.Download) { + src 'https://huggingface.co/datasets/vecdata/siftsmall/resolve/3106e1b83049c44713b1ce06942d0ab474bbdfb6/siftsmall.tar.gz' + dest siftSmallFile.get().asFile + onlyIfModified true + tempAndMove true + retries 3 +} + +tasks.register('extractSiftSmall', Copy) { + dependsOn 'downloadSiftSmall' + from(tarTree(resources.gzip(siftSmallFile))) + into extractDir + + doLast { + println "Extracted files into: ${extractDir.get().asFile}" + fileTree(extractDir).visit { details -> + if (!details.isDirectory()) { + println " - ${details.file}" + } + } + } +} + +test { + dependsOn tasks.named('extractSiftSmall') + inputs.dir extractDir +} + publishing { publications { library(MavenPublication) { diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java index e1c7e34e10..725c1b6123 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java @@ -22,9 +22,18 @@ import com.christianheina.langx.half4j.Half; import com.google.common.base.Suppliers; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.ImmutableList; import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; import java.util.Arrays; +import java.util.List; import java.util.Objects; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -221,4 +230,109 @@ public static Vector fromBytes(@Nonnull final byte[] bytes, int precision) { // TODO throw new UnsupportedOperationException("not implemented yet"); } + + public abstract static class StoredVecsIterator extends AbstractIterator { + @Nonnull + private final FileChannel fileChannel; + + protected StoredVecsIterator(@Nonnull final FileChannel fileChannel) { + this.fileChannel = fileChannel; + } + + @Nonnull + protected abstract N[] newComponentArray(int size); + + @Nonnull + protected abstract N toComponent(@Nonnull ByteBuffer byteBuffer); + + @Nonnull + protected abstract T toTarget(@Nonnull N[] components); + + + @Nullable + @Override + protected T computeNext() { + try { + final ByteBuffer headerBuf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN); + // allocate a buffer for reading floats later; you may reuse + headerBuf.clear(); + final int bytesRead = fileChannel.read(headerBuf); + if (bytesRead < 4) { + if (bytesRead == -1) { + return endOfData(); + } + throw new IOException("corrupt fvecs file"); + } + headerBuf.flip(); + final int dims = headerBuf.getInt(); + if (dims <= 0) { + throw new IOException("Invalid dimension " + dims + " at position " + (fileChannel.position() - 4)); + } + final ByteBuffer vecBuf = ByteBuffer.allocate(dims * 4).order(ByteOrder.LITTLE_ENDIAN); + while (vecBuf.hasRemaining()) { + int read = fileChannel.read(vecBuf); + if (read < 0) { + throw new EOFException("unexpected EOF when reading vector data"); + } + } + vecBuf.flip(); + final N[] rawVecData = newComponentArray(dims); + for (int i = 0; i < dims; i++) { + rawVecData[i] = toComponent(vecBuf); + } + + return toTarget(rawVecData); + } catch (final IOException ioE) { + throw new RuntimeException(ioE); + } + } + } + + public static class StoredFVecsIterator extends StoredVecsIterator { + public StoredFVecsIterator(@Nonnull final FileChannel fileChannel) { + super(fileChannel); + } + + @Nonnull + @Override + protected Double[] newComponentArray(final int size) { + return new Double[size]; + } + + @Nonnull + @Override + protected Double toComponent(@Nonnull final ByteBuffer byteBuffer) { + return (double)byteBuffer.getFloat(); + } + + @Nonnull + @Override + protected DoubleVector toTarget(@Nonnull final Double[] components) { + return new DoubleVector(components); + } + } + + public static class StoredIVecsIterator extends StoredVecsIterator> { + public StoredIVecsIterator(@Nonnull final FileChannel fileChannel) { + super(fileChannel); + } + + @Nonnull + @Override + protected Integer[] newComponentArray(final int size) { + return new Integer[size]; + } + + @Nonnull + @Override + protected Integer toComponent(@Nonnull final ByteBuffer byteBuffer) { + return byteBuffer.getInt(); + } + + @Nonnull + @Override + protected List toTarget(@Nonnull final Integer[] components) { + return ImmutableList.copyOf(components); + } + } } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/NodeHelpers.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/NodeHelpers.java index a11ac8b462..db4e4cf636 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/NodeHelpers.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/NodeHelpers.java @@ -1,5 +1,5 @@ /* - * HNSWHelpers.java + * NodeHelpers.java * * This source file is part of the FoundationDB open source project * diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java index dc070c2066..7a8bf73e0d 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java @@ -30,12 +30,12 @@ import com.apple.foundationdb.tuple.Tuple; import com.apple.test.Tags; import com.christianheina.langx.half4j.Half; +import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; import org.assertj.core.util.Lists; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -53,8 +53,13 @@ import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; import java.util.ArrayList; import java.util.Comparator; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.NavigableSet; @@ -208,9 +213,10 @@ private int basicInsertBatch(final HNSW hnsw, final int batchSize, final long beginTs = System.nanoTime(); for (int i = 0; i < batchSize; i ++) { final var newNodeReference = insertFunction.apply(tr); - if (newNodeReference != null) { - hnsw.insert(tr, newNodeReference).join(); + if (newNodeReference == null) { + return i; } + hnsw.insert(tr, newNodeReference).join(); } final long endTs = System.nanoTime(); logger.info("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", batchSize, nextNodeId, @@ -243,7 +249,6 @@ private int insertBatch(final HNSW hnsw, final int batchSize, } @Test - @Timeout(value = 150, unit = TimeUnit.MINUTES) public void testSIFTInsert10k() throws Exception { final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); final int k = 10; @@ -255,76 +260,62 @@ public void testSIFTInsert10k() throws Exception { HNSW.DEFAULT_CONFIG.toBuilder().setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(), OnWriteListener.NOOP, onReadListener); - final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv"; - final int dimensions = 128; + final Path siftSmallPath = Paths.get(".out/extracted/siftsmall/siftsmall_base.fvecs"); - final AtomicReference queryVectorAtomic = new AtomicReference<>(); - final NavigableSet trueResults = new ConcurrentSkipListSet<>( - Comparator.comparing(NodeReferenceWithDistance::getDistance)); + try (final var fileChannel = FileChannel.open(siftSmallPath, StandardOpenOption.READ)) { + final Iterator vectorIterator = new Vector.StoredFVecsIterator(fileChannel); - try (BufferedReader br = new BufferedReader(new FileReader(tsvFile))) { - for (int i = 0; i < 10000;) { + int i = 0; + while (vectorIterator.hasNext()) { i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, tr -> { - final String line; - try { - line = br.readLine(); - } catch (IOException e) { - throw new RuntimeException(e); + if (!vectorIterator.hasNext()) { + return null; } - final String[] values = Objects.requireNonNull(line).split("\t"); - Assertions.assertEquals(dimensions, values.length); - final Half[] halfs = new Half[dimensions]; + final Vector.DoubleVector doubleVector = vectorIterator.next(); - for (int c = 0; c < values.length; c++) { - final String value = values[c]; - halfs[c] = HNSWHelpers.halfValueOf(Double.parseDouble(value)); - } final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic); - final HalfVector currentVector = new HalfVector(halfs); - final HalfVector queryVector = queryVectorAtomic.get(); - if (queryVector == null) { - queryVectorAtomic.set(currentVector); - return null; - } else { - final double currentDistance = - Vector.comparativeDistance(metric, currentVector, queryVector); - if (trueResults.size() < k || trueResults.last().getDistance() > currentDistance) { - trueResults.add( - new NodeReferenceWithDistance(currentPrimaryKey, currentVector, - Vector.comparativeDistance(metric, currentVector, queryVector))); - } - if (trueResults.size() > k) { - trueResults.remove(trueResults.last()); - } - return new NodeReferenceWithVector(currentPrimaryKey, currentVector); - } + final HalfVector currentVector = doubleVector.toHalfVector(); + return new NodeReferenceWithVector(currentPrimaryKey, currentVector); }); } } - onReadListener.reset(); - final long beginTs = System.nanoTime(); - final List> results = - db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVectorAtomic.get()).join()); - final long endTs = System.nanoTime(); + final Path siftSmallGroundTruthPath = Paths.get(".out/extracted/siftsmall/siftsmall_groundtruth.ivecs"); + final Path siftSmallQueryPath = Paths.get(".out/extracted/siftsmall/siftsmall_query.fvecs"); - for (NodeReferenceAndNode nodeReferenceAndNode : results) { - final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); - logger.info("retrieved result nodeId = {} at distance= {}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), - nodeReferenceWithDistance.getDistance()); - } - for (final NodeReferenceWithDistance nodeReferenceWithDistance : trueResults) { - logger.info("true result nodeId ={} at distance={}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), - nodeReferenceWithDistance.getDistance()); + try (final var queryChannel = FileChannel.open(siftSmallQueryPath, StandardOpenOption.READ); + final var groundTruthChannel = FileChannel.open(siftSmallGroundTruthPath, StandardOpenOption.READ)) { + final Iterator queryIterator = new Vector.StoredFVecsIterator(queryChannel); + final Iterator> groundTruthIterator = new Vector.StoredIVecsIterator(groundTruthChannel); + + Verify.verify(queryIterator.hasNext() == groundTruthIterator.hasNext()); + + while (queryIterator.hasNext()) { + final HalfVector queryVector = queryIterator.next().toHalfVector(); + onReadListener.reset(); + final long beginTs = System.nanoTime(); + final List> results = + db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVector).join()); + final long endTs = System.nanoTime(); + logger.info("retrieved result in elapsedTimeMs={}", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); + + for (NodeReferenceAndNode nodeReferenceAndNode : results) { + final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); + logger.info("retrieved result nodeId = {} at distance = {}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), + nodeReferenceWithDistance.getDistance()); + } + + logger.info("true result vector={}", groundTruthIterator.next()); + } } System.out.println(onReadListener.getNodeCountByLayer()); System.out.println(onReadListener.getBytesReadByLayer()); - logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); + // logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); } @Test @@ -499,7 +490,6 @@ public void testSIFTVectors() throws Exception { standardDeviation / mean); } - @ParameterizedTest @ValueSource(ints = {2, 3, 10, 100, 768}) public void testManyVectorsStandardDeviation(final int dimensionality) { From 130ffee7dae683b3acf7ef460d3d77010028d7b1 Mon Sep 17 00:00:00 2001 From: Normen Seemann Date: Wed, 17 Sep 2025 09:19:05 +0200 Subject: [PATCH 03/10] adding javadocs --- .../apple/foundationdb/async/hnsw/HNSW.java | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java index fb177c9d77..b41eaf7a0f 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java @@ -62,9 +62,6 @@ import static com.apple.foundationdb.async.MoreAsyncUtil.forEach; import static com.apple.foundationdb.async.MoreAsyncUtil.forLoop; -/** - * TODO. - */ @API(API.Status.EXPERIMENTAL) @SuppressWarnings("checkstyle:AbbreviationAsWordInName") public class HNSW { @@ -335,16 +332,10 @@ public static ConfigBuilder newConfigBuilder() { return new ConfigBuilder(); } - /** - * TODO. - */ public HNSW(@Nonnull final Subspace subspace, @Nonnull final Executor executor) { this(subspace, executor, DEFAULT_CONFIG, OnWriteListener.NOOP, OnReadListener.NOOP); } - /** - * TODO. - */ public HNSW(@Nonnull final Subspace subspace, @Nonnull final Executor executor, @Nonnull final Config config, @Nonnull final OnWriteListener onWriteListener, @@ -402,9 +393,6 @@ public OnReadListener getOnReadListener() { // Read Path // - /** - * TODO. - */ @SuppressWarnings("checkstyle:MethodName") // method name introduced by paper @Nonnull public CompletableFuture>> kNearestNeighborsSearch(@Nonnull final ReadTransaction readTransaction, @@ -487,9 +475,6 @@ private CompletableFuture g } } - /** - * TODO. - */ @Nonnull private CompletableFuture greedySearchInliningLayer(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -534,9 +519,6 @@ private CompletableFuture greedySearchInliningLayer(@ }), executor).thenApply(ignored -> currentNodeReferenceAtomic.get()); } - /** - * TODO. - */ @Nonnull private CompletableFuture>> searchLayer(@Nonnull StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -608,9 +590,6 @@ private CompletableFuture }); } - /** - * TODO. - */ @Nonnull private CompletableFuture> fetchNodeIfNotCached(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -625,9 +604,6 @@ private CompletableFuture> fetchNodeIfNotCache }); } - /** - * TODO. - */ @Nonnull private CompletableFuture fetchNodeIfNecessaryAndApply(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -645,9 +621,6 @@ private CompletableFuture< .thenApply(node -> biMapFunction.apply(nodeReference, node)); } - /** - * TODO. - */ @Nonnull private CompletableFuture> fetchNeighborhood(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -671,9 +644,6 @@ private CompletableFuture CompletableFuture>> fetchSomeNodesIfNotCached(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -694,9 +664,6 @@ private CompletableFuture }); } - /** - * TODO. - */ @Nonnull private CompletableFuture> fetchSomeNodesAndApply(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, From 57708ea9192c3c4f6275d9f07ac01c21970290e1 Mon Sep 17 00:00:00 2001 From: Normen Seemann Date: Wed, 17 Sep 2025 11:15:24 +0200 Subject: [PATCH 04/10] adding comments --- .../apple/foundationdb/async/hnsw/HNSW.java | 662 ++++++++++++++++-- 1 file changed, 618 insertions(+), 44 deletions(-) diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java index b41eaf7a0f..798ff7e1a1 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java @@ -62,6 +62,21 @@ import static com.apple.foundationdb.async.MoreAsyncUtil.forEach; import static com.apple.foundationdb.async.MoreAsyncUtil.forLoop; +/** + * An implementation of the Hierarchical Navigable Small World (HNSW) algorithm for + * efficient approximate nearest neighbor (ANN) search. + *

+ * HNSW constructs a multi-layer graph, where each layer is a subset of the one below it. + * The top layers serve as fast entry points to navigate the graph, while the bottom layer + * contains all the data points. This structure allows for logarithmic-time complexity + * for search operations, making it suitable for large-scale, high-dimensional datasets. + *

+ * This class provides methods for building the graph ({@link #insert(Transaction, Tuple, Vector)}) + * and performing k-NN searches ({@link #kNearestNeighborsSearch(ReadTransaction, int, int, Vector)}). + * It is designed to be used with a transactional storage backend, managed via a {@link Subspace}. + * + * @see Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs + */ @API(API.Status.EXPERIMENTAL) @SuppressWarnings("checkstyle:AbbreviationAsWordInName") public class HNSW { @@ -332,10 +347,33 @@ public static ConfigBuilder newConfigBuilder() { return new ConfigBuilder(); } + /** + * Creates a new {@code HNSW} instance using the default configuration, write listener, and read listener. + *

+ * This constructor delegates to the main constructor, providing default values for configuration + * and listeners, simplifying the instantiation process for common use cases. + * + * @param subspace the non-null {@link Subspace} to build the HNSW graph for. + * @param executor the non-null {@link Executor} for concurrent operations, such as building the graph. + */ public HNSW(@Nonnull final Subspace subspace, @Nonnull final Executor executor) { this(subspace, executor, DEFAULT_CONFIG, OnWriteListener.NOOP, OnReadListener.NOOP); } + /** + * Constructs a new HNSW graph instance. + *

+ * This constructor initializes the HNSW graph with the necessary components for storage, + * execution, configuration, and event handling. All parameters are mandatory and must not be null. + * + * @param subspace the {@link Subspace} where the graph data is stored. + * @param executor the {@link Executor} service to use for concurrent operations. + * @param config the {@link Config} object containing HNSW algorithm parameters. + * @param onWriteListener a listener to be notified of write events on the graph. + * @param onReadListener a listener to be notified of read events on the graph. + * + * @throws NullPointerException if any of the parameters are {@code null}. + */ public HNSW(@Nonnull final Subspace subspace, @Nonnull final Executor executor, @Nonnull final Config config, @Nonnull final OnWriteListener onWriteListener, @@ -348,6 +386,11 @@ public HNSW(@Nonnull final Subspace subspace, } + /** + * Gets the subspace associated with this object. + * + * @return the non-null subspace + */ @Nonnull public Subspace getSubspace() { return subspace; @@ -393,6 +436,27 @@ public OnReadListener getOnReadListener() { // Read Path // + /** + * Performs a k-nearest neighbors (k-NN) search for a given query vector. + *

+ * This method implements the search algorithm for an HNSW graph. The search begins at an entry point in the + * highest layer and greedily traverses down through the layers. In each layer, it finds the node closest to the + * {@code queryVector}. This node then serves as the entry point for the search in the layer below. + *

+ * Once the search reaches the base layer (layer 0), it performs a more exhaustive search starting from the + * determined entry point. It explores the graph, maintaining a dynamic list of the best candidates found so far. + * The size of this candidate list is controlled by the {@code efSearch} parameter. Finally, the method selects + * the top {@code k} nodes from the search results, sorted by their distance to the query vector. + * + * @param readTransaction the transaction to use for reading from the database + * @param k the number of nearest neighbors to return + * @param efSearch the size of the dynamic candidate list for the search. A larger value increases accuracy + * at the cost of performance. + * @param queryVector the vector to find the nearest neighbors for + * + * @return a {@link CompletableFuture} that will complete with a list of the {@code k} nearest neighbors, + * sorted by distance in ascending order. The future completes with {@code null} if the index is empty. + */ @SuppressWarnings("checkstyle:MethodName") // method name introduced by paper @Nonnull public CompletableFuture>> kNearestNeighborsSearch(@Nonnull final ReadTransaction readTransaction, @@ -461,6 +525,29 @@ public CompletableFuture + * This method finds the node on the specified layer that is closest to the given query vector, + * starting the search from a designated entry point. The search is "greedy" because it aims to find + * only the single best neighbor. + *

+ * The implementation strategy depends on the {@link NodeKind} of the provided {@link StorageAdapter}. + * If the node kind is {@code INLINING}, it delegates to the specialized {@link #greedySearchInliningLayer} method. + * Otherwise, it uses the more general {@link #searchLayer} method with a search size (ef) of 1. + * The operation is asynchronous. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the {@link StorageAdapter} for accessing the graph data + * @param readTransaction the {@link ReadTransaction} to use for the search + * @param entryNeighbor the starting point for the search on this layer, which includes the node and its distance to + * the query vector + * @param layer the zero-based index of the layer to search within + * @param queryVector the query vector for which to find the nearest neighbor + * + * @return a {@link CompletableFuture} that, upon completion, will contain the closest node found on the layer, + * represented as a {@link NodeReferenceWithDistance} + */ @Nonnull private CompletableFuture greedySearchLayer(@Nonnull StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -475,6 +562,32 @@ private CompletableFuture g } } + /** + * Performs a greedy search for the nearest neighbor to a query vector within a single, non-zero layer of the HNSW + * graph. + *

+ * This search is performed on layers that use {@code InliningNode}s, where neighbor vectors are stored directly + * within the node. + * The search starts from a given {@code entryNeighbor} and iteratively moves to the closest neighbor in the current + * node's + * neighbor list, until no closer neighbor can be found. + *

+ * The entire process is asynchronous, returning a {@link CompletableFuture} that will complete with the best node + * found in this layer. + * + * @param storageAdapter the storage adapter to fetch nodes from the graph + * @param readTransaction the transaction context for database reads + * @param entryNeighbor the entry point for the search in this layer, typically the result from a search in a higher + * layer + * @param layer the layer number to perform the search in. Must be greater than 0. + * @param queryVector the vector for which to find the nearest neighbor + * + * @return a {@link CompletableFuture} that, upon completion, will hold the {@link NodeReferenceWithDistance} of the nearest + * neighbor found in this layer's greedy search + * + * @throws IllegalStateException if a node that is expected to exist cannot be fetched from the + * {@code storageAdapter} during the search + */ @Nonnull private CompletableFuture greedySearchInliningLayer(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -519,6 +632,33 @@ private CompletableFuture greedySearchInliningLayer(@ }), executor).thenApply(ignored -> currentNodeReferenceAtomic.get()); } + /** + * Searches a single layer of the graph to find the nearest neighbors to a query vector. + *

+ * This method implements the greedy search algorithm used in HNSW (Hierarchical Navigable Small World) + * graphs for a specific layer. It begins with a set of entry points and iteratively explores the graph, + * always moving towards nodes that are closer to the {@code queryVector}. + *

+ * It maintains a priority queue of candidates to visit and a result set of the nearest neighbors found so far. + * The size of the dynamic candidate list is controlled by the {@code efSearch} parameter, which balances + * search quality and performance. The entire process is asynchronous, leveraging + * {@link java.util.concurrent.CompletableFuture} + * to handle I/O operations (fetching nodes) without blocking. + * + * @param The type of the node reference, extending {@link NodeReference}. + * @param storageAdapter The storage adapter for accessing node data from the underlying storage. + * @param readTransaction The transaction context for all database read operations. + * @param entryNeighbors A collection of starting nodes for the search in this layer, with their distances + * to the query vector already calculated. + * @param layer The zero-based index of the layer to search. + * @param efSearch The size of the dynamic candidate list. A larger value increases recall at the + * cost of performance. + * @param nodeCache A cache of nodes that have already been fetched from storage to avoid redundant I/O. + * @param queryVector The vector for which to find the nearest neighbors. + * + * @return A {@link java.util.concurrent.CompletableFuture} that, upon completion, will contain a list of the + * best candidate nodes found in this layer, paired with their full node data. + */ @Nonnull private CompletableFuture>> searchLayer(@Nonnull StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -580,16 +720,41 @@ private CompletableFuture }).thenCompose(ignored -> fetchSomeNodesIfNotCached(storageAdapter, readTransaction, layer, nearestNeighbors, nodeCache)) .thenApply(searchResult -> { - debug(l -> l.debug("searched layer={} for efSearch={} with result=={}", layer, efSearch, - searchResult.stream() - .map(nodeReferenceAndNode -> - "(primaryKey=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getPrimaryKey() + - ",distance=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance() + ")") - .collect(Collectors.joining(",")))); + if (logger.isDebugEnabled()) { + logger.debug("searched layer={} for efSearch={} with result=={}", layer, efSearch, + searchResult.stream() + .map(nodeReferenceAndNode -> + "(primaryKey=" + + nodeReferenceAndNode.getNodeReferenceWithDistance().getPrimaryKey() + + ",distance=" + + nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance() + ")") + .collect(Collectors.joining(","))); + } return searchResult; }); } + /** + * Asynchronously fetches a node if it is not already present in the cache. + *

+ * This method first attempts to retrieve the node from the provided {@code nodeCache} using the + * primary key of the {@code nodeReference}. If the node is not found in the cache, it is + * fetched from the underlying storage using the {@code storageAdapter}. Once fetched, the node + * is added to the {@code nodeCache} before the future is completed. + *

+ * This is a convenience method that delegates to + * {@link #fetchNodeIfNecessaryAndApply(StorageAdapter, ReadTransaction, int, NodeReference, + * java.util.function.Function, java.util.function.BiFunction)}. + * + * @param the type of the node reference, which must extend {@link NodeReference} + * @param storageAdapter the storage adapter used to fetch the node from persistent storage + * @param readTransaction the transaction to use for reading from storage + * @param layer the layer index where the node is located + * @param nodeReference the reference to the node to fetch + * @param nodeCache the cache to check for the node and to which the node will be added if fetched + * + * @return a {@link CompletableFuture} that will be completed with the fetched or cached {@link Node} + */ @Nonnull private CompletableFuture> fetchNodeIfNotCached(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -604,6 +769,34 @@ private CompletableFuture> fetchNodeIfNotCache }); } + /** + * Conditionally fetches a node from storage and applies a function to it. + *

+ * This method first attempts to generate a result by applying the {@code fetchBypassFunction}. + * If this function returns a non-null value, that value is returned immediately in a + * completed {@link CompletableFuture}, and no storage access occurs. This provides an + * optimization path, for example, if the required data is already available in a cache. + *

+ * If the bypass function returns {@code null}, the method proceeds to asynchronously fetch the + * node from the given {@code StorageAdapter}. Once the node is retrieved, the + * {@code biMapFunction} is applied to the original {@code nodeReference} and the fetched + * {@code Node} to produce the final result. + * + * @param The type of the input node reference. + * @param The type of the node reference used by the storage adapter. + * @param The type of the result. + * @param storageAdapter The storage adapter used to fetch the node if necessary. + * @param readTransaction The read transaction context for the storage operation. + * @param layer The layer index from which to fetch the node. + * @param nodeReference The reference to the node that may need to be fetched. + * @param fetchBypassFunction A function that provides a potential shortcut. If it returns a + * non-null value, the node fetch is bypassed. + * @param biMapFunction A function to be applied after a successful node fetch, combining the + * original reference and the fetched node to produce the final result. + * + * @return A {@link CompletableFuture} that will complete with the result from either the + * {@code fetchBypassFunction} or the {@code biMapFunction}. + */ @Nonnull private CompletableFuture fetchNodeIfNecessaryAndApply(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -621,6 +814,26 @@ private CompletableFuture< .thenApply(node -> biMapFunction.apply(nodeReference, node)); } + /** + * Asynchronously fetches neighborhood nodes and returns them as {@link NodeReferenceWithVector} instances, + * which include the node's vector. + *

+ * This method efficiently retrieves node data by first checking an in-memory {@code nodeCache}. If a node is not + * in the cache, it is fetched from the {@link StorageAdapter}. Fetched nodes are then added to the cache to + * optimize subsequent lookups. It also handles cases where the input {@code neighborReferences} may already + * contain {@link NodeReferenceWithVector} instances, avoiding redundant work. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the storage adapter to fetch nodes from if they are not in the cache + * @param readTransaction the transaction context for database read operations + * @param layer the graph layer from which to fetch the nodes + * @param neighborReferences an iterable of references to the neighbor nodes to be fetched + * @param nodeCache a map serving as an in-memory cache for nodes. This map will be populated with any + * nodes fetched from storage. + * + * @return a {@link CompletableFuture} that, upon completion, will contain a list of + * {@link NodeReferenceWithVector} objects for the specified neighbors + */ @Nonnull private CompletableFuture> fetchNeighborhood(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -644,6 +857,28 @@ private CompletableFuture + * This method iterates through the provided {@code nodeReferences}. For each reference, it + * first checks the {@code nodeCache}. If the corresponding {@link Node} is found, it is + * used directly. If not, the node is fetched from the {@link StorageAdapter}. Any nodes + * fetched from storage are then added to the {@code nodeCache} to optimize subsequent lookups. + * The entire operation is performed asynchronously. + * + * @param The type of the node reference, which must extend {@link NodeReference}. + * @param storageAdapter The storage adapter used to fetch nodes from storage if they are not in the cache. + * @param readTransaction The transaction context for the read operation. + * @param layer The layer from which to fetch the nodes. + * @param nodeReferences An {@link Iterable} of {@link NodeReferenceWithDistance} objects identifying the nodes to + * be fetched. + * @param nodeCache A map used as a cache. It is checked for existing nodes and updated with any newly fetched + * nodes. + * + * @return A {@link CompletableFuture} which will complete with a {@link List} of + * {@link NodeReferenceAndNode} objects, pairing each requested reference with its corresponding node. + */ @Nonnull private CompletableFuture>> fetchSomeNodesIfNotCached(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -664,6 +899,31 @@ private CompletableFuture }); } + /** + * Asynchronously fetches a collection of nodes from storage and applies a function to each. + *

+ * For each {@link NodeReference} in the provided iterable, this method concurrently fetches the corresponding + * {@code Node} using the given {@link StorageAdapter}. The logic delegates to + * {@code fetchNodeIfNecessaryAndApply}, which determines whether a full node fetch is required. + * If a node is fetched from storage, the {@code biMapFunction} is applied. If the fetch is bypassed + * (e.g., because the reference itself contains sufficient information), the {@code fetchBypassFunction} is used + * instead. + * + * @param The type of the node references to be processed, extending {@link NodeReference}. + * @param The type of the key references within the nodes, extending {@link NodeReference}. + * @param The type of the result after applying one of the mapping functions. + * @param storageAdapter The {@link StorageAdapter} used to fetch nodes from the underlying storage. + * @param readTransaction The {@link ReadTransaction} context for the read operations. + * @param layer The layer index from which the nodes are being fetched. + * @param nodeReferences An {@link Iterable} of {@link NodeReference}s for the nodes to be fetched and processed. + * @param fetchBypassFunction The function to apply to a node reference when the actual node fetch is bypassed, + * mapping the reference directly to a result of type {@code U}. + * @param biMapFunction The function to apply when a node is successfully fetched, mapping the original + * reference and the fetched {@link Node} to a result of type {@code U}. + * + * @return A {@link CompletableFuture} that, upon completion, will hold a {@link java.util.List} of results + * of type {@code U}, corresponding to each processed node reference. + */ @Nonnull private CompletableFuture> fetchSomeNodesAndApply(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @@ -677,18 +937,52 @@ private CompletableFuture< getExecutor()); } + /** + * Asynchronously inserts a node reference and its corresponding vector into the index. + *

+ * This is a convenience method that extracts the primary key and vector from the + * provided {@link NodeReferenceWithVector} and delegates to the + * {@link #insert(Transaction, Tuple, Vector)} method. + * + * @param transaction the transaction context for the operation. Must not be {@code null}. + * @param nodeReferenceWithVector a container object holding the primary key of the node + * and its vector representation. Must not be {@code null}. + * + * @return a {@link CompletableFuture} that will complete when the insertion operation is finished. + */ @Nonnull public CompletableFuture insert(@Nonnull final Transaction transaction, @Nonnull final NodeReferenceWithVector nodeReferenceWithVector) { return insert(transaction, nodeReferenceWithVector.getPrimaryKey(), nodeReferenceWithVector.getVector()); } + /** + * Inserts a new vector with its associated primary key into the HNSW graph. + *

+ * The method first determines a random layer for the new node, called the {@code insertionLayer}. + * It then traverses the graph from the entry point downwards, greedily searching for the nearest + * neighbors to the {@code newVector} at each layer. This search identifies the optimal + * connection points for the new node. + *

+ * Once the nearest neighbors are found, the new node is linked into the graph structure at all + * layers up to its {@code insertionLayer}. Special handling is included for inserting the + * first-ever node into the graph or when a new node's layer is higher than any existing node, + * which updates the graph's entry point. All operations are performed asynchronously. + * + * @param transaction the {@link Transaction} context for all database operations + * @param newPrimaryKey the unique {@link Tuple} primary key for the new node being inserted + * @param newVector the {@link Vector} data to be inserted into the graph + * + * @return a {@link CompletableFuture} that completes when the insertion operation is finished + */ @Nonnull public CompletableFuture insert(@Nonnull final Transaction transaction, @Nonnull final Tuple newPrimaryKey, @Nonnull final Vector newVector) { final Metric metric = getConfig().getMetric(); final int insertionLayer = insertionLayer(getConfig().getRandom()); - debug(l -> l.debug("new node with key={} selected to be inserted into layer={}", newPrimaryKey, insertionLayer)); + if (logger.isDebugEnabled()) { + logger.debug("new node with key={} selected to be inserted into layer={}", newPrimaryKey, insertionLayer); + } return StorageAdapter.fetchEntryNodeReference(transaction, getSubspace(), getOnReadListener()) .thenApply(entryNodeReference -> { @@ -697,14 +991,18 @@ public CompletableFuture insert(@Nonnull final Transaction transaction, @N writeLonelyNodes(transaction, newPrimaryKey, newVector, insertionLayer, -1); StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), new EntryNodeReference(newPrimaryKey, newVector, insertionLayer), getOnWriteListener()); - debug(l -> l.debug("written entry node reference with key={} on layer={}", newPrimaryKey, insertionLayer)); + if (logger.isDebugEnabled()) { + logger.debug("written entry node reference with key={} on layer={}", newPrimaryKey, insertionLayer); + } } else { final int lMax = entryNodeReference.getLayer(); if (insertionLayer > lMax) { writeLonelyNodes(transaction, newPrimaryKey, newVector, insertionLayer, lMax); StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), new EntryNodeReference(newPrimaryKey, newVector, insertionLayer), getOnWriteListener()); - debug(l -> l.debug("written entry node reference with key={} on layer={}", newPrimaryKey, insertionLayer)); + if (logger.isDebugEnabled()) { + logger.debug("written entry node reference with key={} on layer={}", newPrimaryKey, insertionLayer); + } } } return entryNodeReference; @@ -714,8 +1012,10 @@ public CompletableFuture insert(@Nonnull final Transaction transaction, @N } final int lMax = entryNodeReference.getLayer(); - debug(l -> l.debug("entry node with key {} at layer {}", entryNodeReference.getPrimaryKey(), - lMax)); + if (logger.isDebugEnabled()) { + logger.debug("entry node with key {} at layer {}", entryNodeReference.getPrimaryKey(), + lMax); + } final NodeReferenceWithDistance initialNodeReference = new NodeReferenceWithDistance(entryNodeReference.getPrimaryKey(), @@ -735,6 +1035,31 @@ public CompletableFuture insert(@Nonnull final Transaction transaction, @N }).thenCompose(ignored -> AsyncUtil.DONE); } + /** + * Inserts a batch of nodes into the HNSW graph asynchronously. + * + *

This method orchestrates the batch insertion of nodes into the HNSW graph structure. + * For each node in the input {@code batch}, it first assigns a random layer based on the configured + * probability distribution. The batch is then sorted in descending order of these assigned layers to + * ensure higher-layer nodes are processed first, which can optimize subsequent insertions by providing + * better entry points.

+ * + *

The insertion logic proceeds in two main asynchronous stages: + *

    + *
  1. Search Phase: For each node to be inserted, the method concurrently performs a greedy search + * from the graph's main entry point down to the node's target layer. This identifies the nearest neighbors + * at each level, which will serve as entry points for the insertion phase.
  2. + *
  3. Insertion Phase: The method then iterates through the nodes and inserts each one into the graph + * from its target layer downwards, connecting it to its nearest neighbors. If a node's assigned layer is + * higher than the current maximum layer of the graph, it becomes the new main entry point.
  4. + *
+ * All underlying storage operations are performed within the context of the provided {@link Transaction}.

+ * + * @param transaction the transaction to use for all storage operations; must not be {@code null} + * @param batch a {@code List} of {@link NodeReferenceWithVector} objects to insert; must not be {@code null} + * + * @return a {@link CompletableFuture} that completes with {@code null} when the entire batch has been inserted + */ @Nonnull public CompletableFuture insertBatch(@Nonnull final Transaction transaction, @Nonnull List batch) { @@ -797,7 +1122,9 @@ public CompletableFuture insertBatch(@Nonnull final Transaction transactio new EntryNodeReference(itemPrimaryKey, itemVector, itemL); StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), newEntryNodeReference, getOnWriteListener()); - debug(l -> l.debug("written entry node reference with key={} on layer={}", itemPrimaryKey, itemL)); + if (logger.isDebugEnabled()) { + logger.debug("written entry node reference with key={} on layer={}", itemPrimaryKey, itemL); + } return CompletableFuture.completedFuture(newEntryNodeReference); } else { @@ -808,14 +1135,18 @@ public CompletableFuture insertBatch(@Nonnull final Transaction transactio new EntryNodeReference(itemPrimaryKey, itemVector, itemL); StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), newEntryNodeReference, getOnWriteListener()); - debug(l -> l.debug("written entry node reference with key={} on layer={}", itemPrimaryKey, itemL)); + if (logger.isDebugEnabled()) { + logger.debug("written entry node reference with key={} on layer={}", itemPrimaryKey, itemL); + } } else { newEntryNodeReference = entryNodeReference; } } - debug(l -> l.debug("entry node with key {} at layer {}", - currentEntryNodeReference.getPrimaryKey(), currentLMax)); + if (logger.isDebugEnabled()) { + logger.debug("entry node with key {} at layer {}", + currentEntryNodeReference.getPrimaryKey(), currentLMax); + } final var currentSearchEntry = searchEntryReferences.get(index); @@ -826,6 +1157,29 @@ public CompletableFuture insertBatch(@Nonnull final Transaction transactio }).thenCompose(ignored -> AsyncUtil.DONE); } + /** + * Inserts a new vector into the HNSW graph across multiple layers, starting from a given entry point. + *

+ * This method implements the second phase of the HNSW insertion algorithm. It begins at a starting layer, which is + * the minimum of the graph's maximum layer ({@code lMax}) and the new node's randomly assigned + * {@code insertionLayer}. It then iterates downwards to layer 0. In each layer, it invokes + * {@link #insertIntoLayer(StorageAdapter, Transaction, List, int, Tuple, Vector)} to perform the search and + * connect the new node. The set of nearest neighbors found at layer {@code L} serves as the entry points for the + * search at layer {@code L-1}. + *

+ * + * @param transaction the transaction to use for database operations + * @param newPrimaryKey the primary key of the new node being inserted + * @param newVector the vector data of the new node + * @param nodeReference the initial entry point for the search, typically the nearest neighbor found in the highest + * layer + * @param lMax the maximum layer number in the HNSW graph + * @param insertionLayer the randomly determined layer for the new node. The node will be inserted into all layers + * from this layer down to 0. + * + * @return a {@link CompletableFuture} that completes when the new node has been successfully inserted into all + * its designated layers + */ @Nonnull private CompletableFuture insertIntoLayers(@Nonnull final Transaction transaction, @Nonnull final Tuple newPrimaryKey, @@ -833,7 +1187,9 @@ private CompletableFuture insertIntoLayers(@Nonnull final Transaction tran @Nonnull final NodeReferenceWithDistance nodeReference, final int lMax, final int insertionLayer) { - debug(l -> l.debug("nearest entry point at lMax={} is at key={}", lMax, nodeReference.getPrimaryKey())); + if (logger.isDebugEnabled()) { + logger.debug("nearest entry point at lMax={} is at key={}", lMax, nodeReference.getPrimaryKey()); + } return MoreAsyncUtil.>forLoop(Math.min(lMax, insertionLayer), ImmutableList.of(nodeReference), layer -> layer >= 0, layer -> layer - 1, @@ -844,6 +1200,39 @@ private CompletableFuture insertIntoLayers(@Nonnull final Transaction tran }, executor).thenCompose(ignored -> AsyncUtil.DONE); } + /** + * Inserts a new node into a specified layer of the HNSW graph. + *

+ * This method orchestrates the complete insertion process for a single layer. It begins by performing a search + * within the given layer, starting from the provided {@code nearestNeighbors} as entry points, to find a set of + * candidate neighbors for the new node. From this candidate set, it selects the best connections based on the + * graph's parameters (M). + *

+ *

+ * After selecting the neighbors, it creates the new node and links it to them. It then reciprocally updates + * the selected neighbors to link back to the new node. If adding this new link causes a neighbor to exceed its + * maximum allowed connections, its connections are pruned. All changes, including the new node and the updated + * neighbors, are persisted to storage within the given transaction. + *

+ *

+ * The operation is asynchronous and returns a {@link CompletableFuture}. The future completes with the list of + * nodes found during the initial search phase, which are then used as the entry points for insertion into the + * next lower layer. + *

+ * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the storage adapter for reading from and writing to the graph + * @param transaction the transaction context for the database operations + * @param nearestNeighbors the list of nearest neighbors from the layer above, used as entry points for the search + * in this layer + * @param layer the layer number to insert the new node into + * @param newPrimaryKey the primary key of the new node to be inserted + * @param newVector the vector associated with the new node + * + * @return a {@code CompletableFuture} that completes with a list of the nearest neighbors found during the + * initial search phase. This list serves as the entry point for insertion into the next lower layer + * (i.e., {@code layer - 1}). + */ @Nonnull private CompletableFuture> insertIntoLayer(@Nonnull final StorageAdapter storageAdapter, @Nonnull final Transaction transaction, @@ -851,7 +1240,9 @@ private CompletableFuture newVector) { - debug(l -> l.debug("begin insert key={} at layer={}", newPrimaryKey, layer)); + if (logger.isDebugEnabled()) { + logger.debug("begin insert key={} at layer={}", newPrimaryKey, layer); + } final Map> nodeCache = Maps.newConcurrentMap(); return searchLayer(storageAdapter, transaction, @@ -912,11 +1303,33 @@ private CompletableFuture { - debug(l -> l.debug("end insert key={} at layer={}", newPrimaryKey, layer)); + if (logger.isDebugEnabled()) { + logger.debug("end insert key={} at layer={}", newPrimaryKey, layer); + } return nodeReferencesWithDistances; }); } + /** + * Calculates the delta between a current set of neighbors and a new set, producing a + * {@link NeighborsChangeSet} that represents the required insertions and deletions. + *

+ * This method compares the neighbors present in the initial {@code beforeChangeSet} with + * the provided {@code afterNeighbors}. It identifies which neighbors from the "before" state + * are missing in the "after" state (to be deleted) and which new neighbors are present in the + * "after" state but not in the "before" state (to be inserted). It then constructs a new + * {@code NeighborsChangeSet} by wrapping the original one with {@link DeleteNeighborsChangeSet} + * and {@link InsertNeighborsChangeSet} as needed. + * + * @param the type of the node reference, which must extend {@link NodeReference} + * @param beforeChangeSet the change set representing the state of neighbors before the update. + * This is used as the base for calculating changes. Must not be null. + * @param afterNeighbors an iterable collection of the desired neighbors after the update. + * Must not be null. + * + * @return a new {@code NeighborsChangeSet} that includes the necessary deletion and insertion + * operations to transform the neighbors from the "before" state to the "after" state. + */ private NeighborsChangeSet resolveChangeSetFromNewNeighbors(@Nonnull final NeighborsChangeSet beforeChangeSet, @Nonnull final Iterable> afterNeighbors) { final Map beforeNeighborsMap = Maps.newLinkedHashMap(); @@ -959,6 +1372,27 @@ private NeighborsChangeSet resolveChangeSetFromNewN return changeSet; } + /** + * Prunes the neighborhood of a given node if its number of connections exceeds the maximum allowed ({@code mMax}). + *

+ * This is a maintenance operation for the HNSW graph. When new nodes are added, an existing node's neighborhood + * might temporarily grow beyond its limit. This method identifies such cases and trims the neighborhood back down + * to the {@code mMax} best connections, based on the configured distance metric. If the neighborhood size is + * already within the limit, this method does nothing. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the storage adapter to fetch nodes from the database + * @param transaction the transaction context for database operations + * @param selectedNeighbor the node whose neighborhood is being considered for pruning + * @param layer the graph layer on which the operation is performed + * @param mMax the maximum number of neighbors a node is allowed to have on this layer + * @param neighborChangeSet a set of pending changes to the neighborhood that must be included in the pruning + * calculation + * @param nodeCache a cache of nodes to avoid redundant database fetches + * + * @return a {@link CompletableFuture} which completes with a list of the newly selected neighbors for the pruned node. + * If no pruning was necessary, it completes with {@code null}. + */ @Nonnull private CompletableFuture>> pruneNeighborsIfNecessary(@Nonnull final StorageAdapter storageAdapter, @Nonnull final Transaction transaction, @@ -972,8 +1406,10 @@ private CompletableFuture if (selectedNeighborNode.getNeighbors().size() < mMax) { return CompletableFuture.completedFuture(null); } else { - debug(l -> l.debug("pruning neighborhood of key={} which has numNeighbors={} out of mMax={}", - selectedNeighborNode.getPrimaryKey(), selectedNeighborNode.getNeighbors().size(), mMax)); + if (logger.isDebugEnabled()) { + logger.debug("pruning neighborhood of key={} which has numNeighbors={} out of mMax={}", + selectedNeighborNode.getPrimaryKey(), selectedNeighborNode.getNeighbors().size(), mMax); + } return fetchNeighborhood(storageAdapter, transaction, layer, neighborChangeSet.merge(), nodeCache) .thenCompose(nodeReferenceWithVectors -> { final ImmutableList.Builder nodeReferencesWithDistancesBuilder = @@ -998,6 +1434,36 @@ private CompletableFuture } } + /** + * Selects the {@code m} best neighbors for a new node from a set of candidates using the HNSW selection heuristic. + *

+ * This method implements the core logic for neighbor selection within a layer of the HNSW graph. It starts with an + * initial set of candidates ({@code nearestNeighbors}), which can be optionally extended by fetching their own + * neighbors. + * It then iteratively refines this set using a greedy best-first search. + *

+ * The selection heuristic ensures diversity among neighbors. A candidate is added to the result set only if it is + * closer to the query {@code vector} than to any node already in the result set. This prevents selecting neighbors + * that are clustered together. If the {@code keepPrunedConnections} configuration is enabled, candidates that are + * pruned by this heuristic are kept and may be added at the end if the result set is not yet full. + *

+ * The process is asynchronous and returns a {@link CompletableFuture} that will eventually contain the list of + * selected neighbors with their full node data. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the storage adapter to fetch nodes and their neighbors + * @param readTransaction the transaction for performing database reads + * @param nearestNeighbors the initial pool of candidate neighbors, typically from a search in a higher layer + * @param layer the layer in the HNSW graph where the selection is being performed + * @param m the maximum number of neighbors to select + * @param isExtendCandidates a flag indicating whether to extend the initial candidate pool by fetching the + * neighbors of the {@code nearestNeighbors} + * @param nodeCache a cache of nodes to avoid redundant storage lookups + * @param vector the query vector for which neighbors are being selected + * + * @return a {@link CompletableFuture} which will complete with a list of the selected neighbors, + * each represented as a {@link NodeReferenceAndNode} + */ private CompletableFuture>> selectNeighbors(@Nonnull final StorageAdapter storageAdapter, @Nonnull final ReadTransaction readTransaction, @Nonnull final Iterable> nearestNeighbors, @@ -1048,24 +1514,49 @@ private CompletableFuture }).thenCompose(selectedNeighbors -> fetchSomeNodesIfNotCached(storageAdapter, readTransaction, layer, selectedNeighbors, nodeCache)) .thenApply(selectedNeighbors -> { - debug(l -> - l.debug("selected neighbors={}", - selectedNeighbors.stream() - .map(selectedNeighbor -> - "(primaryKey=" + selectedNeighbor.getNodeReferenceWithDistance().getPrimaryKey() + - ",distance=" + selectedNeighbor.getNodeReferenceWithDistance().getDistance() + ")") - .collect(Collectors.joining(",")))); + if (logger.isDebugEnabled()) { + logger.debug("selected neighbors={}", + selectedNeighbors.stream() + .map(selectedNeighbor -> + "(primaryKey=" + selectedNeighbor.getNodeReferenceWithDistance().getPrimaryKey() + + ",distance=" + selectedNeighbor.getNodeReferenceWithDistance().getDistance() + ")") + .collect(Collectors.joining(","))); + } return selectedNeighbors; }); } - private CompletableFuture> extendCandidatesIfNecessary(@Nonnull final StorageAdapter storageAdapter, - @Nonnull final ReadTransaction readTransaction, - @Nonnull final Iterable> candidates, - int layer, - boolean isExtendCandidates, - @Nonnull final Map> nodeCache, - @Nonnull final Vector vector) { + /** + * Conditionally extends a set of candidate nodes by fetching and evaluating their neighbors. + *

+ * If {@code isExtendCandidates} is {@code true}, this method gathers the neighbors of the provided + * {@code candidates}, fetches their full node data, and calculates their distance to the given + * {@code vector}. The resulting list will contain both the original candidates and their newly + * evaluated neighbors. + *

+ * If {@code isExtendCandidates} is {@code false}, the method simply returns a list containing + * only the original candidates. This operation is asynchronous and returns a {@link CompletableFuture}. + * + * @param the type of the {@link NodeReference} + * @param storageAdapter the {@link StorageAdapter} used to access node data from storage + * @param readTransaction the active {@link ReadTransaction} for database access + * @param candidates an {@link Iterable} of initial candidate nodes, which have already been evaluated + * @param layer the graph layer from which to fetch nodes + * @param isExtendCandidates a boolean flag; if {@code true}, the candidate set is extended with neighbors + * @param nodeCache a cache mapping primary keys to {@link Node} objects to avoid redundant fetches + * @param vector the query vector used to calculate distances for any new neighbor nodes + * + * @return a {@link CompletableFuture} which will complete with a list of {@link NodeReferenceWithDistance}, + * containing the original candidates and potentially their neighbors + */ + private CompletableFuture> + extendCandidatesIfNecessary(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final Iterable> candidates, + int layer, + boolean isExtendCandidates, + @Nonnull final Map> nodeCache, + @Nonnull final Vector vector) { if (isExtendCandidates) { final Metric metric = getConfig().getMetric(); @@ -1089,7 +1580,8 @@ private CompletableFuture { - final ImmutableList.Builder extendedCandidatesBuilder = ImmutableList.builder(); + final ImmutableList.Builder extendedCandidatesBuilder = + ImmutableList.builder(); for (final NodeReferenceAndNode candidate : candidates) { extendedCandidatesBuilder.add(candidate.getNodeReferenceWithDistance()); } @@ -1111,6 +1603,21 @@ private CompletableFuture + * A "lonely node" is a node in the layered structure that does not have a right + * sibling. This method iterates downwards from the {@code highestLayerInclusive} + * to the {@code lowestLayerExclusive}. For each layer in this range, it + * retrieves the appropriate {@link StorageAdapter} and calls + * {@link #writeLonelyNodeOnLayer} to persist the node's information. + * + * @param transaction the transaction to use for writing to the database + * @param primaryKey the primary key of the record for which lonely nodes are being written + * @param vector the search path vector that was followed to find this key + * @param highestLayerInclusive the highest layer (inclusive) to begin writing lonely nodes on + * @param lowestLayerExclusive the lowest layer (exclusive) at which to stop writing lonely nodes + */ private void writeLonelyNodes(@Nonnull final Transaction transaction, @Nonnull final Tuple primaryKey, @Nonnull final Vector vector, @@ -1122,6 +1629,21 @@ private void writeLonelyNodes(@Nonnull final Transaction transaction, } } + /** + * Writes a new, isolated ('lonely') node to a specified layer within the graph. + *

+ * This method uses the provided {@link StorageAdapter} to create a new node with the + * given primary key and vector but with an empty set of neighbors. The write + * operation is performed as part of the given {@link Transaction}. This is typically + * used to insert the very first node into an empty graph layer. + * + * @param the type of the node reference, extending {@link NodeReference} + * @param storageAdapter the {@link StorageAdapter} used to access the data store and create nodes; must not be null + * @param transaction the {@link Transaction} context for the write operation; must not be null + * @param layer the layer index where the new node will be written + * @param primaryKey the primary key for the new node; must not be null + * @param vector the vector data for the new node; must not be null + */ private void writeLonelyNodeOnLayer(@Nonnull final StorageAdapter storageAdapter, @Nonnull final Transaction transaction, final int layer, @@ -1131,9 +1653,25 @@ private void writeLonelyNodeOnLayer(@Nonnull final Sto storageAdapter.getNodeFactory() .create(primaryKey, vector, ImmutableList.of()), layer, new BaseNeighborsChangeSet<>(ImmutableList.of())); - debug(l -> l.debug("written lonely node at key={} on layer={}", primaryKey, layer)); + if (logger.isDebugEnabled()) { + logger.debug("written lonely node at key={} on layer={}", primaryKey, layer); + } } + /** + * Scans all nodes within a given layer of the database. + *

+ * The scan is performed transactionally in batches to avoid loading the entire layer + * into memory at once. Each discovered node is passed to the provided {@link Consumer} + * for processing. The operation continues fetching batches until all nodes in the + * specified layer have been processed. + * + * @param db the non-null {@link Database} instance to run the scan against. + * @param layer the specific layer index to scan. + * @param batchSize the number of nodes to retrieve and process in each batch. + * @param nodeConsumer the non-null {@link Consumer} that will accept each {@link Node} + * found in the layer. + */ public void scanLayer(@Nonnull final Database db, final int layer, final int batchSize, @@ -1155,19 +1693,61 @@ public void scanLayer(@Nonnull final Database db, } while (newPrimaryKey != null); } + /** + * Gets the appropriate storage adapter for a given layer. + *

+ * This method selects a {@link StorageAdapter} implementation based on the layer number. + * The logic is intended to use an {@code InliningStorageAdapter} for layers greater + * than 0 and a {@code CompactStorageAdapter} for layer 0. However, the switch to + * the inlining adapter is currently disabled with a hardcoded {@code false}, + * so this method will always return a {@code CompactStorageAdapter}. + * + * @param layer the layer number for which to get the storage adapter; currently unused + * + * @return a non-null {@link StorageAdapter} instance, which will always be a + * {@link CompactStorageAdapter} in the current implementation + */ @Nonnull private StorageAdapter getStorageAdapterForLayer(final int layer) { return false && layer > 0 - ? new InliningStorageAdapter(getConfig(), InliningNode.factory(), getSubspace(), getOnWriteListener(), getOnReadListener()) - : new CompactStorageAdapter(getConfig(), CompactNode.factory(), getSubspace(), getOnWriteListener(), getOnReadListener()); + ? new InliningStorageAdapter(getConfig(), InliningNode.factory(), getSubspace(), getOnWriteListener(), + getOnReadListener()) + : new CompactStorageAdapter(getConfig(), CompactNode.factory(), getSubspace(), getOnWriteListener(), + getOnReadListener()); } + /** + * Calculates a random layer for a new element to be inserted. + *

+ * The layer is selected according to a logarithmic distribution, which ensures that + * the probability of choosing a higher layer decreases exponentially. This is + * achieved by applying the inverse transform sampling method. The specific formula + * is {@code floor(-ln(u) * lambda)}, where {@code u} is a uniform random + * number and {@code lambda} is a normalization factor derived from a system + * configuration parameter {@code M}. + * + * @param random the {@link Random} object used for generating a random number. + * It must not be null. + * + * @return a non-negative integer representing the randomly selected layer. + */ private int insertionLayer(@Nonnull final Random random) { double lambda = 1.0 / Math.log(getConfig().getM()); double u = 1.0 - random.nextDouble(); // Avoid log(0) return (int) Math.floor(-Math.log(u) * lambda); } + /** + * Logs a message at the INFO level, using a consumer for lazy evaluation. + *

+ * This approach avoids the cost of constructing the log message if the INFO + * level is disabled. The provided {@link java.util.function.Consumer} will be + * executed only when {@code logger.isInfoEnabled()} returns {@code true}. + * + * @param loggerConsumer the {@link java.util.function.Consumer} that will be + * accepted if logging is enabled. It receives the + * {@code Logger} instance and must not be null. + */ @SuppressWarnings("PMD.UnusedPrivateMethod") private void info(@Nonnull final Consumer loggerConsumer) { if (logger.isInfoEnabled()) { @@ -1175,12 +1755,6 @@ private void info(@Nonnull final Consumer loggerConsumer) { } } - private void debug(@Nonnull final Consumer loggerConsumer) { - if (logger.isDebugEnabled()) { - loggerConsumer.accept(logger); - } - } - private static class NodeReferenceWithLayer extends NodeReferenceWithVector { private final int layer; From 7704cb0e86a8c4cd022423286a9f6d38a704623d Mon Sep 17 00:00:00 2001 From: Normen Seemann Date: Wed, 17 Sep 2025 22:20:53 +0200 Subject: [PATCH 05/10] more javadoc and tests --- .../foundationdb/async/hnsw/AbstractNode.java | 39 ++- .../async/hnsw/AbstractStorageAdapter.java | 134 ++++++++- .../async/hnsw/BaseNeighborsChangeSet.java | 36 ++- .../foundationdb/async/hnsw/CompactNode.java | 62 ++++- .../async/hnsw/CompactStorageAdapter.java | 3 - .../async/hnsw/DeleteNeighborsChangeSet.java | 56 +++- .../foundationdb/async/hnsw/InliningNode.java | 57 +++- .../async/hnsw/InliningStorageAdapter.java | 169 +++++++++++- .../apple/foundationdb/async/hnsw/Metric.java | 36 +++ .../async/hnsw/NeighborsChangeSet.java | 40 ++- .../apple/foundationdb/async/hnsw/Node.java | 56 +++- .../async/hnsw/NodeReferenceWithVector.java | 48 ++++ .../async/hnsw/HNSWHelpersTest.java | 75 +++++ .../async/hnsw/HNSWModificationTest.java | 256 +++--------------- .../foundationdb/async/hnsw/MetricTest.java | 174 ++++++++++++ 15 files changed, 1007 insertions(+), 234 deletions(-) create mode 100644 fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWHelpersTest.java create mode 100644 fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java index aa062e8700..252185f38b 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java @@ -27,8 +27,14 @@ import java.util.List; /** - * TODO. - * @param node type class. + * An abstract base class implementing the {@link Node} interface. + *

+ * This class provides the fundamental structure for a node within the HNSW graph, + * managing a unique {@link Tuple} primary key and an immutable list of its neighbors. + * Subclasses are expected to provide concrete implementations, potentially adding + * more state or behavior. + * + * @param the type of the node reference used for neighbors, which must extend {@link NodeReference} */ abstract class AbstractNode implements Node { @Nonnull @@ -37,24 +43,53 @@ abstract class AbstractNode implements Node { @Nonnull private final List neighbors; + /** + * Constructs a new {@code AbstractNode} with a specified primary key and a list of neighbors. + *

+ * This constructor creates a defensive, immutable copy of the provided {@code neighbors} list. + * This ensures that the internal state of the node cannot be modified by external + * changes to the original list after construction. + * + * @param primaryKey the unique identifier for this node; must not be {@code null} + * @param neighbors the list of nodes connected to this node; must not be {@code null} + */ protected AbstractNode(@Nonnull final Tuple primaryKey, @Nonnull final List neighbors) { this.primaryKey = primaryKey; this.neighbors = ImmutableList.copyOf(neighbors); } + /** + * Gets the primary key that uniquely identifies this object. + * @return the primary key {@link Tuple}, which will never be {@code null}. + */ @Nonnull @Override public Tuple getPrimaryKey() { return primaryKey; } + /** + * Gets the list of neighbors connected to this node. + *

+ * This method returns a direct reference to the internal list which is + * immutable. + * @return a non-null, possibly empty, list of neighbors. + */ @Nonnull @Override public List getNeighbors() { return neighbors; } + /** + * Gets the neighbor at the specified index. + *

+ * This method provides access to a specific neighbor by its zero-based position + * in the internal list of neighbors. + * @param index the zero-based index of the neighbor to retrieve. + * @return the neighbor at the specified index, guaranteed to be non-null. + */ @Nonnull @Override public N getNeighbor(final int index) { diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java index e3d0c943fc..2b0e17da69 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java @@ -32,7 +32,14 @@ import java.util.concurrent.CompletableFuture; /** - * Implementations and attributes common to all concrete implementations of {@link StorageAdapter}. + * An abstract base class for {@link StorageAdapter} implementations. + *

+ * This class provides the common infrastructure for managing HNSW graph data within {@link Subspace}. + * It handles the configuration, node creation, and listener management, while delegating the actual + * storage-specific read and write operations to concrete subclasses through the {@code fetchNodeInternal} + * and {@code writeNodeInternal} abstract methods. + * + * @param the type of {@link NodeReference} used to reference nodes in the graph */ abstract class AbstractStorageAdapter implements StorageAdapter { @Nonnull @@ -51,6 +58,19 @@ abstract class AbstractStorageAdapter implements Storag private final Subspace dataSubspace; + /** + * Constructs a new {@code AbstractStorageAdapter}. + *

+ * This constructor initializes the adapter with the necessary configuration, + * factories, and listeners for managing an HNSW graph. It also sets up a + * dedicated data subspace within the provided main subspace for storing node data. + * + * @param config the HNSW graph configuration + * @param nodeFactory the factory to create new nodes of type {@code } + * @param subspace the primary subspace for storing all graph-related data + * @param onWriteListener the listener to be called on write operations + * @param onReadListener the listener to be called on read operations + */ protected AbstractStorageAdapter(@Nonnull final HNSW.Config config, @Nonnull final NodeFactory nodeFactory, @Nonnull final Subspace subspace, @Nonnull final OnWriteListener onWriteListener, @@ -63,48 +83,117 @@ protected AbstractStorageAdapter(@Nonnull final HNSW.Config config, @Nonnull fin this.dataSubspace = subspace.subspace(Tuple.from(SUBSPACE_PREFIX_DATA)); } + /** + * Returns the configuration used to build and search this HNSW graph. + * + * @return the current {@link HNSW.Config} object, never {@code null}. + */ @Override @Nonnull public HNSW.Config getConfig() { return config; } + /** + * Gets the factory responsible for creating new nodes. + *

+ * This factory is used to instantiate nodes of the generic type {@code N} + * for the current context. The {@code @Nonnull} annotation guarantees that + * this method will never return {@code null}. + * + * @return the non-null {@link NodeFactory} instance. + */ @Nonnull @Override public NodeFactory getNodeFactory() { return nodeFactory; } + /** + * Gets the kind of this node, which uniquely identifies the type of node. + *

+ * This method is an override and provides a way to determine the concrete + * type of node without using {@code instanceof} checks. + * + * @return the non-null {@link NodeKind} representing the type of this node. + */ @Nonnull @Override public NodeKind getNodeKind() { return getNodeFactory().getNodeKind(); } + /** + * Gets the subspace in which this key or value is stored. + *

+ * This subspace provides a logical separation for keys within the underlying key-value store. + * + * @return the non-null {@link Subspace} for this context + */ @Override @Nonnull public Subspace getSubspace() { return subspace; } + /** + * Gets the subspace for the data associated with this component. + *

+ * The data subspace defines the portion of the directory space where the data + * for this component is stored. + * + * @return the non-null {@link Subspace} for the data + */ @Override @Nonnull public Subspace getDataSubspace() { return dataSubspace; } + /** + * Returns the listener that is notified upon write events. + *

+ * This method is an override and guarantees a non-null return value, + * as indicated by the {@code @Nonnull} annotation. + * + * @return the configured {@link OnWriteListener} instance; will never be {@code null}. + */ @Override @Nonnull public OnWriteListener getOnWriteListener() { return onWriteListener; } + /** + * Gets the listener that is notified upon completion of a read operation. + *

+ * This method is an override and provides the currently configured listener instance. + * The returned listener is guaranteed to be non-null as indicated by the + * {@code @Nonnull} annotation. + * + * @return the non-null {@link OnReadListener} instance. + */ @Override @Nonnull public OnReadListener getOnReadListener() { return onReadListener; } + /** + * Asynchronously fetches a node from a specific layer of the HNSW. + *

+ * The node is identified by its {@code layer} and {@code primaryKey}. The entire fetch operation is + * performed within the given {@link ReadTransaction}. After the underlying + * fetch operation completes, the retrieved node is validated by the + * {@link #checkNode(Node)} method before the returned future is completed. + * + * @param readTransaction the non-null transaction to use for the read operation + * @param layer the layer of the tree from which to fetch the node + * @param primaryKey the non-null primary key that identifies the node to fetch + * + * @return a {@link CompletableFuture} that will complete with the fetched {@link Node} + * once it has been read from storage and validated + */ @Nonnull @Override public CompletableFuture> fetchNode(@Nonnull final ReadTransaction readTransaction, @@ -112,6 +201,20 @@ public CompletableFuture> fetchNode(@Nonnull final ReadTransaction readT return fetchNodeInternal(readTransaction, layer, primaryKey).thenApply(this::checkNode); } + /** + * Asynchronously fetches a specific node from the data store for a given layer and primary key. + *

+ * This is an internal, abstract method that concrete subclasses must implement to define + * the storage-specific logic for retrieving a node. The operation is performed within the + * context of the provided {@link ReadTransaction}. + * + * @param readTransaction the transaction to use for the read operation; must not be {@code null} + * @param layer the layer index from which to fetch the node + * @param primaryKey the primary key that uniquely identifies the node to be fetched; must not be {@code null} + * + * @return a {@link CompletableFuture} that will be completed with the fetched {@link Node}. + * The future will complete with {@code null} if no node is found for the given key and layer. + */ @Nonnull protected abstract CompletableFuture> fetchNodeInternal(@Nonnull ReadTransaction readTransaction, int layer, @Nonnull Tuple primaryKey); @@ -129,6 +232,21 @@ private Node checkNode(@Nullable final Node node) { return node; } + /** + * Writes a given node and its neighbor modifications to the underlying storage. + *

+ * This operation is executed within the context of the provided {@link Transaction}. + * It handles persisting the node's data at a specific {@code layer} and applies + * the changes to its neighbors as defined in the {@link NeighborsChangeSet}. + * This method delegates the core writing logic to an internal method and provides + * debug logging upon completion. + * + * @param transaction the non-null {@link Transaction} context for this write operation + * @param node the non-null {@link Node} to be written to storage + * @param layer the layer index where the node is being written + * @param changeSet the non-null {@link NeighborsChangeSet} detailing the modifications + * to the node's neighbors + */ @Override public void writeNode(@Nonnull Transaction transaction, @Nonnull Node node, int layer, @Nonnull NeighborsChangeSet changeSet) { @@ -138,6 +256,20 @@ public void writeNode(@Nonnull Transaction transaction, @Nonnull Node node, i } } + /** + * Writes a single node to the data store as part of a larger transaction. + *

+ * This is an abstract method that concrete implementations must provide. + * It is responsible for the low-level persistence of the given {@code node} at a + * specific {@code layer}. The implementation should also handle the modifications + * to the node's neighbors, as detailed in the {@code changeSet}. + * + * @param transaction the non-null transaction context for the write operation + * @param node the non-null {@link Node} to write + * @param layer the layer or level of the node in the structure + * @param changeSet the non-null {@link NeighborsChangeSet} detailing additions or + * removals of neighbor links + */ protected abstract void writeNodeInternal(@Nonnull Transaction transaction, @Nonnull Node node, int layer, @Nonnull NeighborsChangeSet changeSet); diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java index bb8271af39..794bd5ae4c 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java @@ -30,28 +30,62 @@ import java.util.function.Predicate; /** - * TODO. + * A base implementation of the {@link NeighborsChangeSet} interface. + *

+ * This class represents a complete, non-delta state of a node's neighbors. It holds a fixed, immutable + * list of neighbors provided at construction time. As such, it does not support parent change sets or writing deltas. + * + * @param the type of the node reference, which must extend {@link NodeReference} */ class BaseNeighborsChangeSet implements NeighborsChangeSet { @Nonnull private final List neighbors; + /** + * Creates a new change set with the specified neighbors. + *

+ * This constructor creates an immutable copy of the provided list. + * + * @param neighbors the list of neighbors for this change set; must not be null. + */ public BaseNeighborsChangeSet(@Nonnull final List neighbors) { this.neighbors = ImmutableList.copyOf(neighbors); } + /** + * Gets the parent change set. + *

+ * This implementation always returns {@code null}, as this type of change set + * does not have a parent. + * + * @return always {@code null}. + */ @Nullable @Override public BaseNeighborsChangeSet getParent() { return null; } + /** + * Retrieves the list of neighbors associated with this object. + *

+ * This implementation fulfills the {@code merge} contract by simply returning the + * existing list of neighbors without performing any additional merging logic. + * @return a non-null list of neighbors. The generic type {@code N} represents + * the type of the neighboring elements. + */ @Nonnull @Override public List merge() { return neighbors; } + /** + * {@inheritDoc} + * + *

This implementation is a no-op and does not write any delta information, + * as indicated by the empty method body. + */ @Override public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, final int layer, @Nonnull final Node node, diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java index a6a28e778d..e58f005dd1 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java @@ -30,7 +30,14 @@ import java.util.Objects; /** - * TODO. + * Represents a compact node within a graph structure, extending {@link AbstractNode}. + *

+ * This node type is considered "compact" because it directly stores its associated + * data vector of type {@code Vector}. It is used to represent a vector in a + * vector space and maintains references to its neighbors via {@link NodeReference} objects. + * + * @see AbstractNode + * @see NodeReference */ public class CompactNode extends AbstractNode { @Nonnull @@ -54,41 +61,94 @@ public NodeKind getNodeKind() { @Nonnull private final Vector vector; + /** + * Constructs a new {@code CompactNode} instance. + *

+ * This constructor initializes the node with its primary key, a data vector, + * and a list of its neighbors. It delegates the initialization of the + * {@code primaryKey} and {@code neighbors} to the superclass constructor. + * + * @param primaryKey the primary key that uniquely identifies this node; must not be {@code null}. + * @param vector the data vector of type {@code Vector} associated with this node; must not be {@code null}. + * @param neighbors a list of {@link NodeReference} objects representing the neighbors of this node; must not be + * {@code null}. + */ public CompactNode(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, @Nonnull final List neighbors) { super(primaryKey, neighbors); this.vector = vector; } + /** + * Returns a {@link NodeReference} that uniquely identifies this node. + *

+ * This implementation creates the reference using the node's primary key, obtained via {@code getPrimaryKey()}. It + * ignores the provided {@code vector} parameter, which exists to fulfill the contract of the overridden method. + * + * @param vector the vector context, which is ignored in this implementation. + * Per the {@code @Nullable} annotation, this can be {@code null}. + * + * @return a non-null {@link NodeReference} to this node. + */ @Nonnull @Override public NodeReference getSelfReference(@Nullable final Vector vector) { return new NodeReference(getPrimaryKey()); } + /** + * Gets the kind of this node. + * This implementation always returns {@link NodeKind#COMPACT}. + * @return the node kind, which is guaranteed to be {@link NodeKind#COMPACT}. + */ @Nonnull @Override public NodeKind getKind() { return NodeKind.COMPACT; } + /** + * Gets the vector of {@code Half} objects. + * @return the non-null vector of {@link Half} objects. + */ @Nonnull public Vector getVector() { return vector; } + /** + * Returns this node as a {@code CompactNode}. As this class is already a {@code CompactNode}, this method provides + * {@code this}. + * @return this object cast as a {@code CompactNode}, which is guaranteed to be non-null. + */ @Nonnull @Override public CompactNode asCompactNode() { return this; } + /** + * Returns this node as an {@link InliningNode}. + *

+ * This override is for node types that are not inlining nodes. As such, it + * will always fail. + * @return this node as a non-null {@link InliningNode} + * @throws IllegalStateException always, as this is not an inlining node + */ @Nonnull @Override public InliningNode asInliningNode() { throw new IllegalStateException("this is not an inlining node"); } + /** + * Gets the shared factory instance for creating {@link NodeReference} objects. + *

+ * This static factory method is the preferred way to obtain a {@code NodeFactory} + * for {@link NodeReference} instances, as it returns a shared, pre-configured object. + * + * @return a shared, non-null instance of {@code NodeFactory} + */ @Nonnull public static NodeFactory factory() { return FACTORY; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java index c3a04f86a2..4d9497ba0a 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java @@ -41,9 +41,6 @@ import java.util.List; import java.util.concurrent.CompletableFuture; -/** - * TODO. - */ class CompactStorageAdapter extends AbstractStorageAdapter implements StorageAdapter { @Nonnull private static final Logger logger = LoggerFactory.getLogger(CompactStorageAdapter.class); diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java index e431561119..e70515531e 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java @@ -33,7 +33,12 @@ import java.util.function.Predicate; /** - * TODO. + * A {@link NeighborsChangeSet} that represents the deletion of a set of neighbors from a parent change set. + *

+ * This class acts as a filter, wrapping a parent {@link NeighborsChangeSet} and providing a view of the neighbors + * that excludes those whose primary keys have been marked for deletion. + * + * @param the type of the node reference, which must extend {@link NodeReference} */ class DeleteNeighborsChangeSet implements NeighborsChangeSet { @Nonnull @@ -45,18 +50,50 @@ class DeleteNeighborsChangeSet implements NeighborsChan @Nonnull private final Set deletedNeighborsPrimaryKeys; + /** + * Constructs a new {@code DeleteNeighborsChangeSet}. + *

+ * This object represents a set of changes where specific neighbors are marked for deletion. + * It holds a reference to a parent {@link NeighborsChangeSet} and creates an immutable copy + * of the primary keys for the neighbors to be deleted. + * + * @param parent the parent {@link NeighborsChangeSet} to which this deletion change belongs. Must not be null. + * @param deletedNeighborsPrimaryKeys a {@link Collection} of primary keys, represented as {@link Tuple}s, + * identifying the neighbors to be deleted. Must not be null. + */ public DeleteNeighborsChangeSet(@Nonnull final NeighborsChangeSet parent, @Nonnull final Collection deletedNeighborsPrimaryKeys) { this.parent = parent; this.deletedNeighborsPrimaryKeys = ImmutableSet.copyOf(deletedNeighborsPrimaryKeys); } + /** + * Gets the parent change set from which this change set was derived. + *

+ * In a sequence of modifications, each {@code NeighborsChangeSet} is derived from a previous state, which is + * considered its parent. This method allows traversing the history of changes backward. + * + * @return the parent {@link NeighborsChangeSet} + */ @Nonnull @Override public NeighborsChangeSet getParent() { return parent; } + /** + * Merges the neighbors from the parent context, filtering out any neighbors that have been marked as deleted. + *

+ * This implementation retrieves the collection of neighbors from its parent by calling + * {@code getParent().merge()}. + * It then filters this collection, removing any neighbor whose primary key is present in the + * {@code deletedNeighborsPrimaryKeys} set. + * This ensures the resulting {@link Iterable} represents a consistent view of neighbors, respecting deletions made + * in the current context. + * + * @return an {@link Iterable} of the merged neighbors, excluding those marked as deleted. This method never returns + * {@code null}. + */ @Nonnull @Override public Iterable merge() { @@ -64,6 +101,23 @@ public Iterable merge() { current -> !deletedNeighborsPrimaryKeys.contains(current.getPrimaryKey())); } + /** + * Writes the delta of changes for a given node to the storage layer. + *

+ * This implementation first delegates to the parent's {@code writeDelta} method to handle its changes, but modifies + * the predicate to exclude any neighbors that are marked for deletion in this delta. + *

+ * It then iterates through the set of locally deleted neighbor primary keys. For each key that matches the supplied + * {@code tuplePredicate}, it instructs the {@link InliningStorageAdapter} to delete the corresponding neighbor + * relationship for the given {@code node}. + * + * @param storageAdapter the storage adapter to which the changes are written + * @param transaction the transaction context for the write operations + * @param layer the layer index where the write operations should occur + * @param node the node for which the delta is being written + * @param tuplePredicate a predicate to filter which neighbor tuples should be processed; + * only deletions matching this predicate will be written + */ @Override public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, final int layer, @Nonnull final Node node, @Nonnull final Predicate tuplePredicate) { diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java index 48e2398950..56d39227d1 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java @@ -30,9 +30,14 @@ import java.util.Objects; /** - * TODO. + * Represents a specific type of node within a graph structure that is used to represent nodes in an HNSW structure. + *

+ * This node extends {@link AbstractNode}, does not store its own vector and instead specifically manages neighbors + * of type {@link NodeReferenceWithVector} (which do store a vector each). + * It provides a concrete implementation for an "inlining" node, distinguishing it from other node types such as + * {@link CompactNode}. */ -class InliningNode extends AbstractNode { +public class InliningNode extends AbstractNode { @Nonnull private static final NodeFactory FACTORY = new NodeFactory<>() { @SuppressWarnings("unchecked") @@ -51,11 +56,32 @@ public NodeKind getNodeKind() { } }; + /** + * Constructs a new {@code InliningNode} with a specified primary key and a list of its neighbors. + *

+ * This constructor initializes the node by calling the constructor of its superclass, + * passing the primary key and neighbor list. + * + * @param primaryKey the non-null primary key of the node, represented by a {@link Tuple}. + * @param neighbors the non-null list of neighbors for this node, where each neighbor + * is a {@link NodeReferenceWithVector}. + */ public InliningNode(@Nonnull final Tuple primaryKey, @Nonnull final List neighbors) { super(primaryKey, neighbors); } + /** + * Gets a reference to this node. + * + * @param vector the vector to be associated with the node reference. Despite the + * {@code @Nullable} annotation, this parameter must not be null. + * + * @return a new {@link NodeReferenceWithVector} instance containing the node's + * primary key and the provided vector; will never be null. + * + * @throws NullPointerException if the provided {@code vector} is null. + */ @Nonnull @Override @SpotBugsSuppressWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") @@ -63,24 +89,51 @@ public NodeReferenceWithVector getSelfReference(@Nullable final Vector vec return new NodeReferenceWithVector(getPrimaryKey(), Objects.requireNonNull(vector)); } + /** + * Gets the kind of this node. + * @return the non-null {@link NodeKind} of this node, which is always + * {@code NodeKind.INLINING}. + */ @Nonnull @Override public NodeKind getKind() { return NodeKind.INLINING; } + /** + * Casts this node to a {@link CompactNode}. + *

+ * This implementation always throws an exception because this specific node type + * cannot be represented as a compact node. + * @return this node as a {@link CompactNode}, never {@code null} + * @throws IllegalStateException always, as this node is not a compact node + */ @Nonnull @Override public CompactNode asCompactNode() { throw new IllegalStateException("this is not a compact node"); } + /** + * Returns this object as an {@link InliningNode}. + *

+ * As this class is already an instance of {@code InliningNode}, this method simply returns {@code this}. + * @return this object, which is guaranteed to be an {@code InliningNode} and never {@code null}. + */ @Nonnull @Override public InliningNode asInliningNode() { return this; } + /** + * Returns the singleton factory instance used to create {@link NodeReferenceWithVector} objects. + *

+ * This method provides a standard way to obtain the factory, ensuring that a single, shared instance is used + * throughout the application. + * + * @return the singleton {@link NodeFactory} instance, never {@code null}. + */ @Nonnull public static NodeFactory factory() { return FACTORY; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java index ebbfd4d698..2835427ca4 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java @@ -39,9 +39,31 @@ import java.util.concurrent.CompletableFuture; /** - * TODO. + * An implementation of {@link StorageAdapter} for an HNSW graph that stores node vectors "in-line" with the node's + * neighbor information. + *

+ * In this storage model, each key-value pair in the database represents a single neighbor relationship. The key + * contains the primary keys of both the source node and the neighbor node, while the value contains the neighbor's + * vector. This contrasts with a "compact" storage model where a compact node represents a vector and all of its + * neighbors. This adapter is responsible for serializing and deserializing these structures to and from the underlying + * key-value store. + * + * @see StorageAdapter + * @see HNSW */ class InliningStorageAdapter extends AbstractStorageAdapter implements StorageAdapter { + /** + * Constructs a new {@code InliningStorageAdapter} with the given configuration and components. + *

+ * This constructor initializes the storage adapter by passing all necessary components + * to its superclass. + * + * @param config the HNSW configuration to use for the graph + * @param nodeFactory the factory to create new {@link NodeReferenceWithVector} instances + * @param subspace the subspace where the HNSW graph data is stored + * @param onWriteListener the listener to be notified on write operations + * @param onReadListener the listener to be notified on read operations + */ public InliningStorageAdapter(@Nonnull final HNSW.Config config, @Nonnull final NodeFactory nodeFactory, @Nonnull final Subspace subspace, @@ -50,18 +72,48 @@ public InliningStorageAdapter(@Nonnull final HNSW.Config config, super(config, nodeFactory, subspace, onWriteListener, onReadListener); } + /** + * Throws {@link IllegalStateException} because an inlining storage adapter cannot be converted to a compact one. + *

+ * This operation is fundamentally not supported for this type of adapter. An inlining adapter stores data directly + * within a parent structure, which is incompatible with the standalone nature of a compact storage format. + * @return This method never returns a value as it always throws an exception. + * @throws IllegalStateException always, as this operation is not supported. + */ @Nonnull @Override public StorageAdapter asCompactStorageAdapter() { throw new IllegalStateException("cannot call this method on an inlining storage adapter"); } + /** + * Returns this object instance as a {@code StorageAdapter} that supports inlining. + *

+ * This implementation returns the current instance ({@code this}) because the class itself is designed to handle + * inlining directly, thus no separate adapter object is needed. + * @return a non-null reference to this object as an {@link StorageAdapter} for inlining. + */ @Nonnull @Override public StorageAdapter asInliningStorageAdapter() { return this; } + /** + * Asynchronously fetches a single node from a given layer by its primary key. + *

+ * This internal method constructs a prefix key based on the {@code layer} and {@code primaryKey}. + * It then performs an asynchronous range scan to retrieve all key-value pairs associated with that prefix. + * Finally, it reconstructs the complete {@link Node} object from the collected raw data using + * the {@code nodeFromRaw} method. + * + * @param readTransaction the transaction to use for reading from the database + * @param layer the layer of the node to fetch + * @param primaryKey the primary key of the node to fetch + * + * @return a {@link CompletableFuture} that will complete with the fetched {@link Node} containing + * {@link NodeReferenceWithVector}s + */ @Nonnull @Override protected CompletableFuture> fetchNodeInternal(@Nonnull final ReadTransaction readTransaction, @@ -74,8 +126,27 @@ protected CompletableFuture> fetchNodeInternal(@No .thenApply(keyValues -> nodeFromRaw(layer, primaryKey, keyValues)); } + /** + * Constructs a {@code Node} from its raw key-value representation from storage. + *

+ * This method is responsible for deserializing a node and its neighbors. It processes a list of {@code KeyValue} + * pairs, where each pair represents a neighbor of the node being constructed. Each neighbor is converted from its + * raw form into a {@link NodeReferenceWithVector} by calling the {@link #neighborFromRaw(int, byte[], byte[])} + * method. + *

+ * Once the node is created with its primary key and list of neighbors, it notifies the configured + * {@link OnReadListener} of the read operation. + * + * @param layer the layer in the graph where this node exists + * @param primaryKey the primary key that uniquely identifies the node + * @param keyValues a list of {@code KeyValue} pairs representing the raw data of the node's neighbors + * + * @return a non-null, fully constructed {@code Node} object with its neighbors + */ @Nonnull - private Node nodeFromRaw(final int layer, final @Nonnull Tuple primaryKey, final List keyValues) { + private Node nodeFromRaw(final int layer, + @Nonnull final Tuple primaryKey, + @Nonnull final List keyValues) { final OnReadListener onReadListener = getOnReadListener(); final ImmutableList.Builder nodeReferencesWithVectorBuilder = ImmutableList.builder(); @@ -89,6 +160,19 @@ private Node nodeFromRaw(final int layer, final @Nonnul return node; } + /** + * Constructs a {@code NodeReferenceWithVector} from raw key and value byte arrays retrieved from storage. + *

+ * This helper method deserializes a neighbor's data. It unpacks the provided {@code key} to extract the neighbor's + * primary key and unpacks the {@code value} to extract the neighbor's vector. It also notifies the configured + * {@link OnReadListener} of the read operation. + * + * @param layer the layer of the graph where the neighbor node is located. + * @param key the raw byte array key from the database, which contains the neighbor's primary key. + * @param value the raw byte array value from the database, which represents the neighbor's vector. + * @return a new {@link NodeReferenceWithVector} instance representing the deserialized neighbor. + * @throws IllegalArgumentException if the key or value byte arrays are malformed and cannot be unpacked. + */ @Nonnull private NodeReferenceWithVector neighborFromRaw(final int layer, final @Nonnull byte[] key, final byte[] value) { final OnReadListener onReadListener = getOnReadListener(); @@ -102,6 +186,21 @@ private NodeReferenceWithVector neighborFromRaw(final int layer, final @Nonnull return new NodeReferenceWithVector(neighborPrimaryKey, neighborVector); } + /** + * Writes a given node and its neighbor changes to the specified layer within a transaction. + *

+ * This implementation first converts the provided {@link Node} to an {@link InliningNode}. It then delegates the + * writing of neighbor modifications to the {@link NeighborsChangeSet#writeDelta} method. After the changes are + * written, it notifies the registered {@code OnWriteListener} that the node has been processed + * via {@code getOnWriteListener().onNodeWritten()}. + * + * @param transaction the transaction context for the write operation; must not be null + * @param node the node to be written, which is expected to be an + * {@code InliningNode}; must not be null + * @param layer the layer index where the node and its neighbor changes should be written + * @param neighborsChangeSet the set of changes to the node's neighbors to be + * persisted; must not be null + */ @Override public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull final Node node, final int layer, @Nonnull final NeighborsChangeSet neighborsChangeSet) { @@ -111,11 +210,36 @@ public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull f getOnWriteListener().onNodeWritten(layer, node); } + /** + * Constructs the raw database key for a node based on its layer and primary key. + *

+ * This key is created by packing a tuple containing the specified {@code layer} and the node's {@code primaryKey} + * within the data subspace. The resulting byte array is suitable for use in direct database lookups and preserves + * the sort order of the components. + * + * @param layer the layer index where the node resides + * @param primaryKey the primary key that uniquely identifies the node within its layer, + * encapsulated in a {@link Tuple} + * + * @return a byte array representing the packed key for the specified node + */ @Nonnull private byte[] getNodeKey(final int layer, @Nonnull final Tuple primaryKey) { return getDataSubspace().pack(Tuple.from(layer, primaryKey)); } + /** + * Writes a neighbor for a given node to the underlying storage within a specific transaction. + *

+ * This method serializes the neighbor's vector and constructs a unique key based on the layer, the source + * {@code node}, and the neighbor's primary key. It then persists this key-value pair using the provided + * {@link Transaction}. After a successful write, it notifies any registered listeners. + * + * @param transaction the {@link Transaction} to use for the write operation + * @param layer the layer index where the node and its neighbor reside + * @param node the source {@link Node} for which the neighbor is being written + * @param neighbor the {@link NodeReferenceWithVector} representing the neighbor to persist + */ public void writeNeighbor(@Nonnull final Transaction transaction, final int layer, @Nonnull final Node node, @Nonnull final NodeReferenceWithVector neighbor) { final byte[] neighborKey = getNeighborKey(layer, node, neighbor.getPrimaryKey()); @@ -126,12 +250,35 @@ public void writeNeighbor(@Nonnull final Transaction transaction, final int laye getOnWriteListener().onKeyValueWritten(layer, neighborKey, value); } + /** + * Deletes a neighbor edge from a given node within a specific layer. + *

+ * This operation removes the key-value pair representing the neighbor relationship from the database within the + * given {@link Transaction}. It also notifies the {@code onWriteListener} about the deletion. + * + * @param transaction the transaction in which to perform the deletion + * @param layer the layer of the graph where the node resides + * @param node the node from which the neighbor edge is removed + * @param neighborPrimaryKey the primary key of the neighbor node to be deleted + */ public void deleteNeighbor(@Nonnull final Transaction transaction, final int layer, @Nonnull final Node node, @Nonnull final Tuple neighborPrimaryKey) { transaction.clear(getNeighborKey(layer, node, neighborPrimaryKey)); getOnWriteListener().onNeighborDeleted(layer, node, neighborPrimaryKey); } + /** + * Constructs the key for a specific neighbor of a node within a given layer. + *

+ * This key is used to uniquely identify and store the neighbor relationship in the underlying data store. It is + * formed by packing a {@link Tuple} containing the {@code layer}, the primary key of the source {@code node}, and + * the {@code neighborPrimaryKey}. + * + * @param layer the layer of the graph where the node and its neighbor reside + * @param node the non-null source node for which the neighbor key is being generated + * @param neighborPrimaryKey the non-null primary key of the neighbor node + * @return a non-null byte array representing the packed key for the neighbor relationship + */ @Nonnull private byte[] getNeighborKey(final int layer, @Nonnull final Node node, @@ -139,6 +286,24 @@ private byte[] getNeighborKey(final int layer, return getDataSubspace().pack(Tuple.from(layer, node.getPrimaryKey(), neighborPrimaryKey)); } + /** + * Scans a specific layer of the graph, reconstructing nodes and their neighbors from the underlying key-value + * store. + *

+ * This method reads raw {@link com.apple.foundationdb.KeyValue} records from the database within a given layer. + * It groups adjacent records that belong to the same parent node and uses a {@link NodeFactory} to construct + * {@link Node} objects. The method supports pagination through the {@code lastPrimaryKey} parameter, allowing for + * incremental scanning of large layers. + * + * @param readTransaction the transaction to use for reading data + * @param layer the layer of the graph to scan + * @param lastPrimaryKey the primary key of the last node read in a previous scan, used for pagination. + * If {@code null}, the scan starts from the beginning of the layer. + * @param maxNumRead the maximum number of raw key-value records to read from the database + * + * @return an {@code Iterable} of {@link Node} objects reconstructed from the scanned layer. Each node contains + * its neighbors within that layer. + */ @Nonnull @Override public Iterable> scanLayer(@Nonnull final ReadTransaction readTransaction, int layer, diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java index 6e236a5d10..f5fe817e53 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java @@ -22,9 +22,45 @@ import javax.annotation.Nonnull; +/** + * Defines a metric for measuring the distance or similarity between n-dimensional vectors. + *

+ * This interface provides a contract for various distance calculation algorithms, such as Euclidean, Manhattan, + * and Cosine distance. Implementations of this interface can be used in algorithms that require a metric for + * comparing data vectors, like clustering or nearest neighbor searches. + */ public interface Metric { + /** + * Calculates a distance between two n-dimensional vectors. + *

+ * The two vectors are represented as arrays of {@link Double} and must be of the + * same length (i.e., have the same number of dimensions). + * + * @param vector1 the first vector. Must not be null. + * @param vector2 the second vector. Must not be null and must have the same + * length as {@code vector1}. + * + * @return the calculated distance as a {@code double}. + * + * @throws IllegalArgumentException if the vectors have different lengths. + * @throws NullPointerException if either {@code vector1} or {@code vector2} is null. + */ double distance(Double[] vector1, Double[] vector2); + /** + * Calculates a comparative distance between two vectors. The comparative distance is used in contexts such as + * ranking where the caller needs to "compare" two distances. In contrast to a true metric, the distances computed + * by this method do not need to follow proper metric invariants: The distance can be negative; the distance + * does not need to follow triangle inequality. + *

+ * This method is an alias for {@link #distance(Double[], Double[])} under normal circumstances. It is not for e.g. + * {@link DotProductMetric} where the distance is the negative dot product. + * + * @param vector1 the first vector, represented as an array of {@code Double}. + * @param vector2 the second vector, represented as an array of {@code Double}. + * + * @return the distance between the two vectors. + */ default double comparativeDistance(Double[] vector1, Double[] vector2) { return distance(vector1, vector2); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java index b7f38ef1a7..081523de5b 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java @@ -28,15 +28,53 @@ import java.util.function.Predicate; /** - * TODO. + * Represents a set of changes to the neighbors of a node within an HNSW graph. + *

+ * Implementations of this interface manage modifications, such as additions or removals of neighbors. often in a + * layered fashion. This allows for composing changes before they are committed to storage. The {@link #getParent()} + * method returns the next element in this layered structure while {@link #merge()} consolidates changes into + * a final neighbor list. + * + * @param the type of the node reference, which must extend {@link NodeReference} */ interface NeighborsChangeSet { + /** + * Gets the parent change set from which this change set was derived. + *

+ * Change sets can be layered, forming a chain of modifications. + * This method allows for traversing up this tree to the preceding set of changes. + * + * @return the parent {@code NeighborsChangeSet}, or {@code null} if this change set + * is the root of the change tree and has no parent. + */ @Nullable NeighborsChangeSet getParent(); + /** + * Merges multiple internal sequences into a single, consolidated iterable sequence. + *

+ * This method combines distinct internal changesets into one continuous stream of neighbors. The specific order + * of the merged elements depends on the implementation. + * + * @return a non-null {@code Iterable} containing the merged sequence of elements. + */ @Nonnull Iterable merge(); + /** + * Writes the neighbor delta for a given {@link Node} to the specified storage layer. + *

+ * This method processes the provided {@code node} and writes only the records that match the given + * {@code primaryKeyPredicate} to the storage system via the {@link InliningStorageAdapter}. The entire operation + * is performed within the context of the supplied {@link Transaction}. + * + * @param storageAdapter the storage adapter to which the delta will be written; must not be null + * @param transaction the transaction context for the write operation; must not be null + * @param layer the specific storage layer to write the delta to + * @param node the source node containing the data to be written; must not be null + * @param primaryKeyPredicate a predicate to filter records by their primary key. Only records + * for which the predicate returns {@code true} will be written. Must not be null. + */ void writeDelta(@Nonnull InliningStorageAdapter storageAdapter, @Nonnull Transaction transaction, int layer, @Nonnull Node node, @Nonnull Predicate primaryKeyPredicate); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java index f2c623f882..3ddae2ec74 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java @@ -28,19 +28,57 @@ import java.util.List; /** - * TODO. - * @param neighbor type + * Represents a node within an HNSW (Hierarchical Navigable Small World) structure. + *

+ * A node corresponds to a data point (vector) in the structure and maintains a list of its neighbors. + * This interface defines the common contract for different node representations, such as {@link CompactNode} + * and {@link InliningNode}. + *

+ * + * @param the type of reference used to point to other nodes, which must extend {@link NodeReference} */ public interface Node { + /** + * Gets the primary key for this object. + *

+ * The primary key is represented as a {@link Tuple} and uniquely identifies + * the object within its storage context. This method is guaranteed to not + * return a null value. + * + * @return the primary key as a {@code Tuple}, which is never {@code null} + */ @Nonnull Tuple getPrimaryKey(); + /** + * Returns a self-reference to this object, enabling fluent method chaining. This allows to create node references + * that contain an vector and are independent of the storage implementation. + * @param vector the vector of {@code Half} objects to process. This parameter + * is optional and can be {@code null}. + * + * @return a non-null reference to this object ({@code this}) for further + * method calls. + */ @Nonnull N getSelfReference(@Nullable Vector vector); + /** + * Gets the list of neighboring nodes. + *

+ * This method is guaranteed to not return {@code null}. If there are no neighbors, an empty list is returned. + * + * @return a non-null list of neighboring nodes. + */ @Nonnull List getNeighbors(); + /** + * Gets the neighbor at the specified index. + *

+ * This method provides access to the neighbors of a particular node or element, identified by a zero-based index. + * @param index the zero-based index of the neighbor to retrieve. + * @return the neighbor at the specified index; this method will never return {@code null}. + */ @Nonnull N getNeighbor(int index); @@ -51,9 +89,23 @@ public interface Node { @Nonnull NodeKind getKind(); + /** + * Converts this node into its {@link CompactNode} representation. + *

+ * A {@code CompactNode} is a space-efficient implementation {@code Node}. This method provides the + * conversion logic to transform the current object into that compact form. + * + * @return a non-null {@link CompactNode} representing the current node. + */ @Nonnull CompactNode asCompactNode(); + /** + * Converts this node into its {@link InliningNode} representation. + * @return this object cast to an {@link InliningNode}; never {@code null}. + * @throws ClassCastException if this object is not actually an instance of + * {@link InliningNode}. + */ @Nonnull InliningNode asInliningNode(); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java index e21b221622..837c88fb00 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java @@ -26,31 +26,71 @@ import javax.annotation.Nonnull; +/** + * Represents a reference to a node that includes an associated vector. + *

+ * This class extends {@link NodeReference} by adding a {@code Vector} field. It encapsulates both the primary key + * of a node and its corresponding vector data, which is particularly useful in vector-based search and + * indexing scenarios. Primarily, node references are used to refer to {@link Node}s in a storage-independent way, i.e. + * a node reference always contains the vector of a node while the node itself (depending on the storage adapter) + * may not. + */ public class NodeReferenceWithVector extends NodeReference { @Nonnull private final Vector vector; + /** + * Constructs a new {@code NodeReferenceWithVector} with a specified primary key and vector. + *

+ * The primary key is used to initialize the parent class via a call to {@code super()}, + * while the vector is stored as a field in this instance. Both parameters are expected + * to be non-null. + * + * @param primaryKey the primary key of the node, must not be null + * @param vector the vector associated with the node, must not be null + */ public NodeReferenceWithVector(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector) { super(primaryKey); this.vector = vector; } + /** + * Gets the vector of {@code Half} objects. + *

+ * This method provides access to the internal vector. The returned vector is guaranteed + * not to be null, as indicated by the {@code @Nonnull} annotation. + * + * @return the vector of {@code Half} objects; will never be {@code null}. + */ @Nonnull public Vector getVector() { return vector; } + /** + * Gets the vector as a {@code Vector} of {@code Double}s. + * @return a non-null {@code Vector} containing the elements of this vector. + */ @Nonnull public Vector getDoubleVector() { return vector.toDoubleVector(); } + /** + * Returns this instance cast as a {@code NodeReferenceWithVector}. + * @return this instance as a {@code NodeReferenceWithVector}, which is never {@code null}. + */ @Nonnull @Override public NodeReferenceWithVector asNodeReferenceWithVector() { return this; } + /** + * Compares this {@code NodeReferenceWithVector} to the specified object for equality. + * @param o the object to compare with this {@code NodeReferenceWithVector}. + * @return {@code true} if the objects are equal; {@code false} otherwise. + */ @Override public boolean equals(final Object o) { if (!(o instanceof NodeReferenceWithVector)) { @@ -62,11 +102,19 @@ public boolean equals(final Object o) { return Objects.equal(vector, ((NodeReferenceWithVector)o).vector); } + /** + * Computes the hash code for this object. + * @return a hash code value for this object. + */ @Override public int hashCode() { return Objects.hashCode(super.hashCode(), vector); } + /** + * Returns a string representation of this object. + * @return a concise string representation of this object. + */ @Override public String toString() { return "NRV[primaryKey=" + getPrimaryKey() + diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWHelpersTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWHelpersTest.java new file mode 100644 index 0000000000..831d3774d1 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWHelpersTest.java @@ -0,0 +1,75 @@ +/* + * HNSWHelpersTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.christianheina.langx.half4j.Half; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +public class HNSWHelpersTest { + @Test + public void bytesToHex_MultipleBytesWithLeadingZeros_ReturnsTrimmedHexTest() { + final byte[] bytes = new byte[] {0, 1, 16, (byte)255}; // Represents 000110FF + final String result = HNSWHelpers.bytesToHex(bytes); + assertEquals("0x110FF", result); + } + + @Test + public void bytesToHex_NegativeByteValues_ReturnsCorrectUnsignedHexTest() { + final byte[] bytes = new byte[] {-1, -2}; // 0xFFFE + final String result = HNSWHelpers.bytesToHex(bytes); + assertEquals("0xFFFE", result); + } + + @Test + public void halfValueOf_NegativeFloat_ReturnsCorrectHalfValue_Test() { + final float inputValue = -56.75f; + final Half expected = Half.valueOf(inputValue); + final Half result = HNSWHelpers.halfValueOf(inputValue); + assertEquals(expected, result); + } + + @Test + public void halfValueOf_PositiveFloat_ReturnsCorrectHalfValue_Test() { + final float inputValue = 123.4375f; + Half expected = Half.valueOf(inputValue); + Half result = HNSWHelpers.halfValueOf(inputValue); + assertEquals(expected, result); + } + + @Test + public void halfValueOf_NegativeDouble_ReturnsCorrectHalfValue_Test() { + final double inputValue = -56.75d; + final Half expected = Half.valueOf(inputValue); + final Half result = HNSWHelpers.halfValueOf(inputValue); + assertEquals(expected, result); + } + + @Test + public void halfValueOf_PositiveDouble_ReturnsCorrectHalfValue_Test() { + final double inputValue = 123.4375d; + Half expected = Half.valueOf(inputValue); + Half result = HNSWHelpers.halfValueOf(inputValue); + assertEquals(expected, result); + } +} \ No newline at end of file diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java index 7a8bf73e0d..c746516a03 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java @@ -32,26 +32,20 @@ import com.christianheina.langx.half4j.Half; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import org.assertj.core.util.Lists; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.parallel.Execution; import org.junit.jupiter.api.parallel.ExecutionMode; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nonnull; -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.FileReader; -import java.io.FileWriter; import java.io.IOException; import java.nio.channels.FileChannel; import java.nio.file.Path; @@ -62,13 +56,10 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.NavigableSet; -import java.util.Objects; import java.util.Random; -import java.util.concurrent.ConcurrentSkipListSet; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; /** @@ -249,9 +240,9 @@ private int insertBatch(final HNSW hnsw, final int batchSize, } @Test - public void testSIFTInsert10k() throws Exception { + public void testSIFTInsertSmall() throws Exception { final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); - final int k = 10; + final int k = 100; final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); final TestOnReadListener onReadListener = new TestOnReadListener(); @@ -272,9 +263,7 @@ public void testSIFTInsert10k() throws Exception { if (!vectorIterator.hasNext()) { return null; } - final Vector.DoubleVector doubleVector = vectorIterator.next(); - final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic); final HalfVector currentVector = doubleVector.toHalfVector(); return new NodeReferenceWithVector(currentPrimaryKey, currentVector); @@ -282,9 +271,14 @@ public void testSIFTInsert10k() throws Exception { } } + validateSIFTSmall(hnsw, k); + } + + private void validateSIFTSmall(@Nonnull final HNSW hnsw, final int k) throws IOException { final Path siftSmallGroundTruthPath = Paths.get(".out/extracted/siftsmall/siftsmall_groundtruth.ivecs"); final Path siftSmallQueryPath = Paths.get(".out/extracted/siftsmall/siftsmall_query.fvecs"); + final TestOnReadListener onReadListener = (TestOnReadListener)hnsw.getOnReadListener(); try (final var queryChannel = FileChannel.open(siftSmallQueryPath, StandardOpenOption.READ); final var groundTruthChannel = FileChannel.open(siftSmallGroundTruthPath, StandardOpenOption.READ)) { @@ -295,34 +289,39 @@ public void testSIFTInsert10k() throws Exception { while (queryIterator.hasNext()) { final HalfVector queryVector = queryIterator.next().toHalfVector(); + final Set groundTruthIndices = ImmutableSet.copyOf(groundTruthIterator.next()); onReadListener.reset(); final long beginTs = System.nanoTime(); final List> results = db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVector).join()); final long endTs = System.nanoTime(); - logger.info("retrieved result in elapsedTimeMs={}", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); + logger.trace("retrieved result in elapsedTimeMs={}", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); + int recallCount = 0; for (NodeReferenceAndNode nodeReferenceAndNode : results) { - final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); - logger.info("retrieved result nodeId = {} at distance = {}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), - nodeReferenceWithDistance.getDistance()); + final NodeReferenceWithDistance nodeReferenceWithDistance = + nodeReferenceAndNode.getNodeReferenceWithDistance(); + final int primaryKeyIndex = (int)nodeReferenceWithDistance.getPrimaryKey().getLong(0); + logger.trace("retrieved result nodeId = {} at distance = {} reading numNodes={}, readBytes={}", + primaryKeyIndex, nodeReferenceWithDistance.getDistance(), + onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer()); + if (groundTruthIndices.contains(primaryKeyIndex)) { + recallCount ++; + } } - logger.info("true result vector={}", groundTruthIterator.next()); + final double recall = (double)recallCount / k; + Assertions.assertTrue(recall > 0.93); + + logger.info("query returned results recall={}", String.format("%.2f", recall * 100.0d)); } } - - System.out.println(onReadListener.getNodeCountByLayer()); - System.out.println(onReadListener.getBytesReadByLayer()); - - // logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); } @Test - @Timeout(value = 150, unit = TimeUnit.MINUTES) - public void testSIFTInsert10kWithBatchInsert() throws Exception { + public void testSIFTInsertSmallUsingBatchAPI() throws Exception { final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); - final int k = 10; + final int k = 100; final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); final TestOnReadListener onReadListener = new TestOnReadListener(); @@ -331,99 +330,26 @@ public void testSIFTInsert10kWithBatchInsert() throws Exception { HNSW.DEFAULT_CONFIG.toBuilder().setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(), OnWriteListener.NOOP, onReadListener); - final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv"; - final int dimensions = 128; + final Path siftSmallPath = Paths.get(".out/extracted/siftsmall/siftsmall_base.fvecs"); - final AtomicReference queryVectorAtomic = new AtomicReference<>(); - final NavigableSet trueResults = new ConcurrentSkipListSet<>( - Comparator.comparing(NodeReferenceWithDistance::getDistance)); + try (final var fileChannel = FileChannel.open(siftSmallPath, StandardOpenOption.READ)) { + final Iterator vectorIterator = new Vector.StoredFVecsIterator(fileChannel); - try (BufferedReader br = new BufferedReader(new FileReader(tsvFile))) { - for (int i = 0; i < 10000;) { + int i = 0; + while (vectorIterator.hasNext()) { i += insertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, tr -> { - final String line; - try { - line = br.readLine(); - } catch (IOException e) { - throw new RuntimeException(e); - } - - final String[] values = Objects.requireNonNull(line).split("\t"); - Assertions.assertEquals(dimensions, values.length); - final Half[] halfs = new Half[dimensions]; - - for (int c = 0; c < values.length; c++) { - final String value = values[c]; - halfs[c] = HNSWHelpers.halfValueOf(Double.parseDouble(value)); - } - final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic); - final HalfVector currentVector = new HalfVector(halfs); - final HalfVector queryVector = queryVectorAtomic.get(); - if (queryVector == null) { - queryVectorAtomic.set(currentVector); + if (!vectorIterator.hasNext()) { return null; - } else { - final double currentDistance = - Vector.comparativeDistance(metric, currentVector, queryVector); - if (trueResults.size() < k || trueResults.last().getDistance() > currentDistance) { - trueResults.add( - new NodeReferenceWithDistance(currentPrimaryKey, currentVector, - Vector.comparativeDistance(metric, currentVector, queryVector))); - } - if (trueResults.size() > k) { - trueResults.remove(trueResults.last()); - } - return new NodeReferenceWithVector(currentPrimaryKey, currentVector); } + final Vector.DoubleVector doubleVector = vectorIterator.next(); + final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final HalfVector currentVector = doubleVector.toHalfVector(); + return new NodeReferenceWithVector(currentPrimaryKey, currentVector); }); } } - - onReadListener.reset(); - final long beginTs = System.nanoTime(); - final List> results = - db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVectorAtomic.get()).join()); - final long endTs = System.nanoTime(); - - for (NodeReferenceAndNode nodeReferenceAndNode : results) { - final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); - logger.info("retrieved result nodeId = {} at distance= {}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), - nodeReferenceWithDistance.getDistance()); - } - - for (final NodeReferenceWithDistance nodeReferenceWithDistance : trueResults) { - logger.info("true result nodeId ={} at distance={}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), - nodeReferenceWithDistance.getDistance()); - } - - System.out.println(onReadListener.getNodeCountByLayer()); - System.out.println(onReadListener.getBytesReadByLayer()); - - logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); - } - - @Test - public void testBasicInsertAndScanLayer() throws Exception { - final Random random = new Random(0); - final AtomicLong nextNodeId = new AtomicLong(0L); - final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), - HNSW.DEFAULT_CONFIG.toBuilder().setM(4).setMMax(4).setMMax0(4).build(), - OnWriteListener.NOOP, OnReadListener.NOOP); - - db.run(tr -> { - for (int i = 0; i < 100; i ++) { - hnsw.insert(tr, createNextPrimaryKey(nextNodeId), createRandomVector(random, 2)).join(); - } - return null; - }); - - int layer = 0; - while (true) { - if (!dumpLayer(hnsw, layer++)) { - break; - } - } + validateSIFTSmall(hnsw, k); } @Test @@ -438,112 +364,6 @@ public void testManyRandomVectors() { } } - @Test - @Timeout(value = 150, unit = TimeUnit.MINUTES) - public void testSIFTVectors() throws Exception { - final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); - - final TestOnReadListener onReadListener = new TestOnReadListener(); - - final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), - HNSW.DEFAULT_CONFIG.toBuilder().setMetric(Metrics.EUCLIDEAN_METRIC.getMetric()) - .setM(32).setMMax(32).setMMax0(64).build(), - OnWriteListener.NOOP, onReadListener); - - - final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv"; - final int dimensions = 128; - final var referenceVector = createRandomVector(new Random(0), dimensions); - long count = 0L; - double mean = 0.0d; - double mean2 = 0.0d; - - try (BufferedReader br = new BufferedReader(new FileReader(tsvFile))) { - for (int i = 0; i < 100_000; i ++) { - final String line; - try { - line = br.readLine(); - } catch (IOException e) { - throw new RuntimeException(e); - } - - final String[] values = Objects.requireNonNull(line).split("\t"); - Assertions.assertEquals(dimensions, values.length); - final Half[] halfs = new Half[dimensions]; - for (int c = 0; c < values.length; c++) { - final String value = values[c]; - halfs[c] = HNSWHelpers.halfValueOf(Double.parseDouble(value)); - } - final HalfVector newVector = new HalfVector(halfs); - final double distance = Vector.comparativeDistance(Metrics.EUCLIDEAN_METRIC.getMetric(), - referenceVector, newVector); - count++; - final double delta = distance - mean; - mean += delta / count; - final double delta2 = distance - mean; - mean2 += delta * delta2; - } - } - final double sampleVariance = mean2 / (count - 1); - final double standardDeviation = Math.sqrt(sampleVariance); - logger.info("mean={}, sample_variance={}, stddeviation={}, cv={}", mean, sampleVariance, standardDeviation, - standardDeviation / mean); - } - - @ParameterizedTest - @ValueSource(ints = {2, 3, 10, 100, 768}) - public void testManyVectorsStandardDeviation(final int dimensionality) { - final Random random = new Random(); - final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); - long count = 0L; - double mean = 0.0d; - double mean2 = 0.0d; - for (long i = 0L; i < 100000; i ++) { - final HalfVector vector1 = createRandomVector(random, dimensionality); - final HalfVector vector2 = createRandomVector(random, dimensionality); - final double distance = Vector.comparativeDistance(metric, vector1, vector2); - count = i + 1; - final double delta = distance - mean; - mean += delta / count; - final double delta2 = distance - mean; - mean2 += delta * delta2; - } - final double sampleVariance = mean2 / (count - 1); - final double standardDeviation = Math.sqrt(sampleVariance); - logger.info("mean={}, sample_variance={}, stddeviation={}, cv={}", mean, sampleVariance, standardDeviation, - standardDeviation / mean); - } - - private boolean dumpLayer(final HNSW hnsw, final int layer) throws IOException { - final String verticesFileName = "/Users/nseemann/Downloads/vertices-" + layer + ".csv"; - final String edgesFileName = "/Users/nseemann/Downloads/edges-" + layer + ".csv"; - - final AtomicLong numReadAtomic = new AtomicLong(0L); - try (final BufferedWriter verticesWriter = new BufferedWriter(new FileWriter(verticesFileName)); - final BufferedWriter edgesWriter = new BufferedWriter(new FileWriter(edgesFileName))) { - hnsw.scanLayer(db, layer, 100, node -> { - final CompactNode compactNode = node.asCompactNode(); - final Vector vector = compactNode.getVector(); - try { - verticesWriter.write(compactNode.getPrimaryKey().getLong(0) + "," + - vector.getComponent(0) + "," + - vector.getComponent(1)); - verticesWriter.newLine(); - - for (final var neighbor : compactNode.getNeighbors()) { - edgesWriter.write(compactNode.getPrimaryKey().getLong(0) + "," + - neighbor.getPrimaryKey().getLong(0)); - edgesWriter.newLine(); - } - numReadAtomic.getAndIncrement(); - } catch (final IOException e) { - throw new RuntimeException("unable to write to file", e); - } - }); - } - return numReadAtomic.get() != 0; - } - private void writeNode(@Nonnull final Transaction transaction, @Nonnull final StorageAdapter storageAdapter, @Nonnull final Node node, diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java new file mode 100644 index 0000000000..d751fe5f00 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java @@ -0,0 +1,174 @@ +/* + * MetricTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class MetricTest { + private final Metric.ManhattanMetric manhattanMetric = new Metric.ManhattanMetric(); + private final Metric.EuclideanMetric euclideanMetric = new Metric.EuclideanMetric(); + private final Metric.EuclideanSquareMetric euclideanSquareMetric = new Metric.EuclideanSquareMetric(); + private final Metric.CosineMetric cosineMetric = new Metric.CosineMetric(); + private Metric.DotProductMetric dotProductMetric; + + @BeforeEach + + public void setUp() { + dotProductMetric = new Metric.DotProductMetric(); + } + + @Test + public void manhattanMetricDistanceWithIdenticalVectorsShouldReturnZeroTest() { + // Arrange + Double[] vector1 = {1.0, 2.5, -3.0}; + Double[] vector2 = {1.0, 2.5, -3.0}; + double expectedDistance = 0.0; + + // Act + double actualDistance = manhattanMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void manhattanMetricDistanceWithPositiveValueVectorsShouldReturnCorrectDistanceTest() { + // Arrange + Double[] vector1 = {1.0, 2.0, 3.0}; + Double[] vector2 = {4.0, 5.0, 6.0}; + double expectedDistance = 9.0; // |1-4| + |2-5| + |3-6| = 3 + 3 + 3 + + // Act + double actualDistance = manhattanMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void euclideanMetricDistanceWithIdenticalVectorsShouldReturnZeroTest() { + // Arrange + Double[] vector1 = {1.0, 2.5, -3.0}; + Double[] vector2 = {1.0, 2.5, -3.0}; + double expectedDistance = 0.0; + + // Act + double actualDistance = euclideanMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void euclideanMetricDistanceWithDifferentPositiveVectorsShouldReturnCorrectDistanceTest() { + // Arrange + Double[] vector1 = {1.0, 2.0}; + Double[] vector2 = {4.0, 6.0}; + double expectedDistance = 5.0; // sqrt((1-4)^2 + (2-6)^2) = sqrt(9 + 16) = 5.0 + + // Act + double actualDistance = euclideanMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void euclideanSquareMetricDistanceWithIdenticalVectorsShouldReturnZeroTest() { + // Arrange + Double[] vector1 = {1.0, 2.5, -3.0}; + Double[] vector2 = {1.0, 2.5, -3.0}; + double expectedDistance = 0.0; + + // Act + double actualDistance = euclideanSquareMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void euclideanSquareMetricDistanceWithDifferentPositiveVectorsShouldReturnCorrectDistanceTest() { + // Arrange + Double[] vector1 = {1.0, 2.0}; + Double[] vector2 = {4.0, 6.0}; + double expectedDistance = 25.0; // (1-4)^2 + (2-6)^2 = 9 + 16 = 25.0 + + // Act + double actualDistance = euclideanSquareMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void cosineMetricDistanceWithIdenticalVectorsReturnsZeroTest() { + // Arrange + Double[] vector1 = {5.0, 3.0, -2.0}; + Double[] vector2 = {5.0, 3.0, -2.0}; + double expectedDistance = 0.0; + + // Act + double actualDistance = cosineMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void cosineMetricDistanceWithOrthogonalVectorsReturnsOneTest() { + // Arrange + Double[] vector1 = {1.0, 0.0}; + Double[] vector2 = {0.0, 1.0}; + double expectedDistance = 1.0; + + // Act + double actualDistance = cosineMetric.distance(vector1, vector2); + + // Assert + assertEquals(expectedDistance, actualDistance, 0.00001); + } + + @Test + public void dotProductMetricComparativeDistanceWithPositiveVectorsTest() { + Double[] vector1 = {1.0, 2.0, 3.0}; + Double[] vector2 = {4.0, 5.0, 6.0}; + double expected = -32.0; + + double actual = dotProductMetric.comparativeDistance(vector1, vector2); + + assertEquals(expected, actual, 0.00001); + } + + @Test + public void dotProductMetricComparativeDistanceWithOrthogonalVectorsReturnsZeroTest() { + Double[] vector1 = {1.0, 0.0}; + Double[] vector2 = {0.0, 1.0}; + double expected = -0.0; + + double actual = dotProductMetric.comparativeDistance(vector1, vector2); + + assertEquals(expected, actual, 0.00001); + } +} \ No newline at end of file From f2e9d5c061936e7c7b20e4e08b57d08f849d2fca Mon Sep 17 00:00:00 2001 From: Normen Seemann Date: Fri, 19 Sep 2025 12:42:56 +0200 Subject: [PATCH 06/10] adding a lot of java doc --- .../foundationdb/async/MoreAsyncUtil.java | 29 +++ .../async/hnsw/BaseNeighborsChangeSet.java | 2 +- .../async/hnsw/CompactStorageAdapter.java | 136 ++++++++++++- .../async/hnsw/DeleteNeighborsChangeSet.java | 2 +- .../async/hnsw/EntryNodeReference.java | 41 +++- .../foundationdb/async/hnsw/HNSWHelpers.java | 15 ++ .../async/hnsw/InliningStorageAdapter.java | 2 +- .../async/hnsw/InsertNeighborsChangeSet.java | 47 ++++- .../apple/foundationdb/async/hnsw/Metric.java | 45 +++++ .../foundationdb/async/hnsw/Metrics.java | 70 ++++++- .../async/hnsw/NeighborsChangeSet.java | 2 +- .../foundationdb/async/hnsw/NodeFactory.java | 28 +++ .../foundationdb/async/hnsw/NodeKind.java | 30 ++- .../async/hnsw/NodeReference.java | 45 +++++ .../async/hnsw/NodeReferenceAndNode.java | 29 +++ .../async/hnsw/NodeReferenceWithDistance.java | 35 ++++ .../async/hnsw/OnReadListener.java | 34 +++- .../async/hnsw/OnWriteListener.java | 45 ++++- .../async/hnsw/StorageAdapter.java | 184 ++++++++++++++++-- .../apple/foundationdb/async/hnsw/Vector.java | 167 +++++++++++++++- .../foundationdb/async/hnsw/package-info.java | 2 +- .../async/hnsw/HNSWModificationTest.java | 9 +- .../foundationdb/async/hnsw/MetricTest.java | 2 +- 23 files changed, 964 insertions(+), 37 deletions(-) diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java index 64e6d6b732..e696512fdd 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java @@ -1057,6 +1057,23 @@ public static CompletableFuture swallowException(@Nonnull CompletableFutur return result; } + /** + * Method that provides the functionality of a for loop, however, in an asynchronous way. The result of this method + * is a {@link CompletableFuture} that represents the result of the last iteration of the loop body. + * @param startI an integer analogous to the starting value of a loop variable in a for loop + * @param startU an object of some type {@code U} that represents some initial state that is passed to the loop's + * initial state + * @param conditionPredicate a predicate on the loop variable that must be true before the next iteration is + * entered; analogous to the condition in a for loop + * @param stepFunction a unary operator used for modifying the loop variable after each iteration + * @param body a bi-function to be called for each iteration; this function is initially invoked using + * {@code startI} and {@code startU}; the result of the body is then passed into the next iterator's body + * together with a new value for the loop variable. In this way callers can access state inside an iteration + * that was computed in a previous iteration. + * @param executor the executor + * @param the type of the result of the body {@link BiFunction} + * @return a {@link CompletableFuture} containing the result of the last iteration's body invocation. + */ @Nonnull public static CompletableFuture forLoop(final int startI, @Nullable final U startU, @Nonnull final IntPredicate conditionPredicate, @@ -1079,6 +1096,18 @@ public static CompletableFuture forLoop(final int startI, @Nullable final }, executor).thenApply(ignored -> lastResultAtomic.get()); } + /** + * Method to iterate over some items, for each of which a body is executed asynchronously. The result of each such + * executed is then collected in a list and returned as a {@link CompletableFuture} over that list. + * @param items the items to iterate over + * @param body a function to be called for each item + * @param parallelism the maximum degree of parallelism this method should use + * @param executor the executor + * @param the type of item + * @param the type of the result + * @return a {@link CompletableFuture} containing a list of results collected from the individual body invocations + */ + @Nonnull @SuppressWarnings("unchecked") public static CompletableFuture> forEach(@Nonnull final Iterable items, @Nonnull final Function> body, diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java index 794bd5ae4c..5d27783b9e 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java @@ -1,5 +1,5 @@ /* - * InliningNode.java + * BaseNeighborsChangeSet.java * * This source file is part of the FoundationDB open source project * diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java index 4d9497ba0a..0c38296807 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java @@ -41,10 +41,30 @@ import java.util.List; import java.util.concurrent.CompletableFuture; +/** + * The {@code CompactStorageAdapter} class is a concrete implementation of {@link StorageAdapter} for managing HNSW + * graph data in a compact format. + *

+ * It handles the serialization and deserialization of graph nodes to and from a persistent data store. This + * implementation is optimized for space efficiency by storing nodes with their accompanying vector data and by storing + * just neighbor primary keys. It extends {@link AbstractStorageAdapter} to inherit common storage logic. + */ class CompactStorageAdapter extends AbstractStorageAdapter implements StorageAdapter { @Nonnull private static final Logger logger = LoggerFactory.getLogger(CompactStorageAdapter.class); + /** + * Constructs a new {@code CompactStorageAdapter}. + *

+ * This constructor initializes the adapter by delegating to the superclass, + * setting up the necessary components for managing an HNSW graph. + * + * @param config the HNSW graph configuration, must not be null. See {@link HNSW.Config}. + * @param nodeFactory the factory used to create new nodes of type {@link NodeReference}, must not be null. + * @param subspace the {@link Subspace} where the graph data is stored, must not be null. + * @param onWriteListener the listener to be notified of write events, must not be null. + * @param onReadListener the listener to be notified of read events, must not be null. + */ public CompactStorageAdapter(@Nonnull final HNSW.Config config, @Nonnull final NodeFactory nodeFactory, @Nonnull final Subspace subspace, @Nonnull final OnWriteListener onWriteListener, @@ -52,18 +72,50 @@ public CompactStorageAdapter(@Nonnull final HNSW.Config config, @Nonnull final N super(config, nodeFactory, subspace, onWriteListener, onReadListener); } + /** + * Returns this storage adapter instance, as it is already a compact storage adapter. + * @return the current instance, which serves as its own compact representation. + * This will never be {@code null}. + */ @Nonnull @Override public StorageAdapter asCompactStorageAdapter() { return this; } + /** + * Returns this adapter as a {@code StorageAdapter} that supports inlining. + *

+ * This operation is not supported by a compact storage adapter. Calling this method on this implementation will + * always result in an {@code IllegalStateException}. + * + * @return an instance of {@code StorageAdapter} that supports inlining + * + * @throws IllegalStateException unconditionally, as this operation is not supported + * on a compact storage adapter. + */ @Nonnull @Override public StorageAdapter asInliningStorageAdapter() { throw new IllegalStateException("cannot call this method on a compact storage adapter"); } + /** + * Asynchronously fetches a node from the database for a given layer and primary key. + *

+ * This internal method constructs a raw byte key from the {@code layer} and {@code primaryKey} + * within the store's data subspace. It then uses the provided {@link ReadTransaction} to + * retrieve the raw value. If a value is found, it is deserialized into a {@link Node} object + * using the {@code nodeFromRaw} method. + * + * @param readTransaction the transaction to use for the read operation + * @param layer the layer of the node to fetch + * @param primaryKey the primary key of the node to fetch + * + * @return a future that will complete with the fetched {@link Node} + * + * @throws IllegalStateException if the node cannot be found in the database for the given key + */ @Nonnull @Override protected CompletableFuture> fetchNodeInternal(@Nonnull final ReadTransaction readTransaction, @@ -80,20 +132,52 @@ protected CompletableFuture> fetchNodeInternal(@Nonnull fina }); } + /** + * Deserializes a raw key-value byte array pair into a {@code Node}. + *

+ * This method first converts the {@code valueBytes} into a {@link Tuple} and then, + * along with the {@code primaryKey}, constructs the final {@code Node} object. + * It also notifies any registered {@link OnReadListener} about the raw key-value + * read and the resulting node creation. + * + * @param layer the layer of the HNSW where this node resides + * @param primaryKey the primary key for the node + * @param keyBytes the raw byte representation of the node's key + * @param valueBytes the raw byte representation of the node's value, which will be deserialized + * + * @return a non-null, deserialized {@link Node} object + */ @Nonnull private Node nodeFromRaw(final int layer, final @Nonnull Tuple primaryKey, @Nonnull final byte[] keyBytes, @Nonnull final byte[] valueBytes) { final Tuple nodeTuple = Tuple.fromBytes(valueBytes); - final Node node = nodeFromTuples(primaryKey, nodeTuple); + final Node node = nodeFromKeyValuesTuples(primaryKey, nodeTuple); final OnReadListener onReadListener = getOnReadListener(); onReadListener.onNodeRead(layer, node); onReadListener.onKeyValueRead(layer, keyBytes, valueBytes); return node; } + /** + * Constructs a compact {@link Node} from its representation as stored key and value tuples. + *

+ * This method deserializes a node by extracting its components from the provided tuples. It verifies that the + * node is of type {@link NodeKind#COMPACT} before delegating the final construction to + * {@link #compactNodeFromTuples(Tuple, Tuple, Tuple)}. The {@code valueTuple} is expected to have a specific + * structure: the serialized node kind at index 0, a nested tuple for the vector at index 1, and a nested + * tuple for the neighbors at index 2. + * + * @param primaryKey the tuple representing the primary key of the node + * @param valueTuple the tuple containing the serialized node data, including kind, vector, and neighbors + * + * @return the reconstructed compact {@link Node} + * + * @throws com.google.common.base.VerifyException if the node kind encoded in {@code valueTuple} is not + * {@link NodeKind#COMPACT} + */ @Nonnull - private Node nodeFromTuples(@Nonnull final Tuple primaryKey, - @Nonnull final Tuple valueTuple) { + private Node nodeFromKeyValuesTuples(@Nonnull final Tuple primaryKey, + @Nonnull final Tuple valueTuple) { final NodeKind nodeKind = NodeKind.fromSerializedNodeKind((byte)valueTuple.getLong(0)); Verify.verify(nodeKind == NodeKind.COMPACT); @@ -105,6 +189,21 @@ private Node nodeFromTuples(@Nonnull final Tuple primaryKey, return compactNodeFromTuples(primaryKey, vectorTuple, neighborsTuple); } + /** + * Creates a compact in-memory representation of a graph node from its constituent storage tuples. + *

+ * This method deserializes the raw data stored in {@code Tuple} objects into their + * corresponding in-memory types. It extracts the vector, constructs a list of + * {@link NodeReference} objects for the neighbors, and then uses a factory to + * assemble the final {@code Node} object. + *

+ * + * @param primaryKey the tuple representing the node's primary key + * @param vectorTuple the tuple containing the node's vector data + * @param neighborsTuple the tuple containing a list of nested tuples, where each nested tuple represents a neighbor + * + * @return a new {@code Node} instance containing the deserialized data from the input tuples + */ @Nonnull private Node compactNodeFromTuples(@Nonnull final Tuple primaryKey, @Nonnull final Tuple vectorTuple, @@ -120,6 +219,21 @@ private Node compactNodeFromTuples(@Nonnull final Tuple primaryKe return getNodeFactory().create(primaryKey, vector, nodeReferences); } + /** + * Writes the internal representation of a compact node to the data store within a given transaction. + * This method handles the serialization of the node's vector and its final set of neighbors based on the + * provided {@code neighborsChangeSet}. + * + *

The node is stored as a {@link Tuple} with the structure {@code (NodeKind, Vector, NeighborPrimaryKeys)}. + * The key for the storage is derived from the node's layer and its primary key. After writing, it notifies any + * registered write listeners via {@code onNodeWritten} and {@code onKeyValueWritten}. + * + * @param transaction the {@link Transaction} to use for the write operation. + * @param node the {@link Node} to be serialized and written; it is processed as a {@link CompactNode}. + * @param layer the graph layer index for the node, used to construct the storage key. + * @param neighborsChangeSet a {@link NeighborsChangeSet} containing the additions and removals, which are + * merged to determine the final set of neighbors to be written. + */ @Override public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull final Node node, final int layer, @Nonnull final NeighborsChangeSet neighborsChangeSet) { @@ -151,6 +265,22 @@ public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull f } } + /** + * Scans a given layer for nodes, returning an iterable over the results. + *

+ * This method reads a limited number of nodes from a specific layer in the underlying data store. + * The scan can be started from a specific point using the {@code lastPrimaryKey} parameter, which is + * useful for paginating through the nodes in a large layer. + * + * @param readTransaction the transaction to use for reading data; must not be {@code null} + * @param layer the layer to scan for nodes + * @param lastPrimaryKey the primary key of the last node from a previous scan. If {@code null}, + * the scan starts from the beginning of the layer. + * @param maxNumRead the maximum number of nodes to read in this scan + * + * @return an {@link Iterable} of {@link Node} objects found in the specified layer, + * limited by {@code maxNumRead} + */ @Nonnull @Override public Iterable> scanLayer(@Nonnull final ReadTransaction readTransaction, int layer, diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java index e70515531e..a4852b66a1 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java @@ -1,5 +1,5 @@ /* - * InliningNode.java + * DeleteNeighborsChangeSet.java * * This source file is part of the FoundationDB open source project * diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java index db81252e17..4a9cbb0ae5 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java @@ -1,5 +1,5 @@ /* - * NodeWithLayer.java + * EntryNodeReference.java * * This source file is part of the FoundationDB open source project * @@ -26,18 +26,50 @@ import javax.annotation.Nonnull; import java.util.Objects; +/** + * Represents an entry reference to a node within a hierarchical graph structure. + *

+ * This class extends {@link NodeReferenceWithVector} by adding a {@code layer} + * attribute. It is used to encapsulate all the necessary information for an + * entry point into a specific layer of the graph, including its unique identifier + * (primary key), its vector representation, and its hierarchical level. + */ class EntryNodeReference extends NodeReferenceWithVector { private final int layer; + /** + * Constructs a new reference to an entry node. + *

+ * This constructor initializes the node with its primary key, its associated vector, + * and the specific layer it belongs to within a hierarchical graph structure. It calls the + * superclass constructor to set the {@code primaryKey} and {@code vector}. + * + * @param primaryKey the primary key identifying the node. Must not be {@code null}. + * @param vector the vector data associated with the node. Must not be {@code null}. + * @param layer the layer number where this entry node is located. + */ public EntryNodeReference(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, final int layer) { super(primaryKey, vector); this.layer = layer; } + /** + * Gets the layer value for this object. + * @return the integer representing the layer + */ public int getLayer() { return layer; } + /** + * Compares this {@code EntryNodeReference} to the specified object for equality. + *

+ * The result is {@code true} if and only if the argument is an instance of {@code EntryNodeReference}, the + * superclass's {@link #equals(Object)} method returns {@code true}, and the {@code layer} fields of both objects + * are equal. + * @param o the object to compare this {@code EntryNodeReference} against. + * @return {@code true} if the given object is equal to this one; {@code false} otherwise. + */ @Override public boolean equals(final Object o) { if (!(o instanceof EntryNodeReference)) { @@ -49,6 +81,13 @@ public boolean equals(final Object o) { return layer == ((EntryNodeReference)o).layer; } + /** + * Generates a hash code for this object. + *

+ * The hash code is computed by combining the hash code of the superclass with the hash code of the {@code layer} + * field. This implementation is consistent with the contract of {@link Object#hashCode()}. + * @return a hash code value for this object. + */ @Override public int hashCode() { return Objects.hash(super.hashCode(), layer); diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java index 322b4f85b0..4921f1280d 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java @@ -31,6 +31,9 @@ public class HNSWHelpers { private static final char[] hexArray = "0123456789ABCDEF".toCharArray(); + /** + * This is a utility class and is not intended to be instantiated. + */ private HNSWHelpers() { // nothing } @@ -51,11 +54,23 @@ public static String bytesToHex(byte[] bytes) { return "0x" + new String(hexChars).replaceFirst("^0+(?!$)", ""); } + /** + * Returns a {@code Half} instance representing the specified {@code double} value, rounded to the nearest + * representable half-precision float value. + * @param d the {@code double} value to be converted. + * @return a non-null {@link Half} instance representing {@code d}. + */ @Nonnull public static Half halfValueOf(final double d) { return Half.shortBitsToHalf(Half.halfToShortBits(Half.valueOf(d))); } + /** + * Returns a {@code Half} instance representing the specified {@code float} value, rounded to the nearest + * representable half-precision float value. + * @param f the {@code float} value to be converted. + * @return a non-null {@link Half} instance representing {@code f}. + */ @Nonnull public static Half halfValueOf(final float f) { return Half.shortBitsToHalf(Half.halfToShortBits(Half.valueOf(f))); diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java index 2835427ca4..c63f2135e0 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java @@ -1,5 +1,5 @@ /* - * CompactStorageAdapter.java + * InliningStorageAdapter.java * * This source file is part of the FoundationDB open source project * diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java index d68d3ae933..f9894ccebd 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java @@ -1,5 +1,5 @@ /* - * InliningNode.java + * InsertNeighborsChangeSet.java * * This source file is part of the FoundationDB open source project * @@ -33,7 +33,14 @@ import java.util.function.Predicate; /** - * TODO. + * Represents an immutable change set for the neighbors of a node in the HNSW graph, specifically + * capturing the insertion of new neighbors. + *

+ * This class layers new neighbors on top of a parent {@link NeighborsChangeSet}, allowing for a + * layered representation of modifications. The changes are not applied to the database until + * {@link #writeDelta} is called. + * + * @param the type of the node reference, which must extend {@link NodeReference} */ class InsertNeighborsChangeSet implements NeighborsChangeSet { @Nonnull @@ -45,6 +52,16 @@ class InsertNeighborsChangeSet implements NeighborsChan @Nonnull private final Map insertedNeighborsMap; + /** + * Creates a new {@code InsertNeighborsChangeSet}. + *

+ * This constructor initializes the change set with its parent and a list of neighbors + * to be inserted. It internally builds an immutable map of the inserted neighbors, + * keyed by their primary key for efficient lookups. + * + * @param parent the parent {@link NeighborsChangeSet} on which this insertion is based. + * @param insertedNeighbors the list of neighbors to be inserted. + */ public InsertNeighborsChangeSet(@Nonnull final NeighborsChangeSet parent, @Nonnull final List insertedNeighbors) { this.parent = parent; @@ -56,18 +73,44 @@ public InsertNeighborsChangeSet(@Nonnull final NeighborsChangeSet parent, this.insertedNeighborsMap = insertedNeighborsMapBuilder.build(); } + /** + * Gets the parent {@code NeighborsChangeSet} from which this change set was derived. + * @return the parent {@link NeighborsChangeSet}, which is never {@code null}. + */ @Nonnull @Override public NeighborsChangeSet getParent() { return parent; } + /** + * Merges the neighbors from this level of the hierarchy with all neighbors from parent levels. + *

+ * This is achieved by creating a combined view that includes the results of the parent's {@code #merge()} call and + * the neighbors that have been inserted at the current level. The resulting {@code Iterable} provides a complete + * set of neighbors from this node and all its ancestors. + * @return a non-null {@code Iterable} containing all neighbors from this node and its ancestors. + */ @Nonnull @Override public Iterable merge() { return Iterables.concat(getParent().merge(), insertedNeighborsMap.values()); } + /** + * Writes the delta of this layer to the specified storage adapter. + *

+ * This implementation first delegates to the parent to write its delta, but excludes any neighbors that have been + * newly inserted in the current context (i.e., those in {@code insertedNeighborsMap}). It then iterates through its + * own newly inserted neighbors. For each neighbor that satisfies the given {@code tuplePredicate}, it writes the + * neighbor relationship to storage via the {@link InliningStorageAdapter}. + * + * @param storageAdapter the storage adapter to write to; must not be null + * @param transaction the transaction context for the write operation; must not be null + * @param layer the layer index to write the data to + * @param node the source node for which the neighbor delta is being written; must not be null + * @param tuplePredicate a predicate to filter which neighbor tuples should be written; must not be null + */ @Override public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, final int layer, @Nonnull final Node node, @Nonnull final Predicate tuplePredicate) { diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java index f5fe817e53..adb1b799b3 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java @@ -84,6 +84,13 @@ private static void validate(Double[] vector1, Double[] vector2) { } } + /** + * Represents the Manhattan distance metric. + *

+ * This metric calculates a distance overlaying the multidimensional space with a grid-like structure only allowing + * orthogonal lines. In 2D this resembles the street structure in Manhattan where one would have to go {@code x} + * blocks north/south and {@code y} blocks east/west leading to a total distance of {@code x + y}. + */ class ManhattanMetric implements Metric { @Override public double distance(final Double[] vector1, final Double[] vector2) { @@ -103,6 +110,13 @@ public String toString() { } } + /** + * Represents the Euclidean distance metric. + *

+ * This metric calculates the "ordinary" straight-line distance between two points + * in Euclidean space. The distance is the square root of the sum of the + * squared differences between the corresponding coordinates of the two points. + */ class EuclideanMetric implements Metric { @Override public double distance(final Double[] vector1, final Double[] vector2) { @@ -118,6 +132,19 @@ public String toString() { } } + /** + * Represents the squared Euclidean distance metric. + *

+ * This metric calculates the sum of the squared differences between the coordinates of two vectors, defined as + * {@code sum((p_i - q_i)^2)}. It is computationally less expensive than the standard Euclidean distance because it + * avoids the final square root operation. + *

+ * This is often preferred in algorithms where comparing distances is more important than the actual distance value, + * such as in clustering algorithms, as it preserves the relative ordering of distances. + * + * @see Squared Euclidean + * distance + */ class EuclideanSquareMetric implements Metric { @Override public double distance(final Double[] vector1, final Double[] vector2) { @@ -141,6 +168,14 @@ public String toString() { } } + /** + * Represents the Cosine distance metric. + *

+ * This metric calculates a "distance" between two vectors {@code v1} and {@code v2} that ranges between + * {@code 0.0d} and {@code 2.0d} that corresponds to {@code 1 - cos(v1, v2)}, meaning that if {@code v1 == v2}, + * the distance is {@code 0} while if {@code v1} is orthogonal to {@code v2} it is {@code 1}. + * @see Metric.CosineMetric + */ class CosineMetric implements Metric { @Override public double distance(final Double[] vector1, final Double[] vector2) { @@ -171,6 +206,16 @@ public String toString() { } } + /** + * Dot product similarity. + *

+ * This metric calculates the inverted dot product of two vectors. It is not a true metric as the dot product can + * be positive at which point the distance is negative. In order to make callers aware of this fact, this distance + * only allows {@link Metric#comparativeDistance(Double[], Double[])} to be called. + * + * @see Dot Product + * @see DotProductMetric + */ class DotProductMetric implements Metric { @Override public double distance(final Double[] vector1, final Double[] vector2) { diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java index 8c30faf852..7a3e4a6a88 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java @@ -1,5 +1,5 @@ /* - * Metric.java + * Metrics.java * * This source file is part of the FoundationDB open source project * @@ -22,20 +22,88 @@ import javax.annotation.Nonnull; +/** + * Represents various distance calculation strategies (metrics) for vectors. + *

+ * Each enum constant holds a specific metric implementation, providing a type-safe way to calculate the distance + * between two points in a multidimensional space. + * + * @see Metric + */ public enum Metrics { + /** + * Represents the Manhattan distance metric, implemented by {@link Metric.ManhattanMetric}. + *

+ * This metric calculates a distance overlaying the multidimensional space with a grid-like structure only allowing + * orthogonal lines. In 2D this resembles the street structure in Manhattan where one would have to go {@code x} + * blocks north/south and {@code y} blocks east/west leading to a total distance of {@code x + y}. + * @see Metric.ManhattanMetric + */ MANHATTAN_METRIC(new Metric.ManhattanMetric()), + + /** + * Represents the Euclidean distance metric, implemented by {@link Metric.EuclideanMetric}. + *

+ * This metric calculates the "ordinary" straight-line distance between two points + * in Euclidean space. The distance is the square root of the sum of the + * squared differences between the corresponding coordinates of the two points. + * @see Metric.EuclideanMetric + */ EUCLIDEAN_METRIC(new Metric.EuclideanMetric()), + + /** + * Represents the squared Euclidean distance metric, implemented by {@link Metric.EuclideanSquareMetric}. + *

+ * This metric calculates the sum of the squared differences between the coordinates of two vectors, defined as + * {@code sum((p_i - q_i)^2)}. It is computationally less expensive than the standard Euclidean distance because it + * avoids the final square root operation. + *

+ * This is often preferred in algorithms where comparing distances is more important than the actual distance value, + * such as in clustering algorithms, as it preserves the relative ordering of distances. + * + * @see Squared Euclidean + * distance + * @see Metric.EuclideanSquareMetric + */ EUCLIDEAN_SQUARE_METRIC(new Metric.EuclideanSquareMetric()), + + /** + * Represents the Cosine distance metric, implemented by {@link Metric.CosineMetric}. + *

+ * This metric calculates a "distance" between two vectors {@code v1} and {@code v2} that ranges between + * {@code 0.0d} and {@code 2.0d} that corresponds to {@code 1 - cos(v1, v2)}, meaning that if {@code v1 == v2}, + * the distance is {@code 0} while if {@code v1} is orthogonal to {@code v2} it is {@code 1}. + * @see Metric.CosineMetric + */ COSINE_METRIC(new Metric.CosineMetric()), + + /** + * Dot product similarity, implemented by {@link Metric.DotProductMetric} + *

+ * This metric calculates the inverted dot product of two vectors. It is not a true metric as the dot product can + * be positive at which point the distance is negative. In order to make callers aware of this fact, this distance + * only allows {@link Metric#comparativeDistance(Double[], Double[])} to be called. + * + * @see Dot Product + * @see Metric.DotProductMetric + */ DOT_PRODUCT_METRIC(new Metric.DotProductMetric()); @Nonnull private final Metric metric; + /** + * Constructs a new Metrics instance with the specified metric. + * @param metric the metric to be associated with this Metrics instance; must not be null. + */ Metrics(@Nonnull final Metric metric) { this.metric = metric; } + /** + * Gets the {@code Metric} associated with this instance. + * @return the non-null {@link Metric} for this instance + */ @Nonnull public Metric getMetric() { return metric; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java index 081523de5b..2eb02e74e3 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java @@ -1,5 +1,5 @@ /* - * InliningNode.java + * NeighborsChangeSet.java * * This source file is part of the FoundationDB open source project * diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java index 321e3f53d8..bbe15f8464 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java @@ -27,11 +27,39 @@ import javax.annotation.Nullable; import java.util.List; +/** + * A factory interface for creating {@link Node} instances within a Hierarchical Navigable Small World (HNSW) graph. + *

+ * Implementations of this interface define how nodes are constructed, allowing for different node types + * or storage strategies within the HNSW structure. + * + * @param the type of {@link NodeReference} used to refer to nodes in the graph + */ public interface NodeFactory { + /** + * Creates a new node with the specified properties. + *

+ * This method is responsible for instantiating a {@code Node} object, initializing it + * with a primary key, an optional feature vector, and a list of its initial neighbors. + * + * @param primaryKey the {@link Tuple} representing the unique primary key for the new node. Must not be + * {@code null}. + * @param vector the optional feature {@link Vector} associated with the node, which can be used for similarity + * calculations. May be {@code null} if the node does not encode a vector (see {@link CompactNode} versus + * {@link InliningNode}. + * @param neighbors the list of initial {@link NodeReference}s for the new node, + * establishing its initial connections in the graph. Must not be {@code null}. + * + * @return a new, non-null {@link Node} instance configured with the provided parameters. + */ @Nonnull Node create(@Nonnull Tuple primaryKey, @Nullable Vector vector, @Nonnull List neighbors); + /** + * Gets the kind of this node. + * @return the kind of this node, never {@code null}. + */ @Nonnull NodeKind getNodeKind(); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java index 13d71a1b9b..de7aeb6572 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java @@ -25,22 +25,50 @@ import javax.annotation.Nonnull; /** - * Enum to capture the kind of node. + * Represents the different kinds of nodes, each associated with a unique byte value for serialization and + * deserialization. */ public enum NodeKind { + /** + * Compact node. Serialization and deserialization is implemented in {@link CompactNode}. + *

+ * Compact nodes store their own vector and their neighbors-list only contain the primary key for each neighbor. + */ COMPACT((byte)0x00), + + /** + * Inlining node. Serialization and deserialization is implemented in {@link InliningNode}. + *

+ * Inlining nodes do not store their own vector and their neighbors-list contain the both the primary key and the + * neighbor vector for each neighbor. Each neighbor is stored in its own key/value pair. + */ INLINING((byte)0x01); private final byte serialized; + /** + * Constructs a new {@code NodeKind} instance with its serialized representation. + * @param serialized the byte value used for serialization + */ NodeKind(final byte serialized) { this.serialized = serialized; } + /** + * Gets the serialized byte value. + * @return the serialized byte value + */ public byte getSerialized() { return serialized; } + /** + * Deserializes a byte into the corresponding {@link NodeKind}. + * @param serializedNodeKind the byte representation of the node kind. + * @return the corresponding {@link NodeKind}, never {@code null}. + * @throws IllegalArgumentException if the {@code serializedNodeKind} does not + * correspond to a known node kind. + */ @Nonnull static NodeKind fromSerializedNodeKind(byte serializedNodeKind) { final NodeKind nodeKind; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java index 59b831d04d..a302607a2c 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java @@ -26,24 +26,56 @@ import javax.annotation.Nonnull; import java.util.Objects; +/** + * Represents a reference to a node, uniquely identified by its primary key. It provides fundamental operations such as + * equality comparison, hashing, and string representation based on this key. It also serves as a base class for more + * specialized node references. + */ public class NodeReference { @Nonnull private final Tuple primaryKey; + /** + * Constructs a new {@code NodeReference} with the specified primary key. + * @param primaryKey the primary key of the node to reference; must not be {@code null}. + */ public NodeReference(@Nonnull final Tuple primaryKey) { this.primaryKey = primaryKey; } + /** + * Gets the primary key for this object. + * @return the primary key as a {@code Tuple} object, which is guaranteed to be non-null. + */ @Nonnull public Tuple getPrimaryKey() { return primaryKey; } + /** + * Casts this object to a {@link NodeReferenceWithVector}. + *

+ * This method is intended to be used on subclasses that actually represent a node reference with a vector. For this + * base class or specific implementation, it is not a valid operation. + * @return this instance cast as a {@code NodeReferenceWithVector} + * @throws IllegalStateException always, to indicate that this object cannot be + * represented as a {@link NodeReferenceWithVector}. + */ @Nonnull public NodeReferenceWithVector asNodeReferenceWithVector() { throw new IllegalStateException("method should not be called"); } + /** + * Compares this {@code NodeReference} to the specified object for equality. + *

+ * The result is {@code true} if and only if the argument is not {@code null} and is a {@code NodeReference} object + * that has the same {@code primaryKey} as this object. + * + * @param o the object to compare with this {@code NodeReference} for equality. + * @return {@code true} if the given object is equal to this one; + * {@code false} otherwise. + */ @Override public boolean equals(final Object o) { if (!(o instanceof NodeReference)) { @@ -53,16 +85,29 @@ public boolean equals(final Object o) { return Objects.equals(primaryKey, that.primaryKey); } + /** + * Generates a hash code for this object based on the primary key. + * @return a hash code value for this object. + */ @Override public int hashCode() { return Objects.hashCode(primaryKey); } + /** + * Returns a string representation of the object. + * @return a string representation of this object. + */ @Override public String toString() { return "NR[primaryKey=" + primaryKey + "]"; } + /** + * Helper to extract the primary keys from a given collection of node references. + * @param neighbors an iterable of {@link NodeReference} objects from which to extract primary keys. + * @return a lazily-evaluated {@code Iterable} of {@link Tuple}s, representing the primary keys of the input nodes. + */ @Nonnull public static Iterable primaryKeys(@Nonnull Iterable neighbors) { return () -> Streams.stream(neighbors) diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java index bbf74e864a..1a2053133d 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java @@ -25,27 +25,56 @@ import javax.annotation.Nonnull; import java.util.List; +/** + * A container class that pairs a {@link NodeReferenceWithDistance} with its corresponding {@link Node} object. + *

+ * This is often used during graph traversal or searching, where a reference to a node (along with its distance from a + * query point) is first identified, and then the complete node data is fetched. This class holds these two related + * pieces of information together. + * @param the type of {@link NodeReference} used within the {@link Node} + */ public class NodeReferenceAndNode { @Nonnull private final NodeReferenceWithDistance nodeReferenceWithDistance; @Nonnull private final Node node; + /** + * Constructs a new instance that pairs a node reference (with distance) with its + * corresponding {@link Node} object. + * @param nodeReferenceWithDistance the reference to a node, which also includes distance information. Must not be + * {@code null}. + * @param node the actual {@code Node} object that the reference points to. Must not be {@code null}. + */ public NodeReferenceAndNode(@Nonnull final NodeReferenceWithDistance nodeReferenceWithDistance, @Nonnull final Node node) { this.nodeReferenceWithDistance = nodeReferenceWithDistance; this.node = node; } + /** + * Gets the node reference and its associated distance. + * @return the non-null {@link NodeReferenceWithDistance} object. + */ @Nonnull public NodeReferenceWithDistance getNodeReferenceWithDistance() { return nodeReferenceWithDistance; } + /** + * Gets the underlying node represented by this object. + * @return the associated {@link Node} instance, never {@code null}. + */ @Nonnull public Node getNode() { return node; } + /** + * Helper to extract the references from a given collection of objects of this container class. + * @param referencesAndNodes an iterable of {@link NodeReferenceAndNode} objects from which to extract the + * references. + * @return a {@link List} of {@link NodeReferenceAndNode}s + */ @Nonnull public static List getReferences(@Nonnull List> referencesAndNodes) { final ImmutableList.Builder referencesBuilder = ImmutableList.builder(); diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java index bc9470735c..5acc345d65 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java @@ -26,19 +26,50 @@ import javax.annotation.Nonnull; import java.util.Objects; +/** + * Represents a reference to a node that includes its vector and its distance from a query vector. + *

+ * This class extends {@link NodeReferenceWithVector} by additionally associating a distance value, typically the result + * of a distance calculation in a nearest neighbor search. Objects of this class are immutable. + */ public class NodeReferenceWithDistance extends NodeReferenceWithVector { private final double distance; + /** + * Constructs a new instance of {@code NodeReferenceWithDistance}. + *

+ * This constructor initializes the reference with the node's primary key, its vector, and the calculated distance + * from some origin vector (e.g., a query vector). It calls the superclass constructor to set the {@code primaryKey} + * and {@code vector}. + * @param primaryKey the primary key of the referenced node, represented as a {@link Tuple}. Must not be null. + * @param vector the vector associated with the referenced node. Must not be null. + * @param distance the calculated distance of this node reference to some query vector or similar. + */ public NodeReferenceWithDistance(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, final double distance) { super(primaryKey, vector); this.distance = distance; } + /** + * Gets the distance. + * @return the current distance value + */ public double getDistance() { return distance; } + /** + * Compares this object against the specified object for equality. + *

+ * The result is {@code true} if and only if the argument is not {@code null}, + * is a {@code NodeReferenceWithDistance} object, has the same properties as + * determined by the superclass's {@link #equals(Object)} method, and has + * the same {@code distance} value. + * @param o the object to compare with this instance for equality. + * @return {@code true} if the specified object is equal to this {@code NodeReferenceWithDistance}; + * {@code false} otherwise. + */ @Override public boolean equals(final Object o) { if (!(o instanceof NodeReferenceWithDistance)) { @@ -51,6 +82,10 @@ public boolean equals(final Object o) { return Double.compare(distance, that.distance) == 0; } + /** + * Generates a hash code for this object. + * @return a hash code value for this object. + */ @Override public int hashCode() { return Objects.hash(super.hashCode(), distance); diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java index 753648cf77..f8a009d32b 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java @@ -24,20 +24,52 @@ import java.util.concurrent.CompletableFuture; /** - * Function interface for a call back whenever we read the slots for a node. + * Interface for call backs whenever we read node data from the database. */ public interface OnReadListener { OnReadListener NOOP = new OnReadListener() { }; + /** + * A callback method that can be overridden to intercept the result of an asynchronous node read. + *

+ * This method provides a hook for subclasses to inspect or modify the {@code CompletableFuture} after an + * asynchronous read operation is initiated. The default implementation is a no-op that simply returns the original + * future. This method is intended to be used to measure elapsed time between the creation of a + * {@link CompletableFuture} and its completion. + * @param the type of the {@code NodeReference} + * @param future the {@code CompletableFuture} representing the pending asynchronous read operation. + * @return a {@code CompletableFuture} that will complete with the read {@code Node}. + * By default, this is the same future that was passed as an argument. + */ + @SuppressWarnings("unused") default CompletableFuture> onAsyncRead(@Nonnull CompletableFuture> future) { return future; } + /** + * Callback method invoked when a node is read during a traversal process. + *

+ * This default implementation does nothing. Implementors can override this method to add custom logic that should + * be executed for each node encountered. This serves as an optional hook for processing nodes as they are read. + * @param layer the layer or depth of the node in the structure, starting from 0. + * @param node the {@link Node} that was just read (guaranteed to be non-null). + */ + @SuppressWarnings("unused") default void onNodeRead(int layer, @Nonnull Node node) { // nothing } + /** + * Callback invoked when a key-value pair is read from a specific layer. + *

+ * This method is typically called during a scan or iteration over data for each key/value pair. + * The default implementation is a no-op and does nothing. + * @param layer the layer from which the key-value pair was read. + * @param key the key that was read, guaranteed to be non-null. + * @param value the value associated with the key, guaranteed to be non-null. + */ + @SuppressWarnings("unused") default void onKeyValueRead(int layer, @Nonnull byte[] key, @Nonnull byte[] value) { diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java index fd4a096208..d645bf8421 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java @@ -25,25 +25,62 @@ import javax.annotation.Nonnull; /** - * Function interface for a call back whenever we read the slots for a node. + * Interface for call backs whenever we write data to the database. */ public interface OnWriteListener { OnWriteListener NOOP = new OnWriteListener() { }; + /** + * Callback method invoked after a node has been successfully written to a specific layer. + *

+ * This is a default method with an empty implementation, allowing implementing classes to override it only if they + * need to react to this event. + * @param layer the index of the layer where the node was written. + * @param node the {@link Node} that was written; guaranteed to be non-null. + */ + @SuppressWarnings("unused") default void onNodeWritten(final int layer, @Nonnull final Node node) { // nothing } - default void onNeighborWritten(final int layer, @Nonnull final Node node, final NodeReference neighbor) { + /** + * Callback method invoked when a neighbor is written for a specific node. + *

+ * This method serves as a notification that a neighbor relationship has been established or updated. It is + * typically called after a write operation successfully adds a {@code neighbor} to the specified {@code node} + * within a given {@code layer}. + *

+ * As a {@code default} method, the base implementation does nothing. Implementers can override this to perform + * custom actions, such as updating caches or triggering subsequent events in response to the change. + * @param layer the index of the layer where the neighbor write operation occurred + * @param node the {@link Node} for which the neighbor was written; must not be null + * @param neighbor the {@link NodeReference} of the neighbor that was written; must not be null + */ + @SuppressWarnings("unused") + default void onNeighborWritten(final int layer, @Nonnull final Node node, + @Nonnull final NodeReference neighbor) { // nothing } - default void onNeighborDeleted(final int layer, @Nonnull final Node node, @Nonnull Tuple neighborPrimaryKey) { + /** + * Callback method invoked when a neighbor of a specific node is deleted. + *

+ * This is a default method and its base implementation is a no-op. Implementors of the interface can override this + * method to react to the deletion of a neighbor node, for example, to clean up related resources or update internal + * state. + * @param layer the layer index where the deletion occurred + * @param node the {@link Node} whose neighbor was deleted + * @param neighborPrimaryKey the primary key (as a {@link Tuple}) of the neighbor that was deleted + */ + @SuppressWarnings("unused") + default void onNeighborDeleted(final int layer, @Nonnull final Node node, + @Nonnull final Tuple neighborPrimaryKey) { // nothing } - default void onKeyValueWritten(final int layer, @Nonnull byte[] key, @Nonnull byte[] value) { + @SuppressWarnings("unused") + default void onKeyValueWritten(final int layer, @Nonnull final byte[] key, @Nonnull final byte[] value) { // nothing } } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java index 82bd281c62..e4e72e593e 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java @@ -32,45 +32,90 @@ import java.util.concurrent.CompletableFuture; /** - * Storage adapter used for serialization and deserialization of nodes. + * Defines the contract for storing and retrieving HNSW graph data to/from a persistent store. + *

+ * This interface provides an abstraction layer over the underlying database, handling the serialization and + * deserialization of HNSW graph components such as nodes, vectors, and their relationships. Implementations of this + * interface are responsible for managing the physical layout of data within a given {@link Subspace}. + * The generic type {@code N} represents the specific type of {@link NodeReference} that this storage adapter manages. + * + * @param the type of {@link NodeReference} this storage adapter manages */ interface StorageAdapter { + /** + * Subspace for entry nodes; these are kept separately from the data. + */ byte SUBSPACE_PREFIX_ENTRY_NODE = 0x01; + /** + * Subspace for data. + */ byte SUBSPACE_PREFIX_DATA = 0x02; /** - * Get the {@link HNSW.Config} associated with this storage adapter. - * @return the configuration used by this storage adapter + * Returns the configuration of the HNSW graph. + *

+ * This configuration object contains all the parameters used to build and search the graph, + * such as the number of neighbors to connect (M), the size of the dynamic list for + * construction (efConstruction), and the beam width for searching (ef). + * @return the {@code HNSW.Config} for this graph, never {@code null}. */ @Nonnull HNSW.Config getConfig(); + /** + * Gets the factory used to create new nodes. + *

+ * This factory is responsible for instantiating new nodes of type {@code N}. + * @return the non-null factory for creating nodes. + */ @Nonnull NodeFactory getNodeFactory(); + /** + * Gets the kind of node this storage adapter manages (and instantiates if needed). + * @return the kind of this node, never {@code null} + */ @Nonnull NodeKind getNodeKind(); + /** + * Returns a view of this object as a {@code StorageAdapter} that is optimized + * for compact data representation. + * @return a non-null {@code StorageAdapter} for {@code NodeReference} objects, + * optimized for compact storage. + */ @Nonnull StorageAdapter asCompactStorageAdapter(); + /** + * Returns a view of this storage as a {@code StorageAdapter} that handles inlined vectors. + *

+ * The returned adapter is specifically designed to work with {@link NodeReferenceWithVector}, assuming that the + * vector data is stored directly within the node reference itself. + * @return a non-null {@link StorageAdapter} + */ @Nonnull StorageAdapter asInliningStorageAdapter(); /** - * Get the subspace used to store this r-tree. - * - * @return r-tree subspace + * Get the subspace used to store this HNSW structure. + * @return the subspace */ @Nonnull Subspace getSubspace(); + /** + * Gets the subspace that contains the data for this object. + *

+ * This subspace represents the portion of the keyspace dedicated to storing the actual data, as opposed to metadata + * or other system-level information. + * @return the subspace containing the data, which is guaranteed to be non-null + */ @Nonnull Subspace getDataSubspace(); /** * Get the on-write listener. - * * @return the on-write listener. */ @Nonnull @@ -78,23 +123,72 @@ interface StorageAdapter { /** * Get the on-read listener. - * * @return the on-read listener. */ @Nonnull OnReadListener getOnReadListener(); + /** + * Asynchronously fetches a node from a specific layer, identified by its primary key. + *

+ * The fetch operation is performed within the scope of the provided {@link ReadTransaction}, ensuring a consistent + * view of the data. The returned {@link CompletableFuture} will be completed with the node once it has been + * retrieved from the underlying data store. + * @param readTransaction the {@link ReadTransaction} context for this read operation + * @param layer the layer from which to fetch the node + * @param primaryKey the {@link Tuple} representing the primary key of the node to retrieve + * @return a non-null {@link CompletableFuture} which will complete with the fetched {@code Node}. + */ @Nonnull CompletableFuture> fetchNode(@Nonnull ReadTransaction readTransaction, int layer, @Nonnull Tuple primaryKey); + /** + * Writes a node and its neighbor changes to the data store within a given transaction. + *

+ * This method is responsible for persisting the state of a {@link Node} and applying any modifications to its + * neighboring nodes as defined in the {@code NeighborsChangeSet}. The entire operation is performed atomically as + * part of the provided {@link Transaction}. + * @param transaction the non-null transaction context for this write operation. + * @param node the non-null node to be written to the data store. + * @param layer the layer index where the node resides. + * @param changeSet the non-null set of changes describing additions or removals of + * neighbors for the given {@link Node}. + */ void writeNode(@Nonnull Transaction transaction, @Nonnull Node node, int layer, @Nonnull NeighborsChangeSet changeSet); + /** + * Scans a specified layer of the directory, returning an iterable sequence of nodes. + *

+ * This method allows for paginated scanning of a layer. The scan can be started from the beginning of the layer by + * passing {@code null} for the {@code lastPrimaryKey}, or it can be resumed from a previous point by providing the + * key of the last item from the prior scan. The number of nodes returned is limited by {@code maxNumRead}. + * + * @param readTransaction the transaction to use for the read operation + * @param layer the index of the layer to scan + * @param lastPrimaryKey the primary key of the last node from a previous scan, + * or {@code null} to start from the beginning of the layer + * @param maxNumRead the maximum number of nodes to return in this scan + * @return an {@link Iterable} that provides the nodes found in the specified layer range + */ Iterable> scanLayer(@Nonnull ReadTransaction readTransaction, int layer, @Nullable Tuple lastPrimaryKey, int maxNumRead); + /** + * Fetches the entry node reference for the HNSW index. + *

+ * This method performs an asynchronous read to retrieve the stored entry point of the index. The entry point + * information, which includes its primary key, vector, and the layer value, is packed into a single key-value + * pair within a dedicated subspace. If no entry node is found, it indicates that the index is empty. + * + * @param readTransaction the transaction to use for the read operation + * @param subspace the subspace where the HNSW index data is stored + * @param onReadListener a listener to be notified of the key-value read operation + * @return a {@link CompletableFuture} that will complete with the {@link EntryNodeReference} + * for the index's entry point, or with {@code null} if the index is empty + */ @Nonnull static CompletableFuture fetchEntryNodeReference(@Nonnull final ReadTransaction readTransaction, @Nonnull final Subspace subspace, @@ -110,13 +204,24 @@ static CompletableFuture fetchEntryNodeReference(@Nonnull fi onReadListener.onKeyValueRead(-1, key, valueBytes); final Tuple entryTuple = Tuple.fromBytes(valueBytes); - final int lMax = (int)entryTuple.getLong(0); + final int layer = (int)entryTuple.getLong(0); final Tuple primaryKey = entryTuple.getNestedTuple(1); final Tuple vectorTuple = entryTuple.getNestedTuple(2); - return new EntryNodeReference(primaryKey, StorageAdapter.vectorFromTuple(vectorTuple), lMax); + return new EntryNodeReference(primaryKey, StorageAdapter.vectorFromTuple(vectorTuple), layer); }); } + /** + * Writes an {@code EntryNodeReference} to the database within a given transaction and subspace. + *

+ * This method serializes the provided {@link EntryNodeReference} into a key-value pair. The key is determined by + * a dedicated subspace for entry nodes, and the value is a tuple containing the layer, primary key, and vector from + * the reference. After writing the data, it notifies the provided {@link OnWriteListener}. + * @param transaction the database transaction to use for the write operation + * @param subspace the subspace where the entry node reference will be stored + * @param entryNodeReference the {@link EntryNodeReference} object to write + * @param onWriteListener the listener to be notified after the key-value pair is written + */ static void writeEntryNodeReference(@Nonnull final Transaction transaction, @Nonnull final Subspace subspace, @Nonnull final EntryNodeReference entryNodeReference, @@ -131,11 +236,33 @@ static void writeEntryNodeReference(@Nonnull final Transaction transaction, onWriteListener.onKeyValueWritten(entryNodeReference.getLayer(), key, value); } + /** + * Creates a {@code HalfVector} from a given {@code Tuple}. + *

+ * This method assumes the vector data is stored as a byte array at the first. position (index 0) of the tuple. It + * extracts this byte array and then delegates to the {@link #vectorFromBytes(byte[])} method for the actual + * conversion. + * @param vectorTuple the tuple containing the vector data as a byte array at index 0. Must not be {@code null}. + * @return a new {@code HalfVector} instance created from the tuple's data. + * This method never returns {@code null}. + */ @Nonnull static Vector.HalfVector vectorFromTuple(final Tuple vectorTuple) { return vectorFromBytes(vectorTuple.getBytes(0)); } + /** + * Creates a {@link Vector.HalfVector} from a byte array. + *

+ * This method interprets the input byte array as a sequence of 16-bit half-precision floating-point numbers. Each + * consecutive pair of bytes is converted into a {@code Half} value, which then becomes a component of the resulting + * vector. The byte array must have an even number of bytes. + * @param vectorBytes the non-null byte array to convert. The length of this array must be even, as each pair of + * bytes represents a single {@link Half} component. + * @return a new {@link Vector.HalfVector} instance created from the byte array. + * @throws com.google.common.base.VerifyException if the length of {@code vectorBytes} is odd, + * as verified by the internal check. + */ @Nonnull static Vector.HalfVector vectorFromBytes(final byte[] vectorBytes) { final int bytesLength = vectorBytes.length; @@ -148,13 +275,29 @@ static Vector.HalfVector vectorFromBytes(final byte[] vectorBytes) { return new Vector.HalfVector(vectorHalfs); } - + /** + * Converts a {@code Vector} into a {@code Tuple}. + *

+ * This method first serializes the given vector into a byte array using the {@link #bytesFromVector(Vector)} helper + * method. It then creates a {@link Tuple} from the resulting byte array. + * @param vector the vector of {@code Half} precision floating-point numbers to convert. Cannot be null. + * @return a new, non-null {@code Tuple} instance representing the contents of the vector. + */ @Nonnull @SuppressWarnings("PrimitiveArrayArgumentToVarargsMethod") static Tuple tupleFromVector(final Vector vector) { return Tuple.from(bytesFromVector(vector)); } + /** + * Converts a {@link Vector} of {@link Half} precision floating-point numbers into a byte array. + *

+ * This method iterates through the input vector, converting each {@link Half} element into its 16-bit short + * representation. It then serializes this short into two bytes, placing them sequentially into the resulting byte + * array. The final array's length will be {@code 2 * vector.size()}. + * @param vector the vector of {@link Half} precision numbers to convert. Must not be null. + * @return a new byte array representing the serialized vector data. This array is never null. + */ @Nonnull static byte[] bytesFromVector(final Vector vector) { final byte[] vectorBytes = new byte[2 * vector.size()]; @@ -167,6 +310,17 @@ static byte[] bytesFromVector(final Vector vector) { return vectorBytes; } + /** + * Constructs a short from two bytes in a byte array in big-endian order. + *

+ * This method reads two consecutive bytes from the {@code bytes} array, starting at the given {@code offset}. The + * byte at {@code offset} is treated as the most significant byte (MSB), and the byte at {@code offset + 1} is the + * least significant byte (LSB). + * @param bytes the source byte array from which to read the short. + * @param offset the starting index in the byte array. This must be an even number + * and ensure that {@code offset + 1} is a valid index. + * @return the short value constructed from the two bytes. + */ static short shortFromBytes(final byte[] bytes, final int offset) { Verify.verify(offset % 2 == 0); int high = bytes[offset] & 0xFF; // Convert to unsigned int @@ -175,6 +329,14 @@ static short shortFromBytes(final byte[] bytes, final int offset) { return (short) ((high << 8) | low); } + /** + * Converts a {@code short} value into a 2-element byte array. + *

+ * The conversion is performed in big-endian byte order, where the most significant byte (MSB) is placed at index 0 + * and the least significant byte (LSB) is at index 1. + * @param value the {@code short} value to be converted. + * @return a new 2-element byte array representing the short value in big-endian order. + */ static byte[] bytesFromShort(final short value) { byte[] result = new byte[2]; result[0] = (byte) ((value >> 8) & 0xFF); // high byte first diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java index 725c1b6123..395159b629 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java @@ -1,5 +1,5 @@ /* - * HNSWHelpers.java + * Vector.java * * This source file is part of the FoundationDB open source project * @@ -22,6 +22,7 @@ import com.christianheina.langx.half4j.Half; import com.google.common.base.Suppliers; +import com.google.common.base.Verify; import com.google.common.collect.AbstractIterator; import com.google.common.collect.ImmutableList; @@ -39,8 +40,13 @@ import java.util.stream.Collectors; /** - * TODO. - * @param representation type + * An abstract base class representing a mathematical vector. + *

+ * This class provides a generic framework for vectors of different numerical types, + * where {@code R} is a subtype of {@link Number}. It includes common operations and functionalities like size, + * component access, equality checks, and conversions. Concrete implementations must provide specific logic for + * data type conversions and raw data representation. + * @param the type of the numbers stored in this vector, which must extend {@link Number}. */ public abstract class Vector { @Nonnull @@ -48,36 +54,106 @@ public abstract class Vector { @Nonnull protected Supplier hashCodeSupplier; + /** + * Constructs a new Vector with the given data. + *

+ * This constructor uses the provided array directly as the backing store for the vector. It does not create a + * defensive copy. Therefore, any subsequent modifications to the input array will be reflected in this vector's + * state. The contract of this constructor is that callers do not modify {@code data} after calling the constructor. + * We do not want to copy the array here for performance reasons. + * @param data the array of elements for this vector; must not be {@code null}. + * @throws NullPointerException if the provided {@code data} array is null. + */ public Vector(@Nonnull final R[] data) { this.data = data; this.hashCodeSupplier = Suppliers.memoize(this::computeHashCode); } + /** + * Returns the number of elements in the vector. + * @return the number of elements + */ public int size() { return data.length; } + /** + * Gets the component of this object at the specified dimension. + *

+ * The dimension is a zero-based index. For a 3D vector, for example, dimension 0 might correspond to the + * x-component, 1 to the y-component, and 2 to the z-component. This method provides direct access to the + * underlying data element. + * @param dimension the zero-based index of the component to retrieve. + * @return the component at the specified dimension, which is guaranteed to be non-null. + * @throws IndexOutOfBoundsException if the {@code dimension} is negative or + * greater than or equal to the number of dimensions of this object. + */ @Nonnull R getComponent(int dimension) { return data[dimension]; } + /** + * Returns the underlying data array. + *

+ * The returned array is guaranteed to be non-null. Note that this method + * returns a direct reference to the internal array, not a copy. + * @return the data array of type {@code R[]}, never {@code null}. + */ @Nonnull public R[] getData() { return data; } + /** + * Gets the raw byte data representation of this object. + *

+ * This method provides a direct, unprocessed view of the object's underlying data. The format of the byte array is + * implementation-specific and should be documented by the concrete class that implements this method. + * @return a non-null byte array containing the raw data. + */ @Nonnull public abstract byte[] getRawData(); + /** + * Converts this object into a {@code Vector} of {@link Half} precision floating-point numbers. + *

+ * As this is an abstract method, implementing classes are responsible for defining the specific conversion logic + * from their internal representation to a {@code Vector} of {@link Half} objects. If this object already is a + * {@code HalfVector} this method should return {@code this}. + * @return a non-null {@link Vector} containing the {@link Half} precision floating-point representation of this + * object. + */ @Nonnull public abstract Vector toHalfVector(); + /** + * Converts this vector into a {@link DoubleVector}. + *

+ * This method provides a way to obtain a double-precision floating-point representation of the vector. If the + * vector is already an instance of {@code DoubleVector}, this method may return the instance itself. Otherwise, + * it will create a new {@code DoubleVector} containing the same elements, which may involve a conversion of the + * underlying data type. + * @return a non-null {@link DoubleVector} representation of this vector. + */ @Nonnull public abstract DoubleVector toDoubleVector(); + /** + * Returns the number of digits to the right of the decimal point. + * @return the precision, which is the number of digits to the right of the decimal point. + */ public abstract int precision(); + /** + * Compares this vector to the specified object for equality. + *

+ * The result is {@code true} if and only if the argument is not {@code null} and is a {@code Vector} object that + * has the same data elements as this object. This method performs a deep equality check on the underlying data + * elements using {@link Objects#deepEquals(Object, Object)}. + * @param o the object to compare with this {@code Vector} for equality. + * @return {@code true} if the given object is a {@code Vector} equivalent to this vector, {@code false} otherwise. + */ @Override public boolean equals(final Object o) { if (!(o instanceof Vector)) { @@ -87,21 +163,49 @@ public boolean equals(final Object o) { return Objects.deepEquals(data, vector.data); } + /** + * Returns a hash code value for this object. The hash code is computed once and memoized. + * @return a hash code value for this object. + */ @Override public int hashCode() { return hashCodeSupplier.get(); } + /** + * Computes a hash code based on the internal {@code data} array. + * @return the computed hash code for this object. + */ private int computeHashCode() { return Arrays.hashCode(data); } + /** + * Returns a string representation of the object. + *

+ * This method provides a default string representation by calling + * {@link #toString(int)} with a predefined indentation level of 3. + * + * @return a string representation of this object with a default indentation. + */ @Override public String toString() { return toString(3); } + /** + * Generates a string representation of the data array, with an option to limit the number of dimensions shown. + *

+ * If the specified {@code limitDimensions} is less than the actual number of dimensions in the data array, + * the resulting string will be a truncated view, ending with {@code ", ..."} to indicate that more elements exist. + * Otherwise, the method returns a complete string representation of the entire array. + * @param limitDimensions The maximum number of array elements to include in the string. A non-positive + * value will cause an {@link com.google.common.base.VerifyException}. + * @return A string representation of the data array, potentially truncated. + * @throws com.google.common.base.VerifyException if {@code limitDimensions} is not positive + */ public String toString(final int limitDimensions) { + Verify.verify(limitDimensions > 0); if (limitDimensions < data.length) { return "[" + Arrays.stream(Arrays.copyOfRange(data, 0, limitDimensions)) .map(String::valueOf) @@ -113,6 +217,10 @@ public String toString(final int limitDimensions) { } } + /** + * A vector class encoding a vector over half components. Conversion to {@link DoubleVector} is supported and + * memoized. + */ public static class HalfVector extends Vector { @Nonnull private final Supplier toDoubleVectorSupplier; @@ -168,6 +276,10 @@ public static HalfVector halfVectorFromBytes(@Nonnull final byte[] vectorBytes) } } + /** + * A vector class encoding a vector over double components. Conversion to {@link HalfVector} is supported and + * memoized. + */ public static class DoubleVector extends Vector { @Nonnull private final Supplier toHalfVectorSupplier; @@ -211,18 +323,54 @@ public byte[] getRawData() { } } + /** + * Calculates the distance between two vectors using a specified metric. + *

+ * This static utility method provides a convenient way to compute the distance by handling the conversion of + * generic {@code Vector} objects to primitive {@code double} arrays. The actual distance computation is then + * delegated to the provided {@link Metric} instance. + * @param the type of the numbers in the vectors, which must extend {@link Number}. + * @param metric the {@link Metric} to use for the distance calculation. + * @param vector1 the first vector. + * @param vector2 the second vector. + * @return the calculated distance between the two vectors as a {@code double}. + */ public static double distance(@Nonnull Metric metric, @Nonnull final Vector vector1, @Nonnull final Vector vector2) { return metric.distance(vector1.toDoubleVector().getData(), vector2.toDoubleVector().getData()); } + /** + * Calculates the comparative distance between two vectors using a specified metric. + *

+ * This utility method converts the input vectors, which can contain any {@link Number} type, into primitive double + * arrays. It then delegates the actual distance computation to the {@code comparativeDistance} method of the + * provided {@link Metric} object. + * @param the type of the numbers in the vectors, which must extend {@link Number}. + * @param metric the {@link Metric} to use for the distance calculation. Must not be null. + * @param vector1 the first vector for the comparison. Must not be null. + * @param vector2 the second vector for the comparison. Must not be null. + * @return the calculated comparative distance as a {@code double}. + * @throws NullPointerException if {@code metric}, {@code vector1}, or {@code vector2} is null. + */ static double comparativeDistance(@Nonnull Metric metric, @Nonnull final Vector vector1, @Nonnull final Vector vector2) { return metric.comparativeDistance(vector1.toDoubleVector().getData(), vector2.toDoubleVector().getData()); } + /** + * Creates a {@code Vector} instance from its byte representation. + *

+ * This method deserializes a byte array into a vector object. The precision parameter is crucial for correctly + * interpreting the byte data. Currently, this implementation only supports 16-bit precision, which corresponds to a + * {@code HalfVector}. + * @param bytes the non-null byte array representing the vector. + * @param precision the precision of the vector's elements in bits (e.g., 16 for half-precision floats). + * @return a new {@code Vector} instance created from the byte array. + * @throws UnsupportedOperationException if the specified {@code precision} is not yet supported. + */ public static Vector fromBytes(@Nonnull final byte[] bytes, int precision) { if (precision == 16) { return HalfVector.halfVectorFromBytes(bytes); @@ -231,6 +379,12 @@ public static Vector fromBytes(@Nonnull final byte[] bytes, int precision) { throw new UnsupportedOperationException("not implemented yet"); } + /** + * Abstract iterator implementation to read the IVecs/FVecs data format that is used by publicly available + * embedding datasets. + * @param the component type of the vectors which must extends {@link Number} + * @param the type of object this iterator creates and uses to represent a stored vector in memory + */ public abstract static class StoredVecsIterator extends AbstractIterator { @Nonnull private final FileChannel fileChannel; @@ -288,6 +442,10 @@ protected T computeNext() { } } + /** + * Iterator to read floating point vectors from a {@link FileChannel} providing an iterator of + * {@link DoubleVector}s. + */ public static class StoredFVecsIterator extends StoredVecsIterator { public StoredFVecsIterator(@Nonnull final FileChannel fileChannel) { super(fileChannel); @@ -312,6 +470,9 @@ protected DoubleVector toTarget(@Nonnull final Double[] components) { } } + /** + * Iterator to read vectors from a {@link FileChannel} into a list of integers. + */ public static class StoredIVecsIterator extends StoredVecsIterator> { public StoredIVecsIterator(@Nonnull final FileChannel fileChannel) { super(fileChannel); diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java index 5565b7f9f6..791fd0728a 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java @@ -19,6 +19,6 @@ */ /** - * Classes and interfaces related to the Hilbert R-tree implementation. + * Classes and interfaces related to the HNSW implementation as used for vector indexes. */ package com.apple.foundationdb.async.hnsw; diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java index c746516a03..795f70cd09 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java @@ -295,16 +295,17 @@ private void validateSIFTSmall(@Nonnull final HNSW hnsw, final int k) throws IOE final List> results = db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVector).join()); final long endTs = System.nanoTime(); - logger.trace("retrieved result in elapsedTimeMs={}", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); + logger.info("retrieved result in elapsedTimeMs={}, reading numNodes={}, readBytes={}", + TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), + onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer()); int recallCount = 0; for (NodeReferenceAndNode nodeReferenceAndNode : results) { final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); final int primaryKeyIndex = (int)nodeReferenceWithDistance.getPrimaryKey().getLong(0); - logger.trace("retrieved result nodeId = {} at distance = {} reading numNodes={}, readBytes={}", - primaryKeyIndex, nodeReferenceWithDistance.getDistance(), - onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer()); + logger.trace("retrieved result nodeId = {} at distance = {} ", + primaryKeyIndex, nodeReferenceWithDistance.getDistance()); if (groundTruthIndices.contains(primaryKeyIndex)) { recallCount ++; } diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java index d751fe5f00..610c47c226 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java @@ -171,4 +171,4 @@ public void dotProductMetricComparativeDistanceWithOrthogonalVectorsReturnsZeroT assertEquals(expected, actual, 0.00001); } -} \ No newline at end of file +} From b2632923ac8b75ecbf84deef8eb7381c756eeffe Mon Sep 17 00:00:00 2001 From: Normen Seemann Date: Fri, 19 Sep 2025 14:25:29 +0200 Subject: [PATCH 07/10] added tests --- .../apple/foundationdb/async/hnsw/HNSW.java | 43 +++++-- .../async/hnsw/HNSWHelpersTest.java | 2 +- ...NSWModificationTest.java => HNSWTest.java} | 112 ++++++++++++++---- 3 files changed, 120 insertions(+), 37 deletions(-) rename fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/{HNSWModificationTest.java => HNSWTest.java} (82%) diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java index 798ff7e1a1..a92e44c3c3 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java @@ -88,6 +88,7 @@ public class HNSW { public static final int MAX_CONCURRENT_SEARCHES = 10; @Nonnull public static final Random DEFAULT_RANDOM = new Random(0L); @Nonnull public static final Metric DEFAULT_METRIC = new Metric.EuclideanMetric(); + public static final boolean DEFAULT_USE_INLINING = false; public static final int DEFAULT_M = 16; public static final int DEFAULT_M_MAX = DEFAULT_M; public static final int DEFAULT_M_MAX_0 = 2 * DEFAULT_M; @@ -119,6 +120,7 @@ public static class Config { private final Random random; @Nonnull private final Metric metric; + private final boolean useInlining; private final int m; private final int mMax; private final int mMax0; @@ -130,6 +132,7 @@ public static class Config { protected Config() { this.random = DEFAULT_RANDOM; this.metric = DEFAULT_METRIC; + this.useInlining = DEFAULT_USE_INLINING; this.m = DEFAULT_M; this.mMax = DEFAULT_M_MAX; this.mMax0 = DEFAULT_M_MAX_0; @@ -139,11 +142,12 @@ protected Config() { this.keepPrunedConnections = DEFAULT_KEEP_PRUNED_CONNECTIONS; } - protected Config(@Nonnull final Random random, @Nonnull final Metric metric, final int m, final int mMax, - final int mMax0, final int efSearch, final int efConstruction, final boolean extendCandidates, - final boolean keepPrunedConnections) { + protected Config(@Nonnull final Random random, @Nonnull final Metric metric, final boolean useInlining, + final int m, final int mMax, final int mMax0, final int efSearch, final int efConstruction, + final boolean extendCandidates, final boolean keepPrunedConnections) { this.random = random; this.metric = metric; + this.useInlining = useInlining; this.m = m; this.mMax = mMax; this.mMax0 = mMax0; @@ -163,6 +167,10 @@ public Metric getMetric() { return metric; } + public boolean isUseInlining() { + return useInlining; + } + public int getM() { return m; } @@ -193,16 +201,16 @@ public boolean isKeepPrunedConnections() { @Nonnull public ConfigBuilder toBuilder() { - return new ConfigBuilder(getRandom(), getMetric(), getM(), getMMax(), getMMax0(), getEfSearch(), - getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections()); + return new ConfigBuilder(getRandom(), getMetric(), isUseInlining(), getM(), getMMax(), getMMax0(), + getEfSearch(), getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections()); } @Override @Nonnull public String toString() { - return "Config[metric=" + getMetric() + "M=" + getM() + " , MMax=" + getMMax() + " , MMax0=" + getMMax0() + - ", efSearch=" + getEfSearch() + ", efConstruction=" + getEfConstruction() + - ", isExtendCandidates=" + isExtendCandidates() + + return "Config[metric=" + getMetric() + "isUseInlining" + isUseInlining() + "M=" + getM() + + " , MMax=" + getMMax() + " , MMax0=" + getMMax0() + ", efSearch=" + getEfSearch() + + ", efConstruction=" + getEfConstruction() + ", isExtendCandidates=" + isExtendCandidates() + ", isKeepPrunedConnections=" + isKeepPrunedConnections() + "]"; } } @@ -219,6 +227,7 @@ public static class ConfigBuilder { private Random random = DEFAULT_RANDOM; @Nonnull private Metric metric = DEFAULT_METRIC; + private boolean useInlining = DEFAULT_USE_INLINING; private int m = DEFAULT_M; private int mMax = DEFAULT_M_MAX; private int mMax0 = DEFAULT_M_MAX_0; @@ -230,11 +239,12 @@ public static class ConfigBuilder { public ConfigBuilder() { } - public ConfigBuilder(@Nonnull Random random, @Nonnull final Metric metric, final int m, final int mMax, - final int mMax0, final int efSearch, final int efConstruction, + public ConfigBuilder(@Nonnull Random random, @Nonnull final Metric metric, final boolean useInlining, + final int m, final int mMax, final int mMax0, final int efSearch, final int efConstruction, final boolean extendCandidates, final boolean keepPrunedConnections) { this.random = random; this.metric = metric; + this.useInlining = useInlining; this.m = m; this.mMax = mMax; this.mMax0 = mMax0; @@ -266,6 +276,15 @@ public ConfigBuilder setMetric(@Nonnull final Metric metric) { return this; } + public boolean isUseInlining() { + return useInlining; + } + + public ConfigBuilder setUseInlining(final boolean useInlining) { + this.useInlining = useInlining; + return this; + } + public int getM() { return m; } @@ -333,7 +352,7 @@ public ConfigBuilder setKeepPrunedConnections(final boolean keepPrunedConnection } public Config build() { - return new Config(getRandom(), getMetric(), getM(), getMMax(), getMMax0(), getEfSearch(), + return new Config(getRandom(), getMetric(), isUseInlining(), getM(), getMMax(), getMMax0(), getEfSearch(), getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections()); } } @@ -1709,7 +1728,7 @@ public void scanLayer(@Nonnull final Database db, */ @Nonnull private StorageAdapter getStorageAdapterForLayer(final int layer) { - return false && layer > 0 + return config.isUseInlining() && layer > 0 ? new InliningStorageAdapter(getConfig(), InliningNode.factory(), getSubspace(), getOnWriteListener(), getOnReadListener()) : new CompactStorageAdapter(getConfig(), CompactNode.factory(), getSubspace(), getOnWriteListener(), diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWHelpersTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWHelpersTest.java index 831d3774d1..f138fd8417 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWHelpersTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWHelpersTest.java @@ -72,4 +72,4 @@ public void halfValueOf_PositiveDouble_ReturnsCorrectHalfValue_Test() { Half result = HNSWHelpers.halfValueOf(inputValue); assertEquals(expected, result); } -} \ No newline at end of file +} diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java similarity index 82% rename from fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java rename to fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java index 795f70cd09..a600d3030c 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java @@ -1,5 +1,5 @@ /* - * HNSWModificationTest.java + * HNSWTest.java * * This source file is part of the FoundationDB open source project * @@ -34,6 +34,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; +import com.google.common.collect.ObjectArrays; +import com.google.common.collect.Sets; import org.assertj.core.util.Lists; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; @@ -42,6 +44,9 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.parallel.Execution; import org.junit.jupiter.api.parallel.ExecutionMode; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -55,12 +60,16 @@ import java.util.Comparator; import java.util.Iterator; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Random; import java.util.Set; +import java.util.TreeSet; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; +import java.util.stream.LongStream; +import java.util.stream.Stream; /** * Tests testing insert/update/deletes of data into/in/from {@link RTree}s. @@ -69,8 +78,8 @@ @SuppressWarnings("checkstyle:AbbreviationAsWordInName") @Tag(Tags.RequiresFDB) @Tag(Tags.Slow) -public class HNSWModificationTest { - private static final Logger logger = LoggerFactory.getLogger(HNSWModificationTest.class); +public class HNSWTest { + private static final Logger logger = LoggerFactory.getLogger(HNSWTest.class); private static final int NUM_TEST_RUNS = 5; private static final int NUM_SAMPLES = 10_000; @@ -88,9 +97,16 @@ public void setUpDb() { db = dbExtension.getDatabase(); } - @Test - public void testCompactSerialization() { - final Random random = new Random(0); + static Stream randomSeeds() { + return LongStream.generate(() -> new Random().nextLong()) + .limit(5) + .boxed(); + } + + @ParameterizedTest(name = "seed={0}") + @MethodSource("randomSeeds") + public void testCompactSerialization(final Long seed) { + final Random random = new Random(seed); final CompactStorageAdapter storageAdapter = new CompactStorageAdapter(HNSW.DEFAULT_CONFIG, CompactNode.factory(), rtSubspace.getSubspace(), OnWriteListener.NOOP, OnReadListener.NOOP); @@ -126,9 +142,10 @@ public void testCompactSerialization() { }).join()); } - @Test - public void testInliningSerialization() { - final Random random = new Random(0); + @ParameterizedTest(name = "seed={0}") + @MethodSource("randomSeeds") + public void testInliningSerialization(final Long seed) { + final Random random = new Random(seed); final InliningStorageAdapter storageAdapter = new InliningStorageAdapter(HNSW.DEFAULT_CONFIG, InliningNode.factory(), rtSubspace.getSubspace(), OnWriteListener.NOOP, OnReadListener.NOOP); @@ -160,9 +177,26 @@ public void testInliningSerialization() { )).join()); } - @Test - public void testBasicInsert() { - final Random random = new Random(0); + static Stream randomSeedsWithOptions() { + Sets.cartesianProduct(ImmutableSet.of(true, false), + ImmutableSet.of(true, false), + ImmutableSet.of(true, false)); + return Sets.cartesianProduct(ImmutableSet.of(true, false), + ImmutableSet.of(true, false), + ImmutableSet.of(true, false)) + .stream() + .flatMap(arguments -> + LongStream.generate(() -> new Random().nextLong()) + .limit(2) + .mapToObj(seed -> Arguments.of(ObjectArrays.concat(seed, arguments.toArray())))); + } + + @ParameterizedTest(name = "seed={0} useInlining={1} extendCandidates={2} keepPrunedConnections={3}") + @MethodSource("randomSeedsWithOptions") + public void testBasicInsert(final long seed, final boolean useInlining, final boolean extendCandidates, + final boolean keepPrunedConnections) { + final Random random = new Random(seed); + final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); final TestOnReadListener onReadListener = new TestOnReadListener(); @@ -170,29 +204,57 @@ public void testBasicInsert() { final int dimensions = 128; final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), HNSW.DEFAULT_CONFIG.toBuilder().setMetric(Metrics.EUCLIDEAN_METRIC.getMetric()) + .setUseInlining(useInlining).setExtendCandidates(extendCandidates) + .setKeepPrunedConnections(keepPrunedConnections) .setM(32).setMMax(32).setMMax0(64).build(), OnWriteListener.NOOP, onReadListener); + final int k = 10; + final HalfVector queryVector = createRandomVector(random, dimensions); + final TreeSet nodesOrderedByDistance = + new TreeSet<>(Comparator.comparing(NodeReferenceWithDistance::getDistance)); + for (int i = 0; i < 1000;) { i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, - tr -> new NodeReferenceWithVector(createNextPrimaryKey(nextNodeIdAtomic), createRandomVector(random, dimensions))); + tr -> { + final var primaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final HalfVector dataVector = createRandomVector(random, dimensions); + final double distance = Vector.comparativeDistance(metric, dataVector, queryVector); + final NodeReferenceWithDistance nodeReferenceWithDistance = + new NodeReferenceWithDistance(primaryKey, dataVector, distance); + nodesOrderedByDistance.add(nodeReferenceWithDistance); + if (nodesOrderedByDistance.size() > k) { + nodesOrderedByDistance.pollLast(); + } + return nodeReferenceWithDistance; + }); } onReadListener.reset(); final long beginTs = System.nanoTime(); - final List> result = - db.run(tr -> hnsw.kNearestNeighborsSearch(tr, 10, 100, createRandomVector(random, dimensions)).join()); + final List> results = + db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVector).join()); final long endTs = System.nanoTime(); - for (NodeReferenceAndNode nodeReferenceAndNode : result) { + final ImmutableSet trueNN = + ImmutableSet.copyOf(NodeReference.primaryKeys(nodesOrderedByDistance)); + + int recallCount = 0; + for (NodeReferenceAndNode nodeReferenceAndNode : results) { final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); logger.info("nodeId ={} at distance={}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), nodeReferenceWithDistance.getDistance()); + if (trueNN.contains(nodeReferenceAndNode.getNode().getPrimaryKey())) { + recallCount ++; + } } - System.out.println(onReadListener.getNodeCountByLayer()); - System.out.println(onReadListener.getBytesReadByLayer()); + final double recall = (double)recallCount / (double)k; + Assertions.assertTrue(recall > 0.93); - logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); + logger.info("search transaction took elapsedTime={}ms; read nodes={}, read bytes={}, recall={}", + TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), + onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer(), + String.format(Locale.ROOT, "%.2f", recall * 100.0d)); } private int basicInsertBatch(final HNSW hnsw, final int batchSize, @@ -210,8 +272,9 @@ private int basicInsertBatch(final HNSW hnsw, final int batchSize, hnsw.insert(tr, newNodeReference).join(); } final long endTs = System.nanoTime(); - logger.info("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", batchSize, nextNodeId, - TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), onReadListener.getNodeCountByLayer(), onReadListener.getSumMByLayer()); + logger.info("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", + batchSize, nextNodeId, TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), + onReadListener.getNodeCountByLayer(), onReadListener.getSumMByLayer()); return batchSize; }); } @@ -233,8 +296,9 @@ private int insertBatch(final HNSW hnsw, final int batchSize, } hnsw.insertBatch(tr, nodeReferenceWithVectorBuilder.build()).join(); final long endTs = System.nanoTime(); - logger.info("inserted batch batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", batchSize, nextNodeId, - TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), onReadListener.getNodeCountByLayer(), onReadListener.getSumMByLayer()); + logger.info("inserted batch batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", + batchSize, nextNodeId, TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), + onReadListener.getNodeCountByLayer(), onReadListener.getSumMByLayer()); return batchSize; }); } @@ -314,7 +378,7 @@ private void validateSIFTSmall(@Nonnull final HNSW hnsw, final int k) throws IOE final double recall = (double)recallCount / k; Assertions.assertTrue(recall > 0.93); - logger.info("query returned results recall={}", String.format("%.2f", recall * 100.0d)); + logger.info("query returned results recall={}", String.format(Locale.ROOT, "%.2f", recall * 100.0d)); } } } From 3552ca6c77cbc5e27112407d33c1aad9958bb7ce Mon Sep 17 00:00:00 2001 From: Normen Seemann Date: Fri, 19 Sep 2025 16:01:13 +0200 Subject: [PATCH 08/10] increase timeout for test case --- .../test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java index a600d3030c..ae31057195 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java @@ -41,6 +41,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.parallel.Execution; import org.junit.jupiter.api.parallel.ExecutionMode; @@ -304,6 +305,7 @@ private int insertBatch(final HNSW hnsw, final int batchSize, } @Test + @Timeout(value = 10, unit = TimeUnit.MINUTES) public void testSIFTInsertSmall() throws Exception { final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); final int k = 100; @@ -384,6 +386,7 @@ private void validateSIFTSmall(@Nonnull final HNSW hnsw, final int k) throws IOE } @Test + @Timeout(value = 10, unit = TimeUnit.MINUTES) public void testSIFTInsertSmallUsingBatchAPI() throws Exception { final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); final int k = 100; From a804023d02d7e225e90b89a6183c3f3496a7933f Mon Sep 17 00:00:00 2001 From: Normen Seemann Date: Tue, 23 Sep 2025 16:30:37 +0200 Subject: [PATCH 09/10] refactored Vector class to be more aligned with math libraries --- .../foundationdb/async/hnsw/CompactNode.java | 14 +- .../async/hnsw/CompactStorageAdapter.java | 3 +- .../async/hnsw/EntryNodeReference.java | 3 +- .../apple/foundationdb/async/hnsw/HNSW.java | 31 ++-- .../foundationdb/async/hnsw/InliningNode.java | 5 +- .../async/hnsw/InliningStorageAdapter.java | 3 +- .../apple/foundationdb/async/hnsw/Metric.java | 28 +-- .../foundationdb/async/hnsw/Metrics.java | 2 +- .../apple/foundationdb/async/hnsw/Node.java | 3 +- .../foundationdb/async/hnsw/NodeFactory.java | 3 +- .../async/hnsw/NodeReferenceWithDistance.java | 3 +- .../async/hnsw/NodeReferenceWithVector.java | 11 +- .../async/hnsw/StorageAdapter.java | 161 +++++++++++++--- .../apple/foundationdb/async/hnsw/Vector.java | 174 +++++++++--------- .../foundationdb/async/hnsw/HNSWTest.java | 31 +--- .../foundationdb/async/hnsw/MetricTest.java | 40 ++-- .../foundationdb/async/hnsw/VectorTest.java | 79 ++++++++ 17 files changed, 385 insertions(+), 209 deletions(-) create mode 100644 fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/VectorTest.java diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java index e58f005dd1..b594e70a2f 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java @@ -33,7 +33,7 @@ * Represents a compact node within a graph structure, extending {@link AbstractNode}. *

* This node type is considered "compact" because it directly stores its associated - * data vector of type {@code Vector}. It is used to represent a vector in a + * data vector of type {@link Vector}. It is used to represent a vector in a * vector space and maintains references to its neighbors via {@link NodeReference} objects. * * @see AbstractNode @@ -46,7 +46,7 @@ public class CompactNode extends AbstractNode { @Nonnull @Override @SpotBugsSuppressWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") - public Node create(@Nonnull final Tuple primaryKey, @Nullable final Vector vector, + public Node create(@Nonnull final Tuple primaryKey, @Nullable final Vector vector, @Nonnull final List neighbors) { return new CompactNode(primaryKey, Objects.requireNonNull(vector), (List)neighbors); } @@ -59,7 +59,7 @@ public NodeKind getNodeKind() { }; @Nonnull - private final Vector vector; + private final Vector vector; /** * Constructs a new {@code CompactNode} instance. @@ -69,11 +69,11 @@ public NodeKind getNodeKind() { * {@code primaryKey} and {@code neighbors} to the superclass constructor. * * @param primaryKey the primary key that uniquely identifies this node; must not be {@code null}. - * @param vector the data vector of type {@code Vector} associated with this node; must not be {@code null}. + * @param vector the data vector of type {@code Vector} associated with this node; must not be {@code null}. * @param neighbors a list of {@link NodeReference} objects representing the neighbors of this node; must not be * {@code null}. */ - public CompactNode(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, + public CompactNode(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, @Nonnull final List neighbors) { super(primaryKey, neighbors); this.vector = vector; @@ -92,7 +92,7 @@ public CompactNode(@Nonnull final Tuple primaryKey, @Nonnull final Vector */ @Nonnull @Override - public NodeReference getSelfReference(@Nullable final Vector vector) { + public NodeReference getSelfReference(@Nullable final Vector vector) { return new NodeReference(getPrimaryKey()); } @@ -112,7 +112,7 @@ public NodeKind getKind() { * @return the non-null vector of {@link Half} objects. */ @Nonnull - public Vector getVector() { + public Vector getVector() { return vector; } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java index 0c38296807..826ba57f9b 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java @@ -30,7 +30,6 @@ import com.apple.foundationdb.subspace.Subspace; import com.apple.foundationdb.tuple.ByteArrayUtil; import com.apple.foundationdb.tuple.Tuple; -import com.christianheina.langx.half4j.Half; import com.google.common.base.Verify; import com.google.common.collect.Lists; import org.slf4j.Logger; @@ -208,7 +207,7 @@ private Node nodeFromKeyValuesTuples(@Nonnull final Tuple primary private Node compactNodeFromTuples(@Nonnull final Tuple primaryKey, @Nonnull final Tuple vectorTuple, @Nonnull final Tuple neighborsTuple) { - final Vector vector = StorageAdapter.vectorFromTuple(vectorTuple); + final Vector vector = StorageAdapter.vectorFromTuple(vectorTuple); final List nodeReferences = Lists.newArrayListWithExpectedSize(neighborsTuple.size()); for (int i = 0; i < neighborsTuple.size(); i ++) { diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java index 4a9cbb0ae5..f8b9587bdd 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java @@ -21,7 +21,6 @@ package com.apple.foundationdb.async.hnsw; import com.apple.foundationdb.tuple.Tuple; -import com.christianheina.langx.half4j.Half; import javax.annotation.Nonnull; import java.util.Objects; @@ -48,7 +47,7 @@ class EntryNodeReference extends NodeReferenceWithVector { * @param vector the vector data associated with the node. Must not be {@code null}. * @param layer the layer number where this entry node is located. */ - public EntryNodeReference(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, final int layer) { + public EntryNodeReference(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, final int layer) { super(primaryKey, vector); this.layer = layer; } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java index a92e44c3c3..47ddf7117a 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java @@ -28,7 +28,6 @@ import com.apple.foundationdb.async.MoreAsyncUtil; import com.apple.foundationdb.subspace.Subspace; import com.apple.foundationdb.tuple.Tuple; -import com.christianheina.langx.half4j.Half; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; @@ -239,7 +238,7 @@ public static class ConfigBuilder { public ConfigBuilder() { } - public ConfigBuilder(@Nonnull Random random, @Nonnull final Metric metric, final boolean useInlining, + public ConfigBuilder(@Nonnull final Random random, @Nonnull final Metric metric, final boolean useInlining, final int m, final int mMax, final int mMax0, final int efSearch, final int efConstruction, final boolean extendCandidates, final boolean keepPrunedConnections) { this.random = random; @@ -481,7 +480,7 @@ public OnReadListener getOnReadListener() { public CompletableFuture>> kNearestNeighborsSearch(@Nonnull final ReadTransaction readTransaction, final int k, final int efSearch, - @Nonnull final Vector queryVector) { + @Nonnull final Vector queryVector) { return StorageAdapter.fetchEntryNodeReference(readTransaction, getSubspace(), getOnReadListener()) .thenCompose(entryPointAndLayer -> { if (entryPointAndLayer == null) { @@ -572,7 +571,7 @@ private CompletableFuture g @Nonnull final ReadTransaction readTransaction, @Nonnull final NodeReferenceWithDistance entryNeighbor, final int layer, - @Nonnull final Vector queryVector) { + @Nonnull final Vector queryVector) { if (storageAdapter.getNodeKind() == NodeKind.INLINING) { return greedySearchInliningLayer(storageAdapter.asInliningStorageAdapter(), readTransaction, entryNeighbor, layer, queryVector); } else { @@ -612,7 +611,7 @@ private CompletableFuture greedySearchInliningLayer(@ @Nonnull final ReadTransaction readTransaction, @Nonnull final NodeReferenceWithDistance entryNeighbor, final int layer, - @Nonnull final Vector queryVector) { + @Nonnull final Vector queryVector) { Verify.verify(layer > 0); final Metric metric = getConfig().getMetric(); final AtomicReference currentNodeReferenceAtomic = @@ -685,7 +684,7 @@ private CompletableFuture final int layer, final int efSearch, @Nonnull final Map> nodeCache, - @Nonnull final Vector queryVector) { + @Nonnull final Vector queryVector) { final Set visited = Sets.newConcurrentHashSet(NodeReference.primaryKeys(entryNeighbors)); final Queue candidates = new PriorityBlockingQueue<>(config.getM(), @@ -995,7 +994,7 @@ public CompletableFuture insert(@Nonnull final Transaction transaction, @N */ @Nonnull public CompletableFuture insert(@Nonnull final Transaction transaction, @Nonnull final Tuple newPrimaryKey, - @Nonnull final Vector newVector) { + @Nonnull final Vector newVector) { final Metric metric = getConfig().getMetric(); final int insertionLayer = insertionLayer(getConfig().getRandom()); @@ -1104,7 +1103,7 @@ public CompletableFuture insertBatch(@Nonnull final Transaction transactio return CompletableFuture.completedFuture(null); } - final Vector itemVector = item.getVector(); + final Vector itemVector = item.getVector(); final int itemL = item.getLayer(); final NodeReferenceWithDistance initialNodeReference = @@ -1128,7 +1127,7 @@ public CompletableFuture insertBatch(@Nonnull final Transaction transactio (index, currentEntryNodeReference) -> { final NodeReferenceWithLayer item = batchWithLayers.get(index); final Tuple itemPrimaryKey = item.getPrimaryKey(); - final Vector itemVector = item.getVector(); + final Vector itemVector = item.getVector(); final int itemL = item.getLayer(); final EntryNodeReference newEntryNodeReference; @@ -1202,7 +1201,7 @@ public CompletableFuture insertBatch(@Nonnull final Transaction transactio @Nonnull private CompletableFuture insertIntoLayers(@Nonnull final Transaction transaction, @Nonnull final Tuple newPrimaryKey, - @Nonnull final Vector newVector, + @Nonnull final Vector newVector, @Nonnull final NodeReferenceWithDistance nodeReference, final int lMax, final int insertionLayer) { @@ -1258,7 +1257,7 @@ private CompletableFuture nearestNeighbors, int layer, @Nonnull final Tuple newPrimaryKey, - @Nonnull final Vector newVector) { + @Nonnull final Vector newVector) { if (logger.isDebugEnabled()) { logger.debug("begin insert key={} at layer={}", newPrimaryKey, layer); } @@ -1490,7 +1489,7 @@ private CompletableFuture final int m, final boolean isExtendCandidates, @Nonnull final Map> nodeCache, - @Nonnull final Vector vector) { + @Nonnull final Vector vector) { return extendCandidatesIfNecessary(storageAdapter, readTransaction, nearestNeighbors, layer, isExtendCandidates, nodeCache, vector) .thenApply(extendedCandidates -> { final List selected = Lists.newArrayListWithExpectedSize(m); @@ -1575,7 +1574,7 @@ private CompletableFuture int layer, boolean isExtendCandidates, @Nonnull final Map> nodeCache, - @Nonnull final Vector vector) { + @Nonnull final Vector vector) { if (isExtendCandidates) { final Metric metric = getConfig().getMetric(); @@ -1639,7 +1638,7 @@ private CompletableFuture */ private void writeLonelyNodes(@Nonnull final Transaction transaction, @Nonnull final Tuple primaryKey, - @Nonnull final Vector vector, + @Nonnull final Vector vector, final int highestLayerInclusive, final int lowestLayerExclusive) { for (int layer = highestLayerInclusive; layer > lowestLayerExclusive; layer --) { @@ -1667,7 +1666,7 @@ private void writeLonelyNodeOnLayer(@Nonnull final Sto @Nonnull final Transaction transaction, final int layer, @Nonnull final Tuple primaryKey, - @Nonnull final Vector vector) { + @Nonnull final Vector vector) { storageAdapter.writeNode(transaction, storageAdapter.getNodeFactory() .create(primaryKey, vector, ImmutableList.of()), layer, @@ -1777,7 +1776,7 @@ private void info(@Nonnull final Consumer loggerConsumer) { private static class NodeReferenceWithLayer extends NodeReferenceWithVector { private final int layer; - public NodeReferenceWithLayer(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, + public NodeReferenceWithLayer(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, final int layer) { super(primaryKey, vector); this.layer = layer; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java index 56d39227d1..c8161b825c 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java @@ -22,7 +22,6 @@ import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; import com.apple.foundationdb.tuple.Tuple; -import com.christianheina.langx.half4j.Half; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -44,7 +43,7 @@ public class InliningNode extends AbstractNode { @Nonnull @Override public Node create(@Nonnull final Tuple primaryKey, - @Nullable final Vector vector, + @Nullable final Vector vector, @Nonnull final List neighbors) { return new InliningNode(primaryKey, (List)neighbors); } @@ -85,7 +84,7 @@ public InliningNode(@Nonnull final Tuple primaryKey, @Nonnull @Override @SpotBugsSuppressWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") - public NodeReferenceWithVector getSelfReference(@Nullable final Vector vector) { + public NodeReferenceWithVector getSelfReference(@Nullable final Vector vector) { return new NodeReferenceWithVector(getPrimaryKey(), Objects.requireNonNull(vector)); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java index c63f2135e0..58d8795777 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java @@ -30,7 +30,6 @@ import com.apple.foundationdb.subspace.Subspace; import com.apple.foundationdb.tuple.ByteArrayUtil; import com.apple.foundationdb.tuple.Tuple; -import com.christianheina.langx.half4j.Half; import com.google.common.collect.ImmutableList; import javax.annotation.Nonnull; @@ -182,7 +181,7 @@ private NodeReferenceWithVector neighborFromRaw(final int layer, final @Nonnull final Tuple neighborValueTuple = Tuple.fromBytes(value); final Tuple neighborPrimaryKey = neighborKeyTuple.getNestedTuple(2); // neighbor primary key - final Vector neighborVector = StorageAdapter.vectorFromTuple(neighborValueTuple); // the entire value is the vector + final Vector neighborVector = StorageAdapter.vectorFromTuple(neighborValueTuple); // the entire value is the vector return new NodeReferenceWithVector(neighborPrimaryKey, neighborVector); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java index adb1b799b3..a49457677f 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java @@ -45,7 +45,7 @@ public interface Metric { * @throws IllegalArgumentException if the vectors have different lengths. * @throws NullPointerException if either {@code vector1} or {@code vector2} is null. */ - double distance(Double[] vector1, Double[] vector2); + double distance(@Nonnull double[] vector1, @Nonnull double[] vector2); /** * Calculates a comparative distance between two vectors. The comparative distance is used in contexts such as @@ -53,15 +53,15 @@ public interface Metric { * by this method do not need to follow proper metric invariants: The distance can be negative; the distance * does not need to follow triangle inequality. *

- * This method is an alias for {@link #distance(Double[], Double[])} under normal circumstances. It is not for e.g. + * This method is an alias for {@link #distance(double[], double[])} under normal circumstances. It is not for e.g. * {@link DotProductMetric} where the distance is the negative dot product. * - * @param vector1 the first vector, represented as an array of {@code Double}. - * @param vector2 the second vector, represented as an array of {@code Double}. + * @param vector1 the first vector, represented as an array of {@code double}. + * @param vector2 the second vector, represented as an array of {@code double}. * * @return the distance between the two vectors. */ - default double comparativeDistance(Double[] vector1, Double[] vector2) { + default double comparativeDistance(@Nonnull double[] vector1, @Nonnull double[] vector2) { return distance(vector1, vector2); } @@ -70,7 +70,7 @@ default double comparativeDistance(Double[] vector1, Double[] vector2) { * @param vector1 The first vector. * @param vector2 The second vector. */ - private static void validate(Double[] vector1, Double[] vector2) { + private static void validate(double[] vector1, double[] vector2) { if (vector1 == null || vector2 == null) { throw new IllegalArgumentException("Vectors cannot be null"); } @@ -93,7 +93,7 @@ private static void validate(Double[] vector1, Double[] vector2) { */ class ManhattanMetric implements Metric { @Override - public double distance(final Double[] vector1, final Double[] vector2) { + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { Metric.validate(vector1, vector2); double sumOfAbsDiffs = 0.0; @@ -119,7 +119,7 @@ public String toString() { */ class EuclideanMetric implements Metric { @Override - public double distance(final Double[] vector1, final Double[] vector2) { + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { Metric.validate(vector1, vector2); return Math.sqrt(EuclideanSquareMetric.distanceInternal(vector1, vector2)); @@ -147,12 +147,12 @@ public String toString() { */ class EuclideanSquareMetric implements Metric { @Override - public double distance(final Double[] vector1, final Double[] vector2) { + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { Metric.validate(vector1, vector2); return distanceInternal(vector1, vector2); } - private static double distanceInternal(final Double[] vector1, final Double[] vector2) { + private static double distanceInternal(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { double sumOfSquares = 0.0d; for (int i = 0; i < vector1.length; i++) { double diff = vector1[i] - vector2[i]; @@ -178,7 +178,7 @@ public String toString() { */ class CosineMetric implements Metric { @Override - public double distance(final Double[] vector1, final Double[] vector2) { + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { Metric.validate(vector1, vector2); double dotProduct = 0.0; @@ -211,19 +211,19 @@ public String toString() { *

* This metric calculates the inverted dot product of two vectors. It is not a true metric as the dot product can * be positive at which point the distance is negative. In order to make callers aware of this fact, this distance - * only allows {@link Metric#comparativeDistance(Double[], Double[])} to be called. + * only allows {@link Metric#comparativeDistance(double[], double[])} to be called. * * @see Dot Product * @see DotProductMetric */ class DotProductMetric implements Metric { @Override - public double distance(final Double[] vector1, final Double[] vector2) { + public double distance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { throw new UnsupportedOperationException("dot product metric is not a true metric and can only be used for ranking"); } @Override - public double comparativeDistance(final Double[] vector1, final Double[] vector2) { + public double comparativeDistance(@Nonnull final double[] vector1, @Nonnull final double[] vector2) { Metric.validate(vector1, vector2); double product = 0.0d; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java index 7a3e4a6a88..0af9cf7af2 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java @@ -82,7 +82,7 @@ public enum Metrics { *

* This metric calculates the inverted dot product of two vectors. It is not a true metric as the dot product can * be positive at which point the distance is negative. In order to make callers aware of this fact, this distance - * only allows {@link Metric#comparativeDistance(Double[], Double[])} to be called. + * only allows {@link Metric#comparativeDistance(double[], double[])} to be called. * * @see Dot Product * @see Metric.DotProductMetric diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java index 3ddae2ec74..88d10480ce 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java @@ -21,7 +21,6 @@ package com.apple.foundationdb.async.hnsw; import com.apple.foundationdb.tuple.Tuple; -import com.christianheina.langx.half4j.Half; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -60,7 +59,7 @@ public interface Node { * method calls. */ @Nonnull - N getSelfReference(@Nullable Vector vector); + N getSelfReference(@Nullable Vector vector); /** * Gets the list of neighboring nodes. diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java index bbe15f8464..814a8d9030 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java @@ -21,7 +21,6 @@ package com.apple.foundationdb.async.hnsw; import com.apple.foundationdb.tuple.Tuple; -import com.christianheina.langx.half4j.Half; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -53,7 +52,7 @@ public interface NodeFactory { * @return a new, non-null {@link Node} instance configured with the provided parameters. */ @Nonnull - Node create(@Nonnull Tuple primaryKey, @Nullable Vector vector, + Node create(@Nonnull Tuple primaryKey, @Nullable Vector vector, @Nonnull List neighbors); /** diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java index 5acc345d65..7b46f65f69 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java @@ -21,7 +21,6 @@ package com.apple.foundationdb.async.hnsw; import com.apple.foundationdb.tuple.Tuple; -import com.christianheina.langx.half4j.Half; import javax.annotation.Nonnull; import java.util.Objects; @@ -45,7 +44,7 @@ public class NodeReferenceWithDistance extends NodeReferenceWithVector { * @param vector the vector associated with the referenced node. Must not be null. * @param distance the calculated distance of this node reference to some query vector or similar. */ - public NodeReferenceWithDistance(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, + public NodeReferenceWithDistance(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, final double distance) { super(primaryKey, vector); this.distance = distance; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java index 837c88fb00..7b29bedb09 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java @@ -21,7 +21,6 @@ package com.apple.foundationdb.async.hnsw; import com.apple.foundationdb.tuple.Tuple; -import com.christianheina.langx.half4j.Half; import com.google.common.base.Objects; import javax.annotation.Nonnull; @@ -29,7 +28,7 @@ /** * Represents a reference to a node that includes an associated vector. *

- * This class extends {@link NodeReference} by adding a {@code Vector} field. It encapsulates both the primary key + * This class extends {@link NodeReference} by adding a {@link Vector} field. It encapsulates both the primary key * of a node and its corresponding vector data, which is particularly useful in vector-based search and * indexing scenarios. Primarily, node references are used to refer to {@link Node}s in a storage-independent way, i.e. * a node reference always contains the vector of a node while the node itself (depending on the storage adapter) @@ -37,7 +36,7 @@ */ public class NodeReferenceWithVector extends NodeReference { @Nonnull - private final Vector vector; + private final Vector vector; /** * Constructs a new {@code NodeReferenceWithVector} with a specified primary key and vector. @@ -49,7 +48,7 @@ public class NodeReferenceWithVector extends NodeReference { * @param primaryKey the primary key of the node, must not be null * @param vector the vector associated with the node, must not be null */ - public NodeReferenceWithVector(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector) { + public NodeReferenceWithVector(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector) { super(primaryKey); this.vector = vector; } @@ -63,7 +62,7 @@ public NodeReferenceWithVector(@Nonnull final Tuple primaryKey, @Nonnull final V * @return the vector of {@code Half} objects; will never be {@code null}. */ @Nonnull - public Vector getVector() { + public Vector getVector() { return vector; } @@ -72,7 +71,7 @@ public Vector getVector() { * @return a non-null {@code Vector} containing the elements of this vector. */ @Nonnull - public Vector getDoubleVector() { + public Vector.DoubleVector getDoubleVector() { return vector.toDoubleVector(); } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java index e4e72e593e..dedad69f21 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java @@ -247,46 +247,87 @@ static void writeEntryNodeReference(@Nonnull final Transaction transaction, * This method never returns {@code null}. */ @Nonnull - static Vector.HalfVector vectorFromTuple(final Tuple vectorTuple) { + static Vector vectorFromTuple(final Tuple vectorTuple) { return vectorFromBytes(vectorTuple.getBytes(0)); } + /** + * Creates a {@link Vector} from a byte array. + *

+ * This method interprets the input byte array by interpreting the first byte of the array as the precision shift. + * The byte array must have the proper size, i.e. the invariant {@code (bytesLength - 1) % precision == 0} must + * hold. + * @param vectorBytes the non-null byte array to convert. + * @return a new {@link Vector} instance created from the byte array. + * @throws com.google.common.base.VerifyException if the length of {@code vectorBytes} does not meet the invariant + * {@code (bytesLength - 1) % precision == 0} + */ + @Nonnull + static Vector vectorFromBytes(final byte[] vectorBytes) { + final int bytesLength = vectorBytes.length; + final int precisionShift = (int)vectorBytes[0]; + final int precision = 1 << precisionShift; + Verify.verify((bytesLength - 1) % precision == 0); + final int numDimensions = bytesLength >>> precisionShift; + switch (precisionShift) { + case 1: + return halfVectorFromBytes(vectorBytes, 1, numDimensions); + case 3: + return doubleVectorFromBytes(vectorBytes, 1, numDimensions); + default: + throw new RuntimeException("unable to serialize vector"); + } + } + /** * Creates a {@link Vector.HalfVector} from a byte array. *

* This method interprets the input byte array as a sequence of 16-bit half-precision floating-point numbers. Each * consecutive pair of bytes is converted into a {@code Half} value, which then becomes a component of the resulting - * vector. The byte array must have an even number of bytes. + * vector. * @param vectorBytes the non-null byte array to convert. The length of this array must be even, as each pair of * bytes represents a single {@link Half} component. * @return a new {@link Vector.HalfVector} instance created from the byte array. - * @throws com.google.common.base.VerifyException if the length of {@code vectorBytes} is odd, - * as verified by the internal check. */ @Nonnull - static Vector.HalfVector vectorFromBytes(final byte[] vectorBytes) { - final int bytesLength = vectorBytes.length; - Verify.verify(bytesLength % 2 == 0); - final int componentSize = bytesLength >>> 1; - final Half[] vectorHalfs = new Half[componentSize]; - for (int i = 0; i < componentSize; i ++) { - vectorHalfs[i] = Half.shortBitsToHalf(shortFromBytes(vectorBytes, i << 1)); + static Vector.HalfVector halfVectorFromBytes(@Nonnull final byte[] vectorBytes, final int offset, final int numDimensions) { + final Half[] vectorHalfs = new Half[numDimensions]; + for (int i = 0; i < numDimensions; i ++) { + vectorHalfs[i] = Half.shortBitsToHalf(shortFromBytes(vectorBytes, offset + (i << 1))); } return new Vector.HalfVector(vectorHalfs); } /** - * Converts a {@code Vector} into a {@code Tuple}. + * Creates a {@link Vector.DoubleVector} from a byte array. + *

+ * This method interprets the input byte array as a sequence of 64-bit double-precision floating-point numbers. Each + * run of eight bytes is converted into a {@code double} value, which then becomes a component of the resulting + * vector. + * @param vectorBytes the non-null byte array to convert. + * @return a new {@link Vector.DoubleVector} instance created from the byte array. + */ + @Nonnull + static Vector.DoubleVector doubleVectorFromBytes(@Nonnull final byte[] vectorBytes, int offset, final int numDimensions) { + final double[] vectorComponents = new double[numDimensions]; + for (int i = 0; i < numDimensions; i ++) { + vectorComponents[i] = Double.longBitsToDouble(longFromBytes(vectorBytes, offset + (i << 3))); + } + return new Vector.DoubleVector(vectorComponents); + } + + /** + * Converts a {@link Vector} into a {@link Tuple}. *

- * This method first serializes the given vector into a byte array using the {@link #bytesFromVector(Vector)} helper + * This method first serializes the given vector into a byte array using the {@link Vector#getRawData()} getter * method. It then creates a {@link Tuple} from the resulting byte array. * @param vector the vector of {@code Half} precision floating-point numbers to convert. Cannot be null. * @return a new, non-null {@code Tuple} instance representing the contents of the vector. */ @Nonnull @SuppressWarnings("PrimitiveArrayArgumentToVarargsMethod") - static Tuple tupleFromVector(final Vector vector) { - return Tuple.from(bytesFromVector(vector)); + static Tuple tupleFromVector(final Vector vector) { + return Tuple.from(vector.getRawData()); } /** @@ -295,17 +336,46 @@ static Tuple tupleFromVector(final Vector vector) { * This method iterates through the input vector, converting each {@link Half} element into its 16-bit short * representation. It then serializes this short into two bytes, placing them sequentially into the resulting byte * array. The final array's length will be {@code 2 * vector.size()}. - * @param vector the vector of {@link Half} precision numbers to convert. Must not be null. + * @param halfVector the vector of {@link Half} precision numbers to convert. Must not be null. + * @return a new byte array representing the serialized vector data. This array is never null. + */ + @Nonnull + static byte[] bytesFromVector(@Nonnull final Vector.HalfVector halfVector) { + final byte[] vectorBytes = new byte[1 + 2 * halfVector.size()]; + vectorBytes[0] = (byte)halfVector.precisionShift(); + for (int i = 0; i < halfVector.size(); i ++) { + final byte[] componentBytes = bytesFromShort(Half.halfToShortBits(Half.valueOf(halfVector.getComponent(i)))); + final int offset = 1 + (i << 1); + vectorBytes[offset] = componentBytes[0]; + vectorBytes[offset + 1] = componentBytes[1]; + } + return vectorBytes; + } + + /** + * Converts a {@link Vector} of {@code double} precision floating-point numbers into a byte array. + *

+ * This method iterates through the input vector, converting each {@code double} element into its 16-bit short + * representation. It then serializes this short into eight bytes, placing them sequentially into the resulting byte + * array. The final array's length will be {@code 8 * vector.size()}. + * @param doubleVector the vector of {@code double} precision numbers to convert. Must not be null. * @return a new byte array representing the serialized vector data. This array is never null. */ @Nonnull - static byte[] bytesFromVector(final Vector vector) { - final byte[] vectorBytes = new byte[2 * vector.size()]; - for (int i = 0; i < vector.size(); i ++) { - final byte[] componentBytes = bytesFromShort(Half.halfToShortBits(vector.getComponent(i))); - final int indexTimesTwo = i << 1; - vectorBytes[indexTimesTwo] = componentBytes[0]; - vectorBytes[indexTimesTwo + 1] = componentBytes[1]; + static byte[] bytesFromVector(final Vector.DoubleVector doubleVector) { + final byte[] vectorBytes = new byte[1 + 8 * doubleVector.size()]; + vectorBytes[0] = (byte)doubleVector.precisionShift(); + for (int i = 0; i < doubleVector.size(); i ++) { + final byte[] componentBytes = bytesFromLong(Double.doubleToLongBits(doubleVector.getComponent(i))); + final int offset = 1 + (i << 3); + vectorBytes[offset] = componentBytes[0]; + vectorBytes[offset + 1] = componentBytes[1]; + vectorBytes[offset + 2] = componentBytes[2]; + vectorBytes[offset + 3] = componentBytes[3]; + vectorBytes[offset + 4] = componentBytes[4]; + vectorBytes[offset + 5] = componentBytes[5]; + vectorBytes[offset + 6] = componentBytes[6]; + vectorBytes[offset + 7] = componentBytes[7]; } return vectorBytes; } @@ -317,12 +387,10 @@ static byte[] bytesFromVector(final Vector vector) { * byte at {@code offset} is treated as the most significant byte (MSB), and the byte at {@code offset + 1} is the * least significant byte (LSB). * @param bytes the source byte array from which to read the short. - * @param offset the starting index in the byte array. This must be an even number - * and ensure that {@code offset + 1} is a valid index. + * @param offset the starting index in the byte array. * @return the short value constructed from the two bytes. */ static short shortFromBytes(final byte[] bytes, final int offset) { - Verify.verify(offset % 2 == 0); int high = bytes[offset] & 0xFF; // Convert to unsigned int int low = bytes[offset + 1] & 0xFF; @@ -343,4 +411,45 @@ static byte[] bytesFromShort(final short value) { result[1] = (byte) (value & 0xFF); // low byte second return result; } + + /** + * Constructs a long from eight bytes in a byte array in big-endian order. + *

+ * This method reads two consecutive bytes from the {@code bytes} array, starting at the given {@code offset}. The + * byte array is treated to be in big-endian order. + * @param bytes the source byte array from which to read the short. + * @param offset the starting index in the byte array. + * @return the long value constructed from the two bytes. + */ + private static long longFromBytes(final byte[] bytes, final int offset) { + return ((bytes[offset ] & 0xFFL) << 56) | + ((bytes[offset + 1] & 0xFFL) << 48) | + ((bytes[offset + 2] & 0xFFL) << 40) | + ((bytes[offset + 3] & 0xFFL) << 32) | + ((bytes[offset + 4] & 0xFFL) << 24) | + ((bytes[offset + 5] & 0xFFL) << 16) | + ((bytes[offset + 6] & 0xFFL) << 8) | + ((bytes[offset + 7] & 0xFFL)); + } + + /** + * Converts a {@code short} value into a 2-element byte array. + *

+ * The conversion is performed in big-endian byte order. + * @param value the {@code long} value to be converted. + * @return a new 8-element byte array representing the short value in big-endian order. + */ + @Nonnull + private static byte[] bytesFromLong(final long value) { + byte[] result = new byte[8]; + result[0] = (byte)(value >>> 56); + result[1] = (byte)(value >>> 48); + result[2] = (byte)(value >>> 40); + result[3] = (byte)(value >>> 32); + result[4] = (byte)(value >>> 24); + result[5] = (byte)(value >>> 16); + result[6] = (byte)(value >>> 8); + result[7] = (byte)value; + return result; + } } diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java index 395159b629..a2ad52b2fe 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java @@ -46,14 +46,18 @@ * where {@code R} is a subtype of {@link Number}. It includes common operations and functionalities like size, * component access, equality checks, and conversions. Concrete implementations must provide specific logic for * data type conversions and raw data representation. - * @param the type of the numbers stored in this vector, which must extend {@link Number}. + */ -public abstract class Vector { +public abstract class Vector { @Nonnull - protected R[] data; + final double[] data; + @Nonnull protected Supplier hashCodeSupplier; + @Nonnull + private final Supplier toRawDataSupplier; + /** * Constructs a new Vector with the given data. *

@@ -61,12 +65,13 @@ public abstract class Vector { * defensive copy. Therefore, any subsequent modifications to the input array will be reflected in this vector's * state. The contract of this constructor is that callers do not modify {@code data} after calling the constructor. * We do not want to copy the array here for performance reasons. - * @param data the array of elements for this vector; must not be {@code null}. + * @param data the components of this vector * @throws NullPointerException if the provided {@code data} array is null. */ - public Vector(@Nonnull final R[] data) { + public Vector(@Nonnull final double[] data) { this.data = data; this.hashCodeSupplier = Suppliers.memoize(this::computeHashCode); + this.toRawDataSupplier = Suppliers.memoize(this::computeRawData); } /** @@ -88,8 +93,7 @@ public int size() { * @throws IndexOutOfBoundsException if the {@code dimension} is negative or * greater than or equal to the number of dimensions of this object. */ - @Nonnull - R getComponent(int dimension) { + double getComponent(int dimension) { return data[dimension]; } @@ -101,7 +105,7 @@ R getComponent(int dimension) { * @return the data array of type {@code R[]}, never {@code null}. */ @Nonnull - public R[] getData() { + public double[] getData() { return data; } @@ -113,7 +117,19 @@ public R[] getData() { * @return a non-null byte array containing the raw data. */ @Nonnull - public abstract byte[] getRawData(); + public byte[] getRawData() { + return toRawDataSupplier.get(); + } + + /** + * Computes the raw byte data representation of this object. + *

+ * This method provides a direct, unprocessed view of the object's underlying data. The format of the byte array is + * implementation-specific and should be documented by the concrete class that implements this method. + * @return a non-null byte array containing the raw data. + */ + @Nonnull + protected abstract byte[] computeRawData(); /** * Converts this object into a {@code Vector} of {@link Half} precision floating-point numbers. @@ -125,7 +141,7 @@ public R[] getData() { * object. */ @Nonnull - public abstract Vector toHalfVector(); + public abstract HalfVector toHalfVector(); /** * Converts this vector into a {@link DoubleVector}. @@ -140,10 +156,19 @@ public R[] getData() { public abstract DoubleVector toDoubleVector(); /** - * Returns the number of digits to the right of the decimal point. - * @return the precision, which is the number of digits to the right of the decimal point. + * Returns the number of bytes used for the serialization of this vector per component. + * @return the component size, i.e. the number of bytes used for the serialization of this vector per component. */ - public abstract int precision(); + public int precision() { + return (1 << precisionShift()); + } + + /** + * Returns the number of bits we need to shift {@code 1} to express {@link #precision()} used for the serialization + * of this vector per component. + * @return returns the number of bits we need to shift {@code 1} to express {@link #precision()} + */ + public abstract int precisionShift(); /** * Compares this vector to the specified object for equality. @@ -159,7 +184,7 @@ public boolean equals(final Object o) { if (!(o instanceof Vector)) { return false; } - final Vector vector = (Vector)o; + final Vector vector = (Vector)o; return Objects.deepEquals(data, vector.data); } @@ -208,11 +233,11 @@ public String toString(final int limitDimensions) { Verify.verify(limitDimensions > 0); if (limitDimensions < data.length) { return "[" + Arrays.stream(Arrays.copyOfRange(data, 0, limitDimensions)) - .map(String::valueOf) + .mapToObj(String::valueOf) .collect(Collectors.joining(",")) + ", ...]"; } else { return "[" + Arrays.stream(data) - .map(String::valueOf) + .mapToObj(String::valueOf) .collect(Collectors.joining(",")) + "]"; } } @@ -221,21 +246,22 @@ public String toString(final int limitDimensions) { * A vector class encoding a vector over half components. Conversion to {@link DoubleVector} is supported and * memoized. */ - public static class HalfVector extends Vector { + public static class HalfVector extends Vector { @Nonnull private final Supplier toDoubleVectorSupplier; - @Nonnull - private final Supplier toRawDataSupplier; - public HalfVector(@Nonnull final Half[] data) { + public HalfVector(@Nonnull final Half[] halfData) { + this(computeDoubleData(halfData)); + } + + public HalfVector(@Nonnull final double[] data) { super(data); this.toDoubleVectorSupplier = Suppliers.memoize(this::computeDoubleVector); - this.toRawDataSupplier = Suppliers.memoize(this::computeRawData); } @Nonnull @Override - public Vector toHalfVector() { + public HalfVector toHalfVector() { return this; } @@ -245,34 +271,29 @@ public DoubleVector toDoubleVector() { return toDoubleVectorSupplier.get(); } - @Override - public int precision() { - return 16; - } - @Nonnull public DoubleVector computeDoubleVector() { - Double[] result = new Double[data.length]; - for (int i = 0; i < data.length; i ++) { - result[i] = data[i].doubleValue(); - } - return new DoubleVector(result); + return new DoubleVector(data); } - @Nonnull @Override - public byte[] getRawData() { - return toRawDataSupplier.get(); + public int precisionShift() { + return 1; } @Nonnull - private byte[] computeRawData() { + @Override + protected byte[] computeRawData() { return StorageAdapter.bytesFromVector(this); } @Nonnull - public static HalfVector halfVectorFromBytes(@Nonnull final byte[] vectorBytes) { - return StorageAdapter.vectorFromBytes(vectorBytes); + private static double[] computeDoubleData(@Nonnull Half[] halfData) { + double[] result = new double[halfData.length]; + for (int i = 0; i < halfData.length; i ++) { + result[i] = halfData[i].doubleValue(); + } + return result; } } @@ -280,11 +301,15 @@ public static HalfVector halfVectorFromBytes(@Nonnull final byte[] vectorBytes) * A vector class encoding a vector over double components. Conversion to {@link HalfVector} is supported and * memoized. */ - public static class DoubleVector extends Vector { + public static class DoubleVector extends Vector { @Nonnull private final Supplier toHalfVectorSupplier; - public DoubleVector(@Nonnull final Double[] data) { + public DoubleVector(@Nonnull final Double[] doubleData) { + this(computeDoubleData(doubleData)); + } + + public DoubleVector(@Nonnull final double[] data) { super(data); this.toHalfVectorSupplier = Suppliers.memoize(this::computeHalfVector); } @@ -295,31 +320,35 @@ public HalfVector toHalfVector() { return toHalfVectorSupplier.get(); } - @Nonnull - public HalfVector computeHalfVector() { - Half[] result = new Half[data.length]; - for (int i = 0; i < data.length; i ++) { - result[i] = Half.valueOf(data[i]); - } - return new HalfVector(result); - } - @Nonnull @Override public DoubleVector toDoubleVector() { return this; } + @Nonnull + public HalfVector computeHalfVector() { + return new HalfVector(data); + } + @Override - public int precision() { - return 64; + public int precisionShift() { + return 3; } @Nonnull @Override - public byte[] getRawData() { - // TODO - throw new UnsupportedOperationException("not implemented yet"); + protected byte[] computeRawData() { + return StorageAdapter.bytesFromVector(this); + } + + @Nonnull + private static double[] computeDoubleData(@Nonnull Double[] doubleData) { + double[] result = new double[doubleData.length]; + for (int i = 0; i < doubleData.length; i ++) { + result[i] = doubleData[i]; + } + return result; } } @@ -329,16 +358,15 @@ public byte[] getRawData() { * This static utility method provides a convenient way to compute the distance by handling the conversion of * generic {@code Vector} objects to primitive {@code double} arrays. The actual distance computation is then * delegated to the provided {@link Metric} instance. - * @param the type of the numbers in the vectors, which must extend {@link Number}. * @param metric the {@link Metric} to use for the distance calculation. * @param vector1 the first vector. * @param vector2 the second vector. * @return the calculated distance between the two vectors as a {@code double}. */ - public static double distance(@Nonnull Metric metric, - @Nonnull final Vector vector1, - @Nonnull final Vector vector2) { - return metric.distance(vector1.toDoubleVector().getData(), vector2.toDoubleVector().getData()); + public static double distance(@Nonnull Metric metric, + @Nonnull final Vector vector1, + @Nonnull final Vector vector2) { + return metric.distance(vector1.getData(), vector2.getData()); } /** @@ -347,36 +375,16 @@ public static double distance(@Nonnull Metric metric, * This utility method converts the input vectors, which can contain any {@link Number} type, into primitive double * arrays. It then delegates the actual distance computation to the {@code comparativeDistance} method of the * provided {@link Metric} object. - * @param the type of the numbers in the vectors, which must extend {@link Number}. * @param metric the {@link Metric} to use for the distance calculation. Must not be null. * @param vector1 the first vector for the comparison. Must not be null. * @param vector2 the second vector for the comparison. Must not be null. * @return the calculated comparative distance as a {@code double}. * @throws NullPointerException if {@code metric}, {@code vector1}, or {@code vector2} is null. */ - static double comparativeDistance(@Nonnull Metric metric, - @Nonnull final Vector vector1, - @Nonnull final Vector vector2) { - return metric.comparativeDistance(vector1.toDoubleVector().getData(), vector2.toDoubleVector().getData()); - } - - /** - * Creates a {@code Vector} instance from its byte representation. - *

- * This method deserializes a byte array into a vector object. The precision parameter is crucial for correctly - * interpreting the byte data. Currently, this implementation only supports 16-bit precision, which corresponds to a - * {@code HalfVector}. - * @param bytes the non-null byte array representing the vector. - * @param precision the precision of the vector's elements in bits (e.g., 16 for half-precision floats). - * @return a new {@code Vector} instance created from the byte array. - * @throws UnsupportedOperationException if the specified {@code precision} is not yet supported. - */ - public static Vector fromBytes(@Nonnull final byte[] bytes, int precision) { - if (precision == 16) { - return HalfVector.halfVectorFromBytes(bytes); - } - // TODO - throw new UnsupportedOperationException("not implemented yet"); + static double comparativeDistance(@Nonnull Metric metric, + @Nonnull final Vector vector1, + @Nonnull final Vector vector2) { + return metric.comparativeDistance(vector1.getData(), vector2.getData()); } /** diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java index ae31057195..ffa0012181 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java @@ -29,7 +29,6 @@ import com.apple.foundationdb.test.TestSubspaceExtension; import com.apple.foundationdb.tuple.Tuple; import com.apple.test.Tags; -import com.christianheina.langx.half4j.Half; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -98,7 +97,7 @@ public void setUpDb() { db = dbExtension.getDatabase(); } - static Stream randomSeeds() { + private static Stream randomSeeds() { return LongStream.generate(() -> new Random().nextLong()) .limit(5) .boxed(); @@ -106,7 +105,7 @@ static Stream randomSeeds() { @ParameterizedTest(name = "seed={0}") @MethodSource("randomSeeds") - public void testCompactSerialization(final Long seed) { + public void testCompactSerialization(final long seed) { final Random random = new Random(seed); final CompactStorageAdapter storageAdapter = new CompactStorageAdapter(HNSW.DEFAULT_CONFIG, CompactNode.factory(), rtSubspace.getSubspace(), @@ -145,7 +144,7 @@ public void testCompactSerialization(final Long seed) { @ParameterizedTest(name = "seed={0}") @MethodSource("randomSeeds") - public void testInliningSerialization(final Long seed) { + public void testInliningSerialization(final long seed) { final Random random = new Random(seed); final InliningStorageAdapter storageAdapter = new InliningStorageAdapter(HNSW.DEFAULT_CONFIG, InliningNode.factory(), rtSubspace.getSubspace(), @@ -211,7 +210,7 @@ public void testBasicInsert(final long seed, final boolean useInlining, final bo OnWriteListener.NOOP, onReadListener); final int k = 10; - final HalfVector queryVector = createRandomVector(random, dimensions); + final HalfVector queryVector = VectorTest.createRandomHalfVector(random, dimensions); final TreeSet nodesOrderedByDistance = new TreeSet<>(Comparator.comparing(NodeReferenceWithDistance::getDistance)); @@ -219,7 +218,7 @@ public void testBasicInsert(final long seed, final boolean useInlining, final bo i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, tr -> { final var primaryKey = createNextPrimaryKey(nextNodeIdAtomic); - final HalfVector dataVector = createRandomVector(random, dimensions); + final HalfVector dataVector = VectorTest.createRandomHalfVector(random, dimensions); final double distance = Vector.comparativeDistance(metric, dataVector, queryVector); final NodeReferenceWithDistance nodeReferenceWithDistance = new NodeReferenceWithDistance(primaryKey, dataVector, distance); @@ -424,9 +423,9 @@ public void testSIFTInsertSmallUsingBatchAPI() throws Exception { public void testManyRandomVectors() { final Random random = new Random(); for (long l = 0L; l < 3000000; l ++) { - final HalfVector randomVector = createRandomVector(random, 768); + final HalfVector randomVector = VectorTest.createRandomHalfVector(random, 768); final Tuple vectorTuple = StorageAdapter.tupleFromVector(randomVector); - final Vector roundTripVector = StorageAdapter.vectorFromTuple(vectorTuple); + final Vector roundTripVector = StorageAdapter.vectorFromTuple(vectorTuple); Vector.comparativeDistance(Metrics.EUCLIDEAN_METRIC.getMetric(), randomVector, roundTripVector); Assertions.assertEquals(randomVector, roundTripVector); } @@ -453,7 +452,7 @@ private Node createRandomCompactNode(@Nonnull final Random random neighborsBuilder.add(createRandomNodeReference(random)); } - return nodeFactory.create(primaryKey, createRandomVector(random, dimensionality), neighborsBuilder.build()); + return nodeFactory.create(primaryKey, VectorTest.createRandomHalfVector(random, dimensionality), neighborsBuilder.build()); } @Nonnull @@ -467,7 +466,7 @@ private Node createRandomInliningNode(@Nonnull final Ra neighborsBuilder.add(createRandomNodeReferenceWithVector(random, dimensionality)); } - return nodeFactory.create(primaryKey, createRandomVector(random, dimensionality), neighborsBuilder.build()); + return nodeFactory.create(primaryKey, VectorTest.createRandomHalfVector(random, dimensionality), neighborsBuilder.build()); } @Nonnull @@ -477,7 +476,7 @@ private NodeReference createRandomNodeReference(@Nonnull final Random random) { @Nonnull private NodeReferenceWithVector createRandomNodeReferenceWithVector(@Nonnull final Random random, final int dimensionality) { - return new NodeReferenceWithVector(createRandomPrimaryKey(random), createRandomVector(random, dimensionality)); + return new NodeReferenceWithVector(createRandomPrimaryKey(random), VectorTest.createRandomHalfVector(random, dimensionality)); } @Nonnull @@ -490,16 +489,6 @@ private static Tuple createNextPrimaryKey(@Nonnull final AtomicLong nextIdAtomic return Tuple.from(nextIdAtomic.getAndIncrement()); } - @Nonnull - private HalfVector createRandomVector(@Nonnull final Random random, final int dimensionality) { - final Half[] components = new Half[dimensionality]; - for (int d = 0; d < dimensionality; d ++) { - // don't ask - components[d] = HNSWHelpers.halfValueOf(random.nextDouble()); - } - return new HalfVector(components); - } - private static class TestOnReadListener implements OnReadListener { final Map nodeCountByLayer; final Map sumMByLayer; diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java index 610c47c226..78df74a7e4 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/MetricTest.java @@ -41,8 +41,8 @@ public void setUp() { @Test public void manhattanMetricDistanceWithIdenticalVectorsShouldReturnZeroTest() { // Arrange - Double[] vector1 = {1.0, 2.5, -3.0}; - Double[] vector2 = {1.0, 2.5, -3.0}; + double[] vector1 = {1.0, 2.5, -3.0}; + double[] vector2 = {1.0, 2.5, -3.0}; double expectedDistance = 0.0; // Act @@ -55,8 +55,8 @@ public void manhattanMetricDistanceWithIdenticalVectorsShouldReturnZeroTest() { @Test public void manhattanMetricDistanceWithPositiveValueVectorsShouldReturnCorrectDistanceTest() { // Arrange - Double[] vector1 = {1.0, 2.0, 3.0}; - Double[] vector2 = {4.0, 5.0, 6.0}; + double[] vector1 = {1.0, 2.0, 3.0}; + double[] vector2 = {4.0, 5.0, 6.0}; double expectedDistance = 9.0; // |1-4| + |2-5| + |3-6| = 3 + 3 + 3 // Act @@ -69,8 +69,8 @@ public void manhattanMetricDistanceWithPositiveValueVectorsShouldReturnCorrectDi @Test public void euclideanMetricDistanceWithIdenticalVectorsShouldReturnZeroTest() { // Arrange - Double[] vector1 = {1.0, 2.5, -3.0}; - Double[] vector2 = {1.0, 2.5, -3.0}; + double[] vector1 = {1.0, 2.5, -3.0}; + double[] vector2 = {1.0, 2.5, -3.0}; double expectedDistance = 0.0; // Act @@ -83,8 +83,8 @@ public void euclideanMetricDistanceWithIdenticalVectorsShouldReturnZeroTest() { @Test public void euclideanMetricDistanceWithDifferentPositiveVectorsShouldReturnCorrectDistanceTest() { // Arrange - Double[] vector1 = {1.0, 2.0}; - Double[] vector2 = {4.0, 6.0}; + double[] vector1 = {1.0, 2.0}; + double[] vector2 = {4.0, 6.0}; double expectedDistance = 5.0; // sqrt((1-4)^2 + (2-6)^2) = sqrt(9 + 16) = 5.0 // Act @@ -97,8 +97,8 @@ public void euclideanMetricDistanceWithDifferentPositiveVectorsShouldReturnCorre @Test public void euclideanSquareMetricDistanceWithIdenticalVectorsShouldReturnZeroTest() { // Arrange - Double[] vector1 = {1.0, 2.5, -3.0}; - Double[] vector2 = {1.0, 2.5, -3.0}; + double[] vector1 = {1.0, 2.5, -3.0}; + double[] vector2 = {1.0, 2.5, -3.0}; double expectedDistance = 0.0; // Act @@ -111,8 +111,8 @@ public void euclideanSquareMetricDistanceWithIdenticalVectorsShouldReturnZeroTes @Test public void euclideanSquareMetricDistanceWithDifferentPositiveVectorsShouldReturnCorrectDistanceTest() { // Arrange - Double[] vector1 = {1.0, 2.0}; - Double[] vector2 = {4.0, 6.0}; + double[] vector1 = {1.0, 2.0}; + double[] vector2 = {4.0, 6.0}; double expectedDistance = 25.0; // (1-4)^2 + (2-6)^2 = 9 + 16 = 25.0 // Act @@ -125,8 +125,8 @@ public void euclideanSquareMetricDistanceWithDifferentPositiveVectorsShouldRetur @Test public void cosineMetricDistanceWithIdenticalVectorsReturnsZeroTest() { // Arrange - Double[] vector1 = {5.0, 3.0, -2.0}; - Double[] vector2 = {5.0, 3.0, -2.0}; + double[] vector1 = {5.0, 3.0, -2.0}; + double[] vector2 = {5.0, 3.0, -2.0}; double expectedDistance = 0.0; // Act @@ -139,8 +139,8 @@ public void cosineMetricDistanceWithIdenticalVectorsReturnsZeroTest() { @Test public void cosineMetricDistanceWithOrthogonalVectorsReturnsOneTest() { // Arrange - Double[] vector1 = {1.0, 0.0}; - Double[] vector2 = {0.0, 1.0}; + double[] vector1 = {1.0, 0.0}; + double[] vector2 = {0.0, 1.0}; double expectedDistance = 1.0; // Act @@ -152,8 +152,8 @@ public void cosineMetricDistanceWithOrthogonalVectorsReturnsOneTest() { @Test public void dotProductMetricComparativeDistanceWithPositiveVectorsTest() { - Double[] vector1 = {1.0, 2.0, 3.0}; - Double[] vector2 = {4.0, 5.0, 6.0}; + double[] vector1 = {1.0, 2.0, 3.0}; + double[] vector2 = {4.0, 5.0, 6.0}; double expected = -32.0; double actual = dotProductMetric.comparativeDistance(vector1, vector2); @@ -163,8 +163,8 @@ public void dotProductMetricComparativeDistanceWithPositiveVectorsTest() { @Test public void dotProductMetricComparativeDistanceWithOrthogonalVectorsReturnsZeroTest() { - Double[] vector1 = {1.0, 0.0}; - Double[] vector2 = {0.0, 1.0}; + double[] vector1 = {1.0, 0.0}; + double[] vector2 = {0.0, 1.0}; double expected = -0.0; double actual = dotProductMetric.comparativeDistance(vector1, vector2); diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/VectorTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/VectorTest.java new file mode 100644 index 0000000000..fa7f27db21 --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/VectorTest.java @@ -0,0 +1,79 @@ +/* + * VectorTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.christianheina.langx.half4j.Half; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.annotation.Nonnull; +import java.util.Random; +import java.util.stream.LongStream; +import java.util.stream.Stream; + +public class VectorTest { + private static Stream randomSeeds() { + return LongStream.generate(() -> new Random().nextLong()) + .limit(5) + .boxed(); + } + + @ParameterizedTest(name = "seed={0}") + @MethodSource("randomSeeds") + void testSerializationDeserializationHalfVector(final long seed) { + final Random random = new Random(seed); + final Vector.HalfVector randomVector = createRandomHalfVector(random, 128); + final Vector deserializedVector = StorageAdapter.vectorFromBytes(randomVector.getRawData()); + Assertions.assertThat(deserializedVector).isInstanceOf(Vector.HalfVector.class); + Assertions.assertThat(deserializedVector).isEqualTo(randomVector); + } + + @ParameterizedTest(name = "seed={0}") + @MethodSource("randomSeeds") + void testSerializationDeserializationDoubleVector(final long seed) { + final Random random = new Random(seed); + final Vector.DoubleVector randomVector = createRandomDoubleVector(random, 128); + final Vector deserializedVector = StorageAdapter.vectorFromBytes(randomVector.getRawData()); + Assertions.assertThat(deserializedVector).isInstanceOf(Vector.DoubleVector.class); + Assertions.assertThat(deserializedVector).isEqualTo(randomVector); + } + + @Nonnull + static Vector.HalfVector createRandomHalfVector(@Nonnull final Random random, final int dimensionality) { + final Half[] components = new Half[dimensionality]; + for (int d = 0; d < dimensionality; d ++) { + // don't ask + components[d] = HNSWHelpers.halfValueOf(random.nextDouble()); + } + return new Vector.HalfVector(components); + } + + @Nonnull + static Vector.DoubleVector createRandomDoubleVector(@Nonnull final Random random, final int dimensionality) { + final double[] components = new double[dimensionality]; + for (int d = 0; d < dimensionality; d ++) { + // don't ask + components[d] = random.nextDouble(); + } + return new Vector.DoubleVector(components); + } +} From dc370afa6f664b89ccdeac5831cd430651abc45d Mon Sep 17 00:00:00 2001 From: Normen Seemann Date: Tue, 23 Sep 2025 16:34:37 +0200 Subject: [PATCH 10/10] removed efSearch from HNSW --- .../apple/foundationdb/async/hnsw/HNSW.java | 31 ++++--------------- .../foundationdb/async/hnsw/HNSWTest.java | 8 +++++ 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java index 47ddf7117a..a1875d4988 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java @@ -91,7 +91,6 @@ public class HNSW { public static final int DEFAULT_M = 16; public static final int DEFAULT_M_MAX = DEFAULT_M; public static final int DEFAULT_M_MAX_0 = 2 * DEFAULT_M; - public static final int DEFAULT_EF_SEARCH = 100; public static final int DEFAULT_EF_CONSTRUCTION = 200; public static final boolean DEFAULT_EXTEND_CANDIDATES = false; public static final boolean DEFAULT_KEEP_PRUNED_CONNECTIONS = false; @@ -123,7 +122,6 @@ public static class Config { private final int m; private final int mMax; private final int mMax0; - private final int efSearch; private final int efConstruction; private final boolean extendCandidates; private final boolean keepPrunedConnections; @@ -135,14 +133,13 @@ protected Config() { this.m = DEFAULT_M; this.mMax = DEFAULT_M_MAX; this.mMax0 = DEFAULT_M_MAX_0; - this.efSearch = DEFAULT_EF_SEARCH; this.efConstruction = DEFAULT_EF_CONSTRUCTION; this.extendCandidates = DEFAULT_EXTEND_CANDIDATES; this.keepPrunedConnections = DEFAULT_KEEP_PRUNED_CONNECTIONS; } protected Config(@Nonnull final Random random, @Nonnull final Metric metric, final boolean useInlining, - final int m, final int mMax, final int mMax0, final int efSearch, final int efConstruction, + final int m, final int mMax, final int mMax0, final int efConstruction, final boolean extendCandidates, final boolean keepPrunedConnections) { this.random = random; this.metric = metric; @@ -150,7 +147,6 @@ protected Config(@Nonnull final Random random, @Nonnull final Metric metric, fin this.m = m; this.mMax = mMax; this.mMax0 = mMax0; - this.efSearch = efSearch; this.efConstruction = efConstruction; this.extendCandidates = extendCandidates; this.keepPrunedConnections = keepPrunedConnections; @@ -182,10 +178,6 @@ public int getMMax0() { return mMax0; } - public int getEfSearch() { - return efSearch; - } - public int getEfConstruction() { return efConstruction; } @@ -201,15 +193,15 @@ public boolean isKeepPrunedConnections() { @Nonnull public ConfigBuilder toBuilder() { return new ConfigBuilder(getRandom(), getMetric(), isUseInlining(), getM(), getMMax(), getMMax0(), - getEfSearch(), getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections()); + getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections()); } @Override @Nonnull public String toString() { return "Config[metric=" + getMetric() + "isUseInlining" + isUseInlining() + "M=" + getM() + - " , MMax=" + getMMax() + " , MMax0=" + getMMax0() + ", efSearch=" + getEfSearch() + - ", efConstruction=" + getEfConstruction() + ", isExtendCandidates=" + isExtendCandidates() + + " , MMax=" + getMMax() + " , MMax0=" + getMMax0() + ", efConstruction=" + getEfConstruction() + + ", isExtendCandidates=" + isExtendCandidates() + ", isKeepPrunedConnections=" + isKeepPrunedConnections() + "]"; } } @@ -230,7 +222,6 @@ public static class ConfigBuilder { private int m = DEFAULT_M; private int mMax = DEFAULT_M_MAX; private int mMax0 = DEFAULT_M_MAX_0; - private int efSearch = DEFAULT_EF_SEARCH; private int efConstruction = DEFAULT_EF_CONSTRUCTION; private boolean extendCandidates = DEFAULT_EXTEND_CANDIDATES; private boolean keepPrunedConnections = DEFAULT_KEEP_PRUNED_CONNECTIONS; @@ -239,7 +230,7 @@ public ConfigBuilder() { } public ConfigBuilder(@Nonnull final Random random, @Nonnull final Metric metric, final boolean useInlining, - final int m, final int mMax, final int mMax0, final int efSearch, final int efConstruction, + final int m, final int mMax, final int mMax0, final int efConstruction, final boolean extendCandidates, final boolean keepPrunedConnections) { this.random = random; this.metric = metric; @@ -247,7 +238,6 @@ public ConfigBuilder(@Nonnull final Random random, @Nonnull final Metric metric, this.m = m; this.mMax = mMax; this.mMax0 = mMax0; - this.efSearch = efSearch; this.efConstruction = efConstruction; this.extendCandidates = extendCandidates; this.keepPrunedConnections = keepPrunedConnections; @@ -314,15 +304,6 @@ public ConfigBuilder setMMax0(final int mMax0) { return this; } - public int getEfSearch() { - return efSearch; - } - - public ConfigBuilder setEfSearch(final int efSearch) { - this.efSearch = efSearch; - return this; - } - public int getEfConstruction() { return efConstruction; } @@ -351,7 +332,7 @@ public ConfigBuilder setKeepPrunedConnections(final boolean keepPrunedConnection } public Config build() { - return new Config(getRandom(), getMetric(), isUseInlining(), getM(), getMMax(), getMMax0(), getEfSearch(), + return new Config(getRandom(), getMetric(), isUseInlining(), getM(), getMMax(), getMMax0(), getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections()); } } diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java index ffa0012181..6f9515d8e9 100644 --- a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java @@ -68,6 +68,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; +import java.util.stream.Collectors; import java.util.stream.LongStream; import java.util.stream.Stream; @@ -255,6 +256,13 @@ public void testBasicInsert(final long seed, final boolean useInlining, final bo TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), onReadListener.getNodeCountByLayer(), onReadListener.getBytesReadByLayer(), String.format(Locale.ROOT, "%.2f", recall * 100.0d)); + + final Set usedIds = + LongStream.range(0, 1000) + .boxed() + .collect(Collectors.toSet()); + + hnsw.scanLayer(db, 0, 100, node -> Assertions.assertTrue(usedIds.remove(node.getPrimaryKey().getLong(0)))); } private int basicInsertBatch(final HNSW hnsw, final int batchSize,