Skip to content

Commit

Permalink
Work on image matching
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Mar 6, 2025
1 parent b03d840 commit 4053ea5
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 4 deletions.
65 changes: 65 additions & 0 deletions olmocr/bench/katex/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image

def find_image_match(large_pil, small_pil, device=None) -> tuple[float, int, int]:
"""
Finds the best matching location of a small image inside a large image using 2D convolution.
Returns a matching score and the coordinates of the best match.
Args:
large_pil (PIL.Image): The large image (document).
small_pil (PIL.Image): The small image (patch).
device (str, optional): "cuda" for GPU, "cpu" for CPU, or None for auto-selection.
Returns:
(score, x, y):
- score: Matching score between 0.0 and 1.0, where 1.0 is a perfect match
- x, y: Coordinates of the top-left corner of the best match location
"""

# Auto-select device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Convert images to grayscale and NumPy arrays
large_img = np.array(large_pil.convert("L"), dtype=np.float32) / 255.0
small_img = np.array(small_pil.convert("L"), dtype=np.float32) / 255.0

# Swap things around so large image is actually the largest
if small_img.shape[0] > large_img.shape[0] and small_img.shape[1] > large_img.shape[1]:
small_img, large_img = large_img, small_img

# Convert to PyTorch tensors
large_tensor = torch.tensor(large_img).unsqueeze(0).unsqueeze(0).to(device) # (1, 1, H, W)
small_tensor = torch.tensor(small_img).unsqueeze(0).unsqueeze(0).to(device) # (1, 1, h, w)

# Normalize the template (small image) for proper correlation calculation
# This makes the convolution behave like template matching
small_mean = torch.mean(small_tensor)
small_std = torch.std(small_tensor)
small_normalized = (small_tensor - small_mean) / (small_std + 1e-7)

# Calculate convolution
def conv2d_fn(large, small):
return F.conv2d(large, small, padding="same")

# Perform convolution
result = conv2d_fn(large_tensor, small_normalized)

# Find the max correlation and its position in a single call
# result shape is [1, 1, H, W]
max_val, max_loc = torch.max(result.view(-1), 0)

# Convert flat index to 2D coordinates
result_size = result.squeeze().size()
best_y = (max_loc // result_size[1]).item()
best_x = (max_loc % result_size[1]).item()

# Extract the region from the large image that matches the small image
h, w = small_img.shape

score = (max_val / torch.mean(large_tensor)).item()

return score, best_x, best_y
17 changes: 13 additions & 4 deletions olmocr/bench/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from olmocr.repeatdetect import RepeatDetector
from .katex.render import render_equation
from .katex.compare import find_image_match

class TestType(str, Enum):
PRESENT = "present"
Expand Down Expand Up @@ -432,6 +433,11 @@ def __post_init__(self):
if len(self.math.strip()) == 0:
raise ValidationError(f"Math test must have non-empty math expression")

self.reference_render = render_equation(self.math)

if self.reference_render is None:
raise ValidationError(f"Math equation {self.math} was not able to render")


def run(self, content: str) -> Tuple[bool, str]:
# Store both the search pattern and the full pattern to replace
Expand All @@ -452,20 +458,23 @@ def run(self, content: str) -> Tuple[bool, str]:

# Replace all instances of this pattern with empty strings
modified_content = re.sub(replace_pattern, '', modified_content, flags=re.DOTALL)

print("All equations", equations)

# If an equation in the markdown exactly matches our math string, then that's good enough
# we don't have to do a more expensive comparison
if any(hyp == self.math for hyp in equations):
return True, ""

# If not, then let's render the math equation itself and now compare to each hypothesis
reference_render = render_equation(self.math)

for hypothesis in equations:
hypothesis_render = render_equation(hypothesis)

if not hypothesis_render:
continue

# Now, let's see what the matchup is between the two images
match = find_image_match(hypothesis_render, self.reference_render)
print(f"Match score for {self.math} vs {hypothesis}, {match}")

return False, f"No match found for {self.math} anywhere in content"


Expand Down

0 comments on commit 4053ea5

Please sign in to comment.