diff --git a/README.md b/README.md
index 477e8e7..0c6d08c 100644
--- a/README.md
+++ b/README.md
@@ -280,7 +280,7 @@ from semhash import SemHash
texts = load_dataset("ag_news", split="train")["text"]
# Load an embedding model (in this example, a multilingual model)
-model = StaticModel.from_pretrained("minishlab/M2V_multilingual_output")
+model = StaticModel.from_pretrained("minishlab/potion-multilingual-128M")
# Initialize a SemHash with the model and custom encoder
semhash = SemHash.from_records(records=texts, model=model)
@@ -318,21 +318,23 @@ deduplicated_texts = semhash.self_deduplicate()
Using custom ANN backends
-The following code snippet shows how to use a custom ANN backend and custom args with SemHash:
+By default, we use [USearch](https://github.com/unum-cloud/USearch) as the ANN (approximate-nearest neighbors) backend for deduplication. We recommend keeping this since the recall for smaller datasets is ~100%, and it's needed for larger datasets (>1M samples) since these will take too long to deduplicate without ANN. If you want to use a flat/exact-matching backend, you can set `ann_backend=Backend.BASIC` in the SemHash constructor:
```python
-from datasets import load_dataset
from semhash import SemHash
from vicinity import Backend
-# Load a dataset to deduplicate
-texts = load_dataset("ag_news", split="train")["text"]
+semhash = SemHash.from_records(records=texts, ann_backend=Backend.BASIC)
+```
-# Initialize a SemHash with the model and custom ann backend and custom args
-semhash = SemHash.from_records(records=texts, ann_backend=Backend.FAISS, nlist=50)
+Any backend from [Vicinity](https://github.com/MinishLab/vicinity) can be used with SemHash. The following code snippet shows how to use [FAISS](https://github.com/facebookresearch/faiss) with a custom `nlist` parameter:
-# Deduplicate the texts
-deduplicated_texts = semhash.self_deduplicate()
+```python
+from datasets import load_dataset
+from semhash import SemHash
+from vicinity import Backend
+
+semhash = SemHash.from_records(records=texts, ann_backend=Backend.FAISS, nlist=50)
```
For the full list of supported ANN backends and args, see the [Vicinity docs](https://github.com/MinishLab/vicinity/tree/main?tab=readme-ov-file#supported-backends).
@@ -398,11 +400,6 @@ representative_texts = semhash.self_find_representative().selected
```
-NOTE: By default, we use the ANN (approximate-nearest neighbors) backend for deduplication. We recommend keeping this since the recall for smaller datasets is ~100%, and it's needed for larger datasets (>1M samples) since these will take too long to deduplicate without ANN. If you want to use the flat/exact-matching backend, you can set `use_ann=False` in the SemHash constructor:
-
-```python
-semhash = SemHash.from_records(records=texts, use_ann=False)
-```
@@ -410,7 +407,7 @@ semhash = SemHash.from_records(records=texts, use_ann=False)
We've benchmarked SemHash on a variety of datasets to measure the deduplication performance and speed. The benchmarks were run with the following setup:
- The benchmarks were all run on CPU
-- The benchmarks were all run with `use_ann=True`
+- The benchmarks were all run with the default ANN backend (usearch)
- The used encoder is the default encoder ([potion-base-8M](https://huggingface.co/minishlab/potion-base-8M)).
- The timings include the encoding time, index building time, and deduplication time.
### Train Deduplication Benchmark
diff --git a/benchmarks/run_benchmarks.py b/benchmarks/run_benchmarks.py
index 128b01b..bb36a52 100644
--- a/benchmarks/run_benchmarks.py
+++ b/benchmarks/run_benchmarks.py
@@ -45,7 +45,7 @@ def main() -> None: # noqa: C901
# Build the SemHash instance
build_start = perf_counter()
- semhash = SemHash.from_records(model=model, use_ann=True, records=train_records, columns=columns)
+ semhash = SemHash.from_records(model=model, records=train_records, columns=columns)
build_end = perf_counter()
build_time = build_end - build_start
# Time how long it takes to deduplicate the train set
diff --git a/semhash/semhash.py b/semhash/semhash.py
index c655d7d..1061ae8 100644
--- a/semhash/semhash.py
+++ b/semhash/semhash.py
@@ -45,7 +45,6 @@ def from_records(
cls,
records: Sequence[Record],
columns: Sequence[str] | None = None,
- use_ann: bool = True,
model: Encoder | None = None,
ann_backend: Backend | str = Backend.USEARCH,
**kwargs: Any,
@@ -57,9 +56,8 @@ def from_records(
:param records: A list of records (strings or dictionaries).
:param columns: Columns to featurize if records are dictionaries.
- :param use_ann: Whether to use approximate nearest neighbors (True) or basic search (False). Default is True.
:param model: (Optional) An Encoder model. If None, the default model is used (minishlab/potion-base-8M).
- :param ann_backend: (Optional) The ANN backend to use if use_ann is True. Defaults to Backend.USEARCH.
+ :param ann_backend: (Optional) The ANN backend to use. Defaults to Backend.USEARCH.
:param **kwargs: Any additional keyword arguments to pass to the Vicinity index.
:return: A SemHash instance with a fitted vicinity index.
"""
@@ -90,11 +88,10 @@ def from_records(
embeddings = featurize(deduplicated_records, columns, model)
# Build the Vicinity index
- backend = ann_backend if use_ann else Backend.BASIC
index = Index.from_vectors_and_items(
vectors=embeddings,
items=items,
- backend_type=backend,
+ backend_type=ann_backend,
**kwargs,
)
@@ -107,7 +104,6 @@ def from_embeddings(
records: Sequence[Record],
model: Encoder,
columns: Sequence[str] | None = None,
- use_ann: bool = True,
ann_backend: Backend | str = Backend.USEARCH,
**kwargs: Any,
) -> SemHash:
@@ -121,8 +117,7 @@ def from_embeddings(
:param model: The Encoder model used for creating the embeddings.
:param columns: Columns to use if records are dictionaries. If None and records are strings,
defaults to ["text"].
- :param use_ann: Whether to use approximate nearest neighbors (True) or basic search (False). Default is True.
- :param ann_backend: (Optional) The ANN backend to use if use_ann is True. Defaults to Backend.USEARCH.
+ :param ann_backend: (Optional) The ANN backend to use. Defaults to Backend.USEARCH.
:param **kwargs: Any additional keyword arguments to pass to the Vicinity index.
:return: A SemHash instance with a fitted vicinity index.
:raises ValueError: If the number of embeddings doesn't match the number of records.
@@ -157,9 +152,8 @@ def from_embeddings(
deduplicated_embeddings = embeddings[embedding_indices]
# Create the index
- backend_type = ann_backend if use_ann else Backend.BASIC
index = Index.from_vectors_and_items(
- vectors=deduplicated_embeddings, items=items, backend_type=backend_type, **kwargs
+ vectors=deduplicated_embeddings, items=items, backend_type=ann_backend, **kwargs
)
return cls(index=index, model=model, columns=columns, was_string=was_string)
diff --git a/tests/conftest.py b/tests/conftest.py
index ff55fb4..69eebe8 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -8,12 +8,6 @@ def model() -> StaticModel:
return StaticModel.from_pretrained("tests/data/test_model")
-@pytest.fixture(params=[True, False], ids=["use_ann=True", "use_ann=False"])
-def use_ann(request: pytest.FixtureRequest) -> bool:
- """Whether to use approximate nearest neighbors or not."""
- return request.param
-
-
@pytest.fixture
def train_texts() -> list[str]:
"""A list of train texts for testing outlier and representative filtering."""
diff --git a/tests/test_semhash.py b/tests/test_semhash.py
index 5923e3c..9c68c17 100644
--- a/tests/test_semhash.py
+++ b/tests/test_semhash.py
@@ -6,7 +6,7 @@
from semhash.utils import Encoder
-def test_single_dataset_deduplication(use_ann: bool, model: Encoder) -> None:
+def test_single_dataset_deduplication(model: Encoder) -> None:
"""Test single dataset deduplication."""
# No duplicates
texts = [
@@ -14,7 +14,7 @@ def test_single_dataset_deduplication(use_ann: bool, model: Encoder) -> None:
"The master sword can seal the darkness.",
"Ganondorf has invaded Hyrule!",
]
- semhash = SemHash.from_records(records=texts, use_ann=use_ann, model=model)
+ semhash = SemHash.from_records(records=texts, model=model)
deduplicated_texts = semhash.self_deduplicate().selected
assert deduplicated_texts == texts
@@ -25,12 +25,12 @@ def test_single_dataset_deduplication(use_ann: bool, model: Encoder) -> None:
"It's dangerous to go alone!", # Exact duplicate
"It's not safe to go alone!", # Semantically similar
]
- semhash = SemHash.from_records(records=texts, use_ann=use_ann, model=model)
+ semhash = SemHash.from_records(records=texts, model=model)
deduplicated_texts = semhash.self_deduplicate(0.7).selected
assert deduplicated_texts == ["It's dangerous to go alone!"]
-def test_multi_dataset_deduplication(use_ann: bool, model: Encoder) -> None:
+def test_multi_dataset_deduplication(model: Encoder) -> None:
"""Test deduplication across two datasets."""
# No duplicates
texts1 = [
@@ -43,7 +43,7 @@ def test_multi_dataset_deduplication(use_ann: bool, model: Encoder) -> None:
"Zelda is the princess of Hyrule.",
"Ganon is the king of thieves.",
]
- semhash = SemHash.from_records(texts1, columns=None, use_ann=use_ann, model=model)
+ semhash = SemHash.from_records(texts1, columns=None, model=model)
deduplicated_texts = semhash.deduplicate(texts2).selected
assert deduplicated_texts == texts2
@@ -57,7 +57,7 @@ def test_multi_dataset_deduplication(use_ann: bool, model: Encoder) -> None:
assert deduplicated_texts == []
-def test_single_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) -> None:
+def test_single_dataset_deduplication_multicolumn(model: Encoder) -> None:
"""Test single dataset deduplication with multi-column records."""
records = [
{"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"},
@@ -72,7 +72,6 @@ def test_single_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder)
semhash = SemHash.from_records(
records,
columns=["question", "context", "answer"],
- use_ann=use_ann,
model=model,
)
deduplicated = semhash.self_deduplicate(threshold=0.7)
@@ -83,7 +82,7 @@ def test_single_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder)
]
-def test_multi_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) -> None:
+def test_multi_dataset_deduplication_multicolumn(model: Encoder) -> None:
"""Test multi dataset deduplication with multi-column records."""
train_records = [
{"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"},
@@ -101,7 +100,6 @@ def test_multi_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder)
semhash = SemHash.from_records(
train_records,
columns=["question", "context", "answer"],
- use_ann=use_ann,
model=model,
)
deduplicated = semhash.deduplicate(test_records).selected
@@ -110,17 +108,17 @@ def test_multi_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder)
]
-def test_from_records_without_columns(use_ann: bool, model: Encoder) -> None:
+def test_from_records_without_columns(model: Encoder) -> None:
"""Test fitting without specifying columns."""
records = [
{"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"},
{"question": "Who is the princess?", "context": "The princess is Zelda", "answer": "Zelda"},
]
with pytest.raises(ValueError):
- SemHash.from_records(records, columns=None, use_ann=use_ann, model=model)
+ SemHash.from_records(records, columns=None, model=model)
-def test_deduplicate_with_only_exact_duplicates(use_ann: bool, model: Encoder) -> None:
+def test_deduplicate_with_only_exact_duplicates(model: Encoder) -> None:
"""Test deduplicating with only exact duplicates."""
texts1 = [
"It's dangerous to go alone!",
@@ -132,7 +130,7 @@ def test_deduplicate_with_only_exact_duplicates(use_ann: bool, model: Encoder) -
"It's dangerous to go alone!",
"It's dangerous to go alone!",
]
- semhash = SemHash.from_records(texts1, use_ann=use_ann, model=model)
+ semhash = SemHash.from_records(texts1, model=model)
deduplicated = semhash.self_deduplicate()
assert deduplicated.selected == ["It's dangerous to go alone!"]
@@ -140,9 +138,9 @@ def test_deduplicate_with_only_exact_duplicates(use_ann: bool, model: Encoder) -
assert deduplicated.selected == []
-def test_self_find_representative(use_ann: bool, model: Encoder, train_texts: list[str]) -> None:
+def test_self_find_representative(model: Encoder, train_texts: list[str]) -> None:
"""Test the self_find_representative method."""
- semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model)
+ semhash = SemHash.from_records(records=train_texts, model=model)
result = semhash.self_find_representative(
candidate_limit=5,
selection_size=3,
@@ -157,18 +155,18 @@ def test_self_find_representative(use_ann: bool, model: Encoder, train_texts: li
}, "Expected representatives to be blueberry, pineapple, and grape"
-def test_find_representative(use_ann: bool, model: Encoder, train_texts: list[str], test_texts: list[str]) -> None:
+def test_find_representative(model: Encoder, train_texts: list[str], test_texts: list[str]) -> None:
"""Test the find_representative method."""
- semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model)
+ semhash = SemHash.from_records(records=train_texts, model=model)
result = semhash.find_representative(records=test_texts, candidate_limit=5, selection_size=3, diversity=0.5)
assert len(result.selected) == 3, "Expected 3 representatives"
selected = {r["text"] for r in result.selected}
assert selected == {"grapefruit", "banana", "apple"}, "Expected representatives to be grapefruit, banana, and apple"
-def test_filter_outliers(use_ann: bool, model: Encoder, train_texts: list[str], test_texts: list[str]) -> None:
+def test_filter_outliers(model: Encoder, train_texts: list[str], test_texts: list[str]) -> None:
"""Test the filter_outliers method."""
- semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model)
+ semhash = SemHash.from_records(records=train_texts, model=model)
result = semhash.filter_outliers(records=test_texts, outlier_percentage=0.2)
assert len(result.filtered) == 2, "Expected 2 outliers"
assert len(result.selected) == len(test_texts) - 2
@@ -176,9 +174,9 @@ def test_filter_outliers(use_ann: bool, model: Encoder, train_texts: list[str],
assert filtered == {"motorcycle", "plane"}, "Expected outliers to be motorcycle and plane"
-def test_self_filter_outliers(use_ann: bool, model: Encoder, train_texts: list[str]) -> None:
+def test_self_filter_outliers(model: Encoder, train_texts: list[str]) -> None:
"""Test the self_filter_outliers method."""
- semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model)
+ semhash = SemHash.from_records(records=train_texts, model=model)
result = semhash.self_filter_outliers(outlier_percentage=0.1)
assert len(result.filtered) == 2, "Expected 2 outliers"
assert len(result.selected) == len(train_texts) - 2
@@ -216,7 +214,7 @@ def test__diversify(monkeypatch: pytest.MonkeyPatch) -> None:
assert result_empty.scores_filtered == []
-def test_from_embeddings(use_ann: bool, model: Encoder, train_texts: list[str]) -> None:
+def test_from_embeddings(model: Encoder, train_texts: list[str]) -> None:
"""Test from_embeddings constructor with validation and comparison to from_records."""
# Test validation: mismatched shapes
with pytest.raises(ValueError, match="Number of embeddings"):
@@ -224,12 +222,10 @@ def test_from_embeddings(use_ann: bool, model: Encoder, train_texts: list[str])
SemHash.from_embeddings(embeddings=wrong_embeddings, records=train_texts, model=model)
# Test that from_embeddings behaves same as from_records
- semhash_from_records = SemHash.from_records(records=train_texts, model=model, use_ann=use_ann)
+ semhash_from_records = SemHash.from_records(records=train_texts, model=model)
embeddings = model.encode(train_texts)
- semhash_from_embeddings = SemHash.from_embeddings(
- embeddings=embeddings, records=train_texts, model=model, use_ann=use_ann
- )
+ semhash_from_embeddings = SemHash.from_embeddings(embeddings=embeddings, records=train_texts, model=model)
# Both should give same deduplication results
result1 = semhash_from_records.self_deduplicate(threshold=0.95)