Skip to content

Commit

Permalink
Improved use ScoreTracker to avoid wasteful searching for very large k (
Browse files Browse the repository at this point in the history
#387)

This improves upon #384 by making the quantiles estimation more lightweight. It models the recent scores as a Normal distribution and uses incremental updates to track sufficient statistics of its mean and variance. Then, quantiles are computed from these statistics.

---------

Co-authored-by: Jonathan Ellis <[email protected]>
  • Loading branch information
marianotepper and jbellis authored Jan 17, 2025
1 parent 9613109 commit 7cbb2e1
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr
// track scores to predict when we are done with threshold queries
var scoreTracker = threshold > 0
? new ScoreTracker.TwoPhaseTracker(threshold)
: PRUNE ? new ScoreTracker.TwoPhaseTracker(1.0) : new ScoreTracker.NoOpTracker();
: PRUNE ? new ScoreTracker.RelaxedMonotonicityTracker(rerankK) : new ScoreTracker.NoOpTracker();
VectorFloat<?> similarities = null;

// add evicted results from the last call back to the candidates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.util.AbstractLongHeap;
import io.github.jbellis.jvector.util.BoundedLongHeap;
import org.apache.commons.math3.stat.StatUtils;

Expand Down Expand Up @@ -93,8 +92,105 @@ public boolean shouldStop() {
// (paper suggests using the median of recent scores, but experimentally that is too prone to false positives.
// 90th does seem to be enough, but 99th doesn't result in much extra work, so we'll be conservative)
double windowMedian = StatUtils.percentile(recentScores, 99);
double worstBest = sortableIntToFloat((int) bestScores.top());
return windowMedian < worstBest && windowMedian < threshold;
double worstBestScore = sortableIntToFloat((int) bestScores.top());
return windowMedian < worstBestScore && windowMedian < threshold;
}
}

/**
* Follows the methodology of section 3.1 in "VBase: Unifying Online Vector Similarity Search
* and Relational Queries via Relaxed Monotonicity" to determine when we've left phase 1
* (finding the local maximum) and entered phase 2 (mostly just finding worse options)
* To compute quantiles quickly, we treat the distribution of the data as Normal,
* track its mean and variance, and compute quantiles from them as:
* mean + SIGMA_FACTOR * sqrt(variance)
* Empirically, SIGMA_FACTOR=1.75 seems to work reasonably well
* (approximately the 96th percentile of the Normal distribution).
*/
class RelaxedMonotonicityTracker implements ScoreTracker {
static final double SIGMA_FACTOR = 1.75;

// a sliding window of recent scores
private final double[] recentScores;
private int recentEntryIndex;

// Heap of the best scores seen so far
BoundedLongHeap bestScores;

// observation count
private int observationCount;

// the sample mean
private double mean;

// the sample variance multiplied by n-1
private double dSquared;

/**
* Constructor
* @param bestScoresTracked the number of tracked scores used to estimate if we are unlikely to improve
* the results anymore. An empirical rule of thumb is bestScoresTracked=rerankK.
*/
RelaxedMonotonicityTracker(int bestScoresTracked) {
// A quick empirical study yields that the number of recent scores
// that we need to consider grows by a factor of ~sqrt(bestScoresTracked / 2)
int factor = (int) Math.round(Math.sqrt(bestScoresTracked / 2.0));
this.recentScores = new double[200 * factor];
this.bestScores = new BoundedLongHeap(bestScoresTracked);
this.mean = 0;
this.dSquared = 0;
}

@Override
public void track(float score) {
bestScores.push(floatToSortableInt(score));
observationCount++;

// The updates of the sufficient statistics follow
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online
// and
// https://nestedsoftware.com/2019/09/26/incremental-average-and-standard-deviation-with-sliding-window-470k.176143.html
if (observationCount <= this.recentScores.length) {
// if the buffer is not full yet, use standard Welford method
var meanDelta = (score - this.mean) / observationCount;
var newMean = this.mean + meanDelta;

var dSquaredDelta = ((score - newMean) * (score - this.mean));
var newDSquared = this.dSquared + dSquaredDelta;

this.mean = newMean;
this.dSquared = newDSquared;
} else {
// once the buffer is full, adjust Welford method for window size
var oldScore = recentScores[recentEntryIndex];
var meanDelta = (score - oldScore) / this.recentScores.length;
var newMean = this.mean + meanDelta;

var dSquaredDelta = ((score - oldScore) * (score - newMean + oldScore - this.mean));
var newDSquared = this.dSquared + dSquaredDelta;

this.mean = newMean;
this.dSquared = newDSquared;
}
recentScores[recentEntryIndex] = score;
recentEntryIndex = (recentEntryIndex + 1) % this.recentScores.length;
}

@Override
public boolean shouldStop() {
// don't stop if we don't have enough data points
if (observationCount < this.recentScores.length) {
return false;
}

// We're in phase 2 if the q-th percentile of the recent scores evaluated,
// mean + SIGMA_FACTOR * sqrt(variance),
// is lower than the worst of the best scores seen.
// (paper suggests using the median of recent scores, but experimentally that is too prone to false positives)
double std = Math.sqrt(this.dSquared / (this.recentScores.length - 1));
double windowPercentile = this.mean + SIGMA_FACTOR * std;
double worstBestScore = sortableIntToFloat((int) bestScores.top());
return windowPercentile < worstBestScore;
}
}
}

0 comments on commit 7cbb2e1

Please sign in to comment.