Skip to content

Commit

Permalink
Nice tables support
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Mar 4, 2025
1 parent 3a0bcb6 commit 004486f
Showing 1 changed file with 142 additions and 67 deletions.
209 changes: 142 additions & 67 deletions olmocr/bench/tests.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,107 @@
import json
import re
import numpy as np
from bs4 import BeautifulSoup

from dataclasses import asdict, dataclass
from enum import Enum
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict, Any

from fuzzysearch import find_near_matches
from rapidfuzz import fuzz


def parse_markdown_tables(md_content: str) -> List[np.ndarray]:
"""
Extract and parse all markdown tables from the provided content.
Args:
md_content: The markdown content containing tables
Returns:
A list of numpy arrays, each representing a parsed table
"""
# Extract all tables from markdown
table_pattern = r'(\|(?:[^|]*\|)+)\s*\n\|(?:[:-]+\|)+\s*\n((?:\|(?:[^|]*\|)+\s*\n)+)'
table_matches = re.finditer(table_pattern, md_content)

parsed_tables = []

for table_match in table_matches:
# Extract header and body from the table match
header_row = table_match.group(1).strip()
body_rows = table_match.group(2).strip().split('\n')

# Process header and rows to remove leading/trailing |
header_cells = [cell.strip() for cell in header_row.split('|')]
if header_cells[0] == '':
header_cells = header_cells[1:]
if header_cells[-1] == '':
header_cells = header_cells[:-1]

# Process table body rows
table_data = []
for row in [header_row] + body_rows:
if '|' not in row: # Skip separator row
continue

cells = [cell.strip() for cell in row.split('|')]
if cells[0] == '':
cells = cells[1:]
if cells[-1] == '':
cells = cells[:-1]

table_data.append(cells)

# Skip separator row (second row with dashes)
if len(table_data) > 1 and all('-' in cell for cell in table_data[1]):
table_data = [table_data[0]] + table_data[2:]

# Convert to numpy array for easier manipulation
# First ensure all rows have the same number of columns by padding if necessary
max_cols = max(len(row) for row in table_data)
padded_data = [row + [''] * (max_cols - len(row)) for row in table_data]
table_array = np.array(padded_data)

parsed_tables.append(table_array)

return parsed_tables


def parse_html_tables(html_content: str) -> List[np.ndarray]:
"""
Extract and parse all HTML tables from the provided content.
Args:
html_content: The HTML content containing tables
Returns:
A list of numpy arrays, each representing a parsed table
"""
soup = BeautifulSoup(html_content, 'html.parser')
tables = soup.find_all('table')

parsed_tables = []

for table in tables:
rows = table.find_all(['tr'])
table_data = []

for row in rows:
cells = row.find_all(['th', 'td'])
row_data = [cell.get_text().strip() for cell in cells]
table_data.append(row_data)

# Ensure all rows have the same number of columns
if table_data:
max_cols = max(len(row) for row in table_data)
padded_data = [row + [''] * (max_cols - len(row)) for row in table_data]
table_array = np.array(padded_data)
parsed_tables.append(table_array)

return parsed_tables


class TestType(str, Enum):
PRESENT = "present"
ABSENT = "absent"
Expand Down Expand Up @@ -147,87 +239,68 @@ def run(self, md_content: str) -> Tuple[bool, str]:
return True, ""
return False, (f"Could not find a location where '{self.before[:40]}...' appears before " f"'{self.after[:40]}...'.")





@dataclass
class TableTest(BasePDFTest):
"""
Test to verify certain properties of a table are held, namely that some cells appear relative to other cells correctly
Test to verify certain properties of a table are held, namely that some cells appear relative to other cells correctly
"""
# This is the target cell, which must exist in at least one place in the table
cell: str

# These properties say that the cell immediately up/down/left/right of the target cell has the string specified
up: str=""
down: str=""
left: str=""
right: str=""
up: str = ""
down: str = ""
left: str = ""
right: str = ""

# These properties say that the cell all the way up, or all the way left of the target cell (ex. headings) has the string value specified
top_heading: str=""
left_heading: str=""
top_heading: str = ""
left_heading: str = ""


def __post_init__(self):
super().__post_init__()
if self.type != TestType.TABLE.value:
raise ValidationError(f"Invalid type for TabelText: {self.type}")
raise ValidationError(f"Invalid type for TableTest: {self.type}")

def run(self, md_content: str) -> Tuple[bool, str]:
def run(self, content: str) -> Tuple[bool, str]:
"""
Run the table test on provided markdown content.
Run the table test on provided content.
Finds all markdown tables and checks if any cell matches the target cell
and satisfies the specified relationships (up, down, left, right, headings).
Finds all tables (markdown and/or HTML based on content_type) and checks if any cell
matches the target cell and satisfies the specified relationships.
Args:
md_content: The markdown content containing tables
content: The content containing tables (markdown or HTML)
Returns:
A tuple (passed, explanation) where 'passed' is True if at least one cell
passes all the specified relationships, and 'explanation' provides details when the test fails.
A tuple (passed, explanation) where 'passed' is True if the test passes,
and 'explanation' provides details when the test fails.
"""
# Initialize variables to track tables and results
tables_to_check = []
failed_reasons = []

# Threshold for fuzzy matching derived from max_diffs
threshold = 1.0 - (self.max_diffs / (len(self.cell) if len(self.cell) > 0 else 1))

# Extract all tables from markdown
table_pattern = r'(\|(?:[^|]*\|)+)\s*\n\|(?:[:-]+\|)+\s*\n((?:\|(?:[^|]*\|)+\s*\n)+)'
table_matches = re.finditer(table_pattern, md_content)
failed_reasons = []
# Parse tables based on content_type
md_tables = parse_markdown_tables(content)
tables_to_check.extend(md_tables)

for table_match in table_matches:
# Extract header and body from the table match
header_row = table_match.group(1).strip()
body_rows = table_match.group(2).strip().split('\n')

# Process header and rows to remove leading/trailing |
header_cells = [cell.strip() for cell in header_row.split('|')]
if header_cells[0] == '':
header_cells = header_cells[1:]
if header_cells[-1] == '':
header_cells = header_cells[:-1]

# Process table body rows
table_data = []
for row in [header_row] + body_rows:
if '|' not in row: # Skip separator row
continue

cells = [cell.strip() for cell in row.split('|')]
if cells[0] == '':
cells = cells[1:]
if cells[-1] == '':
cells = cells[:-1]

table_data.append(cells)

# Skip separator row (second row with dashes)
if len(table_data) > 1 and all('-' in cell for cell in table_data[1]):
table_data = [table_data[0]] + table_data[2:]

# Convert to numpy array for easier manipulation
# First ensure all rows have the same number of columns by padding if necessary
max_cols = max(len(row) for row in table_data)
padded_data = [row + [''] * (max_cols - len(row)) for row in table_data]
table_array = np.array(padded_data)
html_tables = parse_html_tables(content)
tables_to_check.extend(html_tables)

# If no tables found, return failure
if not tables_to_check:
return False, f"No tables found in the content at all"

# Check each table
for table_array in tables_to_check:
# Find all cells that match the target cell using fuzzy matching
matches = []
for i in range(table_array.shape[0]):
Expand All @@ -245,39 +318,39 @@ def run(self, md_content: str) -> Tuple[bool, str]:
# Check the relationships for each matching cell
for row_idx, col_idx in matches:
all_relationships_satisfied = True
failed_reasons = []
current_failed_reasons = []

# Check up relationship
if self.up and row_idx > 0:
up_cell = table_array[row_idx - 1, col_idx]
up_similarity = fuzz.ratio(self.up, up_cell) / 100.0
if up_similarity < threshold:
all_relationships_satisfied = False
failed_reasons.append(f"Cell above '{up_cell}' doesn't match expected '{self.up}' (similarity: {up_similarity:.2f})")
current_failed_reasons.append(f"Cell above '{up_cell}' doesn't match expected '{self.up}' (similarity: {up_similarity:.2f})")

# Check down relationship
if self.down and row_idx < table_array.shape[0] - 1:
down_cell = table_array[row_idx + 1, col_idx]
down_similarity = fuzz.ratio(self.down, down_cell) / 100.0
if down_similarity < threshold:
all_relationships_satisfied = False
failed_reasons.append(f"Cell below '{down_cell}' doesn't match expected '{self.down}' (similarity: {down_similarity:.2f})")
current_failed_reasons.append(f"Cell below '{down_cell}' doesn't match expected '{self.down}' (similarity: {down_similarity:.2f})")

# Check left relationship
if self.left and col_idx > 0:
left_cell = table_array[row_idx, col_idx - 1]
left_similarity = fuzz.ratio(self.left, left_cell) / 100.0
if left_similarity < threshold:
all_relationships_satisfied = False
failed_reasons.append(f"Cell to the left '{left_cell}' doesn't match expected '{self.left}' (similarity: {left_similarity:.2f})")
current_failed_reasons.append(f"Cell to the left '{left_cell}' doesn't match expected '{self.left}' (similarity: {left_similarity:.2f})")

# Check right relationship
if self.right and col_idx < table_array.shape[1] - 1:
right_cell = table_array[row_idx, col_idx + 1]
right_similarity = fuzz.ratio(self.right, right_cell) / 100.0
if right_similarity < threshold:
all_relationships_satisfied = False
failed_reasons.append(f"Cell to the right '{right_cell}' doesn't match expected '{self.right}' (similarity: {right_similarity:.2f})")
current_failed_reasons.append(f"Cell to the right '{right_cell}' doesn't match expected '{self.right}' (similarity: {right_similarity:.2f})")

# Check top heading relationship
if self.top_heading and row_idx > 0:
Expand All @@ -290,12 +363,12 @@ def run(self, md_content: str) -> Tuple[bool, str]:

if not top_heading_cell:
all_relationships_satisfied = False
failed_reasons.append(f"No non-empty top heading found in column {col_idx}")
current_failed_reasons.append(f"No non-empty top heading found in column {col_idx}")
else:
top_similarity = fuzz.ratio(self.top_heading, top_heading_cell) / 100.0
if top_similarity < threshold:
all_relationships_satisfied = False
failed_reasons.append(f"Top heading '{top_heading_cell}' doesn't match expected '{self.top_heading}' (similarity: {top_similarity:.2f})")
current_failed_reasons.append(f"Top heading '{top_heading_cell}' doesn't match expected '{self.top_heading}' (similarity: {top_similarity:.2f})")

# Check left heading relationship
if self.left_heading and col_idx > 0:
Expand All @@ -308,24 +381,26 @@ def run(self, md_content: str) -> Tuple[bool, str]:

if not left_heading_cell:
all_relationships_satisfied = False
failed_reasons.append(f"No non-empty left heading found in row {row_idx}")
current_failed_reasons.append(f"No non-empty left heading found in row {row_idx}")
else:
left_heading_similarity = fuzz.ratio(self.left_heading, left_heading_cell) / 100.0
if left_heading_similarity < threshold:
all_relationships_satisfied = False
failed_reasons.append(f"Left heading '{left_heading_cell}' doesn't match expected '{self.left_heading}' (similarity: {left_heading_similarity:.2f})")
current_failed_reasons.append(f"Left heading '{left_heading_cell}' doesn't match expected '{self.left_heading}' (similarity: {left_heading_similarity:.2f})")


# If all relationships are satisfied for this cell, the test passes
if all_relationships_satisfied:
return True, ""

else:
failed_reasons.extend(current_failed_reasons)

# If we've gone through all tables and all matching cells and none satisfied all relationships
if not failed_reasons:
return False, f"No cell matching '{self.cell}' found in any table with threshold {threshold}"
else:
return False, f"Found cells matching '{self.cell}' but relationships were not satisfied: {'; '.join(failed_reasons)}"


def load_tests(jsonl_file: str) -> List[BasePDFTest]:
"""
Load tests from a JSONL file.
Expand Down Expand Up @@ -376,4 +451,4 @@ def save_tests(tests: List[BasePDFTest], jsonl_file: str) -> None:
"""
with open(jsonl_file, "w") as file:
for test in tests:
file.write(json.dumps(asdict(test)) + "\n")
file.write(json.dumps(asdict(test)) + "\n")

0 comments on commit 004486f

Please sign in to comment.