Skip to content

Commit 4e61b2a

Browse files
committed
Merge branch 'main' into Frontend-UI
2 parents f730959 + 237687e commit 4e61b2a

File tree

1 file changed

+56
-56
lines changed
  • visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents

1 file changed

+56
-56
lines changed

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/base_agent.py

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
openai_client = OpenAI()
3535

3636

37-
class BaseAgent:
37+
class BaseAgent:
3838
# no need to pass an initial playwright_manager to the agent class
3939
def __init__(
4040
self,
@@ -94,7 +94,7 @@ def get_path_to_root(self, node: LATSNode) -> List[LATSNode]:
9494
def _get_tree_data(self):
9595
nodes = collect_all_nodes(self.root_node)
9696
tree_data = []
97-
97+
9898
for node in nodes:
9999
node_data = {
100100
"id": id(node),
@@ -109,21 +109,21 @@ def _get_tree_data(self):
109109
# "reward": node.reward
110110
}
111111
tree_data.append(node_data)
112-
112+
113113
return tree_data
114-
114+
115115
## TODO: newly added, debug needed
116116
async def remove_simulated_trajectory(self, starting_node, terminal_node: LATSNode, websocket=None):
117117
# to be implemented
118118
trajectory_data = []
119119
path = []
120-
120+
121121
# Collect path from terminal to root
122122
current = terminal_node
123123
while current is not starting_node:
124124
path.append(current)
125125
current = current.parent
126-
126+
127127
# Process nodes in order from root to terminal
128128
for level, node in enumerate(reversed(path)):
129129
node_data = {
@@ -140,20 +140,20 @@ async def remove_simulated_trajectory(self, starting_node, terminal_node: LATSNo
140140
"is_terminal_node": node == terminal_node
141141
}
142142
trajectory_data.append(node_data)
143-
143+
144144
await self.websocket_simulation_removed(trajectory_data, websocket=websocket)
145145
pass
146-
146+
147147
def _get_trajectory_data(self, terminal_node: LATSNode):
148148
trajectory_data = []
149149
path = []
150-
150+
151151
# Collect path from terminal to root
152152
current = terminal_node
153153
while current is not None:
154154
path.append(current)
155155
current = current.parent
156-
156+
157157
# Process nodes in order from root to terminal
158158
for level, node in enumerate(reversed(path)):
159159
node_data = {
@@ -170,15 +170,15 @@ def _get_trajectory_data(self, terminal_node: LATSNode):
170170
"is_terminal_node": node == terminal_node
171171
}
172172
trajectory_data.append(node_data)
173-
174-
return trajectory_data
173+
174+
return trajectory_data
175175

176176

177177
async def _reset_browser(self, websocket=None) -> Optional[str]:
178178
await self.playwright_manager.close()
179-
179+
180180
## reset account using api-based account reset
181-
if self.config.account_reset:
181+
if self.config.account_reset:
182182
try:
183183
# Use aiohttp instead of curl
184184
async with aiohttp.ClientSession() as session:
@@ -204,7 +204,7 @@ async def _reset_browser(self, websocket=None) -> Optional[str]:
204204
"reason": error_msg,
205205
"timestamp": datetime.utcnow().isoformat()
206206
})
207-
207+
208208
except Exception as e:
209209
print(f"Error during account reset: {e}")
210210
if websocket:
@@ -231,7 +231,7 @@ async def _reset_browser(self, websocket=None) -> Optional[str]:
231231
session_id = None
232232
live_browser_url = None
233233
await page.goto(self.starting_url, wait_until="networkidle")
234-
234+
235235
# Send success message if websocket is provided
236236
if websocket:
237237
if self.config.storage_state:
@@ -252,7 +252,7 @@ async def _reset_browser(self, websocket=None) -> Optional[str]:
252252
"session_id": session_id,
253253
"timestamp": datetime.utcnow().isoformat()
254254
})
255-
255+
256256
return live_browser_url, session_id
257257
except Exception as e:
258258
print(f"Error setting up browser: {e}")
@@ -264,7 +264,7 @@ async def _reset_browser(self, websocket=None) -> Optional[str]:
264264
"timestamp": datetime.utcnow().isoformat()
265265
})
266266
return None, None
267-
267+
268268
# TODO: if no websocket, print the json data
269269
# TODO: do we need node expansion data?
270270
# TODO: four types of websocket messages, do we need more type of websocket messages?
@@ -311,7 +311,7 @@ async def websocket_tree_update(self, type, tree_data, websocket=None):
311311
})
312312
else:
313313
print(f"{type} updated: {tree_data}")
314-
314+
315315
async def websocket_node_created(self, child, node, websocket=None):
316316
if websocket:
317317
await websocket.send_json({
@@ -327,7 +327,7 @@ async def websocket_node_created(self, child, node, websocket=None):
327327
print(f"Node parent: {GREEN}{id(node)}{RESET}")
328328
print(f"Node action: {GREEN}{child.action}{RESET}")
329329
print(f"Node description: {GREEN}{child.natural_language_description}{RESET}")
330-
330+
331331
## node simulated
332332
## message log and d3 visualization add different information
333333
async def websocket_node_simulated(self, child, node, websocket=None):
@@ -371,7 +371,7 @@ async def websocket_simulation_result(self, reward, terminal_node, websocket=Non
371371
else:
372372
print(f"Simulation reward: {GREEN}{reward}{RESET}")
373373
print(f"Simulation terminal node: {GREEN}{terminal_node}{RESET}")
374-
374+
375375
async def websocket_search_complete(self, status, score, path, websocket=None):
376376
if websocket:
377377
await websocket.send_json({
@@ -385,11 +385,11 @@ async def websocket_search_complete(self, status, score, path, websocket=None):
385385
print(f"Search complete: {GREEN}{status}{RESET}")
386386
print(f"Search score: {GREEN}{score}{RESET}")
387387
print(f"Search path: {GREEN}{path}{RESET}")
388-
388+
389389
# shared, not implemented, BFS, DFS and LATS has its own node selection logic
390390
async def node_selection(self, node, websocket = None):
391391
NotImplemented
392-
392+
393393

394394
async def node_expansion(self, node: LATSNode, websocket = None) -> None:
395395
if websocket:
@@ -440,10 +440,10 @@ async def node_expansion(self, node: LATSNode, websocket = None) -> None:
440440
"timestamp": datetime.utcnow().isoformat()
441441
})
442442

443-
443+
444444
# node evaluation
445445
# change the node evaluation to use the new prompt
446-
async def node_children_evaluation(self, node: LATSNode, websocket=None) -> None:
446+
async def node_children_evaluation(self, node: LATSNode, websocket = None) -> None:
447447
if websocket:
448448
await websocket.send_json({
449449
"type": "evaluation_start",
@@ -480,7 +480,7 @@ async def node_children_evaluation(self, node: LATSNode, websocket=None) -> None
480480
child.value = score
481481
# child.reward = score
482482

483-
async def node_evaluation(self, node: LATSNode, websocket=None) -> None:
483+
async def node_evaluation(self, node: LATSNode, websocket = None) -> None:
484484
"""Evaluate the current node and assign its score."""
485485
if websocket:
486486
node_info = {
@@ -500,7 +500,7 @@ async def node_evaluation(self, node: LATSNode, websocket=None) -> None:
500500
try:
501501
# Get the path from root to this node
502502
path = self.get_path_to_root(node)
503-
503+
504504
# Create trajectory for scoring (skip root node)
505505
trajectory = []
506506
for n in path[1:]: # Skip root node
@@ -509,7 +509,7 @@ async def node_evaluation(self, node: LATSNode, websocket=None) -> None:
509509
"action": n.action,
510510
"feedback": n.feedback
511511
})
512-
512+
513513
try:
514514
# Score the trajectory
515515
# if node.is_terminal:
@@ -520,21 +520,21 @@ async def node_evaluation(self, node: LATSNode, websocket=None) -> None:
520520
else:
521521
prompt = create_llm_prompt(trajectory, self.goal)
522522
result = score_trajectory_with_openai(
523-
prompt,
524-
openai_client,
523+
prompt,
524+
openai_client,
525525
model=self.config.evaluation_model
526526
)
527527
score = result["overall_score"]
528-
528+
529529
except Exception as e:
530530
error_msg = f"Error scoring node {id(node)}: {str(e)}"
531531
print(error_msg)
532532
score = float('-inf')
533-
533+
534534
# Assign the score to the node
535535
node.value = score
536536
# node.reward = score
537-
537+
538538
if websocket:
539539
await websocket.send_json({
540540
"type": "node_evaluation_complete",
@@ -548,7 +548,7 @@ async def node_evaluation(self, node: LATSNode, websocket=None) -> None:
548548
except Exception as e:
549549
error_msg = f"Error in node evaluation: {str(e)}"
550550
print(error_msg)
551-
551+
552552
# shared
553553
## TODO: check the logic of updating value/ reward, is the input value?
554554
def backpropagate(self, node: LATSNode, value: float) -> None:
@@ -569,26 +569,26 @@ async def simulation(self, node: LATSNode, websocket=None) -> tuple[float, LATSN
569569
print("print the entire tree")
570570
print_entire_tree(self.root_node)
571571
return await self.rollout(node, websocket=websocket)
572-
572+
573573
# refactor simulation, rollout, send_completion_request methods
574574
# TODO: check, score as reward and then update value of the starting node?
575575
async def rollout(self, node: LATSNode, websocket=None)-> tuple[float, LATSNode]:
576576
# Reset browser state
577577
live_browser_url, session_id = await self._reset_browser(websocket)
578578
path = self.get_path_to_root(node)
579-
579+
580580
print("execute path")
581581
# Execute path
582582

583583
messages = []
584584
trajectory = []
585-
585+
586586
for n in path[1:]: # Skip root node
587587
success = await playwright_step_execution(
588-
n,
589-
self.goal,
590-
self.playwright_manager,
591-
is_replay=False,
588+
n,
589+
self.goal,
590+
self.playwright_manager,
591+
is_replay=False,
592592
log_folder=self.config.log_folder
593593
)
594594
if not success:
@@ -617,18 +617,18 @@ async def rollout(self, node: LATSNode, websocket=None)-> tuple[float, LATSNode]
617617

618618
messages = [{"role": "user", "content": f"Action is: {n.action}"} for n in path[1:]]
619619
goal_finished, confidence_score = goal_finished_evaluator(
620-
messages,
621-
openai_client,
622-
self.goal,
620+
messages,
621+
openai_client,
622+
self.goal,
623623
page_info['screenshot']
624624
)
625625
print("evaluating")
626-
626+
627627
score = confidence_score if goal_finished else 0
628628
await self.remove_simulated_trajectory(starting_node=node, terminal_node=terminal_node, websocket=websocket)
629629

630630
return score, terminal_node
631-
631+
632632

633633
# TODO: decide whether to keep the tree update
634634
async def send_completion_request(self, plan, depth, node, trajectory=[], websocket=None):
@@ -661,7 +661,7 @@ async def send_completion_request(self, plan, depth, node, trajectory=[], websoc
661661
)
662662
next_action = updated_actions[0]
663663
retry_count = self.config.retry_count if hasattr(self.config, 'retry_count') else 1 # Default retries if not set
664-
664+
665665
for attempt in range(retry_count):
666666
try:
667667
# Convert action to Python code
@@ -673,7 +673,7 @@ async def send_completion_request(self, plan, depth, node, trajectory=[], websoc
673673
extracted_number = parse_function_args(function_args)
674674
element = await locate_element(page, extracted_number)
675675
next_action["element"] = element
676-
676+
677677
# Execute action
678678
await execute_action(next_action, self.action_set, page, context, self.goal, page_info['interactive_elements'],
679679
self.config.log_folder)
@@ -730,9 +730,9 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
730730
# Reset browser and get live URL
731731
live_browser_url, session_id = await self._reset_browser(websocket)
732732
path = self.get_path_to_root(node)
733-
733+
734734
# Execute path
735-
for n in path[1:]: # Skip root node
735+
for n in path[1:]: # Skip root node
736736
success = await playwright_step_execution(
737737
n,
738738
self.goal,
@@ -743,7 +743,7 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
743743
if not success:
744744
n.is_terminal = True
745745
return []
746-
746+
747747
if not n.feedback:
748748
n.feedback = await generate_feedback(
749749
self.goal,
@@ -757,11 +757,11 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
757757

758758
messages = [{"role": "user", "content": f"Action is: {n.action}"} for n in path[1:]]
759759

760-
760+
761761
next_actions = await extract_top_actions(
762762
[{"natural_language_description": n.natural_language_description, "action": n.action, "feedback": n.feedback} for n in path[1:]],
763763
self.goal,
764-
self.images,
764+
self.images,
765765
page_info,
766766
self.action_set,
767767
openai_client,
@@ -788,7 +788,7 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
788788
})
789789
return []
790790
continue
791-
791+
792792
page = await self.playwright_manager.get_page()
793793
code, function_calls = self.action_set.to_python_code(action["action"])
794794

@@ -803,5 +803,5 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
803803
children.append(action)
804804

805805
if not children:
806-
node.is_terminal = True
807-
return children
806+
node.is_terminal = True
807+
return children

0 commit comments

Comments
 (0)