Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions semhash/datamodels.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import json
import warnings
from collections import defaultdict
from collections.abc import Hashable, Sequence
from dataclasses import dataclass, field
Expand Down Expand Up @@ -65,35 +64,13 @@ class DeduplicationResult(Generic[Record]):
filtered: List of DuplicateRecord objects containing details about duplicates of an original record.
threshold: The similarity threshold used for deduplication.
columns: Columns used for deduplication.
deduplicated: Deprecated, use selected instead.
duplicates: Deprecated, use filtered instead.

"""

selected: list[Record] = field(default_factory=list)
filtered: list[DuplicateRecord] = field(default_factory=list)
threshold: float = field(default=0.9)
columns: Sequence[str] | None = field(default=None)
deduplicated: list[Record] = field(default_factory=list) # Deprecated
duplicates: list[DuplicateRecord] = field(default_factory=list) # Deprecated

def __post_init__(self) -> None:
"""Initialize deprecated fields and warn about deprecation."""
if self.deduplicated or self.duplicates:
warnings.warn(
"'deduplicated' and 'duplicates' fields are deprecated and will be removed in a future release. Use 'selected' and 'filtered' instead.",
DeprecationWarning,
stacklevel=2,
)

if not self.selected and self.deduplicated:
self.selected = self.deduplicated
if not self.filtered and self.duplicates:
self.filtered = self.duplicates
if not self.deduplicated:
self.deduplicated = self.selected
if not self.duplicates:
self.duplicates = self.filtered

@property
def duplicate_ratio(self) -> float:
Expand Down
2 changes: 1 addition & 1 deletion semhash/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def map_deduplication_result_to_strings(result: DeduplicationResult, columns: Se
"""Convert the record and duplicates in each DuplicateRecord back to strings if self.was_string is True."""
deduplicated_str = [dict_to_string(r, columns) for r in result.selected]
mapped = []
for dup_rec in result.duplicates:
for dup_rec in result.filtered:
record_as_str = dict_to_string(dup_rec.record, columns)
duplicates_as_str = [(dict_to_string(r, columns), score) for r, score in dup_rec.duplicates]
mapped.append(
Expand Down
6 changes: 3 additions & 3 deletions semhash/semhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def deduplicate(
# If no records are left after removing exact duplicates, return early
if not dict_records:
return DeduplicationResult(
deduplicated=[], duplicates=duplicate_records, threshold=threshold, columns=self.columns
selected=[], filtered=duplicate_records, threshold=threshold, columns=self.columns
)

# Compute embeddings for the new records
Expand All @@ -221,7 +221,7 @@ def deduplicate(
)

result = DeduplicationResult(
deduplicated=deduplicated_records, duplicates=duplicate_records, threshold=threshold, columns=self.columns
selected=deduplicated_records, filtered=duplicate_records, threshold=threshold, columns=self.columns
)

if self._was_string:
Expand Down Expand Up @@ -290,7 +290,7 @@ def self_deduplicate(
seen_items.update(frozen_items)

result = DeduplicationResult(
deduplicated=deduplicated_records, duplicates=duplicate_records, threshold=threshold, columns=self.columns
selected=deduplicated_records, filtered=duplicate_records, threshold=threshold, columns=self.columns
)

if self._was_string:
Expand Down
21 changes: 0 additions & 21 deletions tests/test_datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,27 +97,6 @@ def test_rethreshold_exception() -> None:
d.rethreshold(0.6)


def test_deprecation_deduplicated_duplicates() -> None:
"""Test deprecation warnings for deduplicated and duplicates fields."""
if semhash.version.__version__ < "0.4.0":
with pytest.warns(DeprecationWarning):
d = DeduplicationResult(
deduplicated=["a", "b", "c"],
duplicates=[
DuplicateRecord("d", False, [("x", 0.9), ("y", 0.8)]),
DuplicateRecord("e", False, [("z", 0.8)]),
],
threshold=0.8,
)
else:
raise ValueError("deprecate `deduplicated` and `duplicates` fields in `DeduplicationResult`")
assert d.selected == ["a", "b", "c"]
assert d.filtered == [
DuplicateRecord("d", False, [("x", 0.9), ("y", 0.8)]),
DuplicateRecord("e", False, [("z", 0.8)]),
]


def test_selected_with_duplicates_strings() -> None:
"""Test selected_with_duplicates for strings."""
d = DeduplicationResult(
Expand Down