Skip to content

Commit 8fcb3a6

Browse files
committed
rabitq works inside of HNSW
1 parent cb4b44a commit 8fcb3a6

File tree

3 files changed

+79
-246
lines changed

3 files changed

+79
-246
lines changed

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java

Lines changed: 72 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ public class HNSW {
9292

9393
public static final int MAX_CONCURRENT_NODE_READS = 16;
9494
public static final int MAX_CONCURRENT_NEIGHBOR_FETCHES = 3;
95-
public static final int MAX_CONCURRENT_SEARCHES = 10;
9695
public static final long DEFAULT_RANDOM_SEED = 0L;
9796
@Nonnull public static final Metric DEFAULT_METRIC = Metric.EUCLIDEAN_METRIC;
9897
public static final boolean DEFAULT_USE_INLINING = false;
@@ -1290,193 +1289,78 @@ public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @N
12901289
.thenCompose(nodeReference ->
12911290
insertIntoLayers(transaction, storageTransform, quantizer, newPrimaryKey,
12921291
transformedNewVector, nodeReference, lMax, insertionLayer))
1293-
.thenCompose(ignored -> {
1294-
if (getConfig().isUseRaBitQ() && !accessInfo.canUseRaBitQ()) {
1295-
if (shouldSampleVector()) {
1296-
StorageAdapter.appendSampledVector(transaction, getSubspace(),
1297-
1, transformedNewVector, onWriteListener);
1298-
}
1299-
if (shouldMaintainStats()) {
1300-
return StorageAdapter.consumeSampledVectors(transaction, getSubspace(),
1301-
50, onReadListener)
1302-
.thenApply(sampledVectors -> {
1303-
RealVector partialVector = null;
1304-
int partialCount = 0;
1305-
for (final AggregatedVector sampledVector : sampledVectors) {
1306-
partialVector = partialVector == null
1307-
? sampledVector.getPartialVector()
1308-
: partialVector.add(sampledVector.getPartialVector());
1309-
partialCount += sampledVector.getPartialCount();
1310-
}
1311-
if (partialCount > 0) {
1312-
StorageAdapter.appendSampledVector(transaction, getSubspace(),
1313-
partialCount, partialVector, onWriteListener);
1314-
if (logger.isTraceEnabled()) {
1315-
logger.trace("updated stats with numVectors={}, partialCount={}, partialVector={}",
1316-
sampledVectors.size(), partialCount, partialVector);
1317-
}
1318-
1319-
if (partialCount >= getConfig().getStatsThreshold()) {
1320-
final long rotatorSeed = random.nextLong();
1321-
final FhtKacRotator rotator = new FhtKacRotator(rotatorSeed, getConfig().getNumDimensions(), 10);
1322-
1323-
final RealVector centroid =
1324-
partialVector.multiply(1.0d / partialCount);
1325-
final RealVector transformedCentroid = rotator.applyInvert(centroid);
1326-
1327-
final var transformedEntryNodeVector =
1328-
rotator.applyInvert(currentAccessInfo.getEntryNodeReference()
1329-
.getVector()).subtract(transformedCentroid);
1330-
1331-
final AccessInfo newAccessInfo =
1332-
new AccessInfo(currentAccessInfo.getEntryNodeReference().withVector(transformedEntryNodeVector),
1333-
rotatorSeed, transformedCentroid);
1334-
StorageAdapter.writeAccessInfo(transaction, getSubspace(), newAccessInfo, onWriteListener);
1335-
StorageAdapter.removeAllSampledVectors(transaction, getSubspace());
1336-
if (logger.isTraceEnabled()) {
1337-
logger.trace("established rotatorSeed={}, centroid with count={}, centroid={}",
1338-
rotatorSeed, partialCount, transformedCentroid);
1339-
}
1340-
}
1341-
}
1342-
return null;
1343-
});
1344-
}
1345-
}
1346-
return AsyncUtil.DONE;
1347-
});
1292+
.thenCompose(ignored ->
1293+
addToStats(transaction, currentAccessInfo, transformedNewVector));
13481294
}).thenCompose(ignored -> AsyncUtil.DONE);
13491295
}
13501296

1351-
/**
1352-
* Inserts a batch of nodes into the HNSW graph asynchronously.
1353-
*
1354-
* <p>This method orchestrates the batch insertion of nodes into the HNSW graph structure.
1355-
* For each node in the input {@code batch}, it first assigns a random layer based on the configured
1356-
* probability distribution. The batch is then sorted in descending order of these assigned layers to
1357-
* ensure higher-layer nodes are processed first, which can optimize subsequent insertions by providing
1358-
* better entry points.</p>
1359-
*
1360-
* <p>The insertion logic proceeds in two main asynchronous stages:
1361-
* <ol>
1362-
* <li><b>Search Phase:</b> For each node to be inserted, the method concurrently performs a greedy search
1363-
* from the graph's main entry point down to the node's target layer. This identifies the nearest neighbors
1364-
* at each level, which will serve as entry points for the insertion phase.</li>
1365-
* <li><b>Insertion Phase:</b> The method then iterates through the nodes and inserts each one into the graph
1366-
* from its target layer downwards, connecting it to its nearest neighbors. If a node's assigned layer is
1367-
* higher than the current maximum layer of the graph, it becomes the new main entry point.</li>
1368-
* </ol>
1369-
* All underlying storage operations are performed within the context of the provided {@link Transaction}.</p>
1370-
*
1371-
* @param transaction the transaction to use for all storage operations; must not be {@code null}
1372-
* @param batch a {@code List} of {@link NodeReferenceWithVector} objects to insert; must not be {@code null}
1373-
*
1374-
* @return a {@link CompletableFuture} that completes with {@code null} when the entire batch has been inserted
1375-
*/
13761297
@Nonnull
1377-
public CompletableFuture<Void> insertBatch(@Nonnull final Transaction transaction,
1378-
@Nonnull List<NodeReferenceWithVector> batch) {
1379-
// determine the layer each item should be inserted at
1380-
final List<NodeReferenceWithLayer> batchWithLayers = Lists.newArrayListWithCapacity(batch.size());
1381-
for (final NodeReferenceWithVector current : batch) {
1382-
batchWithLayers.add(
1383-
new NodeReferenceWithLayer(current.getPrimaryKey(), current.getVector(), insertionLayer()));
1384-
}
1385-
// sort the layers in reverse order
1386-
batchWithLayers.sort(Comparator.comparing(NodeReferenceWithLayer::getLayer).reversed());
1387-
1388-
return StorageAdapter.fetchAccessInfo(getConfig(), transaction, getSubspace(), getOnReadListener())
1389-
.thenCompose(accessInfo -> {
1390-
final int lMax =
1391-
accessInfo == null ? -1 : accessInfo.getEntryNodeReference().getLayer();
1392-
1393-
final AffineOperator storageTransform = storageTransform(accessInfo);
1394-
final Quantizer quantizer = quantizer(accessInfo);
1395-
final Estimator estimator = quantizer.estimator();
1396-
1397-
return forEach(batchWithLayers,
1398-
item -> {
1399-
if (lMax == -1) {
1400-
return CompletableFuture.completedFuture(null);
1298+
private CompletableFuture<Void> addToStats(@Nonnull final Transaction transaction,
1299+
@Nonnull final AccessInfo currentAccessInfo,
1300+
@Nonnull final RealVector transformedNewVector) {
1301+
if (getConfig().isUseRaBitQ() && !currentAccessInfo.canUseRaBitQ()) {
1302+
if (shouldSampleVector()) {
1303+
StorageAdapter.appendSampledVector(transaction, getSubspace(),
1304+
1, transformedNewVector, onWriteListener);
1305+
}
1306+
if (shouldMaintainStats()) {
1307+
return StorageAdapter.consumeSampledVectors(transaction, getSubspace(),
1308+
50, onReadListener)
1309+
.thenApply(sampledVectors -> {
1310+
final AggregatedVector aggregatedSampledVector =
1311+
aggregateVectors(sampledVectors);
1312+
1313+
if (aggregatedSampledVector != null) {
1314+
final int partialCount = aggregatedSampledVector.getPartialCount();
1315+
final RealVector partialVector = aggregatedSampledVector.getPartialVector();
1316+
StorageAdapter.appendSampledVector(transaction, getSubspace(),
1317+
partialCount, partialVector, onWriteListener);
1318+
if (logger.isTraceEnabled()) {
1319+
logger.trace("updated stats with numVectors={}, partialCount={}, partialVector={}",
1320+
sampledVectors.size(), partialCount, partialVector);
14011321
}
14021322

1403-
final EntryNodeReference entryNodeReference = accessInfo.getEntryNodeReference();
1404-
1405-
final RealVector itemVector = item.getVector();
1406-
final RealVector transformedItemVector = storageTransform.applyInvert(itemVector);
1407-
1408-
final int itemL = item.getLayer();
1409-
1410-
final NodeReferenceWithDistance initialNodeReference =
1411-
new NodeReferenceWithDistance(entryNodeReference.getPrimaryKey(),
1412-
entryNodeReference.getVector(),
1413-
estimator.distance(transformedItemVector, entryNodeReference.getVector()));
1414-
1415-
return forLoop(lMax, initialNodeReference,
1416-
layer -> layer > itemL,
1417-
layer -> layer - 1,
1418-
(layer, previousNodeReference) -> {
1419-
final StorageAdapter<? extends NodeReference> storageAdapter = getStorageAdapterForLayer(layer);
1420-
return greedySearchLayer(storageAdapter, transaction, storageTransform,
1421-
estimator, previousNodeReference, layer, transformedItemVector);
1422-
}, executor);
1423-
}, MAX_CONCURRENT_SEARCHES, getExecutor())
1424-
.thenCompose(searchEntryReferences ->
1425-
forLoop(0, accessInfo == null ? null : accessInfo.getEntryNodeReference(),
1426-
index -> index < batchWithLayers.size(),
1427-
index -> index + 1,
1428-
(index, currentEntryNodeReference) -> {
1429-
final NodeReferenceWithLayer item = batchWithLayers.get(index);
1430-
final Tuple itemPrimaryKey = item.getPrimaryKey();
1431-
final RealVector itemVector = item.getVector();
1432-
final int itemL = item.getLayer();
1433-
1434-
final EntryNodeReference newEntryNodeReference;
1435-
final int currentLMax;
1436-
1437-
if (accessInfo == null) {
1438-
// this is the first node
1439-
writeLonelyNodes(quantizer, transaction, itemPrimaryKey, itemVector, itemL, -1);
1440-
newEntryNodeReference =
1441-
new EntryNodeReference(itemPrimaryKey, itemVector, itemL);
1442-
StorageAdapter.writeAccessInfo(transaction, getSubspace(),
1443-
new AccessInfo(newEntryNodeReference, -1L, null), getOnWriteListener());
1444-
if (logger.isTraceEnabled()) {
1445-
logger.trace("written initial entry node reference for batch with key={} on layer={}", itemPrimaryKey, itemL);
1446-
}
1447-
1448-
return CompletableFuture.completedFuture(newEntryNodeReference);
1449-
} else {
1450-
currentLMax = currentEntryNodeReference.getLayer();
1451-
if (itemL > currentLMax) {
1452-
writeLonelyNodes(quantizer, transaction, itemPrimaryKey, itemVector, itemL, lMax);
1453-
newEntryNodeReference =
1454-
new EntryNodeReference(itemPrimaryKey, itemVector, itemL);
1455-
StorageAdapter.writeAccessInfo(transaction, getSubspace(),
1456-
accessInfo.withNewEntryNodeReference(newEntryNodeReference),
1457-
getOnWriteListener());
1458-
if (logger.isTraceEnabled()) {
1459-
logger.trace("written higher entry node reference for batch with key={} on layer={}", itemPrimaryKey, itemL);
1460-
}
1461-
} else {
1462-
// entry node stays the same
1463-
newEntryNodeReference = accessInfo.getEntryNodeReference();
1464-
}
1465-
}
1466-
1467-
if (logger.isTraceEnabled()) {
1468-
logger.trace("entry node read for batch with key {} at layer {}",
1469-
currentEntryNodeReference.getPrimaryKey(), currentLMax);
1470-
}
1471-
1472-
final var currentSearchEntry =
1473-
searchEntryReferences.get(index);
1323+
if (partialCount >= getConfig().getStatsThreshold()) {
1324+
final long rotatorSeed = random.nextLong();
1325+
final FhtKacRotator rotator =
1326+
new FhtKacRotator(rotatorSeed, getConfig().getNumDimensions(), 10);
1327+
1328+
final RealVector centroid =
1329+
partialVector.multiply(1.0d / partialCount);
1330+
final RealVector transformedCentroid = rotator.applyInvert(centroid);
1331+
1332+
final var transformedEntryNodeVector =
1333+
rotator.applyInvert(currentAccessInfo.getEntryNodeReference()
1334+
.getVector()).subtract(transformedCentroid);
1335+
1336+
final AccessInfo newAccessInfo =
1337+
new AccessInfo(currentAccessInfo.getEntryNodeReference().withVector(transformedEntryNodeVector),
1338+
rotatorSeed, transformedCentroid);
1339+
StorageAdapter.writeAccessInfo(transaction, getSubspace(), newAccessInfo, onWriteListener);
1340+
StorageAdapter.removeAllSampledVectors(transaction, getSubspace());
1341+
if (logger.isTraceEnabled()) {
1342+
logger.trace("established rotatorSeed={}, centroid with count={}, centroid={}",
1343+
rotatorSeed, partialCount, transformedCentroid);
1344+
}
1345+
}
1346+
}
1347+
return null;
1348+
});
1349+
}
1350+
}
1351+
return AsyncUtil.DONE;
1352+
}
14741353

1475-
return insertIntoLayers(transaction, storageTransform, quantizer,
1476-
itemPrimaryKey, itemVector, currentSearchEntry, lMax, itemL)
1477-
.thenApply(ignored -> newEntryNodeReference);
1478-
}, getExecutor()));
1479-
}).thenCompose(ignored -> AsyncUtil.DONE);
1354+
@Nullable
1355+
AggregatedVector aggregateVectors(@Nonnull final Iterable<AggregatedVector> vectors) {
1356+
RealVector partialVector = null;
1357+
int partialCount = 0;
1358+
for (final AggregatedVector vector : vectors) {
1359+
partialVector = partialVector == null
1360+
? vector.getPartialVector() : partialVector.add(vector.getPartialVector());
1361+
partialCount += vector.getPartialCount();
1362+
}
1363+
return partialCount == 0 ? null : new AggregatedVector(partialCount, partialVector);
14801364
}
14811365

14821366
/**
@@ -2097,7 +1981,13 @@ private boolean shouldSampleVector() {
20971981
}
20981982

20991983
private boolean shouldMaintainStats() {
2100-
return random.nextDouble() < getConfig().getMaintainStatsProbability();
1984+
return shouldMaintainStats(1);
1985+
}
1986+
1987+
private boolean shouldMaintainStats(final int batchSize) {
1988+
// pBatch = 1 - (1 - p)^n == -expm1(n * log1p(-p))
1989+
double pBatch = -Math.expm1(batchSize * Math.log1p(-getConfig().getMaintainStatsProbability()));
1990+
return random.nextDouble() < pBatch;
21011991
}
21021992

21031993
private static class NodeReferenceWithLayer extends NodeReferenceWithVector {

0 commit comments

Comments
 (0)