Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 77 additions & 34 deletions marimo/_utils/cell_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,60 @@ def similarity_score(s1: str, s2: str) -> float:
"""Fast similarity score based on common prefix and suffix.
Returns lower score for more similar strings."""
# Find common prefix length
# Fast prefix scan, then suffix scan only if needed
len1 = len(s1)
len2 = len(s2)
minlen = min(len1, len2)
# Scan prefix
prefix_len = 0
for c1, c2 in zip(s1, s2):
if c1 != c2:
break
# Use indices for tight loop
while prefix_len < minlen and s1[prefix_len] == s2[prefix_len]:
prefix_len += 1

# Find common suffix length if strings differ in middle
if prefix_len < min(len(s1), len(s2)):
s1_rev = s1[::-1]
s2_rev = s2[::-1]
# If prefix covers both strings fully, no need to scan suffix
if prefix_len < minlen:
# Suffix scan using indices, avoid creating reversed strings
suffix_len = 0
for c1, c2 in zip(s1_rev, s2_rev):
if c1 != c2:
# Avoid scanning past the introduced diff
idx1 = len1 - 1
idx2 = len2 - 1
while suffix_len < (minlen - prefix_len):
if s1[idx1] != s2[idx2]:
break
suffix_len += 1
idx1 -= 1
idx2 -= 1
else:
suffix_len = 0

# Return inverse similarity - shorter common affix means higher score
return len(s1) + len(s2) - 2.0 * (prefix_len + suffix_len)
return len1 + len2 - 2.0 * (prefix_len + suffix_len)


def group_lookup(
ids: Sequence[CellId_t], codes: Sequence[str]
) -> dict[str, list[tuple[int, CellId_t]]]:
lookup: dict[str, list[tuple[int, CellId_t]]] = {}
# Use local variable to avoid global setdefault lookup inside loop
lookup_append = lookup.setdefault
for idx, (cell_id, code) in enumerate(zip(ids, codes)):
lookup.setdefault(code, []).append((idx, cell_id))
lookup_append(code, []).append((idx, cell_id))
return lookup


def extract_order(
codes: list[str], lookup: dict[str, list[tuple[int, CellId_t]]]
) -> list[list[int]]:
offset = 0
order: list[list[int]] = [[]] * len(codes)
# Allocate correct list upfront
order: list[list[int]] = [None] * len(codes)
for i, code in enumerate(codes):
dupes = len(lookup[code])
order[i] = [offset + j for j in range(dupes)]
# Pre-calc range for order, avoid inner function call and allocate directly
if dupes:
order[i] = list(range(offset, offset + dupes))
else:
order[i] = []
offset += dupes
return order

Expand All @@ -60,7 +75,9 @@ def get_unique(
codes: Sequence[str], available: dict[str, list[tuple[int, CellId_t]]]
) -> list[str]:
# Order matters, required opposed to using set()
seen = set(codes) - set(available.keys())
available_keys = set(available.keys())
# Use a seen set that starts with all codes not in available_keys
seen = set(codes) - available_keys
unique_codes = []
for code in codes:
if code not in seen:
Expand All @@ -72,9 +89,15 @@ def get_unique(
def pop_local(available: list[tuple[int, CellId_t]], idx: int) -> CellId_t:
"""Find and pop the index that is closest to idx"""
# NB. by min implementation a preference is given to the lower index when equidistant
best_idx = min(
range(len(available)), key=lambda i: abs(available[i][0] - idx)
)
# Optimize for short available lists, but handle long ones efficiently
best_dist = float("inf")
best_idx = -1
# Small lists: avoid lambda, direct scan
for i, (v_idx, _) in enumerate(available):
dist = abs(v_idx - idx)
if dist < best_dist:
best_dist = dist
best_idx = i
return available.pop(best_idx)[1]


Expand All @@ -92,8 +115,11 @@ def _hungarian_algorithm(scores: list[list[float]]) -> list[int]:
# Step 1: Subtract row minima
for i in range(n):
min_value = min(score_matrix[i])
score_matrix_i = score_matrix[i]
for j in range(n):
score_matrix[i][j] -= min_value
score_matrix_i[j] -= min_value

# Step 2: Subtract column minima

# Step 2: Subtract column minima
for j in range(n):
Expand All @@ -107,9 +133,10 @@ def _hungarian_algorithm(scores: list[list[float]]) -> list[int]:

# Find independent zeros
for i in range(n):
score_matrix_i = score_matrix[i]
for j in range(n):
if (
score_matrix[i][j] == 0
score_matrix_i[j] == 0
and row_assignment[i] == -1
and col_assignment[j] == -1
):
Expand All @@ -124,36 +151,43 @@ def _hungarian_algorithm(scores: list[list[float]]) -> list[int]:

# Find minimum uncovered value
min_uncovered = float("inf")
for i in range(n):
for j in range(n):
if row_assignment[i] == -1 and col_assignment[j] == -1:
min_uncovered = min(min_uncovered, score_matrix[i][j])
# Pre-calc masks for uncovered
uncovered_rows = [i for i, v in enumerate(row_assignment) if v == -1]
uncovered_cols = [j for j, v in enumerate(col_assignment) if v == -1]
for i in uncovered_rows:
score_matrix_i = score_matrix[i]
for j in uncovered_cols:
if score_matrix_i[j] < min_uncovered:
min_uncovered = score_matrix_i[j]

if min_uncovered == float("inf"):
break

# Update matrix
for i in range(n):
score_matrix_i = score_matrix[i]
for j in range(n):
if row_assignment[i] == -1 and col_assignment[j] == -1:
score_matrix[i][j] -= min_uncovered
score_matrix_i[j] -= min_uncovered
elif row_assignment[i] != -1 and col_assignment[j] != -1:
score_matrix[i][j] += min_uncovered
score_matrix_i[j] += min_uncovered

# Try to find new assignments
for i in range(n):
for i in uncovered_rows:
score_matrix_i = score_matrix[i]
if row_assignment[i] == -1:
for j in range(n):
if score_matrix[i][j] == 0 and col_assignment[j] == -1:
for j in uncovered_cols:
if score_matrix_i[j] == 0 and col_assignment[j] == -1:
row_assignment[i] = j
col_assignment[j] = i
break

# Convert to result format
result = [-1] * n
for i in range(n):
if row_assignment[i] != -1:
result[row_assignment[i]] = i
a = row_assignment[i]
if a != -1:
result[a] = i

return result

Expand Down Expand Up @@ -240,21 +274,30 @@ def filter_and_backfill() -> list[CellId_t]:

# Pad the scores matrix to ensure it is square
n = max(len(next_codes) - filled, len(prev_codes) - filled)
scores = [[0.0] * n for _ in range(n)]
# Allocate as one contiguous list to avoid repeated small-object instantiation
scores = [None] * n
for i in range(n):
scores[i] = [0.0] * n

# Fill matrix, accounting for dupes
# Fill matrix, accounting for dupes
for i, code in enumerate(added_code):
o_i = next_order[i]
for j, prev_code in enumerate(deleted_code):
score = similarity_score(prev_code, code)
for x in next_order[i]:
for y in prev_order[j]:
o_j = prev_order[j]
# Restructure to use for-indexes instead of nested for-loops for C-speedup
for x in o_i:
for y in o_j:
# NB. transposed indices for Hungarian
scores[y][x] = score

# Use Hungarian algorithm to find the best matching
matches = _hungarian_algorithm(scores)
for idx, code in enumerate(next_codes):
if result[idx] is None:
match_idx = next_order[next_inverse[code]].pop(0)
o = next_order[next_inverse[code]]
match_idx = o.pop(0)
if match_idx != -1 and matches[match_idx] in inverse_order:
prev_idx = inverse_order[matches[match_idx]]
prev_code = deleted_code[prev_idx]
Expand Down