-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathgeobench.py
More file actions
407 lines (329 loc) · 15.7 KB
/
geobench.py
File metadata and controls
407 lines (329 loc) · 15.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
import os
import json
import math
import pandas as pd
import re
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import datetime
import argparse
import haversine
from dotenv import load_dotenv
from geo2p.canon import are_same_country
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from scripts.parser import parse_response, Guess
SYSTEM_PROMPT = """
You are participating in a geolocation challenge. Based on the provided image:
1. Carefully analyze the image for clues about its location (architecture, signage, vegetation, terrain, etc.)
2. Think step-by-step about what country this is likely to be in and why
3. Estimate the approximate latitude and longitude based on your analysis
Take your time to reason through the evidence. Your final answer MUST include these three lines somewhere in your response:
country: [country name]
lat: [latitude as a decimal number]
lng: [longitude as a decimal number]
You can provide additional reasoning or explanation, but these three specific lines MUST be included.
"""
SEARCH = """
You have access to Google Search, which you should use to improve your answer.
"""
from models import *
load_dotenv()
@dataclass
class Location:
image_path: str
country: str
lat: float
lng: float
@property
def coordinates(self) -> Tuple[float, float]:
return (self.lat, self.lng)
@property
def id(self) -> str:
filename = os.path.basename(self.image_path)
name_without_ext = os.path.splitext(filename)[0]
return name_without_ext
@dataclass
class BenchmarkResult:
location: Location
guess: Optional[Guess]
distance_km: Optional[float] = None
score: Optional[int] = None
country_correct: Optional[bool] = None
refused: bool = False
error_message: Optional[str] = None
def calculate_metrics(self, scale):
if self.refused or self.guess is None:
self.distance_km = None
self.score = 0
self.country_correct = False
return
self.distance_km = haversine.haversine(
self.location.coordinates,
self.guess.coordinates
)
if self.distance_km is not None:
self.score = calculate_score(self.distance_km / 1000, scale)
else:
self.score = None
self.country_correct = are_same_country(self.location.country, self.guess.country)
class GeoGuessrBenchmark:
def __init__(self,
dataset_path: str,
model: str = "ClaudeHaiku",
api_key: Optional[str] = None,
max_retries: int = 3):
self.dataset_path = dataset_path
self.locations = self._load_dataset()
self.results = []
self.max_retries = max_retries
self.results_lock = threading.Lock()
self.run_folder = ""
try:
model_class = globals()[model]
if not issubclass(model_class, BaseMultimodalModel):
raise ValueError(f"{model} is not a valid BaseMultimodalModel class")
if not api_key and model_class.api_key_name:
api_key = os.getenv(model_class.api_key_name)
if not api_key:
raise ValueError(f"API key {model_class.api_key_name} not found for {model}")
self.model = model_class(api_key)
except KeyError:
raise ValueError(f"Unknown model provider: {model}. Make sure the class is defined.")
def _load_dataset(self) -> List[Location]:
with open(os.path.join(self.dataset_path, "metadata.json"), "r") as f:
data = json.load(f)
bounds = data['bounds']
min_bound = (bounds['min']['lat'], bounds['min']['lng'])
max_bound = (bounds['max']['lat'], bounds['max']['lng'])
self.scale = haversine.haversine(min_bound, max_bound) / 7.458421
locations = []
for item in data['images']:
locations.append(Location(
image_path=os.path.join(self.dataset_path, item["image_path"]),
country=item["country"],
lat=item["lat"],
lng=item["lng"]
))
return locations
def run_benchmark(self, args) -> Dict:
locations_to_test = self.locations
if args.sample_id is not None:
locations_to_test = [loc for loc in self.locations if loc.id == str(args.sample_id)]
if not locations_to_test:
raise ValueError(f"Image ID '{args.sample_id}' not found in dataset")
elif args.samples and args.samples < len(self.locations):
import random
locations_to_test = random.sample(self.locations, args.samples)
self.results = []
start_index = 0
if args.continue_from is not None:
if 1 <= args.continue_from <= len(locations_to_test):
start_index = args.continue_from - 1 # Adjust to 0-based index
else:
raise ValueError(f"Invalid continue-from value: {args.continue_from}. Must be between 1 and {len(locations_to_test)}")
# Filter out skipped locations
locations_to_process = [(i, loc) for i, loc in enumerate(locations_to_test) if i >= start_index]
if args.parallel and args.parallel > 1:
self._parallel_workers = args.parallel
self._run_parallel(locations_to_process, len(locations_to_test))
else:
self._run_sequential(locations_to_process, len(locations_to_test))
return self._compile_results()
def _run_sequential(self, locations_to_process, total_count):
for i, location in locations_to_process:
print(f"Testing location: {location.id} ({i+1}/{total_count})")
result = self._evaluate_location(location)
self.results.append(result)
self._print_result(result)
self._save_incremental_results()
def _run_parallel(self, locations_to_process, total_count):
results_with_index = []
num_workers = getattr(self, '_parallel_workers', 1)
with ThreadPoolExecutor(max_workers=num_workers) as executor:
future_to_location = {
executor.submit(self._evaluate_location, loc): (i, loc)
for i, loc in locations_to_process
}
for completed_count, future in enumerate(as_completed(future_to_location), 1):
i, location = future_to_location[future]
try:
result = future.result()
except Exception as e:
print(f"Error processing location {location.id}: {str(e)}")
result = BenchmarkResult(
location=location,
guess=None,
refused=True,
error_message=str(e)
)
with self.results_lock:
results_with_index.append((i, result))
# Sort and update results list to maintain order
results_with_index.sort(key=lambda x: x[0])
self.results = [r for _, r in results_with_index]
# Save after each completion
self._save_incremental_results()
print(f"Completed location: {location.id} ({completed_count}/{len(locations_to_process)})")
self._print_result(result)
def _print_result(self, result):
if result.refused:
print(f" ✗ REFUSED: {result.error_message}")
else:
status = "✓" if result.country_correct else "✗"
distance = f"{result.distance_km:.1f}km" if result.distance_km is not None else "N/A"
score = result.score if result.score is not None else "0"
print(f" {status} Distance: {distance}, Score: {score}")
def _evaluate_location(self, location: Location) -> BenchmarkResult:
for attempt in range(self.max_retries):
try:
response = self.model.query(location.image_path, SYSTEM_PROMPT, self.run_folder, location.id)
os.makedirs(f"{self.run_folder}/output/", exist_ok=True)
with open(f"{self.run_folder}/output/{location.id}.txt", "w", encoding="utf-8") as f:
f.write(response)
try:
guess = parse_response(response)
result = BenchmarkResult(location=location, guess=guess)
result.calculate_metrics(self.scale)
return result
except ValueError as parse_error:
# Don't retry format errors from the LLM
print(f" Format error (attempt {attempt+1}): {str(parse_error)}")
if "missing required fields" in str(parse_error) or "parse" in str(parse_error):
return BenchmarkResult(
location=location,
guess=None,
refused=True,
error_message=f"Format error: {str(parse_error)}"
)
except Exception as e:
error_msg = str(e)
print(f" API/network error (attempt {attempt+1}): {error_msg}")
if attempt < self.max_retries - 1:
print(f" Retrying...")
continue
return BenchmarkResult(
location=location,
guess=None,
refused=True,
error_message=error_msg
)
# This should never be reached due to the return in the exception handler
return BenchmarkResult(
location=location,
guess=None,
refused=True,
error_message="Max retries exceeded"
)
def _compile_results(self) -> Dict:
total = len(self.results)
country_correct = sum(1 for r in self.results if r.country_correct)
refusals = sum(1 for r in self.results if r.refused)
valid_results = [r for r in self.results if not r.refused]
avg_distance = sum(r.distance_km for r in valid_results) / len(valid_results) if valid_results else None
avg_score = sum(r.score for r in valid_results) / len(valid_results) if valid_results else None
median_distance = sorted(r.distance_km for r in valid_results)[len(valid_results) // 2] if valid_results else None
median_score = sorted(r.score for r in valid_results)[len(valid_results) // 2] if valid_results else None
return {
"model": self.model.name,
"test": os.path.basename(self.dataset_path),
"n": total,
"country_success_rate": country_correct / total if total > 0 else 0,
"refusal_rate": refusals / total if total > 0 else 0,
"average_distance_km": avg_distance,
"average_score": avg_score,
"median_distance_km": median_distance,
"median_score": median_score,
"provider": self.model.provider,
"detailed_results": self.results
}
def save_results(self, output_path: str):
results_dict = self._compile_results()
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(f"{output_path}summary.json", "w") as f:
json.dump({k: v for k, v in results_dict.items() if k != "detailed_results"}, f, indent=2)
records = []
for r in self.results:
record = {
"location_id": r.location.id,
"country_true": r.location.country,
"lat_true": r.location.lat,
"lng_true": r.location.lng,
"refused": r.refused,
"error_message": r.error_message
}
if not r.refused and r.guess:
record.update({
"country_guess": r.guess.country,
"lat_guess": r.guess.lat,
"lng_guess": r.guess.lng,
"distance_km": r.distance_km,
"score": r.score,
"country_correct": r.country_correct
})
records.append(record)
pd.DataFrame(records).to_csv(f"{output_path}detailed.csv", index=False)
def _save_incremental_results(self):
if not self.run_folder:
return
output_path = f"{self.run_folder}/results/"
os.makedirs(output_path, exist_ok=True)
records = []
for r in self.results:
record = {
"location_id": r.location.id,
"country_true": r.location.country,
"lat_true": r.location.lat,
"lng_true": r.location.lng,
"refused": r.refused,
"error_message": r.error_message
}
if not r.refused and r.guess:
record.update({
"country_guess": r.guess.country,
"lat_guess": r.guess.lat,
"lng_guess": r.guess.lng,
"distance_km": r.distance_km,
"score": r.score,
"country_correct": r.country_correct
})
records.append(record)
pd.DataFrame(records).to_csv(f"{output_path}detailed.csv", index=False)
results_dict = self._compile_results()
with open(f"{output_path}summary.json", "w") as f:
json.dump({k: v for k, v in results_dict.items() if k != "detailed_results"}, f, indent=2)
def calculate_score(distance: float, scale: float) -> int:
if distance * 1000000 <= 25:
return 5000
return round(5000 * math.pow(0.99866017, (distance * 1000000) / scale))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GeoGuessr Benchmark Tool")
parser.add_argument("--dataset", "-d", type=str, default="acw",
help="Dataset subfolder to use (default: 'acw')")
parser.add_argument("--samples", "-n", type=int, default=None,
help="Number of samples to test (default: all)")
parser.add_argument("--sample-id", "-i", type=int, default=None, help="Run a specific sample by ID")
parser.add_argument("--model", "-m", type=str, default="claude",
help="Model provider to use (default: 'claude')")
parser.add_argument("--max-retries", type=int, default=3,
help="Maximum number of retries for API/network errors (default: 3)")
parser.add_argument("--continue-from", type=int, default=None,
help="Continue from a specific sample number (1-indexed)")
parser.add_argument("--parallel", "-p", type=int, default=1,
help="Number of parallel workers to use (default: 1)")
args = parser.parse_args()
dataset_path = f"dataset/{args.dataset}"
benchmark = GeoGuessrBenchmark(
dataset_path=dataset_path,
model=args.model,
max_retries=args.max_retries
)
runtime = datetime.datetime.now().strftime('%Y-%m-%dT%H_%M_%S')
benchmark.run_folder = f"responses/{benchmark.model.name}_{args.dataset}_{runtime}"
results = benchmark.run_benchmark(args)
benchmark.save_results(benchmark.run_folder + "/results/")
print(f"Total samples: {results['n']}")
print(f"Country success rate: {results['country_success_rate']:.2%}")
print(f"Average distance: {results['average_distance_km']:.1f} km")
print(f"Average score: {results['average_score']:.1f}")
print(f"Refusal rate: {results['refusal_rate']:.2%}")