diff --git a/README.md b/README.md index 0b553f2..20dea4f 100644 --- a/README.md +++ b/README.md @@ -264,6 +264,10 @@ least_similar = deduplication_result.get_least_similar_from_duplicates() # Rethreshold the duplicates. This allows you to instantly rethreshold the duplicates with a new threshold without having to re-deduplicate the texts. deduplication_result.rethreshold(0.95) + +# View selected records along with their duplicates. +# This is the opposite of the `filtered` attribute, which shows for every duplicate the record that caused it. +deduplication_result.selected_with_duplicates ``` diff --git a/semhash/datamodels.py b/semhash/datamodels.py index 4c5f32a..2539561 100644 --- a/semhash/datamodels.py +++ b/semhash/datamodels.py @@ -1,8 +1,17 @@ +from __future__ import annotations + import warnings +from collections import defaultdict from dataclasses import dataclass, field -from typing import Generic, TypeVar +from typing import Any, Generic, Hashable, Sequence, TypeVar + +from frozendict import frozendict +from typing_extensions import TypeAlias + +from semhash.utils import to_frozendict -Record = TypeVar("Record", str, dict[str, str]) +Record = TypeVar("Record", str, dict[str, Any]) +DuplicateList: TypeAlias = list[tuple[Record, float]] @dataclass @@ -20,13 +29,29 @@ class DuplicateRecord(Generic[Record]): record: Record exact: bool - duplicates: list[tuple[Record, float]] = field(default_factory=list) + duplicates: DuplicateList = field(default_factory=list) def _rethreshold(self, threshold: float) -> None: """Rethreshold the duplicates.""" self.duplicates = [(d, score) for d, score in self.duplicates if score >= threshold] +@dataclass +class SelectedWithDuplicates(Generic[Record]): + """ + A record that has been selected along with its duplicates. + + Attributes + ---------- + record: The original record being selected. + duplicates: List of tuples consisting of duplicate records and their associated scores. + + """ + + record: Record + duplicates: DuplicateList = field(default_factory=list) + + @dataclass class DeduplicationResult(Generic[Record]): """ @@ -37,6 +62,7 @@ class DeduplicationResult(Generic[Record]): selected: List of deduplicated records after removing duplicates. filtered: List of DuplicateRecord objects containing details about duplicates of an original record. threshold: The similarity threshold used for deduplication. + columns: Columns used for deduplication. deduplicated: Deprecated, use selected instead. duplicates: Deprecated, use filtered instead. @@ -45,6 +71,7 @@ class DeduplicationResult(Generic[Record]): selected: list[Record] = field(default_factory=list) filtered: list[DuplicateRecord] = field(default_factory=list) threshold: float = field(default=0.9) + columns: Sequence[str] | None = field(default=None) deduplicated: list[Record] = field(default_factory=list) # Deprecated duplicates: list[DuplicateRecord] = field(default_factory=list) # Deprecated @@ -102,6 +129,38 @@ def rethreshold(self, threshold: float) -> None: self.selected.append(dup.record) self.threshold = threshold + @property + def selected_with_duplicates(self) -> list[SelectedWithDuplicates[Record]]: + """ + For every kept record, return the duplicates that were removed along with their similarity scores. + + :return: A list of tuples where each tuple contains a kept record + and a list of its duplicates with their similarity scores. + """ + + def _to_hashable(record: Record) -> frozendict[str, str] | str: + """Convert a record to a hashable representation.""" + if isinstance(record, dict) and self.columns is not None: + # Convert dict to frozendict for immutability and hashability + return to_frozendict(record, set(self.columns)) + return str(record) + + # Build a mapping from original-record to [(duplicate, score), …] + buckets: defaultdict[Hashable, DuplicateList] = defaultdict(list) + for duplicate_record in self.filtered: + for original_record, score in duplicate_record.duplicates: + buckets[_to_hashable(original_record)].append((duplicate_record.record, float(score))) + + result: list[SelectedWithDuplicates[Record]] = [] + for selected in self.selected: + # Get the list of duplicates for the selected record + raw_list = buckets.get(_to_hashable(selected), []) + # Ensure we don't have duplicates in the list + deduped = {_to_hashable(rec): (rec, score) for rec, score in raw_list} + result.append(SelectedWithDuplicates(record=selected, duplicates=list(deduped.values()))) + + return result + @dataclass class FilterResult(Generic[Record]): diff --git a/semhash/records.py b/semhash/records.py index 69fd427..0266b99 100644 --- a/semhash/records.py +++ b/semhash/records.py @@ -1,15 +1,8 @@ from typing import Sequence -from frozendict import frozendict - from semhash.datamodels import DeduplicationResult, DuplicateRecord -def to_frozendict(record: dict[str, str], columns: set[str]) -> frozendict[str, str]: - """Convert a record to a frozendict.""" - return frozendict({k: record.get(k, "") for k in columns}) - - def dict_to_string(record: dict[str, str], columns: Sequence[str]) -> str: r""" Turn a record into a single string. diff --git a/semhash/semhash.py b/semhash/semhash.py index 09eed40..e2dfa5b 100644 --- a/semhash/semhash.py +++ b/semhash/semhash.py @@ -11,8 +11,8 @@ from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult, Record from semhash.index import Index -from semhash.records import add_scores_to_records, map_deduplication_result_to_strings, to_frozendict -from semhash.utils import Encoder, compute_candidate_limit +from semhash.records import add_scores_to_records, map_deduplication_result_to_strings +from semhash.utils import Encoder, compute_candidate_limit, to_frozendict class SemHash(Generic[Record]): @@ -190,7 +190,9 @@ def deduplicate( # If no records are left after removing exact duplicates, return early if not dict_records: - return DeduplicationResult(deduplicated=[], duplicates=duplicate_records, threshold=threshold) + return DeduplicationResult( + deduplicated=[], duplicates=duplicate_records, threshold=threshold, columns=self.columns + ) # Compute embeddings for the new records embeddings = self._featurize(records=dict_records, columns=self.columns, model=self.model) @@ -212,7 +214,7 @@ def deduplicate( ) result = DeduplicationResult( - deduplicated=deduplicated_records, duplicates=duplicate_records, threshold=threshold + deduplicated=deduplicated_records, duplicates=duplicate_records, threshold=threshold, columns=self.columns ) if self._was_string: @@ -281,7 +283,7 @@ def self_deduplicate( seen_items.update(frozen_items) result = DeduplicationResult( - deduplicated=deduplicated_records, duplicates=duplicate_records, threshold=threshold + deduplicated=deduplicated_records, duplicates=duplicate_records, threshold=threshold, columns=self.columns ) if self._was_string: diff --git a/semhash/utils.py b/semhash/utils.py index 03224d5..1956012 100644 --- a/semhash/utils.py +++ b/semhash/utils.py @@ -1,6 +1,7 @@ from typing import Any, Protocol, Sequence, Union import numpy as np +from frozendict import frozendict class Encoder(Protocol): @@ -21,6 +22,11 @@ def encode( ... # pragma: no cover +def to_frozendict(record: dict[str, str], columns: set[str]) -> frozendict[str, str]: + """Convert a record to a frozendict.""" + return frozendict({k: record.get(k, "") for k in columns}) + + def compute_candidate_limit( total: int, selection_size: int, diff --git a/tests/test_datamodels.py b/tests/test_datamodels.py index 0472db9..d770672 100644 --- a/tests/test_datamodels.py +++ b/tests/test_datamodels.py @@ -2,7 +2,7 @@ import semhash import semhash.version -from semhash.datamodels import DeduplicationResult, DuplicateRecord +from semhash.datamodels import DeduplicationResult, DuplicateRecord, SelectedWithDuplicates def test_deduplication_scoring() -> None: @@ -27,13 +27,13 @@ def test_deduplication_scoring_exact() -> None: def test_deduplication_scoring_exact_empty() -> None: """Test the deduplication scoring.""" - d = DeduplicationResult([], [], 0.8) + d = DeduplicationResult([], [], 0.8, columns=["text"]) assert d.exact_duplicate_ratio == 0.0 def test_deduplication_scoring_empty() -> None: """Test the deduplication scoring.""" - d = DeduplicationResult([], [], 0.8) + d = DeduplicationResult([], [], 0.8, columns=["text"]) assert d.duplicate_ratio == 0.0 @@ -64,7 +64,7 @@ def test_get_least_similar_from_duplicates() -> None: def test_get_least_similar_from_duplicates_empty() -> None: """Test getting the least similar duplicates.""" - d = DeduplicationResult([], [], 0.8) + d = DeduplicationResult([], [], 0.8, columns=["text"]) assert d.get_least_similar_from_duplicates(1) == [] @@ -116,3 +116,106 @@ def test_deprecation_deduplicated_duplicates() -> None: DuplicateRecord("d", False, [("x", 0.9), ("y", 0.8)]), DuplicateRecord("e", False, [("z", 0.8)]), ] + + +def test_selected_with_duplicates_strings() -> None: + """Test selected_with_duplicates for strings.""" + d = DeduplicationResult( + selected=["original"], + filtered=[ + DuplicateRecord("duplicate_1", False, [("original", 0.9)]), + DuplicateRecord("duplicate_2", False, [("original", 0.8)]), + ], + threshold=0.8, + ) + + expected = [ + SelectedWithDuplicates( + record="original", + duplicates=[("duplicate_1", 0.9), ("duplicate_2", 0.8)], + ) + ] + assert d.selected_with_duplicates == expected + + +def test_selected_with_duplicates_dicts() -> None: + """Test selected_with_duplicates for dicts.""" + selected = {"id": 0, "text": "hello"} + d = DeduplicationResult( + selected=[selected], + filtered=[ + DuplicateRecord({"id": 1, "text": "hello"}, True, [(selected, 1.0)]), + DuplicateRecord({"id": 2, "text": "helllo"}, False, [(selected, 0.1)]), + ], + threshold=0.8, + columns=["text"], + ) + + items = d.selected_with_duplicates + assert len(items) == 1 + kept = items[0].record + dups = items[0].duplicates + assert kept == selected + assert {r["id"] for r, _ in dups} == {1, 2} + + +def test_selected_with_duplicates_multi_column() -> None: + """Test selected_with_duplicates for multi-columns.""" + selected = {"text": "hello", "text2": "world"} + d = DeduplicationResult( + selected=[selected], + filtered=[ + DuplicateRecord({"text": "hello", "text2": "world"}, True, [(selected, 1.0)]), + DuplicateRecord({"text": "helllo", "text2": "world"}, False, [(selected, 0.1)]), + ], + threshold=0.8, + columns=["text", "text2"], + ) + + items = d.selected_with_duplicates + assert len(items) == 1 + kept = items[0].record + assert kept == selected + + +def test_selected_with_duplicates_unhashable_values() -> None: + """Test selected_with_duplicates with unhashable values in records.""" + selected = {"text": "hello", "a": [1, 2, 3]} # list -> unhashable value + filtered = {"text": "hello", "a": [1, 2, 3], "flag": True} + + d = DeduplicationResult( + selected=[selected], + filtered=[DuplicateRecord(filtered, exact=False, duplicates=[(selected, 1.0)])], + threshold=0.8, + columns=["text"], + ) + + items = d.selected_with_duplicates + assert items == [SelectedWithDuplicates(record=selected, duplicates=[(filtered, 1.0)])] + + +def test_selected_with_duplicates_removes_internal_duplicates() -> None: + """Test that selected_with_duplicates removes internal duplicates that have the same hash.""" + selected = {"id": 0, "text": "hello"} + filtered = {"id": 1, "text": "hello"} + + d = DeduplicationResult( + selected=[selected], + filtered=[ + DuplicateRecord(filtered, exact=False, duplicates=[(selected, 0.95)]), + DuplicateRecord(filtered, exact=False, duplicates=[(selected, 0.90)]), + ], + threshold=0.8, + columns=["text"], + ) + + items = d.selected_with_duplicates + assert len(items) == 1 + + selected_record = items[0].record + duplicate_list = items[0].duplicates + # Should keep the kept record unchanged + assert selected_record == selected + # The duplicate row must appear only once + assert len(duplicate_list) == 1 + assert duplicate_list[0][0] == filtered