Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ from semhash import SemHash
texts = load_dataset("ag_news", split="train")["text"]

# Load an embedding model (in this example, a multilingual model)
model = StaticModel.from_pretrained("minishlab/M2V_multilingual_output")
model = StaticModel.from_pretrained("minishlab/potion-multilingual-128M")

# Initialize a SemHash with the model and custom encoder
semhash = SemHash.from_records(records=texts, model=model)
Expand Down Expand Up @@ -318,21 +318,23 @@ deduplicated_texts = semhash.self_deduplicate()
<summary> Using custom ANN backends </summary>
<br>

The following code snippet shows how to use a custom ANN backend and custom args with SemHash:
By default, we use [USearch](https://github.com/unum-cloud/USearch) as the ANN (approximate-nearest neighbors) backend for deduplication. We recommend keeping this since the recall for smaller datasets is ~100%, and it's needed for larger datasets (>1M samples) since these will take too long to deduplicate without ANN. If you want to use a flat/exact-matching backend, you can set `ann_backend=Backend.BASIC` in the SemHash constructor:

```python
from datasets import load_dataset
from semhash import SemHash
from vicinity import Backend

# Load a dataset to deduplicate
texts = load_dataset("ag_news", split="train")["text"]
semhash = SemHash.from_records(records=texts, ann_backend=Backend.BASIC)
```

# Initialize a SemHash with the model and custom ann backend and custom args
semhash = SemHash.from_records(records=texts, ann_backend=Backend.FAISS, nlist=50)
Any backend from [Vicinity](https://github.com/MinishLab/vicinity) can be used with SemHash. The following code snippet shows how to use [FAISS](https://github.com/facebookresearch/faiss) with a custom `nlist` parameter:

# Deduplicate the texts
deduplicated_texts = semhash.self_deduplicate()
```python
from datasets import load_dataset
from semhash import SemHash
from vicinity import Backend

semhash = SemHash.from_records(records=texts, ann_backend=Backend.FAISS, nlist=50)
```

For the full list of supported ANN backends and args, see the [Vicinity docs](https://github.com/MinishLab/vicinity/tree/main?tab=readme-ov-file#supported-backends).
Expand Down Expand Up @@ -398,19 +400,14 @@ representative_texts = semhash.self_find_representative().selected
```
</details>

NOTE: By default, we use the ANN (approximate-nearest neighbors) backend for deduplication. We recommend keeping this since the recall for smaller datasets is ~100%, and it's needed for larger datasets (>1M samples) since these will take too long to deduplicate without ANN. If you want to use the flat/exact-matching backend, you can set `use_ann=False` in the SemHash constructor:

```python
semhash = SemHash.from_records(records=texts, use_ann=False)
```



## Benchmarks

We've benchmarked SemHash on a variety of datasets to measure the deduplication performance and speed. The benchmarks were run with the following setup:
- The benchmarks were all run on CPU
- The benchmarks were all run with `use_ann=True`
- The benchmarks were all run with the default ANN backend (usearch)
- The used encoder is the default encoder ([potion-base-8M](https://huggingface.co/minishlab/potion-base-8M)).
- The timings include the encoding time, index building time, and deduplication time.
### Train Deduplication Benchmark
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main() -> None: # noqa: C901

# Build the SemHash instance
build_start = perf_counter()
semhash = SemHash.from_records(model=model, use_ann=True, records=train_records, columns=columns)
semhash = SemHash.from_records(model=model, records=train_records, columns=columns)
build_end = perf_counter()
build_time = build_end - build_start
# Time how long it takes to deduplicate the train set
Expand Down
14 changes: 4 additions & 10 deletions semhash/semhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def from_records(
cls,
records: Sequence[Record],
columns: Sequence[str] | None = None,
use_ann: bool = True,
model: Encoder | None = None,
ann_backend: Backend | str = Backend.USEARCH,
**kwargs: Any,
Expand All @@ -57,9 +56,8 @@ def from_records(

:param records: A list of records (strings or dictionaries).
:param columns: Columns to featurize if records are dictionaries.
:param use_ann: Whether to use approximate nearest neighbors (True) or basic search (False). Default is True.
:param model: (Optional) An Encoder model. If None, the default model is used (minishlab/potion-base-8M).
:param ann_backend: (Optional) The ANN backend to use if use_ann is True. Defaults to Backend.USEARCH.
:param ann_backend: (Optional) The ANN backend to use. Defaults to Backend.USEARCH.
:param **kwargs: Any additional keyword arguments to pass to the Vicinity index.
:return: A SemHash instance with a fitted vicinity index.
"""
Expand Down Expand Up @@ -90,11 +88,10 @@ def from_records(
embeddings = featurize(deduplicated_records, columns, model)

# Build the Vicinity index
backend = ann_backend if use_ann else Backend.BASIC
index = Index.from_vectors_and_items(
vectors=embeddings,
items=items,
backend_type=backend,
backend_type=ann_backend,
**kwargs,
)

Expand All @@ -107,7 +104,6 @@ def from_embeddings(
records: Sequence[Record],
model: Encoder,
columns: Sequence[str] | None = None,
use_ann: bool = True,
ann_backend: Backend | str = Backend.USEARCH,
**kwargs: Any,
) -> SemHash:
Expand All @@ -121,8 +117,7 @@ def from_embeddings(
:param model: The Encoder model used for creating the embeddings.
:param columns: Columns to use if records are dictionaries. If None and records are strings,
defaults to ["text"].
:param use_ann: Whether to use approximate nearest neighbors (True) or basic search (False). Default is True.
:param ann_backend: (Optional) The ANN backend to use if use_ann is True. Defaults to Backend.USEARCH.
:param ann_backend: (Optional) The ANN backend to use. Defaults to Backend.USEARCH.
:param **kwargs: Any additional keyword arguments to pass to the Vicinity index.
:return: A SemHash instance with a fitted vicinity index.
:raises ValueError: If the number of embeddings doesn't match the number of records.
Expand Down Expand Up @@ -157,9 +152,8 @@ def from_embeddings(
deduplicated_embeddings = embeddings[embedding_indices]

# Create the index
backend_type = ann_backend if use_ann else Backend.BASIC
index = Index.from_vectors_and_items(
vectors=deduplicated_embeddings, items=items, backend_type=backend_type, **kwargs
vectors=deduplicated_embeddings, items=items, backend_type=ann_backend, **kwargs
)

return cls(index=index, model=model, columns=columns, was_string=was_string)
Expand Down
6 changes: 0 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@ def model() -> StaticModel:
return StaticModel.from_pretrained("tests/data/test_model")


@pytest.fixture(params=[True, False], ids=["use_ann=True", "use_ann=False"])
def use_ann(request: pytest.FixtureRequest) -> bool:
"""Whether to use approximate nearest neighbors or not."""
return request.param


@pytest.fixture
def train_texts() -> list[str]:
"""A list of train texts for testing outlier and representative filtering."""
Expand Down
48 changes: 22 additions & 26 deletions tests/test_semhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from semhash.utils import Encoder


def test_single_dataset_deduplication(use_ann: bool, model: Encoder) -> None:
def test_single_dataset_deduplication(model: Encoder) -> None:
"""Test single dataset deduplication."""
# No duplicates
texts = [
"It's dangerous to go alone!",
"The master sword can seal the darkness.",
"Ganondorf has invaded Hyrule!",
]
semhash = SemHash.from_records(records=texts, use_ann=use_ann, model=model)
semhash = SemHash.from_records(records=texts, model=model)
deduplicated_texts = semhash.self_deduplicate().selected

assert deduplicated_texts == texts
Expand All @@ -25,12 +25,12 @@ def test_single_dataset_deduplication(use_ann: bool, model: Encoder) -> None:
"It's dangerous to go alone!", # Exact duplicate
"It's not safe to go alone!", # Semantically similar
]
semhash = SemHash.from_records(records=texts, use_ann=use_ann, model=model)
semhash = SemHash.from_records(records=texts, model=model)
deduplicated_texts = semhash.self_deduplicate(0.7).selected
assert deduplicated_texts == ["It's dangerous to go alone!"]


def test_multi_dataset_deduplication(use_ann: bool, model: Encoder) -> None:
def test_multi_dataset_deduplication(model: Encoder) -> None:
"""Test deduplication across two datasets."""
# No duplicates
texts1 = [
Expand All @@ -43,7 +43,7 @@ def test_multi_dataset_deduplication(use_ann: bool, model: Encoder) -> None:
"Zelda is the princess of Hyrule.",
"Ganon is the king of thieves.",
]
semhash = SemHash.from_records(texts1, columns=None, use_ann=use_ann, model=model)
semhash = SemHash.from_records(texts1, columns=None, model=model)
deduplicated_texts = semhash.deduplicate(texts2).selected
assert deduplicated_texts == texts2

Expand All @@ -57,7 +57,7 @@ def test_multi_dataset_deduplication(use_ann: bool, model: Encoder) -> None:
assert deduplicated_texts == []


def test_single_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) -> None:
def test_single_dataset_deduplication_multicolumn(model: Encoder) -> None:
"""Test single dataset deduplication with multi-column records."""
records = [
{"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"},
Expand All @@ -72,7 +72,6 @@ def test_single_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder)
semhash = SemHash.from_records(
records,
columns=["question", "context", "answer"],
use_ann=use_ann,
model=model,
)
deduplicated = semhash.self_deduplicate(threshold=0.7)
Expand All @@ -83,7 +82,7 @@ def test_single_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder)
]


def test_multi_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) -> None:
def test_multi_dataset_deduplication_multicolumn(model: Encoder) -> None:
"""Test multi dataset deduplication with multi-column records."""
train_records = [
{"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"},
Expand All @@ -101,7 +100,6 @@ def test_multi_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder)
semhash = SemHash.from_records(
train_records,
columns=["question", "context", "answer"],
use_ann=use_ann,
model=model,
)
deduplicated = semhash.deduplicate(test_records).selected
Expand All @@ -110,17 +108,17 @@ def test_multi_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder)
]


def test_from_records_without_columns(use_ann: bool, model: Encoder) -> None:
def test_from_records_without_columns(model: Encoder) -> None:
"""Test fitting without specifying columns."""
records = [
{"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"},
{"question": "Who is the princess?", "context": "The princess is Zelda", "answer": "Zelda"},
]
with pytest.raises(ValueError):
SemHash.from_records(records, columns=None, use_ann=use_ann, model=model)
SemHash.from_records(records, columns=None, model=model)


def test_deduplicate_with_only_exact_duplicates(use_ann: bool, model: Encoder) -> None:
def test_deduplicate_with_only_exact_duplicates(model: Encoder) -> None:
"""Test deduplicating with only exact duplicates."""
texts1 = [
"It's dangerous to go alone!",
Expand All @@ -132,17 +130,17 @@ def test_deduplicate_with_only_exact_duplicates(use_ann: bool, model: Encoder) -
"It's dangerous to go alone!",
"It's dangerous to go alone!",
]
semhash = SemHash.from_records(texts1, use_ann=use_ann, model=model)
semhash = SemHash.from_records(texts1, model=model)
deduplicated = semhash.self_deduplicate()
assert deduplicated.selected == ["It's dangerous to go alone!"]

deduplicated = semhash.deduplicate(texts2)
assert deduplicated.selected == []


def test_self_find_representative(use_ann: bool, model: Encoder, train_texts: list[str]) -> None:
def test_self_find_representative(model: Encoder, train_texts: list[str]) -> None:
"""Test the self_find_representative method."""
semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model)
semhash = SemHash.from_records(records=train_texts, model=model)
result = semhash.self_find_representative(
candidate_limit=5,
selection_size=3,
Expand All @@ -157,28 +155,28 @@ def test_self_find_representative(use_ann: bool, model: Encoder, train_texts: li
}, "Expected representatives to be blueberry, pineapple, and grape"


def test_find_representative(use_ann: bool, model: Encoder, train_texts: list[str], test_texts: list[str]) -> None:
def test_find_representative(model: Encoder, train_texts: list[str], test_texts: list[str]) -> None:
"""Test the find_representative method."""
semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model)
semhash = SemHash.from_records(records=train_texts, model=model)
result = semhash.find_representative(records=test_texts, candidate_limit=5, selection_size=3, diversity=0.5)
assert len(result.selected) == 3, "Expected 3 representatives"
selected = {r["text"] for r in result.selected}
assert selected == {"grapefruit", "banana", "apple"}, "Expected representatives to be grapefruit, banana, and apple"


def test_filter_outliers(use_ann: bool, model: Encoder, train_texts: list[str], test_texts: list[str]) -> None:
def test_filter_outliers(model: Encoder, train_texts: list[str], test_texts: list[str]) -> None:
"""Test the filter_outliers method."""
semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model)
semhash = SemHash.from_records(records=train_texts, model=model)
result = semhash.filter_outliers(records=test_texts, outlier_percentage=0.2)
assert len(result.filtered) == 2, "Expected 2 outliers"
assert len(result.selected) == len(test_texts) - 2
filtered = {r["text"] for r in result.filtered}
assert filtered == {"motorcycle", "plane"}, "Expected outliers to be motorcycle and plane"


def test_self_filter_outliers(use_ann: bool, model: Encoder, train_texts: list[str]) -> None:
def test_self_filter_outliers(model: Encoder, train_texts: list[str]) -> None:
"""Test the self_filter_outliers method."""
semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model)
semhash = SemHash.from_records(records=train_texts, model=model)
result = semhash.self_filter_outliers(outlier_percentage=0.1)
assert len(result.filtered) == 2, "Expected 2 outliers"
assert len(result.selected) == len(train_texts) - 2
Expand Down Expand Up @@ -216,20 +214,18 @@ def test__diversify(monkeypatch: pytest.MonkeyPatch) -> None:
assert result_empty.scores_filtered == []


def test_from_embeddings(use_ann: bool, model: Encoder, train_texts: list[str]) -> None:
def test_from_embeddings(model: Encoder, train_texts: list[str]) -> None:
"""Test from_embeddings constructor with validation and comparison to from_records."""
# Test validation: mismatched shapes
with pytest.raises(ValueError, match="Number of embeddings"):
wrong_embeddings = model.encode(["apple", "banana"])
SemHash.from_embeddings(embeddings=wrong_embeddings, records=train_texts, model=model)

# Test that from_embeddings behaves same as from_records
semhash_from_records = SemHash.from_records(records=train_texts, model=model, use_ann=use_ann)
semhash_from_records = SemHash.from_records(records=train_texts, model=model)

embeddings = model.encode(train_texts)
semhash_from_embeddings = SemHash.from_embeddings(
embeddings=embeddings, records=train_texts, model=model, use_ann=use_ann
)
semhash_from_embeddings = SemHash.from_embeddings(embeddings=embeddings, records=train_texts, model=model)

# Both should give same deduplication results
result1 = semhash_from_records.self_deduplicate(threshold=0.95)
Expand Down