diff --git a/README.md b/README.md index 477e8e7..0c6d08c 100644 --- a/README.md +++ b/README.md @@ -280,7 +280,7 @@ from semhash import SemHash texts = load_dataset("ag_news", split="train")["text"] # Load an embedding model (in this example, a multilingual model) -model = StaticModel.from_pretrained("minishlab/M2V_multilingual_output") +model = StaticModel.from_pretrained("minishlab/potion-multilingual-128M") # Initialize a SemHash with the model and custom encoder semhash = SemHash.from_records(records=texts, model=model) @@ -318,21 +318,23 @@ deduplicated_texts = semhash.self_deduplicate() Using custom ANN backends
-The following code snippet shows how to use a custom ANN backend and custom args with SemHash: +By default, we use [USearch](https://github.com/unum-cloud/USearch) as the ANN (approximate-nearest neighbors) backend for deduplication. We recommend keeping this since the recall for smaller datasets is ~100%, and it's needed for larger datasets (>1M samples) since these will take too long to deduplicate without ANN. If you want to use a flat/exact-matching backend, you can set `ann_backend=Backend.BASIC` in the SemHash constructor: ```python -from datasets import load_dataset from semhash import SemHash from vicinity import Backend -# Load a dataset to deduplicate -texts = load_dataset("ag_news", split="train")["text"] +semhash = SemHash.from_records(records=texts, ann_backend=Backend.BASIC) +``` -# Initialize a SemHash with the model and custom ann backend and custom args -semhash = SemHash.from_records(records=texts, ann_backend=Backend.FAISS, nlist=50) +Any backend from [Vicinity](https://github.com/MinishLab/vicinity) can be used with SemHash. The following code snippet shows how to use [FAISS](https://github.com/facebookresearch/faiss) with a custom `nlist` parameter: -# Deduplicate the texts -deduplicated_texts = semhash.self_deduplicate() +```python +from datasets import load_dataset +from semhash import SemHash +from vicinity import Backend + +semhash = SemHash.from_records(records=texts, ann_backend=Backend.FAISS, nlist=50) ``` For the full list of supported ANN backends and args, see the [Vicinity docs](https://github.com/MinishLab/vicinity/tree/main?tab=readme-ov-file#supported-backends). @@ -398,11 +400,6 @@ representative_texts = semhash.self_find_representative().selected ``` -NOTE: By default, we use the ANN (approximate-nearest neighbors) backend for deduplication. We recommend keeping this since the recall for smaller datasets is ~100%, and it's needed for larger datasets (>1M samples) since these will take too long to deduplicate without ANN. If you want to use the flat/exact-matching backend, you can set `use_ann=False` in the SemHash constructor: - -```python -semhash = SemHash.from_records(records=texts, use_ann=False) -``` @@ -410,7 +407,7 @@ semhash = SemHash.from_records(records=texts, use_ann=False) We've benchmarked SemHash on a variety of datasets to measure the deduplication performance and speed. The benchmarks were run with the following setup: - The benchmarks were all run on CPU -- The benchmarks were all run with `use_ann=True` +- The benchmarks were all run with the default ANN backend (usearch) - The used encoder is the default encoder ([potion-base-8M](https://huggingface.co/minishlab/potion-base-8M)). - The timings include the encoding time, index building time, and deduplication time. ### Train Deduplication Benchmark diff --git a/benchmarks/run_benchmarks.py b/benchmarks/run_benchmarks.py index 128b01b..bb36a52 100644 --- a/benchmarks/run_benchmarks.py +++ b/benchmarks/run_benchmarks.py @@ -45,7 +45,7 @@ def main() -> None: # noqa: C901 # Build the SemHash instance build_start = perf_counter() - semhash = SemHash.from_records(model=model, use_ann=True, records=train_records, columns=columns) + semhash = SemHash.from_records(model=model, records=train_records, columns=columns) build_end = perf_counter() build_time = build_end - build_start # Time how long it takes to deduplicate the train set diff --git a/semhash/semhash.py b/semhash/semhash.py index c655d7d..1061ae8 100644 --- a/semhash/semhash.py +++ b/semhash/semhash.py @@ -45,7 +45,6 @@ def from_records( cls, records: Sequence[Record], columns: Sequence[str] | None = None, - use_ann: bool = True, model: Encoder | None = None, ann_backend: Backend | str = Backend.USEARCH, **kwargs: Any, @@ -57,9 +56,8 @@ def from_records( :param records: A list of records (strings or dictionaries). :param columns: Columns to featurize if records are dictionaries. - :param use_ann: Whether to use approximate nearest neighbors (True) or basic search (False). Default is True. :param model: (Optional) An Encoder model. If None, the default model is used (minishlab/potion-base-8M). - :param ann_backend: (Optional) The ANN backend to use if use_ann is True. Defaults to Backend.USEARCH. + :param ann_backend: (Optional) The ANN backend to use. Defaults to Backend.USEARCH. :param **kwargs: Any additional keyword arguments to pass to the Vicinity index. :return: A SemHash instance with a fitted vicinity index. """ @@ -90,11 +88,10 @@ def from_records( embeddings = featurize(deduplicated_records, columns, model) # Build the Vicinity index - backend = ann_backend if use_ann else Backend.BASIC index = Index.from_vectors_and_items( vectors=embeddings, items=items, - backend_type=backend, + backend_type=ann_backend, **kwargs, ) @@ -107,7 +104,6 @@ def from_embeddings( records: Sequence[Record], model: Encoder, columns: Sequence[str] | None = None, - use_ann: bool = True, ann_backend: Backend | str = Backend.USEARCH, **kwargs: Any, ) -> SemHash: @@ -121,8 +117,7 @@ def from_embeddings( :param model: The Encoder model used for creating the embeddings. :param columns: Columns to use if records are dictionaries. If None and records are strings, defaults to ["text"]. - :param use_ann: Whether to use approximate nearest neighbors (True) or basic search (False). Default is True. - :param ann_backend: (Optional) The ANN backend to use if use_ann is True. Defaults to Backend.USEARCH. + :param ann_backend: (Optional) The ANN backend to use. Defaults to Backend.USEARCH. :param **kwargs: Any additional keyword arguments to pass to the Vicinity index. :return: A SemHash instance with a fitted vicinity index. :raises ValueError: If the number of embeddings doesn't match the number of records. @@ -157,9 +152,8 @@ def from_embeddings( deduplicated_embeddings = embeddings[embedding_indices] # Create the index - backend_type = ann_backend if use_ann else Backend.BASIC index = Index.from_vectors_and_items( - vectors=deduplicated_embeddings, items=items, backend_type=backend_type, **kwargs + vectors=deduplicated_embeddings, items=items, backend_type=ann_backend, **kwargs ) return cls(index=index, model=model, columns=columns, was_string=was_string) diff --git a/tests/conftest.py b/tests/conftest.py index ff55fb4..69eebe8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,12 +8,6 @@ def model() -> StaticModel: return StaticModel.from_pretrained("tests/data/test_model") -@pytest.fixture(params=[True, False], ids=["use_ann=True", "use_ann=False"]) -def use_ann(request: pytest.FixtureRequest) -> bool: - """Whether to use approximate nearest neighbors or not.""" - return request.param - - @pytest.fixture def train_texts() -> list[str]: """A list of train texts for testing outlier and representative filtering.""" diff --git a/tests/test_semhash.py b/tests/test_semhash.py index 5923e3c..9c68c17 100644 --- a/tests/test_semhash.py +++ b/tests/test_semhash.py @@ -6,7 +6,7 @@ from semhash.utils import Encoder -def test_single_dataset_deduplication(use_ann: bool, model: Encoder) -> None: +def test_single_dataset_deduplication(model: Encoder) -> None: """Test single dataset deduplication.""" # No duplicates texts = [ @@ -14,7 +14,7 @@ def test_single_dataset_deduplication(use_ann: bool, model: Encoder) -> None: "The master sword can seal the darkness.", "Ganondorf has invaded Hyrule!", ] - semhash = SemHash.from_records(records=texts, use_ann=use_ann, model=model) + semhash = SemHash.from_records(records=texts, model=model) deduplicated_texts = semhash.self_deduplicate().selected assert deduplicated_texts == texts @@ -25,12 +25,12 @@ def test_single_dataset_deduplication(use_ann: bool, model: Encoder) -> None: "It's dangerous to go alone!", # Exact duplicate "It's not safe to go alone!", # Semantically similar ] - semhash = SemHash.from_records(records=texts, use_ann=use_ann, model=model) + semhash = SemHash.from_records(records=texts, model=model) deduplicated_texts = semhash.self_deduplicate(0.7).selected assert deduplicated_texts == ["It's dangerous to go alone!"] -def test_multi_dataset_deduplication(use_ann: bool, model: Encoder) -> None: +def test_multi_dataset_deduplication(model: Encoder) -> None: """Test deduplication across two datasets.""" # No duplicates texts1 = [ @@ -43,7 +43,7 @@ def test_multi_dataset_deduplication(use_ann: bool, model: Encoder) -> None: "Zelda is the princess of Hyrule.", "Ganon is the king of thieves.", ] - semhash = SemHash.from_records(texts1, columns=None, use_ann=use_ann, model=model) + semhash = SemHash.from_records(texts1, columns=None, model=model) deduplicated_texts = semhash.deduplicate(texts2).selected assert deduplicated_texts == texts2 @@ -57,7 +57,7 @@ def test_multi_dataset_deduplication(use_ann: bool, model: Encoder) -> None: assert deduplicated_texts == [] -def test_single_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) -> None: +def test_single_dataset_deduplication_multicolumn(model: Encoder) -> None: """Test single dataset deduplication with multi-column records.""" records = [ {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, @@ -72,7 +72,6 @@ def test_single_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) semhash = SemHash.from_records( records, columns=["question", "context", "answer"], - use_ann=use_ann, model=model, ) deduplicated = semhash.self_deduplicate(threshold=0.7) @@ -83,7 +82,7 @@ def test_single_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) ] -def test_multi_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) -> None: +def test_multi_dataset_deduplication_multicolumn(model: Encoder) -> None: """Test multi dataset deduplication with multi-column records.""" train_records = [ {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, @@ -101,7 +100,6 @@ def test_multi_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) semhash = SemHash.from_records( train_records, columns=["question", "context", "answer"], - use_ann=use_ann, model=model, ) deduplicated = semhash.deduplicate(test_records).selected @@ -110,17 +108,17 @@ def test_multi_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) ] -def test_from_records_without_columns(use_ann: bool, model: Encoder) -> None: +def test_from_records_without_columns(model: Encoder) -> None: """Test fitting without specifying columns.""" records = [ {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, {"question": "Who is the princess?", "context": "The princess is Zelda", "answer": "Zelda"}, ] with pytest.raises(ValueError): - SemHash.from_records(records, columns=None, use_ann=use_ann, model=model) + SemHash.from_records(records, columns=None, model=model) -def test_deduplicate_with_only_exact_duplicates(use_ann: bool, model: Encoder) -> None: +def test_deduplicate_with_only_exact_duplicates(model: Encoder) -> None: """Test deduplicating with only exact duplicates.""" texts1 = [ "It's dangerous to go alone!", @@ -132,7 +130,7 @@ def test_deduplicate_with_only_exact_duplicates(use_ann: bool, model: Encoder) - "It's dangerous to go alone!", "It's dangerous to go alone!", ] - semhash = SemHash.from_records(texts1, use_ann=use_ann, model=model) + semhash = SemHash.from_records(texts1, model=model) deduplicated = semhash.self_deduplicate() assert deduplicated.selected == ["It's dangerous to go alone!"] @@ -140,9 +138,9 @@ def test_deduplicate_with_only_exact_duplicates(use_ann: bool, model: Encoder) - assert deduplicated.selected == [] -def test_self_find_representative(use_ann: bool, model: Encoder, train_texts: list[str]) -> None: +def test_self_find_representative(model: Encoder, train_texts: list[str]) -> None: """Test the self_find_representative method.""" - semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model) + semhash = SemHash.from_records(records=train_texts, model=model) result = semhash.self_find_representative( candidate_limit=5, selection_size=3, @@ -157,18 +155,18 @@ def test_self_find_representative(use_ann: bool, model: Encoder, train_texts: li }, "Expected representatives to be blueberry, pineapple, and grape" -def test_find_representative(use_ann: bool, model: Encoder, train_texts: list[str], test_texts: list[str]) -> None: +def test_find_representative(model: Encoder, train_texts: list[str], test_texts: list[str]) -> None: """Test the find_representative method.""" - semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model) + semhash = SemHash.from_records(records=train_texts, model=model) result = semhash.find_representative(records=test_texts, candidate_limit=5, selection_size=3, diversity=0.5) assert len(result.selected) == 3, "Expected 3 representatives" selected = {r["text"] for r in result.selected} assert selected == {"grapefruit", "banana", "apple"}, "Expected representatives to be grapefruit, banana, and apple" -def test_filter_outliers(use_ann: bool, model: Encoder, train_texts: list[str], test_texts: list[str]) -> None: +def test_filter_outliers(model: Encoder, train_texts: list[str], test_texts: list[str]) -> None: """Test the filter_outliers method.""" - semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model) + semhash = SemHash.from_records(records=train_texts, model=model) result = semhash.filter_outliers(records=test_texts, outlier_percentage=0.2) assert len(result.filtered) == 2, "Expected 2 outliers" assert len(result.selected) == len(test_texts) - 2 @@ -176,9 +174,9 @@ def test_filter_outliers(use_ann: bool, model: Encoder, train_texts: list[str], assert filtered == {"motorcycle", "plane"}, "Expected outliers to be motorcycle and plane" -def test_self_filter_outliers(use_ann: bool, model: Encoder, train_texts: list[str]) -> None: +def test_self_filter_outliers(model: Encoder, train_texts: list[str]) -> None: """Test the self_filter_outliers method.""" - semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model) + semhash = SemHash.from_records(records=train_texts, model=model) result = semhash.self_filter_outliers(outlier_percentage=0.1) assert len(result.filtered) == 2, "Expected 2 outliers" assert len(result.selected) == len(train_texts) - 2 @@ -216,7 +214,7 @@ def test__diversify(monkeypatch: pytest.MonkeyPatch) -> None: assert result_empty.scores_filtered == [] -def test_from_embeddings(use_ann: bool, model: Encoder, train_texts: list[str]) -> None: +def test_from_embeddings(model: Encoder, train_texts: list[str]) -> None: """Test from_embeddings constructor with validation and comparison to from_records.""" # Test validation: mismatched shapes with pytest.raises(ValueError, match="Number of embeddings"): @@ -224,12 +222,10 @@ def test_from_embeddings(use_ann: bool, model: Encoder, train_texts: list[str]) SemHash.from_embeddings(embeddings=wrong_embeddings, records=train_texts, model=model) # Test that from_embeddings behaves same as from_records - semhash_from_records = SemHash.from_records(records=train_texts, model=model, use_ann=use_ann) + semhash_from_records = SemHash.from_records(records=train_texts, model=model) embeddings = model.encode(train_texts) - semhash_from_embeddings = SemHash.from_embeddings( - embeddings=embeddings, records=train_texts, model=model, use_ann=use_ann - ) + semhash_from_embeddings = SemHash.from_embeddings(embeddings=embeddings, records=train_texts, model=model) # Both should give same deduplication results result1 = semhash_from_records.self_deduplicate(threshold=0.95)