Skip to content

Commit

Permalink
Model runners
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Mar 5, 2025
1 parent 5cb32c3 commit 5611d79
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 39 deletions.
108 changes: 83 additions & 25 deletions olmocr/bench/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import glob
import importlib
import os
from functools import partial

from tqdm import tqdm

Expand Down Expand Up @@ -40,8 +41,48 @@ def parse_method_arg(method_arg):
return name, kwargs, folder_name


async def process_pdfs(config, pdf_directory, data_directory, repeats, force):
"""Process PDFs with both sync and async functions"""
# Wrapper to run synchronous functions in the event loop
async def run_sync_in_executor(func, *args, **kwargs):
"""Run a synchronous function in the default executor"""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, partial(func, *args, **kwargs))


async def process_pdf(pdf_path, method, kwargs, output_path, is_async):
"""Process a single PDF and save the result to output_path"""
try:
if is_async:
# Run async function directly
markdown = await method(pdf_path, page_num=1, **kwargs)
else:
# Run synchronous function in the executor
markdown = await run_sync_in_executor(method, pdf_path, page_num=1, **kwargs)

if markdown is None:
print(f"Warning, did not get output for {os.path.basename(output_path)}")
# Write blank to this file, so that it's marked as an error and not just skipped in evals
with open(output_path, "w") as out_f:
out_f.write("")
return False

# Write the markdown to the output file
with open(output_path, "w") as out_f:
out_f.write(markdown)

return True
except Exception as ex:
print(f"Exception {str(ex)} occurred while processing {os.path.basename(output_path)}")
# Write blank to this file, so that it's marked as an error and not just skipped in evals
with open(output_path, "w") as out_f:
out_f.write("")
return False


async def process_pdfs(config, pdf_directory, data_directory, repeats, force, max_parallel=None):
"""
Process PDFs using asyncio for both sync and async methods,
limiting the number of concurrent tasks to max_parallel.
"""
for candidate in config.keys():
print(f"Starting conversion using {candidate} with kwargs: {config[candidate]['kwargs']}")
folder_name = config[candidate]["folder_name"]
Expand All @@ -55,35 +96,51 @@ async def process_pdfs(config, pdf_directory, data_directory, repeats, force):
all_pdfs = glob.glob(os.path.join(pdf_directory, "*.pdf"))
all_pdfs.sort()

for pdf_path in tqdm(all_pdfs, desc=candidate):
# Prepare all tasks
tasks = []
task_descriptions = {}

for pdf_path in all_pdfs:
base_name = os.path.basename(pdf_path).replace(".pdf", "")

for i in range(1, repeats + 1):
output_filename = f"{base_name}_{i}.md"
output_path = os.path.join(candidate_output_dir, output_filename)

if os.path.exists(output_path) and not force:
print(f"Skipping {base_name}_{i} for {candidate}, file already exists")
print("Rerun with --force flag to force regeneration")
continue

try:
if is_async:
# Run async function
markdown = await method(pdf_path, page_num=1, **kwargs)
else:
# Run synchronous function
markdown = method(pdf_path, page_num=1, **kwargs)
except Exception as ex:
print(f"Exception {str(ex)} occurred while processing {base_name}_{i}")
markdown = None

if markdown is None:
print(f"Warning, did not get output for {base_name}_{i}")
continue

with open(output_path, "w") as out_f:
out_f.write(markdown)

task = process_pdf(pdf_path, method, kwargs, output_path, is_async)
tasks.append(task)
task_descriptions[id(task)] = f"{base_name}_{i} ({candidate})"

# Process tasks with semaphore to limit concurrency
semaphore = asyncio.Semaphore(max_parallel or 1) # Default to 1 if not specified

async def process_with_semaphore(task):
async with semaphore:
return await task

# Wrap each task with the semaphore
limited_tasks = [process_with_semaphore(task) for task in tasks]

# Process tasks with progress bar
if limited_tasks:
completed = 0
with tqdm(total=len(limited_tasks), desc=f"Processing {candidate}") as pbar:
for task in asyncio.as_completed(limited_tasks):
try:
result = await task
if result:
completed += 1
except Exception as e:
print(f"Task failed: {e}")
finally:
pbar.update(1)

print(f"Completed {completed} out of {len(limited_tasks)} tasks for {candidate}")


if __name__ == "__main__":
Expand All @@ -98,6 +155,7 @@ async def process_pdfs(config, pdf_directory, data_directory, repeats, force):
parser.add_argument("--repeats", type=int, default=1, help="Number of times to repeat the conversion for each PDF.")
parser.add_argument("--dir", type=str, default=os.path.join(os.path.dirname(__file__), "sample_data"), help="Path to the data folder in which to save outputs, pdfs should be in /pdfs folder within it.")
parser.add_argument("--force", action="store_true", default=False, help="Force regenerating of output files, even if they already exist")
parser.add_argument("--parallel", type=int, default=10, help="Maximum number of concurrent tasks")
args = parser.parse_args()

# Mapping of method names to a tuple: (module path, function name)
Expand Down Expand Up @@ -125,5 +183,5 @@ async def process_pdfs(config, pdf_directory, data_directory, repeats, force):
data_directory = args.dir
pdf_directory = os.path.join(data_directory, "pdfs")

# Run the async process function
asyncio.run(process_pdfs(config, pdf_directory, data_directory, args.repeats, args.force))
# Run the async process function with the parallel argument
asyncio.run(process_pdfs(config, pdf_directory, data_directory, args.repeats, args.force, args.parallel))
7 changes: 2 additions & 5 deletions olmocr/bench/runners/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,14 @@ async def run_server(pdf_path: str, page_num: int = 1, server: str = "localhost:
async with httpx.AsyncClient(timeout=300) as client:
response = await client.post(url, json=request)

print(response.status_code)
response.raise_for_status()
data = response.json()

print(data)
choice = data["choices"][0]
print(choice)
assert choice["finish_reason"] == "stop", "Response from server did not finish with finish_reason stop as expected, this is probably going to lead to bad data"

if response_template == "json":
data = choice["message"]["content"]
page_data = json.loads(page_data)
page_data = json.loads(choice["message"]["content"])
page_response = PageResponse(**page_data)
return page_response.natural_text
elif response_template == "plain":
Expand Down
22 changes: 13 additions & 9 deletions olmocr/bench/scripts/convert_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,16 @@ create_conda_env() {
}

# Function to start sglang server with OpenAI API for a specific model
# Now accepting additional arguments after the model name
start_sglang_server() {
model_name=$1
shift # Remove the first argument (model_name) from the argument list

echo "Starting sglang server for model: $model_name"
echo "Additional arguments: $@"

# Start the server in the background and save the PID
python -m sglang.launch_server --model $model_name --chat-template qwen2-vl &
# Start the server in the background with all remaining arguments and save the PID
python -m sglang.launch_server --model $model_name $@ &
SERVER_PID=$!

# Check if the server process is running
Expand Down Expand Up @@ -121,19 +125,19 @@ source activate olmocr
# For each model, start server, run benchmark, then stop server

# olmocr_base_temp0_1
start_sglang_server "allenai/olmOCR-7B-0225-preview"
python -m olmocr.bench.convert server:name=olmocr_base_temp0_1:model=allenai/olmOCR-7B-0225-preview:temperature=0.1:response_template=json --repeats 5
python -m olmocr.bench.convert server:name=olmocr_base_temp0_8:model=allenai/olmOCR-7B-0225-preview:temperature=0.8:response_template=json --repeats 5
start_sglang_server "allenai/olmOCR-7B-0225-preview" --mem-fraction-static 0.7
python -m olmocr.bench.convert server:name=olmocr_base_temp0_1:model=allenai/olmOCR-7B-0225-preview:temperature=0.1:prompt_template=fine_tune:response_template=json --repeats 5 --parallel 20
python -m olmocr.bench.convert server:name=olmocr_base_temp0_8:model=allenai/olmOCR-7B-0225-preview:temperature=0.8:prompt_template=fine_tune:response_template=json --repeats 5 --parallel 20
stop_sglang_server

# qwen2_vl_7b
start_sglang_server "Qwen/Qwen2-VL-7B-Instruct"
python -m olmocr.bench.convert server:name=qwen2_vl_7b:model=Qwen/Qwen2-VL-7B-Instruct:temperature=0.1:response_template=plain --repeats 5
start_sglang_server "Qwen/Qwen2-VL-7B-Instruct" --mem-fraction-static 0.7
python -m olmocr.bench.convert server:name=qwen2_vl_7b:model=Qwen/Qwen2-VL-7B-Instruct:temperature=0.1:prompt_template=full:response_template=plain --repeats 5 --parallel 20
stop_sglang_server

# qwen25_vl_7b
start_sglang_server "Qwen/Qwen2.5-VL-7B-Instruct"
python -m olmocr.bench.convert server:name=qwen25_vl_7b:model=Qwen/Qwen2.5-VL-7B-Instruct:temperature=0.1:response_template=plain --repeats 5
start_sglang_server "Qwen/Qwen2.5-VL-7B-Instruct" --mem-fraction-static 0.7
python -m olmocr.bench.convert server:name=qwen25_vl_7b:model=Qwen/Qwen2.5-VL-7B-Instruct:temperature=0.1:prompt_template=full:response_template=plain --repeats 5 --parallel 20
stop_sglang_server

# Create and activate mineru environment
Expand Down

0 comments on commit 5611d79

Please sign in to comment.