Skip to content

Commit

Permalink
Update Test2DThreshold to control for averages instead of worst-case …
Browse files Browse the repository at this point in the history
…statistics (#391)

Update Test2DThreshold to control for averages instead of worst-case quantities
  • Loading branch information
marianotepper authored Jan 23, 2025
1 parent 6cc68db commit dcad8fd
Showing 1 changed file with 14 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import io.github.jbellis.jvector.LuceneTestCase;
import io.github.jbellis.jvector.TestUtil;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
Expand All @@ -35,18 +34,18 @@ public class Test2DThreshold extends LuceneTestCase {
@Test
public void testThreshold10k() throws IOException {
for (int i = 0; i < 10; i++) {
testThreshold(10_000, 16);
testThreshold(10_000, 16, 0.85f, 0.9f);
}
}

@Test
public void testThreshold20k() throws IOException {
for (int i = 0; i < 10; i++) {
testThreshold(20_000, 24);
testThreshold(20_000, 24, 0.75f, 0.95f);
}
}

public void testThreshold(int graphSize, int maxDegree) throws IOException {
public void testThreshold(int graphSize, int maxDegree, float visitedRatioThreshold, float recallThreshold) throws IOException {
var R = getRandom();

// build index
Expand All @@ -57,16 +56,24 @@ public void testThreshold(int graphSize, int maxDegree) throws IOException {

// test raw vectors
var searcher = new GraphSearcher(onHeapGraph);
for (int i = 0; i < 10; i++) {

int nQueries = 100;
float meanVisitedRatio = 0;
float meanRecall = 0;

for (int i = 0; i < nQueries; i++) {
TestParams tp = createTestParams(vectors);

var sf = ravv.rerankerFor(tp.q, VectorSimilarityFunction.EUCLIDEAN);
var result = searcher.search(new SearchScoreProvider(sf), vectors.length, tp.th, Bits.ALL);

assert result.getVisitedCount() < vectors.length : "visited all vectors for threshold " + tp.th;
assert result.getNodes().length >= 0.85 * tp.exactCount : "returned " + result.getNodes().length + " nodes for threshold " + tp.th + " out of " + tp.exactCount;
meanVisitedRatio += ((float) result.getVisitedCount()) / (vectors.length * nQueries);
meanRecall += ((float) result.getNodes().length) / (tp.exactCount * nQueries);
}

assert meanVisitedRatio < visitedRatioThreshold : "visited " + meanVisitedRatio * 100 + "% of the vectors, which is more than " + visitedRatioThreshold * 100 + "%";
assert meanRecall > recallThreshold : "the recall is too low: " + meanRecall + " < " + recallThreshold;

// test compressed
// FIXME see https://github.com/jbellis/jvector/issues/254
// Path outputPath = Files.createTempFile("graph", ".jvector");
Expand Down

0 comments on commit dcad8fd

Please sign in to comment.