@@ -141,10 +141,10 @@ async def evaluate_selected_path(self, path) -> None:
141
141
"feedback" : n .feedback
142
142
})
143
143
144
- # Score the trajectory
145
- # TODO: if node is terminal, score is 0?
146
- # if node.is_terminal:
147
- # score = 0
144
+ ## fix for MCTS agent only
145
+ if len ( trajectory ) == 0 :
146
+ score = 0
147
+ return score
148
148
prompt = create_llm_prompt (trajectory , self .goal )
149
149
print (f"prompt: { prompt } " )
150
150
result = score_trajectory_with_openai (
@@ -230,8 +230,10 @@ async def reflection_backtracking(self, path) -> List[LATSNode]:
230
230
print ("Suggested improvements:" )
231
231
for improvement in reflection_result ["suggested_improvements" ]:
232
232
print (f"- { improvement } " )
233
+ print (f"current_node: { current_node .action } " )
234
+ print (f"current_node: { current_node .natural_language_description } " )
233
235
234
- return path
236
+ return path , current_node
235
237
236
238
async def mcts_search (self , websocket = None ) -> Optional [LATSNode ]:
237
239
best_score = float ('-inf' )
@@ -249,17 +251,22 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
249
251
# "node selection" combines selection and partial simulation
250
252
print (f"{ GREEN } Step 1: Node Selection{ RESET } " )
251
253
await self .websocket_step_start (step = 1 , step_name = "node_selection" , websocket = websocket )
252
- node = await self .node_selection (self .root_node , websocket )
254
+ selected_node = await self .node_selection (self .root_node , websocket )
255
+ tree_data = self ._get_tree_data ()
256
+ if websocket :
257
+ await self .websocket_tree_update (type = "tree_update_node_selection" , websocket = websocket , tree_data = tree_data )
258
+ else :
259
+ print_entire_tree (self .root_node )
253
260
254
- if node is None :
261
+ if selected_node is None :
255
262
logger .warning ("All paths lead to terminal nodes. Ending search." )
256
263
break
257
264
258
265
# Step 2: Node Expansion
259
266
print (f"{ GREEN } Step 2: Node Expansion{ RESET } " )
260
267
await self .websocket_step_start (step = 2 , step_name = "node_expansion" , websocket = websocket )
261
- await self .node_expansion (node , websocket )
262
- if node is None :
268
+ await self .node_expansion (selected_node , websocket )
269
+ if selected_node is None :
263
270
# all the nodes are terminal, stop the search
264
271
print (f"{ RED } All nodes are terminal, stopping search{ RESET } " )
265
272
break
@@ -274,29 +281,34 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
274
281
# TODO: implement simulation using openai
275
282
print (f"{ GREEN } Step 3: Simulation{ RESET } " )
276
283
await self .websocket_step_start (step = 3 , step_name = "simulation" , websocket = websocket )
277
- path = self .get_path_to_root (node )
284
+ path = self .get_path_to_root (selected_node )
278
285
score = await self .evaluate_selected_path (path )
279
286
# change to reward later?
280
287
if score > best_score :
281
288
best_score = score
282
289
best_path = path
290
+ best_node = selected_node
283
291
print (f"\n New best path found!" )
284
- print (f"Previous best score: { best_score :.3f} " )
285
- print (f"New best score: { score :.3f} " )
292
+ print (f"best score: { best_score :.3f} " )
293
+ print (f"best node: { best_node .action } " )
294
+ print (f"best node: { best_node .natural_language_description } " )
295
+ print (f"best path: { best_path } " )
286
296
287
297
288
298
## Step 4: reflection backtracking
289
299
print (f"{ GREEN } Step 4: Reflection Backtracking{ RESET } " )
290
300
await self .websocket_step_start (step = 4 , step_name = "reflection_backtracking" , websocket = websocket )
291
301
if score >= self .config .reflection_score :
292
302
# Convert path to serializable trajectory
293
- trajectory = [node .action for node in path if node .action is not None ]
294
- await self .websocket_search_complete ("success" , score , trajectory , websocket = websocket )
295
- return node
303
+ # trajectory = [node.action for node in path if node.action is not None]
304
+ await self .websocket_search_complete ("success" , score , selected_node . get_trajectory () , websocket = websocket )
305
+ return selected_node
296
306
297
307
print (f"path: { path } " )
298
- path = await self .reflection_backtracking (path )
308
+ path , current_node = await self .reflection_backtracking (path )
299
309
print (f"path: { path } " )
310
+ print (f"current_node: { current_node .action } " )
311
+ print (f"current_node: { current_node .natural_language_description } " )
300
312
301
313
# Step 5: backpropagation
302
314
print (f"{ GREEN } Step 5: Backpropagation{ RESET } " )
@@ -308,8 +320,12 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
308
320
print (f"Node { node .action } :" )
309
321
print (f" Visits: { node .visits } " )
310
322
print (f" Value: { old_value :.3f} -> { node .value :.3f} " )
323
+ if websocket :
324
+ await self .websocket_tree_update (type = "tree_update_node_backpropagation" , websocket = websocket , tree_data = tree_data )
325
+ else :
326
+ print_entire_tree (self .root_node )
311
327
if best_node :
312
328
# Convert node to serializable trajectory
313
- trajectory = [n .action for n in self .get_path_to_root (best_node ) if n .action is not None ]
329
+ # trajectory = [n.action for n in self.get_path_to_root(best_node) if n.action is not None]
314
330
await self .websocket_search_complete ("partial_success" , best_node .value , best_node .get_trajectory (), websocket = websocket )
315
331
return best_node
0 commit comments