Skip to content

Commit 3aa201e

Browse files
authored
vector index maintainer (#3738)
1 parent 1cb0559 commit 3aa201e

39 files changed

+4313
-241
lines changed

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ public CompactStorageAdapter(@Nonnull final Config config,
8686
* @param layer the layer of the node to fetch
8787
* @param primaryKey the primary key of the node to fetch
8888
*
89-
* @return a future that will complete with the fetched {@link AbstractNode}
89+
* @return a future that will complete with the fetched {@link AbstractNode} or {@code null} if the node cannot
90+
* be fetched
9091
*
9192
* @throws IllegalStateException if the node cannot be found in the database for the given key
9293
*/
@@ -101,7 +102,7 @@ protected CompletableFuture<AbstractNode<NodeReference>> fetchNodeInternal(@Nonn
101102
return readTransaction.get(keyBytes)
102103
.thenApply(valueBytes -> {
103104
if (valueBytes == null) {
104-
throw new IllegalStateException("cannot fetch node");
105+
return null;
105106
}
106107
return nodeFromRaw(storageTransform, layer, primaryKey, keyBytes, valueBytes);
107108
});

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Config.java

Lines changed: 148 additions & 113 deletions
Large diffs are not rendered by default.

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@
5555
import java.util.Objects;
5656
import java.util.PriorityQueue;
5757
import java.util.Queue;
58-
import java.util.Random;
5958
import java.util.Set;
59+
import java.util.SplittableRandom;
6060
import java.util.concurrent.CompletableFuture;
6161
import java.util.concurrent.Executor;
6262
import java.util.concurrent.atomic.AtomicReference;
@@ -89,8 +89,6 @@ public class HNSW {
8989
@Nonnull
9090
private static final Logger logger = LoggerFactory.getLogger(HNSW.class);
9191

92-
@Nonnull
93-
private final Random random;
9492
@Nonnull
9593
private final Subspace subspace;
9694
@Nonnull
@@ -141,7 +139,6 @@ public HNSW(@Nonnull final Subspace subspace,
141139
@Nonnull final Config config,
142140
@Nonnull final OnWriteListener onWriteListener,
143141
@Nonnull final OnReadListener onReadListener) {
144-
this.random = new Random(config.getRandomSeed());
145142
this.subspace = subspace;
146143
this.executor = executor;
147144
this.config = config;
@@ -581,7 +578,7 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) {
581578
return onReadListener.onAsyncRead(
582579
storageAdapter.fetchNode(readTransaction, storageTransform, layer,
583580
nodeReference.getPrimaryKey()))
584-
.thenApply(node -> biMapFunction.apply(nodeReference, node));
581+
.thenApply(node -> biMapFunction.apply(nodeReference, Objects.requireNonNull(node)));
585582
}
586583

587584
/**
@@ -748,19 +745,35 @@ private Quantizer quantizer(@Nullable final AccessInfo accessInfo) {
748745
@Nonnull
749746
public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @Nonnull final Tuple newPrimaryKey,
750747
@Nonnull final RealVector newVector) {
751-
final int insertionLayer = insertionLayer();
748+
final SplittableRandom random = random(newPrimaryKey);
749+
final int insertionLayer = insertionLayer(random);
752750
if (logger.isTraceEnabled()) {
753751
logger.trace("new node with key={} selected to be inserted into layer={}", newPrimaryKey, insertionLayer);
754752
}
755753

756754
return StorageAdapter.fetchAccessInfo(getConfig(), transaction, getSubspace(), getOnReadListener())
757-
.thenCompose(accessInfo -> {
758-
final AccessInfo currentAccessInfo;
755+
.thenCombine(exists(transaction, newPrimaryKey),
756+
(accessInfo, nodeAlreadyExists) -> {
757+
if (nodeAlreadyExists) {
758+
if (logger.isDebugEnabled()) {
759+
logger.debug("new record already exists in HNSW with key={} on layer={}",
760+
newPrimaryKey, insertionLayer);
761+
}
762+
}
763+
return new AccessInfoAndNodeExistence(accessInfo, nodeAlreadyExists);
764+
})
765+
.thenCompose(accessInfoAndNodeExistence -> {
766+
if (accessInfoAndNodeExistence.isNodeExists()) {
767+
return AsyncUtil.DONE;
768+
}
769+
770+
final AccessInfo accessInfo = accessInfoAndNodeExistence.getAccessInfo();
759771
final AffineOperator storageTransform = storageTransform(accessInfo);
760772
final Transformed<RealVector> transformedNewVector = storageTransform.transform(newVector);
761773
final Quantizer quantizer = quantizer(accessInfo);
762774
final Estimator estimator = quantizer.estimator();
763775

776+
final AccessInfo currentAccessInfo;
764777
if (accessInfo == null) {
765778
// this is the first node
766779
writeLonelyNodes(quantizer, transaction, newPrimaryKey, transformedNewVector,
@@ -817,10 +830,24 @@ public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @N
817830
insertIntoLayers(transaction, storageTransform, quantizer, newPrimaryKey,
818831
transformedNewVector, nodeReference, lMax, insertionLayer))
819832
.thenCompose(ignored ->
820-
addToStatsIfNecessary(transaction, currentAccessInfo, transformedNewVector));
833+
addToStatsIfNecessary(random.split(), transaction, currentAccessInfo, transformedNewVector));
821834
}).thenCompose(ignored -> AsyncUtil.DONE);
822835
}
823836

837+
@Nonnull
838+
@VisibleForTesting
839+
CompletableFuture<Boolean> exists(@Nonnull final ReadTransaction readTransaction,
840+
@Nonnull final Tuple primaryKey) {
841+
final StorageAdapter<? extends NodeReference> storageAdapter = getStorageAdapterForLayer(0);
842+
843+
//
844+
// Call fetchNode() to check for the node's existence; we are handing in the identity operator, since we don't
845+
// care about the vector itself at all.
846+
//
847+
return storageAdapter.fetchNode(readTransaction, AffineOperator.identity(), 0, primaryKey)
848+
.thenApply(Objects::nonNull);
849+
}
850+
824851
/**
825852
* Method to keep stats if necessary. Stats need to be kept and maintained when the client would like to use
826853
* e.g. RaBitQ as RaBitQ needs a stable somewhat correct centroid in order to function properly.
@@ -832,21 +859,23 @@ public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @N
832859
* in order to finally compute the centroid if {@link Config#getStatsThreshold()} number of vectors have been
833860
* sampled and aggregated. That centroid is then used to update the access info.
834861
*
862+
* @param random a random to use
835863
* @param transaction the transaction
836864
* @param currentAccessInfo this current access info that was fetched as part of an insert
837865
* @param transformedNewVector the new vector (in the transformed coordinate system) that may be added
838866
* @return a future that returns {@code null} when completed
839867
*/
840868
@Nonnull
841-
private CompletableFuture<Void> addToStatsIfNecessary(@Nonnull final Transaction transaction,
869+
private CompletableFuture<Void> addToStatsIfNecessary(@Nonnull final SplittableRandom random,
870+
@Nonnull final Transaction transaction,
842871
@Nonnull final AccessInfo currentAccessInfo,
843872
@Nonnull final Transformed<RealVector> transformedNewVector) {
844873
if (getConfig().isUseRaBitQ() && !currentAccessInfo.canUseRaBitQ()) {
845-
if (shouldSampleVector()) {
874+
if (shouldSampleVector(random)) {
846875
StorageAdapter.appendSampledVector(transaction, getSubspace(),
847876
1, transformedNewVector, onWriteListener);
848877
}
849-
if (shouldMaintainStats()) {
878+
if (shouldMaintainStats(random)) {
850879
return StorageAdapter.consumeSampledVectors(transaction, getSubspace(),
851880
50, onReadListener)
852881
.thenApply(sampledVectors -> {
@@ -1512,6 +1541,15 @@ private StorageAdapter<? extends NodeReference> getStorageAdapterForLayer(final
15121541
getOnReadListener());
15131542
}
15141543

1544+
@Nonnull
1545+
private SplittableRandom random(@Nonnull final Tuple primaryKey) {
1546+
if (config.isDeterministicSeeding()) {
1547+
return new SplittableRandom(primaryKey.hashCode());
1548+
} else {
1549+
return new SplittableRandom(System.nanoTime());
1550+
}
1551+
}
1552+
15151553
/**
15161554
* Calculates a random layer for a new element to be inserted.
15171555
* <p>
@@ -1521,20 +1559,20 @@ private StorageAdapter<? extends NodeReference> getStorageAdapterForLayer(final
15211559
* is {@code floor(-ln(u) * lambda)}, where {@code u} is a uniform random
15221560
* number and {@code lambda} is a normalization factor derived from a system
15231561
* configuration parameter {@code M}.
1524-
*
1562+
* @param random a random to use
15251563
* @return a non-negative integer representing the randomly selected layer.
15261564
*/
1527-
private int insertionLayer() {
1565+
private int insertionLayer(@Nonnull final SplittableRandom random) {
15281566
double lambda = 1.0 / Math.log(getConfig().getM());
15291567
double u = 1.0 - random.nextDouble(); // Avoid log(0)
15301568
return (int) Math.floor(-Math.log(u) * lambda);
15311569
}
15321570

1533-
private boolean shouldSampleVector() {
1571+
private boolean shouldSampleVector(@Nonnull final SplittableRandom random) {
15341572
return random.nextDouble() < getConfig().getSampleVectorStatsProbability();
15351573
}
15361574

1537-
private boolean shouldMaintainStats() {
1575+
private boolean shouldMaintainStats(@Nonnull final SplittableRandom random) {
15381576
return random.nextDouble() < getConfig().getMaintainStatsProbability();
15391577
}
15401578

@@ -1546,4 +1584,24 @@ private static <T> List<T> drain(@Nonnull Queue<T> queue) {
15461584
}
15471585
return resultBuilder.build();
15481586
}
1587+
1588+
private static class AccessInfoAndNodeExistence {
1589+
@Nullable
1590+
private final AccessInfo accessInfo;
1591+
private final boolean nodeExists;
1592+
1593+
public AccessInfoAndNodeExistence(@Nullable final AccessInfo accessInfo, final boolean nodeExists) {
1594+
this.accessInfo = accessInfo;
1595+
this.nodeExists = nodeExists;
1596+
}
1597+
1598+
@Nullable
1599+
public AccessInfo getAccessInfo() {
1600+
return accessInfo;
1601+
}
1602+
1603+
public boolean isNodeExists() {
1604+
return nodeExists;
1605+
}
1606+
}
15491607
}

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
import com.apple.foundationdb.async.AsyncUtil;
3030
import com.apple.foundationdb.linear.AffineOperator;
3131
import com.apple.foundationdb.linear.DoubleRealVector;
32-
import com.apple.foundationdb.linear.FloatRealVector;
33-
import com.apple.foundationdb.linear.HalfRealVector;
3432
import com.apple.foundationdb.linear.Quantizer;
3533
import com.apple.foundationdb.linear.RealVector;
3634
import com.apple.foundationdb.linear.Transformed;
@@ -59,7 +57,6 @@
5957
* @param <N> the type of {@link NodeReference} this storage adapter manages
6058
*/
6159
interface StorageAdapter<N extends NodeReference> {
62-
ImmutableList<VectorType> VECTOR_TYPES = ImmutableList.copyOf(VectorType.values());
6360

6461
/**
6562
* Subspace for data.
@@ -199,29 +196,24 @@ static RealVector vectorFromTuple(@Nonnull final Config config, @Nonnull final T
199196
/**
200197
* Creates a {@link RealVector} from a byte array.
201198
* <p>
202-
* This method interprets the input byte array by interpreting the first byte of the array as the precision shift.
203-
* The byte array must have the proper size, i.e. the invariant {@code (bytesLength - 1) % precision == 0} must
204-
* hold.
199+
* This method interprets the input byte array by interpreting the first byte of the array.
200+
* It the delegates to {@link RealVector#fromBytes(VectorType, byte[])}.
205201
* @param config an HNSW config
206202
* @param vectorBytes the non-null byte array to convert.
207203
* @return a new {@link RealVector} instance created from the byte array.
208-
* @throws com.google.common.base.VerifyException if the length of {@code vectorBytes} does not meet the invariant
209-
* {@code (bytesLength - 1) % precision == 0}
210204
*/
211205
@Nonnull
212206
static RealVector vectorFromBytes(@Nonnull final Config config, @Nonnull final byte[] vectorBytes) {
213207
final byte vectorTypeOrdinal = vectorBytes[0];
214-
switch (fromVectorTypeOrdinal(vectorTypeOrdinal)) {
215-
case HALF:
216-
return HalfRealVector.fromBytes(vectorBytes);
217-
case SINGLE:
218-
return FloatRealVector.fromBytes(vectorBytes);
219-
case DOUBLE:
220-
return DoubleRealVector.fromBytes(vectorBytes);
208+
switch (RealVector.fromVectorTypeOrdinal(vectorTypeOrdinal)) {
221209
case RABITQ:
222210
Verify.verify(config.isUseRaBitQ());
223211
return EncodedRealVector.fromBytes(vectorBytes, config.getNumDimensions(),
224212
config.getRaBitQNumExBits());
213+
case HALF:
214+
case SINGLE:
215+
case DOUBLE:
216+
return RealVector.fromBytes(vectorBytes);
225217
default:
226218
throw new RuntimeException("unable to serialize vector");
227219
}
@@ -251,11 +243,6 @@ static Tuple tupleFromVector(@Nonnull final RealVector vector) {
251243
return Tuple.from(vector.getRawData());
252244
}
253245

254-
@Nonnull
255-
static VectorType fromVectorTypeOrdinal(final int ordinal) {
256-
return VECTOR_TYPES.get(ordinal);
257-
}
258-
259246
@Nonnull
260247
static CompletableFuture<AccessInfo> fetchAccessInfo(@Nonnull final Config config,
261248
@Nonnull final ReadTransaction readTransaction,

fdb-extensions/src/main/java/com/apple/foundationdb/linear/RealVector.java

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020

2121
package com.apple.foundationdb.linear;
2222

23-
import com.google.common.base.Preconditions;
2423
import com.apple.foundationdb.half.Half;
24+
import com.google.common.base.Preconditions;
25+
import com.google.common.collect.ImmutableList;
2526

2627
import javax.annotation.Nonnull;
2728

@@ -34,6 +35,8 @@
3435
* data type conversions and raw data representation.
3536
*/
3637
public interface RealVector {
38+
ImmutableList<VectorType> VECTOR_TYPES = ImmutableList.copyOf(VectorType.values());
39+
3740
/**
3841
* Returns the number of elements in the vector, i.e. the number of dimensions.
3942
* @return the number of dimensions
@@ -189,4 +192,47 @@ default RealVector multiply(final double scalarFactor) {
189192
}
190193
return withData(result);
191194
}
195+
196+
@Nonnull
197+
static VectorType fromVectorTypeOrdinal(final int ordinal) {
198+
return VECTOR_TYPES.get(ordinal);
199+
}
200+
201+
/**
202+
* Creates a {@link RealVector} from a byte array.
203+
* <p>
204+
* This method interprets the input byte array by interpreting the first byte of the array as the type of vector.
205+
* It then delegates to {@link #fromBytes(VectorType, byte[])} to do the actual deserialization.
206+
*
207+
* @param vectorBytes the non-null byte array to convert.
208+
* @return a new {@link RealVector} instance created from the byte array.
209+
*/
210+
@Nonnull
211+
static RealVector fromBytes(@Nonnull final byte[] vectorBytes) {
212+
final byte vectorTypeOrdinal = vectorBytes[0];
213+
return fromBytes(fromVectorTypeOrdinal(vectorTypeOrdinal), vectorBytes);
214+
}
215+
216+
/**
217+
* Creates a {@link RealVector} from a byte array.
218+
* <p>
219+
* This implementation dispatches to the actual logic that deserialize a byte array to a vector which is located in
220+
* the respective implementations of {@link RealVector}.
221+
* @param vectorType the vector type of the serialized vector
222+
* @param vectorBytes the non-null byte array to convert.
223+
* @return a new {@link RealVector} instance created from the byte array.
224+
*/
225+
@Nonnull
226+
static RealVector fromBytes(@Nonnull final VectorType vectorType, @Nonnull final byte[] vectorBytes) {
227+
switch (vectorType) {
228+
case HALF:
229+
return HalfRealVector.fromBytes(vectorBytes);
230+
case SINGLE:
231+
return FloatRealVector.fromBytes(vectorBytes);
232+
case DOUBLE:
233+
return DoubleRealVector.fromBytes(vectorBytes);
234+
default:
235+
throw new RuntimeException("unable to deserialize vector");
236+
}
237+
}
192238
}

0 commit comments

Comments
 (0)