34
34
openai_client = OpenAI ()
35
35
36
36
37
- ## TODO: remove account reset websocket message
38
- ## browser setup message, ok to leave there in the _reset_browser method
39
-
40
-
41
37
class BaseAgent :
42
38
# no need to pass an initial playwright_manager to the agent class
43
39
def __init__ (
@@ -381,6 +377,10 @@ async def websocket_search_complete(self, status, score, path, websocket=None):
381
377
"path" : path ,
382
378
"timestamp" : datetime .utcnow ().isoformat ()
383
379
})
380
+ else :
381
+ print (f"Search complete: { GREEN } { status } { RESET } " )
382
+ print (f"Search score: { GREEN } { score } { RESET } " )
383
+ print (f"Search path: { GREEN } { path } { RESET } " )
384
384
385
385
# shared, not implemented, BFS, DFS and LATS has its own node selection logic
386
386
async def node_selection (self , node , websocket = None ):
@@ -485,31 +485,19 @@ def backpropagate(self, node: LATSNode, value: float) -> None:
485
485
node = node .parent
486
486
487
487
# shared
488
- async def simulation (self , node : LATSNode , max_depth : int = 2 , num_simulations = 1 , websocket = None ) -> tuple [float , LATSNode ]:
488
+ async def simulation (self , node : LATSNode , websocket = None ) -> tuple [float , LATSNode ]:
489
489
depth = node .depth
490
+ num_simulations = self .config .num_simulations
491
+ max_depth = self .config .max_depth
490
492
print ("print the trajectory" )
491
493
print_trajectory (node )
492
494
print ("print the entire tree" )
493
495
print_entire_tree (self .root_node )
494
- # if websocket:
495
- # tree_data = self._get_tree_data()
496
- # await self.websocket_tree_update(type="tree_update_simulation", tree_data=tree_data, websocket=websocket)
497
- # await websocket.send_json({
498
- # "type": "tree_update",
499
- # "tree": tree_data,
500
- # "timestamp": datetime.utcnow().isoformat()
501
- # })
502
- # trajectory_data = self._get_trajectory_data(node)
503
- # await websocket.send_json({
504
- # "type": "trajectory_update",
505
- # "trajectory": trajectory_data,
506
- # "timestamp": datetime.utcnow().isoformat()
507
- # })
508
- return await self .rollout (node , max_depth = max_depth , websocket = websocket )
496
+ return await self .rollout (node , websocket = websocket )
509
497
510
498
# refactor simulation, rollout, send_completion_request methods
511
499
# TODO: check, score as reward and then update value of the starting node?
512
- async def rollout (self , node : LATSNode , max_depth : int = 2 , websocket = None )-> tuple [float , LATSNode ]:
500
+ async def rollout (self , node : LATSNode , websocket = None )-> tuple [float , LATSNode ]:
513
501
# Reset browser state
514
502
await self ._reset_browser ()
515
503
path = self .get_path_to_root (node )
@@ -540,23 +528,14 @@ async def rollout(self, node: LATSNode, max_depth: int = 2, websocket=None)-> tu
540
528
"action" : n .action ,
541
529
"feedback" : n .feedback
542
530
})
543
- ## call the prompt agent
544
531
print ("current depth: " , len (path ) - 1 )
545
532
print ("max depth: " , self .config .max_depth )
546
533
547
- ## find a better name for this
548
534
trajectory , terminal_node = await self .send_completion_request (self .goal , len (path ) - 1 , node = n , trajectory = trajectory , websocket = websocket )
549
535
print ("print the trajectory" )
550
536
print_trajectory (terminal_node )
551
537
print ("print the entire tree" )
552
538
print_entire_tree (self .root_node )
553
- # if websocket:
554
- # trajectory_data = self._get_trajectory_data(node)
555
- # await websocket.send_json({
556
- # "type": "trajectory_update",
557
- # "trajectory": trajectory_data,
558
- # "timestamp": datetime.utcnow().isoformat()
559
- # })
560
539
561
540
page = await self .playwright_manager .get_page ()
562
541
page_info = await extract_page_info (page , self .config .fullpage , self .config .log_folder )
@@ -583,12 +562,6 @@ async def send_completion_request(self, plan, depth, node, trajectory=[], websoc
583
562
print ("print the entire tree" )
584
563
print_entire_tree (self .root_node )
585
564
if websocket :
586
- # tree_data = self._get_tree_data()
587
- # await websocket.send_json({
588
- # "type": "tree_update",
589
- # "tree": tree_data,
590
- # "timestamp": datetime.utcnow().isoformat()
591
- # })
592
565
trajectory_data = self ._get_trajectory_data (node )
593
566
await websocket .send_json ({
594
567
"type" : "trajectory_update" ,
@@ -684,15 +657,7 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
684
657
path = self .get_path_to_root (node )
685
658
686
659
# Execute path
687
- for n in path [1 :]: # Skip root node
688
- # if websocket:
689
- # await websocket.send_json({
690
- # "type": "replaying_action",
691
- # "node_id": id(n),
692
- # "action": n.action,
693
- # "timestamp": datetime.utcnow().isoformat()
694
- # })
695
-
660
+ for n in path [1 :]: # Skip root node
696
661
success = await playwright_step_execution (
697
662
n ,
698
663
self .goal ,
@@ -702,12 +667,6 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
702
667
)
703
668
if not success :
704
669
n .is_terminal = True
705
- # if websocket:
706
- # await websocket.send_json({
707
- # "type": "replay_failed",
708
- # "node_id": id(n),
709
- # "timestamp": datetime.utcnow().isoformat()
710
- # })
711
670
return []
712
671
713
672
if not n .feedback :
@@ -716,26 +675,13 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
716
675
n .natural_language_description ,
717
676
self .playwright_manager ,
718
677
)
719
- # if websocket:
720
- # await websocket.send_json({
721
- # "type": "feedback_generated",
722
- # "node_id": id(n),
723
- # "feedback": n.feedback,
724
- # "timestamp": datetime.utcnow().isoformat()
725
- # })
726
678
727
679
time .sleep (3 )
728
680
page = await self .playwright_manager .get_page ()
729
681
page_info = await extract_page_info (page , self .config .fullpage , self .config .log_folder )
730
682
731
683
messages = [{"role" : "user" , "content" : f"Action is: { n .action } " } for n in path [1 :]]
732
-
733
- # if websocket:
734
- # await websocket.send_json({
735
- # "type": "generating_actions",
736
- # "node_id": id(node),
737
- # "timestamp": datetime.utcnow().isoformat()
738
- # })
684
+
739
685
740
686
next_actions = await extract_top_actions (
741
687
[{"natural_language_description" : n .natural_language_description , "action" : n .action , "feedback" : n .feedback } for n in path [1 :]],
@@ -779,23 +725,8 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
779
725
action ["element" ] = element
780
726
except Exception as e :
781
727
action ["element" ] = None
782
- # if websocket:
783
- # await websocket.send_json({
784
- # "type": "element_location_failed",
785
- # "action": action["action"],
786
- # "error": str(e),
787
- # "timestamp": datetime.utcnow().isoformat()
788
- # })
789
728
children .append (action )
790
729
791
730
if not children :
792
- node .is_terminal = True
793
- # if websocket:
794
- # await websocket.send_json({
795
- # "type": "node_terminal",
796
- # "node_id": id(node),
797
- # "reason": "no_valid_actions",
798
- # "timestamp": datetime.utcnow().isoformat()
799
- # })
800
-
731
+ node .is_terminal = True
801
732
return children
0 commit comments