Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 36 additions & 43 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,45 @@ filtered_texts = semhash.self_filter_outliers().selected
representative_texts = semhash.self_find_representative().selected
```

The `deduplicate` and `self_deduplicate` functions return a [DeduplicationResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#L30). This object stores the deduplicated corpus, a set of duplicate object (along with the objects that caused duplication), and several useful functions to further inspect the deduplication result. Examples of how these functions can be used can be found in the [usage](#usage) section.
The `deduplicate` and `self_deduplicate` functions return a [DeduplicationResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#L58). This object stores the deduplicated corpus, a set of duplicate object (along with the objects that caused duplication), and several useful functions to further inspect the deduplication result.

The `filter_outliers`, `self_filter_outliers`, `find_representative`, and `self_find_representative` functions return a [FilterResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#106). This object stores the found outliers/representative samples.
The `filter_outliers`, `self_filter_outliers`, `find_representative`, and `self_find_representative` functions return a [FilterResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#179). This object stores the found outliers/representative samples.

For both the `DeduplicationResult` and `FilterResult` objects, you can easily view the filtered records with the `selected` attribute (e.g. to view outliers: `outliers = semhash.self_filter_outliers().filtered`)

### Inspecting Deduplication Results

The `DeduplicationResult` object provides powerful tools for understanding and refining your deduplication:

```python
from datasets import load_dataset
from semhash import SemHash

# Load and deduplicate a dataset
texts = load_dataset("ag_news", split="train")["text"]
semhash = SemHash.from_records(records=texts)
result = semhash.self_deduplicate()

# Access deduplicated and duplicate records
deduplicated_texts = result.selected
duplicate_texts = result.filtered

# View deduplication statistics
print(f"Duplicate ratio: {result.duplicate_ratio}")
print(f"Exact duplicate ratio: {result.exact_duplicate_ratio}")

# Find edge cases to tune your threshold
least_similar = result.get_least_similar_from_duplicates(n=5)

# Adjust threshold without re-deduplicating
result.rethreshold(0.95)

# View each kept record with its duplicate cluster
for item in result.selected_with_duplicates:
print(f"Kept: {item.record}")
print(f"Duplicates: {item.duplicates}") # List of (duplicate_text, similarity_score)
```

## Main Features

- **Fast**: SemHash uses [model2vec](https://github.com/MinishLab/model2vec) to embed texts and [vicinity](https://github.com/MinishLab/vicinity) to perform similarity search, making it extremely fast.
Expand Down Expand Up @@ -231,47 +265,6 @@ representative_records = semhash.self_find_representative().selected

</details>

<details>
<summary> DeduplicationResult functionality </summary>
<br>

The `DeduplicationResult` object returned by the `deduplicate` and `self_deduplicate` functions contains several useful functions to inspect the deduplication result. The following code snippet shows how to use these functions:

```python
from datasets import load_dataset
from semhash import SemHash

# Load a dataset to deduplicate
texts = load_dataset("ag_news", split="train")["text"]

# Initialize a SemHash instance
semhash = SemHash.from_records(records=texts)

# Deduplicate the texts
deduplication_result = semhash.self_deduplicate()

# Check the deduplicated texts
deduplication_result.selected
# Check the duplicates
deduplication_result.filtered
# See what percentage of the texts were duplicates
deduplication_result.duplicate_ratio
# See what percentage of the texts were exact duplicates
deduplication_result.exact_duplicate_ratio

# Get the least similar text from the duplicates. This is useful for finding the right threshold for deduplication.
least_similar = deduplication_result.get_least_similar_from_duplicates()

# Rethreshold the duplicates. This allows you to instantly rethreshold the duplicates with a new threshold without having to re-deduplicate the texts.
deduplication_result.rethreshold(0.95)

# View selected records along with their duplicates.
# This is the opposite of the `filtered` attribute, which shows for every duplicate the record that caused it.
deduplication_result.selected_with_duplicates
```

</details>

<details>
<summary> Using custom encoders </summary>
<br>
Expand Down
8 changes: 6 additions & 2 deletions semhash/datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict
from collections.abc import Hashable, Sequence
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any, Generic, TypeAlias, TypeVar

from frozendict import frozendict
Expand Down Expand Up @@ -123,14 +124,17 @@ def rethreshold(self, threshold: float) -> None:
"""Rethreshold the duplicates."""
if self.threshold > threshold:
raise ValueError("Threshold is smaller than the given value.")
for dup in self.filtered:
# Invalidate cached property before modifying data
self.__dict__.pop("selected_with_duplicates", None)
# Rethreshold duplicates and move records without duplicates to selected
for dup in list(self.filtered):
dup._rethreshold(threshold)
if not dup.duplicates:
self.filtered.remove(dup)
self.selected.append(dup.record)
self.threshold = threshold

@property
@cached_property
def selected_with_duplicates(self) -> list[SelectedWithDuplicates[Record]]:
"""
For every kept record, return the duplicates that were removed along with their similarity scores.
Expand Down
45 changes: 45 additions & 0 deletions tests/test_datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,48 @@ def test_selected_with_duplicates_removes_internal_duplicates() -> None:
# The duplicate row must appear only once
assert len(duplicate_list) == 1
assert duplicate_list[0][0] == filtered


def test_selected_with_duplicates_caching() -> None:
"""Test that selected_with_duplicates is properly cached."""
d = DeduplicationResult(
selected=["original"],
filtered=[
DuplicateRecord("duplicate_1", False, [("original", 0.9)]),
DuplicateRecord("duplicate_2", False, [("original", 0.8)]),
],
threshold=0.8,
)

# First access should compute and cache
result1 = d.selected_with_duplicates
# Second access should return the cached result (same object)
result2 = d.selected_with_duplicates
assert result1 is result2


def test_selected_with_duplicates_cache_invalidation_on_rethreshold() -> None:
"""Test that rethreshold invalidates the selected_with_duplicates cache."""
d = DeduplicationResult(
selected=["original"],
filtered=[
DuplicateRecord("duplicate_1", False, [("original", 0.9)]),
DuplicateRecord("duplicate_2", False, [("original", 0.8)]),
DuplicateRecord("duplicate_3", False, [("original", 0.7)]),
],
threshold=0.7,
)

# Access before rethreshold
result1 = d.selected_with_duplicates
assert len(result1[0].duplicates) == 3

# Rethreshold should invalidate cache
d.rethreshold(0.85)

# Access after rethreshold should give new result
result2 = d.selected_with_duplicates
assert len(result2[0].duplicates) == 1
assert result2[0].duplicates[0][0] == "duplicate_1"
# Results should be different objects
assert result1 is not result2