Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ least_similar = deduplication_result.get_least_similar_from_duplicates()

# Rethreshold the duplicates. This allows you to instantly rethreshold the duplicates with a new threshold without having to re-deduplicate the texts.
deduplication_result.rethreshold(0.95)

# View selected records along with their duplicates.
# This is the opposite of the `filtered` attribute, which shows for every duplicate the record that caused it.
deduplication_result.selected_with_duplicates
```

</details>
Expand Down
65 changes: 62 additions & 3 deletions semhash/datamodels.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from __future__ import annotations

import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Generic, TypeVar
from typing import Any, Generic, Hashable, Sequence, TypeVar

from frozendict import frozendict
from typing_extensions import TypeAlias

from semhash.utils import to_frozendict

Record = TypeVar("Record", str, dict[str, str])
Record = TypeVar("Record", str, dict[str, Any])
DuplicateList: TypeAlias = list[tuple[Record, float]]


@dataclass
Expand All @@ -20,13 +29,29 @@ class DuplicateRecord(Generic[Record]):

record: Record
exact: bool
duplicates: list[tuple[Record, float]] = field(default_factory=list)
duplicates: DuplicateList = field(default_factory=list)

def _rethreshold(self, threshold: float) -> None:
"""Rethreshold the duplicates."""
self.duplicates = [(d, score) for d, score in self.duplicates if score >= threshold]


@dataclass
class SelectedWithDuplicates(Generic[Record]):
"""
A record that has been selected along with its duplicates.

Attributes
----------
record: The original record being selected.
duplicates: List of tuples consisting of duplicate records and their associated scores.

"""

record: Record
duplicates: DuplicateList = field(default_factory=list)


@dataclass
class DeduplicationResult(Generic[Record]):
"""
Expand All @@ -37,6 +62,7 @@ class DeduplicationResult(Generic[Record]):
selected: List of deduplicated records after removing duplicates.
filtered: List of DuplicateRecord objects containing details about duplicates of an original record.
threshold: The similarity threshold used for deduplication.
columns: Columns used for deduplication.
deduplicated: Deprecated, use selected instead.
duplicates: Deprecated, use filtered instead.

Expand All @@ -45,6 +71,7 @@ class DeduplicationResult(Generic[Record]):
selected: list[Record] = field(default_factory=list)
filtered: list[DuplicateRecord] = field(default_factory=list)
threshold: float = field(default=0.9)
columns: Sequence[str] | None = field(default=None)
deduplicated: list[Record] = field(default_factory=list) # Deprecated
duplicates: list[DuplicateRecord] = field(default_factory=list) # Deprecated

Expand Down Expand Up @@ -102,6 +129,38 @@ def rethreshold(self, threshold: float) -> None:
self.selected.append(dup.record)
self.threshold = threshold

@property
def selected_with_duplicates(self) -> list[SelectedWithDuplicates[Record]]:
"""
For every kept record, return the duplicates that were removed along with their similarity scores.

:return: A list of tuples where each tuple contains a kept record
and a list of its duplicates with their similarity scores.
"""

def _to_hashable(record: Record) -> frozendict[str, str] | str:
"""Convert a record to a hashable representation."""
if isinstance(record, dict) and self.columns is not None:
# Convert dict to frozendict for immutability and hashability
return to_frozendict(record, set(self.columns))
return str(record)

# Build a mapping from original-record to [(duplicate, score), …]
buckets: defaultdict[Hashable, DuplicateList] = defaultdict(list)
for duplicate_record in self.filtered:
for original_record, score in duplicate_record.duplicates:
buckets[_to_hashable(original_record)].append((duplicate_record.record, float(score)))

result: list[SelectedWithDuplicates[Record]] = []
for selected in self.selected:
# Get the list of duplicates for the selected record
raw_list = buckets.get(_to_hashable(selected), [])
# Ensure we don't have duplicates in the list
deduped = {_to_hashable(rec): (rec, score) for rec, score in raw_list}
result.append(SelectedWithDuplicates(record=selected, duplicates=list(deduped.values())))

return result
Comment thread
Pringled marked this conversation as resolved.


@dataclass
class FilterResult(Generic[Record]):
Expand Down
7 changes: 0 additions & 7 deletions semhash/records.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
from typing import Sequence

from frozendict import frozendict

from semhash.datamodels import DeduplicationResult, DuplicateRecord


def to_frozendict(record: dict[str, str], columns: set[str]) -> frozendict[str, str]:
"""Convert a record to a frozendict."""
return frozendict({k: record.get(k, "") for k in columns})


def dict_to_string(record: dict[str, str], columns: Sequence[str]) -> str:
r"""
Turn a record into a single string.
Expand Down
12 changes: 7 additions & 5 deletions semhash/semhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult, Record
from semhash.index import Index
from semhash.records import add_scores_to_records, map_deduplication_result_to_strings, to_frozendict
from semhash.utils import Encoder, compute_candidate_limit
from semhash.records import add_scores_to_records, map_deduplication_result_to_strings
from semhash.utils import Encoder, compute_candidate_limit, to_frozendict


class SemHash(Generic[Record]):
Expand Down Expand Up @@ -190,7 +190,9 @@ def deduplicate(

# If no records are left after removing exact duplicates, return early
if not dict_records:
return DeduplicationResult(deduplicated=[], duplicates=duplicate_records, threshold=threshold)
return DeduplicationResult(
deduplicated=[], duplicates=duplicate_records, threshold=threshold, columns=self.columns
)

# Compute embeddings for the new records
embeddings = self._featurize(records=dict_records, columns=self.columns, model=self.model)
Expand All @@ -212,7 +214,7 @@ def deduplicate(
)

result = DeduplicationResult(
deduplicated=deduplicated_records, duplicates=duplicate_records, threshold=threshold
deduplicated=deduplicated_records, duplicates=duplicate_records, threshold=threshold, columns=self.columns
)

if self._was_string:
Expand Down Expand Up @@ -281,7 +283,7 @@ def self_deduplicate(
seen_items.update(frozen_items)

result = DeduplicationResult(
deduplicated=deduplicated_records, duplicates=duplicate_records, threshold=threshold
deduplicated=deduplicated_records, duplicates=duplicate_records, threshold=threshold, columns=self.columns
)

if self._was_string:
Expand Down
6 changes: 6 additions & 0 deletions semhash/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Protocol, Sequence, Union

import numpy as np
from frozendict import frozendict


class Encoder(Protocol):
Expand All @@ -21,6 +22,11 @@ def encode(
... # pragma: no cover


def to_frozendict(record: dict[str, str], columns: set[str]) -> frozendict[str, str]:
"""Convert a record to a frozendict."""
return frozendict({k: record.get(k, "") for k in columns})


def compute_candidate_limit(
total: int,
selection_size: int,
Expand Down
111 changes: 107 additions & 4 deletions tests/test_datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import semhash
import semhash.version
from semhash.datamodels import DeduplicationResult, DuplicateRecord
from semhash.datamodels import DeduplicationResult, DuplicateRecord, SelectedWithDuplicates


def test_deduplication_scoring() -> None:
Expand All @@ -27,13 +27,13 @@ def test_deduplication_scoring_exact() -> None:

def test_deduplication_scoring_exact_empty() -> None:
"""Test the deduplication scoring."""
d = DeduplicationResult([], [], 0.8)
d = DeduplicationResult([], [], 0.8, columns=["text"])
assert d.exact_duplicate_ratio == 0.0


def test_deduplication_scoring_empty() -> None:
"""Test the deduplication scoring."""
d = DeduplicationResult([], [], 0.8)
d = DeduplicationResult([], [], 0.8, columns=["text"])
assert d.duplicate_ratio == 0.0


Expand Down Expand Up @@ -64,7 +64,7 @@ def test_get_least_similar_from_duplicates() -> None:

def test_get_least_similar_from_duplicates_empty() -> None:
"""Test getting the least similar duplicates."""
d = DeduplicationResult([], [], 0.8)
d = DeduplicationResult([], [], 0.8, columns=["text"])
assert d.get_least_similar_from_duplicates(1) == []


Expand Down Expand Up @@ -116,3 +116,106 @@ def test_deprecation_deduplicated_duplicates() -> None:
DuplicateRecord("d", False, [("x", 0.9), ("y", 0.8)]),
DuplicateRecord("e", False, [("z", 0.8)]),
]


def test_selected_with_duplicates_strings() -> None:
"""Test selected_with_duplicates for strings."""
d = DeduplicationResult(
selected=["original"],
filtered=[
DuplicateRecord("duplicate_1", False, [("original", 0.9)]),
DuplicateRecord("duplicate_2", False, [("original", 0.8)]),
],
threshold=0.8,
)

expected = [
SelectedWithDuplicates(
record="original",
duplicates=[("duplicate_1", 0.9), ("duplicate_2", 0.8)],
)
]
assert d.selected_with_duplicates == expected


def test_selected_with_duplicates_dicts() -> None:
"""Test selected_with_duplicates for dicts."""
selected = {"id": 0, "text": "hello"}
d = DeduplicationResult(
selected=[selected],
filtered=[
DuplicateRecord({"id": 1, "text": "hello"}, True, [(selected, 1.0)]),
DuplicateRecord({"id": 2, "text": "helllo"}, False, [(selected, 0.1)]),
],
threshold=0.8,
columns=["text"],
)

items = d.selected_with_duplicates
assert len(items) == 1
kept = items[0].record
dups = items[0].duplicates
assert kept == selected
assert {r["id"] for r, _ in dups} == {1, 2}


def test_selected_with_duplicates_multi_column() -> None:
"""Test selected_with_duplicates for multi-columns."""
selected = {"text": "hello", "text2": "world"}
d = DeduplicationResult(
selected=[selected],
filtered=[
DuplicateRecord({"text": "hello", "text2": "world"}, True, [(selected, 1.0)]),
DuplicateRecord({"text": "helllo", "text2": "world"}, False, [(selected, 0.1)]),
],
threshold=0.8,
columns=["text", "text2"],
)

items = d.selected_with_duplicates
assert len(items) == 1
kept = items[0].record
assert kept == selected


def test_selected_with_duplicates_unhashable_values() -> None:
"""Test selected_with_duplicates with unhashable values in records."""
selected = {"text": "hello", "a": [1, 2, 3]} # list -> unhashable value
filtered = {"text": "hello", "a": [1, 2, 3], "flag": True}

d = DeduplicationResult(
selected=[selected],
filtered=[DuplicateRecord(filtered, exact=False, duplicates=[(selected, 1.0)])],
threshold=0.8,
columns=["text"],
Comment thread
Pringled marked this conversation as resolved.
)

items = d.selected_with_duplicates
assert items == [SelectedWithDuplicates(record=selected, duplicates=[(filtered, 1.0)])]


def test_selected_with_duplicates_removes_internal_duplicates() -> None:
"""Test that selected_with_duplicates removes internal duplicates that have the same hash."""
selected = {"id": 0, "text": "hello"}
filtered = {"id": 1, "text": "hello"}

d = DeduplicationResult(
selected=[selected],
filtered=[
DuplicateRecord(filtered, exact=False, duplicates=[(selected, 0.95)]),
DuplicateRecord(filtered, exact=False, duplicates=[(selected, 0.90)]),
],
threshold=0.8,
columns=["text"],
)

items = d.selected_with_duplicates
assert len(items) == 1

selected_record = items[0].record
duplicate_list = items[0].duplicates
# Should keep the kept record unchanged
assert selected_record == selected
# The duplicate row must appear only once
assert len(duplicate_list) == 1
assert duplicate_list[0][0] == filtered
Loading