@@ -22,91 +22,91 @@ async def run(self, websocket=None) -> list[LATSNode]:
22
22
return best_node
23
23
24
24
async def lats_search (self , websocket = None ):
25
- terminal_nodes = []
25
+ terminal_nodes = []
26
26
27
- for i in range (self .config .iterations ):
28
- await self .websocket_iteration_start (i , websocket = websocket )
29
-
30
- print (f"Iteration { i } /{ self .config .iterations } ..." )
27
+ for i in range (self .config .iterations ):
28
+ await self .websocket_iteration_start (i , websocket = websocket )
29
+
30
+ print (f"Iteration { i } /{ self .config .iterations } ..." )
31
31
32
- # Step 1: Node Selection
33
- ## TODO: move websocket node selection into node_selection method
34
- print (f"{ GREEN } Step 1: node selection{ RESET } " )
35
- await self .websocket_step_start (step = 1 , step_name = "node_selection" , websocket = websocket )
36
- node = await self .node_selection (self .root_node )
37
- await self .websocket_node_selection (node , websocket = websocket )
32
+ # Step 1: Node Selection
33
+ ## TODO: move websocket node selection into node_selection method
34
+ print (f"{ GREEN } Step 1: node selection{ RESET } " )
35
+ await self .websocket_step_start (step = 1 , step_name = "node_selection" , websocket = websocket )
36
+ node = await self .node_selection (self .root_node )
37
+ await self .websocket_node_selection (node , websocket = websocket )
38
38
39
- if node is None :
40
- print ("All paths lead to terminal nodes with reward 0. Ending search." )
41
- break
39
+ if node is None :
40
+ print ("All paths lead to terminal nodes with reward 0. Ending search." )
41
+ break
42
42
43
- # Step 2: Node Expansion
44
- print (f"{ GREEN } Step 2: node expansion{ RESET } " )
45
- await self .websocket_step_start (step = 2 , step_name = "node_expansion" , websocket = websocket )
46
- await self .node_expansion (node , websocket )
47
- if node is None :
48
- # all the nodes are terminal, stop the search
49
- print (f"{ RED } All nodes are terminal, stopping search{ RESET } " )
50
- break
51
- tree_data = self ._get_tree_data ()
52
- if websocket :
53
- await self .websocket_tree_update (type = "tree_update_node_expansion" , websocket = websocket , tree_data = tree_data )
54
- else :
55
- print_entire_tree (self .root_node )
43
+ # Step 2: Node Expansion
44
+ print (f"{ GREEN } Step 2: node expansion{ RESET } " )
45
+ await self .websocket_step_start (step = 2 , step_name = "node_expansion" , websocket = websocket )
46
+ await self .node_expansion (node , websocket )
47
+ if node is None :
48
+ # all the nodes are terminal, stop the search
49
+ print (f"{ RED } All nodes are terminal, stopping search{ RESET } " )
50
+ break
51
+ tree_data = self ._get_tree_data ()
52
+ if websocket :
53
+ await self .websocket_tree_update (type = "tree_update_node_expansion" , websocket = websocket , tree_data = tree_data )
54
+ else :
55
+ print_entire_tree (self .root_node )
56
56
57
57
58
- # Step 3: Evaluation
59
- print (f"{ GREEN } Step 3: node chilren evaluation{ RESET } " )
60
- await self .websocket_step_start (step = 3 , step_name = "node_children_evaluation" , websocket = websocket )
61
- await self .node_children_evaluation (node )
62
- tree_data = self ._get_tree_data ()
63
- if websocket :
64
- await self .websocket_tree_update (type = "tree_update_node_children_evaluation" , websocket = websocket , tree_data = tree_data )
65
- else :
66
- print ("after evaluation" )
67
- print_entire_tree (self .root_node )
58
+ # Step 3: Evaluation
59
+ print (f"{ GREEN } Step 3: node chilren evaluation{ RESET } " )
60
+ await self .websocket_step_start (step = 3 , step_name = "node_children_evaluation" , websocket = websocket )
61
+ await self .node_children_evaluation (node )
62
+ tree_data = self ._get_tree_data ()
63
+ if websocket :
64
+ await self .websocket_tree_update (type = "tree_update_node_children_evaluation" , websocket = websocket , tree_data = tree_data )
65
+ else :
66
+ print ("after evaluation" )
67
+ print_entire_tree (self .root_node )
68
68
69
69
70
- # Step 4: Simulation
71
- print (f"{ GREEN } Step 4: simulation{ RESET } " )
72
- await self .websocket_step_start (step = 4 , step_name = "simulation" , websocket = websocket )
73
- selected_node = max (node .children , key = lambda child : child .value )
74
- await self .websocket_node_selection (selected_node , websocket = websocket , type = "node_selected_for_simulation" )
75
- reward , terminal_node = await self .simulation (selected_node , websocket = websocket )
76
- terminal_nodes .append (terminal_node )
77
- await self .websocket_simulation_result (reward , terminal_node , websocket = websocket )
70
+ # Step 4: Simulation
71
+ print (f"{ GREEN } Step 4: simulation{ RESET } " )
72
+ await self .websocket_step_start (step = 4 , step_name = "simulation" , websocket = websocket )
73
+ selected_node = max (node .children , key = lambda child : child .value )
74
+ await self .websocket_node_selection (selected_node , websocket = websocket , type = "node_selected_for_simulation" )
75
+ reward , terminal_node = await self .simulation (selected_node , websocket = websocket )
76
+ terminal_nodes .append (terminal_node )
77
+ await self .websocket_simulation_result (reward , terminal_node , websocket = websocket )
78
78
79
- if reward == 1 :
80
- await self .websocket_search_complete ("success" , reward , terminal_node .get_trajectory (), websocket = websocket )
81
- return terminal_node
79
+ if reward == 1 :
80
+ await self .websocket_search_complete ("success" , reward , terminal_node .get_trajectory (), websocket = websocket )
81
+ return terminal_node
82
82
83
- # Step 5: Backpropagation
84
- print (f"{ GREEN } Step 5: backpropagation{ RESET } " )
85
- await self .websocket_step_start (step = 5 , step_name = "backpropagation" , websocket = websocket )
86
- self .backpropagate (terminal_node , reward )
87
- tree_data = self ._get_tree_data ()
88
- if websocket :
89
- await self .websocket_tree_update (type = "tree_update_node_backpropagation" , websocket = websocket , tree_data = tree_data )
90
- else :
91
- print ("after backpropagation" )
92
- print_entire_tree (self .root_node )
83
+ # Step 5: Backpropagation
84
+ print (f"{ GREEN } Step 5: backpropagation{ RESET } " )
85
+ await self .websocket_step_start (step = 5 , step_name = "backpropagation" , websocket = websocket )
86
+ self .backpropagate (terminal_node , reward )
87
+ tree_data = self ._get_tree_data ()
88
+ if websocket :
89
+ await self .websocket_tree_update (type = "tree_update_node_backpropagation" , websocket = websocket , tree_data = tree_data )
90
+ else :
91
+ print ("after backpropagation" )
92
+ print_entire_tree (self .root_node )
93
93
94
- # Find best node
95
- all_nodes_list = collect_all_nodes (self .root_node )
96
- all_nodes_list .extend (terminal_nodes )
97
-
98
- ## temp change: if reward is the same, choose the deeper node
99
- best_child = max (all_nodes_list , key = lambda x : (x .reward , x .depth ))
94
+ # Find best node
95
+ all_nodes_list = collect_all_nodes (self .root_node )
96
+ all_nodes_list .extend (terminal_nodes )
97
+
98
+ ## temp change: if reward is the same, choose the deeper node
99
+ best_child = max (all_nodes_list , key = lambda x : (x .reward , x .depth ))
100
+
101
+ if best_child .value >= 0.75 :
102
+ print ("Successful trajectory found" )
103
+ await self .websocket_search_complete ("success" , best_child .value , best_child .get_trajectory (), websocket = websocket )
104
+ else :
105
+ print ("Unsuccessful trajectory found" )
106
+ await self .websocket_search_complete ("partial_success" , best_child .value , best_child .get_trajectory (), websocket = websocket )
107
+ await self .playwright_manager .close ()
100
108
101
- if best_child .value >= 0.75 :
102
- print ("Successful trajectory found" )
103
- await self .websocket_search_complete ("success" , best_child .value , best_child .get_trajectory (), websocket = websocket )
104
- else :
105
- print ("Unsuccessful trajectory found" )
106
- await self .websocket_search_complete ("partial_success" , best_child .value , best_child .get_trajectory (), websocket = websocket )
107
- await self .playwright_manager .close ()
108
-
109
- return best_child if best_child is not None else self .root_node
109
+ return best_child if best_child is not None else self .root_node
110
110
111
111
async def node_selection (self , node : LATSNode , websocket = None ) -> Optional [LATSNode ]:
112
112
if node .is_terminal :
0 commit comments