@@ -165,11 +165,13 @@ async def websocket_reflection_backtracking(self, path, selected_node, websocket
165
165
if websocket :
166
166
await websocket .send_json ({
167
167
"type" : "reflection_backtracking" ,
168
- "path" : [node .action for node in path if node .action is not None ],
168
+ "path" : [{
169
+ "natural_language_description" : node .natural_language_description ,
170
+ "action" : node .action } for node in path if node .action is not None ],
169
171
"node_id" : id (selected_node ),
170
- "node_parent_id " : id (selected_node .parent ),
171
- "node_action " : selected_node .action ,
172
- "node_description " : selected_node .natural_language_description ,
172
+ "parent_id " : id (selected_node .parent ),
173
+ "action " : selected_node .action ,
174
+ "description " : selected_node .natural_language_description ,
173
175
"trajectory" : selected_node .get_trajectory ()
174
176
})
175
177
@@ -304,80 +306,81 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
304
306
305
307
# Step 3: simulation using the current node, (generate a path using the current node, and score the path)
306
308
# TODO: implement simulation using openai
307
- print (f"{ GREEN } Step 3: Simulation{ RESET } " )
308
- await self .websocket_step_start (step = 3 , step_name = "simulation" , websocket = websocket )
309
- path = self .get_path_to_root (selected_node )
310
- # here score is the reward
311
- score = await self .evaluate_selected_path (path )
312
- # change to reward later?
313
- if score > best_score :
314
- best_score = score
315
- best_path = path
316
- best_node = selected_node
317
- print (f"\n New best path found!" )
318
- print (f"best score: { best_score :.3f} " )
319
- print (f"best node: { best_node .action } " )
320
- print (f"best node: { best_node .natural_language_description } " )
321
- print (f"best path: { best_path } " )
322
-
323
- # add websocket information, just use websocket here
324
- if websocket :
325
- await self .websocket_simulation_result (score , selected_node , websocket = websocket )
309
+ if selected_node != self .root_node :
310
+ print (f"{ GREEN } Step 3: Simulation{ RESET } " )
311
+ await self .websocket_step_start (step = 3 , step_name = "simulation" , websocket = websocket )
312
+ path = self .get_path_to_root (selected_node )
313
+ # here score is the reward
314
+ score = await self .evaluate_selected_path (path )
315
+ # change to reward later?
316
+ if score > best_score :
317
+ best_score = score
318
+ best_path = path
319
+ best_node = selected_node
320
+ print (f"\n New best path found!" )
321
+ print (f"best score: { best_score :.3f} " )
322
+ print (f"best node: { best_node .action } " )
323
+ print (f"best node: { best_node .natural_language_description } " )
324
+ print (f"best path: { best_path } " )
326
325
326
+ # add websocket information, just use websocket here
327
+ if websocket :
328
+ await self .websocket_simulation_result (score , selected_node , websocket = websocket )
327
329
328
- ## Step 4: reflection backtracking
329
- print (f"{ GREEN } Step 4: Reflection Backtracking{ RESET } " )
330
- await self .websocket_step_start (step = 4 , step_name = "reflection_backtracking" , websocket = websocket )
331
- if score >= self .config .reflection_score :
332
- # Convert path to serializable trajectory
333
- # trajectory = [node.action for node in path if node.action is not None]
334
- await self .websocket_search_complete ("success" , score , selected_node .get_trajectory (), websocket = websocket )
335
- await self .playwright_manager .close ()
336
- return selected_node
337
330
338
- print (f"path: { path } " )
339
- path , current_node = await self .reflection_backtracking (path )
340
- print (f"path: { path } " )
341
- print (f"current_node: { current_node .action } " )
342
- print (f"current_node: { current_node .natural_language_description } " )
331
+ ## Step 4: reflection backtracking
332
+ print (f"{ GREEN } Step 4: Reflection Backtracking{ RESET } " )
333
+ await self .websocket_step_start (step = 4 , step_name = "reflection_backtracking" , websocket = websocket )
334
+ if score >= self .config .reflection_score :
335
+ # Convert path to serializable trajectory
336
+ # trajectory = [node.action for node in path if node.action is not None]
337
+ await self .websocket_search_complete ("success" , score , selected_node .get_trajectory (), websocket = websocket )
338
+ await self .playwright_manager .close ()
339
+ return selected_node
343
340
344
- # add websocket information, just use websocket here
345
- if websocket :
346
- await self .websocket_reflection_backtracking (path , current_node , websocket = websocket )
341
+ print (f"path: { path } " )
342
+ path , current_node = await self .reflection_backtracking (path )
343
+ print (f"path: { path } " )
344
+ print (f"current_node: { current_node .action } " )
345
+ print (f"current_node: { current_node .natural_language_description } " )
347
346
348
- # Step 5: backpropagation
349
- print (f"{ GREEN } Step 5: Backpropagation{ RESET } " )
350
- await self .websocket_step_start (step = 5 , step_name = "backpropagation" , websocket = websocket )
351
- for node in path :
352
- if node != self .root_node :
353
- old_value = node .value
354
- node .visits += 1
355
- node .value += (score - node .value ) / node .visits
356
- # consiste with lats backpropagation
357
- #node.value = (node.value * (node.visits - 1) + score) / node.visits
358
- print (f"Node { node .action } :" )
359
- print (f" Visits: { node .visits } " )
360
- print (f" Value: { old_value :.3f} -> { node .value :.3f} " )
361
347
# add websocket information, just use websocket here
362
- # if websocket:
363
- # await websocket.send_json({
364
- # "type": "backpropagation",
365
- # "node_id": id(node),
366
- # "node_parent_id": id(node.parent),
367
- # "node_action": node.action,
368
- # "node_value": node.value,
369
- # "node_visits": node.visits,
370
- # "node_old_value": old_value,
371
- # "node_description": node.natural_language_description,
372
- # })
348
+ if websocket :
349
+ await self .websocket_reflection_backtracking (path , current_node , websocket = websocket )
373
350
374
- tree_data = self ._get_tree_data ()
375
- print_entire_tree (self .root_node )
376
- print (tree_data )
377
- if websocket :
378
- await self .websocket_tree_update (type = "tree_update_node_backpropagation" , websocket = websocket , tree_data = tree_data )
379
- else :
351
+ # Step 5: backpropagation
352
+ print (f"{ GREEN } Step 5: Backpropagation{ RESET } " )
353
+ await self .websocket_step_start (step = 5 , step_name = "backpropagation" , websocket = websocket )
354
+ for node in path :
355
+ if node != self .root_node :
356
+ old_value = node .value
357
+ node .visits += 1
358
+ node .value += (score - node .value ) / node .visits
359
+ # consiste with lats backpropagation
360
+ #node.value = (node.value * (node.visits - 1) + score) / node.visits
361
+ print (f"Node { node .action } :" )
362
+ print (f" Visits: { node .visits } " )
363
+ print (f" Value: { old_value :.3f} -> { node .value :.3f} " )
364
+ # add websocket information, just use websocket here
365
+ # if websocket:
366
+ # await websocket.send_json({
367
+ # "type": "backpropagation",
368
+ # "node_id": id(node),
369
+ # "node_parent_id": id(node.parent),
370
+ # "node_action": node.action,
371
+ # "node_value": node.value,
372
+ # "node_visits": node.visits,
373
+ # "node_old_value": old_value,
374
+ # "node_description": node.natural_language_description,
375
+ # })
376
+
377
+ tree_data = self ._get_tree_data ()
380
378
print_entire_tree (self .root_node )
379
+ print (tree_data )
380
+ if websocket :
381
+ await self .websocket_tree_update (type = "tree_update_node_backpropagation" , websocket = websocket , tree_data = tree_data )
382
+ else :
383
+ print_entire_tree (self .root_node )
381
384
if best_node :
382
385
# Convert node to serializable trajectory
383
386
# trajectory = [n.action for n in self.get_path_to_root(best_node) if n.action is not None]
0 commit comments