Skip to content

Commit e6c1308

Browse files
authored
Merge branch 'main' into main
2 parents 9e85d88 + 8666268 commit e6c1308

File tree

1 file changed

+43
-2
lines changed

1 file changed

+43
-2
lines changed

src/agentlab/analyze/agent_xray.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from attr import dataclass
1515
from langchain.schema import BaseMessage, HumanMessage
1616
from openai import OpenAI
17-
from PIL import Image
17+
from PIL import Image, ImageDraw
1818

1919
from agentlab.analyze import inspect_results
2020
from agentlab.experiments.exp_utils import RESULTS_DIR
@@ -530,9 +530,47 @@ def wrapper(*args, **kwargs):
530530
return decorator
531531

532532

533+
def tag_screenshot_with_action(screenshot: Image, action: str) -> Image:
534+
"""
535+
If action is a coordinate action, try to render it on the screenshot.
536+
537+
e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot
538+
539+
Args:
540+
screenshot: The screenshot to tag.
541+
action: The action to tag the screenshot with.
542+
543+
Returns:
544+
The tagged screenshot.
545+
546+
Raises:
547+
ValueError: If the action parsing fails.
548+
"""
549+
if action.startswith("mouse_click"):
550+
try:
551+
coords = action[action.index("(") + 1 : action.index(")")].split(",")
552+
coords = [c.strip() for c in coords]
553+
if len(coords) != 2:
554+
raise ValueError(f"Invalid coordinate format: {coords}")
555+
if coords[0].startswith("x="):
556+
coords[0] = coords[0][2:]
557+
if coords[1].startswith("y="):
558+
coords[1] = coords[1][2:]
559+
x, y = float(coords[0].strip()), float(coords[1].strip())
560+
draw = ImageDraw.Draw(screenshot)
561+
radius = 5
562+
draw.ellipse(
563+
(x - radius, y - radius, x + radius, y + radius), fill="red", outline="red"
564+
)
565+
except (ValueError, IndexError) as e:
566+
warning(f"Failed to parse action '{action}': {e}")
567+
return screenshot
568+
569+
533570
def update_screenshot(som_or_not: str):
534571
global info
535-
return get_screenshot(info, som_or_not=som_or_not)
572+
action = info.exp_result.steps_info[info.step].action
573+
return tag_screenshot_with_action(get_screenshot(info, som_or_not=som_or_not), action)
536574

537575

538576
def get_screenshot(info: Info, step: int = None, som_or_not: str = "Raw Screenshots"):
@@ -549,6 +587,9 @@ def update_screenshot_pair(som_or_not: str):
549587
global info
550588
s1 = get_screenshot(info, info.step, som_or_not)
551589
s2 = get_screenshot(info, info.step + 1, som_or_not)
590+
591+
if s1 is not None:
592+
s1 = tag_screenshot_with_action(s1, info.exp_result.steps_info[info.step].action)
552593
return s1, s2
553594

554595

0 commit comments

Comments
 (0)