34
34
openai_client = OpenAI ()
35
35
36
36
37
- class BaseAgent :
37
+ class BaseAgent :
38
38
# no need to pass an initial playwright_manager to the agent class
39
39
def __init__ (
40
40
self ,
@@ -94,7 +94,7 @@ def get_path_to_root(self, node: LATSNode) -> List[LATSNode]:
94
94
def _get_tree_data (self ):
95
95
nodes = collect_all_nodes (self .root_node )
96
96
tree_data = []
97
-
97
+
98
98
for node in nodes :
99
99
node_data = {
100
100
"id" : id (node ),
@@ -109,21 +109,21 @@ def _get_tree_data(self):
109
109
# "reward": node.reward
110
110
}
111
111
tree_data .append (node_data )
112
-
112
+
113
113
return tree_data
114
-
114
+
115
115
## TODO: newly added, debug needed
116
116
async def remove_simulated_trajectory (self , starting_node , terminal_node : LATSNode , websocket = None ):
117
117
# to be implemented
118
118
trajectory_data = []
119
119
path = []
120
-
120
+
121
121
# Collect path from terminal to root
122
122
current = terminal_node
123
123
while current is not starting_node :
124
124
path .append (current )
125
125
current = current .parent
126
-
126
+
127
127
# Process nodes in order from root to terminal
128
128
for level , node in enumerate (reversed (path )):
129
129
node_data = {
@@ -140,20 +140,20 @@ async def remove_simulated_trajectory(self, starting_node, terminal_node: LATSNo
140
140
"is_terminal_node" : node == terminal_node
141
141
}
142
142
trajectory_data .append (node_data )
143
-
143
+
144
144
await self .websocket_simulation_removed (trajectory_data , websocket = websocket )
145
145
pass
146
-
146
+
147
147
def _get_trajectory_data (self , terminal_node : LATSNode ):
148
148
trajectory_data = []
149
149
path = []
150
-
150
+
151
151
# Collect path from terminal to root
152
152
current = terminal_node
153
153
while current is not None :
154
154
path .append (current )
155
155
current = current .parent
156
-
156
+
157
157
# Process nodes in order from root to terminal
158
158
for level , node in enumerate (reversed (path )):
159
159
node_data = {
@@ -170,15 +170,15 @@ def _get_trajectory_data(self, terminal_node: LATSNode):
170
170
"is_terminal_node" : node == terminal_node
171
171
}
172
172
trajectory_data .append (node_data )
173
-
174
- return trajectory_data
173
+
174
+ return trajectory_data
175
175
176
176
177
177
async def _reset_browser (self , websocket = None ) -> Optional [str ]:
178
178
await self .playwright_manager .close ()
179
-
179
+
180
180
## reset account using api-based account reset
181
- if self .config .account_reset :
181
+ if self .config .account_reset :
182
182
try :
183
183
# Use aiohttp instead of curl
184
184
async with aiohttp .ClientSession () as session :
@@ -204,7 +204,7 @@ async def _reset_browser(self, websocket=None) -> Optional[str]:
204
204
"reason" : error_msg ,
205
205
"timestamp" : datetime .utcnow ().isoformat ()
206
206
})
207
-
207
+
208
208
except Exception as e :
209
209
print (f"Error during account reset: { e } " )
210
210
if websocket :
@@ -231,7 +231,7 @@ async def _reset_browser(self, websocket=None) -> Optional[str]:
231
231
session_id = None
232
232
live_browser_url = None
233
233
await page .goto (self .starting_url , wait_until = "networkidle" )
234
-
234
+
235
235
# Send success message if websocket is provided
236
236
if websocket :
237
237
if self .config .storage_state :
@@ -252,7 +252,7 @@ async def _reset_browser(self, websocket=None) -> Optional[str]:
252
252
"session_id" : session_id ,
253
253
"timestamp" : datetime .utcnow ().isoformat ()
254
254
})
255
-
255
+
256
256
return live_browser_url , session_id
257
257
except Exception as e :
258
258
print (f"Error setting up browser: { e } " )
@@ -264,7 +264,7 @@ async def _reset_browser(self, websocket=None) -> Optional[str]:
264
264
"timestamp" : datetime .utcnow ().isoformat ()
265
265
})
266
266
return None , None
267
-
267
+
268
268
# TODO: if no websocket, print the json data
269
269
# TODO: do we need node expansion data?
270
270
# TODO: four types of websocket messages, do we need more type of websocket messages?
@@ -311,7 +311,7 @@ async def websocket_tree_update(self, type, tree_data, websocket=None):
311
311
})
312
312
else :
313
313
print (f"{ type } updated: { tree_data } " )
314
-
314
+
315
315
async def websocket_node_created (self , child , node , websocket = None ):
316
316
if websocket :
317
317
await websocket .send_json ({
@@ -327,7 +327,7 @@ async def websocket_node_created(self, child, node, websocket=None):
327
327
print (f"Node parent: { GREEN } { id (node )} { RESET } " )
328
328
print (f"Node action: { GREEN } { child .action } { RESET } " )
329
329
print (f"Node description: { GREEN } { child .natural_language_description } { RESET } " )
330
-
330
+
331
331
## node simulated
332
332
## message log and d3 visualization add different information
333
333
async def websocket_node_simulated (self , child , node , websocket = None ):
@@ -371,7 +371,7 @@ async def websocket_simulation_result(self, reward, terminal_node, websocket=Non
371
371
else :
372
372
print (f"Simulation reward: { GREEN } { reward } { RESET } " )
373
373
print (f"Simulation terminal node: { GREEN } { terminal_node } { RESET } " )
374
-
374
+
375
375
async def websocket_search_complete (self , status , score , path , websocket = None ):
376
376
if websocket :
377
377
await websocket .send_json ({
@@ -385,11 +385,11 @@ async def websocket_search_complete(self, status, score, path, websocket=None):
385
385
print (f"Search complete: { GREEN } { status } { RESET } " )
386
386
print (f"Search score: { GREEN } { score } { RESET } " )
387
387
print (f"Search path: { GREEN } { path } { RESET } " )
388
-
388
+
389
389
# shared, not implemented, BFS, DFS and LATS has its own node selection logic
390
390
async def node_selection (self , node , websocket = None ):
391
391
NotImplemented
392
-
392
+
393
393
394
394
async def node_expansion (self , node : LATSNode , websocket = None ) -> None :
395
395
if websocket :
@@ -440,10 +440,10 @@ async def node_expansion(self, node: LATSNode, websocket = None) -> None:
440
440
"timestamp" : datetime .utcnow ().isoformat ()
441
441
})
442
442
443
-
443
+
444
444
# node evaluation
445
445
# change the node evaluation to use the new prompt
446
- async def node_children_evaluation (self , node : LATSNode , websocket = None ) -> None :
446
+ async def node_children_evaluation (self , node : LATSNode , websocket = None ) -> None :
447
447
if websocket :
448
448
await websocket .send_json ({
449
449
"type" : "evaluation_start" ,
@@ -480,7 +480,7 @@ async def node_children_evaluation(self, node: LATSNode, websocket=None) -> None
480
480
child .value = score
481
481
# child.reward = score
482
482
483
- async def node_evaluation (self , node : LATSNode , websocket = None ) -> None :
483
+ async def node_evaluation (self , node : LATSNode , websocket = None ) -> None :
484
484
"""Evaluate the current node and assign its score."""
485
485
if websocket :
486
486
node_info = {
@@ -500,7 +500,7 @@ async def node_evaluation(self, node: LATSNode, websocket=None) -> None:
500
500
try :
501
501
# Get the path from root to this node
502
502
path = self .get_path_to_root (node )
503
-
503
+
504
504
# Create trajectory for scoring (skip root node)
505
505
trajectory = []
506
506
for n in path [1 :]: # Skip root node
@@ -509,7 +509,7 @@ async def node_evaluation(self, node: LATSNode, websocket=None) -> None:
509
509
"action" : n .action ,
510
510
"feedback" : n .feedback
511
511
})
512
-
512
+
513
513
try :
514
514
# Score the trajectory
515
515
# if node.is_terminal:
@@ -520,21 +520,21 @@ async def node_evaluation(self, node: LATSNode, websocket=None) -> None:
520
520
else :
521
521
prompt = create_llm_prompt (trajectory , self .goal )
522
522
result = score_trajectory_with_openai (
523
- prompt ,
524
- openai_client ,
523
+ prompt ,
524
+ openai_client ,
525
525
model = self .config .evaluation_model
526
526
)
527
527
score = result ["overall_score" ]
528
-
528
+
529
529
except Exception as e :
530
530
error_msg = f"Error scoring node { id (node )} : { str (e )} "
531
531
print (error_msg )
532
532
score = float ('-inf' )
533
-
533
+
534
534
# Assign the score to the node
535
535
node .value = score
536
536
# node.reward = score
537
-
537
+
538
538
if websocket :
539
539
await websocket .send_json ({
540
540
"type" : "node_evaluation_complete" ,
@@ -548,7 +548,7 @@ async def node_evaluation(self, node: LATSNode, websocket=None) -> None:
548
548
except Exception as e :
549
549
error_msg = f"Error in node evaluation: { str (e )} "
550
550
print (error_msg )
551
-
551
+
552
552
# shared
553
553
## TODO: check the logic of updating value/ reward, is the input value?
554
554
def backpropagate (self , node : LATSNode , value : float ) -> None :
@@ -569,26 +569,26 @@ async def simulation(self, node: LATSNode, websocket=None) -> tuple[float, LATSN
569
569
print ("print the entire tree" )
570
570
print_entire_tree (self .root_node )
571
571
return await self .rollout (node , websocket = websocket )
572
-
572
+
573
573
# refactor simulation, rollout, send_completion_request methods
574
574
# TODO: check, score as reward and then update value of the starting node?
575
575
async def rollout (self , node : LATSNode , websocket = None )-> tuple [float , LATSNode ]:
576
576
# Reset browser state
577
577
live_browser_url , session_id = await self ._reset_browser (websocket )
578
578
path = self .get_path_to_root (node )
579
-
579
+
580
580
print ("execute path" )
581
581
# Execute path
582
582
583
583
messages = []
584
584
trajectory = []
585
-
585
+
586
586
for n in path [1 :]: # Skip root node
587
587
success = await playwright_step_execution (
588
- n ,
589
- self .goal ,
590
- self .playwright_manager ,
591
- is_replay = False ,
588
+ n ,
589
+ self .goal ,
590
+ self .playwright_manager ,
591
+ is_replay = False ,
592
592
log_folder = self .config .log_folder
593
593
)
594
594
if not success :
@@ -617,18 +617,18 @@ async def rollout(self, node: LATSNode, websocket=None)-> tuple[float, LATSNode]
617
617
618
618
messages = [{"role" : "user" , "content" : f"Action is: { n .action } " } for n in path [1 :]]
619
619
goal_finished , confidence_score = goal_finished_evaluator (
620
- messages ,
621
- openai_client ,
622
- self .goal ,
620
+ messages ,
621
+ openai_client ,
622
+ self .goal ,
623
623
page_info ['screenshot' ]
624
624
)
625
625
print ("evaluating" )
626
-
626
+
627
627
score = confidence_score if goal_finished else 0
628
628
await self .remove_simulated_trajectory (starting_node = node , terminal_node = terminal_node , websocket = websocket )
629
629
630
630
return score , terminal_node
631
-
631
+
632
632
633
633
# TODO: decide whether to keep the tree update
634
634
async def send_completion_request (self , plan , depth , node , trajectory = [], websocket = None ):
@@ -661,7 +661,7 @@ async def send_completion_request(self, plan, depth, node, trajectory=[], websoc
661
661
)
662
662
next_action = updated_actions [0 ]
663
663
retry_count = self .config .retry_count if hasattr (self .config , 'retry_count' ) else 1 # Default retries if not set
664
-
664
+
665
665
for attempt in range (retry_count ):
666
666
try :
667
667
# Convert action to Python code
@@ -673,7 +673,7 @@ async def send_completion_request(self, plan, depth, node, trajectory=[], websoc
673
673
extracted_number = parse_function_args (function_args )
674
674
element = await locate_element (page , extracted_number )
675
675
next_action ["element" ] = element
676
-
676
+
677
677
# Execute action
678
678
await execute_action (next_action , self .action_set , page , context , self .goal , page_info ['interactive_elements' ],
679
679
self .config .log_folder )
@@ -730,9 +730,9 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
730
730
# Reset browser and get live URL
731
731
live_browser_url , session_id = await self ._reset_browser (websocket )
732
732
path = self .get_path_to_root (node )
733
-
733
+
734
734
# Execute path
735
- for n in path [1 :]: # Skip root node
735
+ for n in path [1 :]: # Skip root node
736
736
success = await playwright_step_execution (
737
737
n ,
738
738
self .goal ,
@@ -743,7 +743,7 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
743
743
if not success :
744
744
n .is_terminal = True
745
745
return []
746
-
746
+
747
747
if not n .feedback :
748
748
n .feedback = await generate_feedback (
749
749
self .goal ,
@@ -757,11 +757,11 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
757
757
758
758
messages = [{"role" : "user" , "content" : f"Action is: { n .action } " } for n in path [1 :]]
759
759
760
-
760
+
761
761
next_actions = await extract_top_actions (
762
762
[{"natural_language_description" : n .natural_language_description , "action" : n .action , "feedback" : n .feedback } for n in path [1 :]],
763
763
self .goal ,
764
- self .images ,
764
+ self .images ,
765
765
page_info ,
766
766
self .action_set ,
767
767
openai_client ,
@@ -788,7 +788,7 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
788
788
})
789
789
return []
790
790
continue
791
-
791
+
792
792
page = await self .playwright_manager .get_page ()
793
793
code , function_calls = self .action_set .to_python_code (action ["action" ])
794
794
@@ -803,5 +803,5 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
803
803
children .append (action )
804
804
805
805
if not children :
806
- node .is_terminal = True
807
- return children
806
+ node .is_terminal = True
807
+ return children
0 commit comments