Skip to content

Commit 3167877

Browse files
authored
Merge pull request #83 from PathOnAI/fix-search-complete-eval
1) fix search complete websocket message, 2) empty trajectory, score = 0
2 parents 85da7ee + 5108781 commit 3167877

File tree

2 files changed

+34
-17
lines changed

2 files changed

+34
-17
lines changed

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,10 @@ async def evaluate_selected_path(self, path) -> None:
141141
"feedback": n.feedback
142142
})
143143

144-
# Score the trajectory
145-
# TODO: if node is terminal, score is 0?
146-
# if node.is_terminal:
147-
# score = 0
144+
## fix for MCTS agent only
145+
if len(trajectory) == 0:
146+
score = 0
147+
return score
148148
prompt = create_llm_prompt(trajectory, self.goal)
149149
print(f"prompt: {prompt}")
150150
result = score_trajectory_with_openai(
@@ -230,8 +230,10 @@ async def reflection_backtracking(self, path) -> List[LATSNode]:
230230
print("Suggested improvements:")
231231
for improvement in reflection_result["suggested_improvements"]:
232232
print(f"- {improvement}")
233+
print(f"current_node: {current_node.action}")
234+
print(f"current_node: {current_node.natural_language_description}")
233235

234-
return path
236+
return path, current_node
235237

236238
async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
237239
best_score = float('-inf')
@@ -249,17 +251,22 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
249251
# "node selection" combines selection and partial simulation
250252
print(f"{GREEN}Step 1: Node Selection{RESET}")
251253
await self.websocket_step_start(step=1, step_name="node_selection", websocket=websocket)
252-
node = await self.node_selection(self.root_node, websocket)
254+
selected_node = await self.node_selection(self.root_node, websocket)
255+
tree_data = self._get_tree_data()
256+
if websocket:
257+
await self.websocket_tree_update(type="tree_update_node_selection", websocket=websocket, tree_data=tree_data)
258+
else:
259+
print_entire_tree(self.root_node)
253260

254-
if node is None:
261+
if selected_node is None:
255262
logger.warning("All paths lead to terminal nodes. Ending search.")
256263
break
257264

258265
# Step 2: Node Expansion
259266
print(f"{GREEN}Step 2: Node Expansion{RESET}")
260267
await self.websocket_step_start(step=2, step_name="node_expansion", websocket=websocket)
261-
await self.node_expansion(node, websocket)
262-
if node is None:
268+
await self.node_expansion(selected_node, websocket)
269+
if selected_node is None:
263270
# all the nodes are terminal, stop the search
264271
print(f"{RED}All nodes are terminal, stopping search{RESET}")
265272
break
@@ -274,29 +281,34 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
274281
# TODO: implement simulation using openai
275282
print(f"{GREEN}Step 3: Simulation{RESET}")
276283
await self.websocket_step_start(step=3, step_name="simulation", websocket=websocket)
277-
path = self.get_path_to_root(node)
284+
path = self.get_path_to_root(selected_node)
278285
score = await self.evaluate_selected_path(path)
279286
# change to reward later?
280287
if score > best_score:
281288
best_score = score
282289
best_path = path
290+
best_node = selected_node
283291
print(f"\nNew best path found!")
284-
print(f"Previous best score: {best_score:.3f}")
285-
print(f"New best score: {score:.3f}")
292+
print(f"best score: {best_score:.3f}")
293+
print(f"best node: {best_node.action}")
294+
print(f"best node: {best_node.natural_language_description}")
295+
print(f"best path: {best_path}")
286296

287297

288298
## Step 4: reflection backtracking
289299
print(f"{GREEN}Step 4: Reflection Backtracking{RESET}")
290300
await self.websocket_step_start(step=4, step_name="reflection_backtracking", websocket=websocket)
291301
if score >= self.config.reflection_score:
292302
# Convert path to serializable trajectory
293-
trajectory = [node.action for node in path if node.action is not None]
294-
await self.websocket_search_complete("success", score, trajectory, websocket=websocket)
295-
return node
303+
# trajectory = [node.action for node in path if node.action is not None]
304+
await self.websocket_search_complete("success", score, selected_node.get_trajectory(), websocket=websocket)
305+
return selected_node
296306

297307
print(f"path: {path}")
298-
path = await self.reflection_backtracking(path)
308+
path, current_node = await self.reflection_backtracking(path)
299309
print(f"path: {path}")
310+
print(f"current_node: {current_node.action}")
311+
print(f"current_node: {current_node.natural_language_description}")
300312

301313
# Step 5: backpropagation
302314
print(f"{GREEN}Step 5: Backpropagation{RESET}")
@@ -308,8 +320,12 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
308320
print(f"Node {node.action}:")
309321
print(f" Visits: {node.visits}")
310322
print(f" Value: {old_value:.3f} -> {node.value:.3f}")
323+
if websocket:
324+
await self.websocket_tree_update(type="tree_update_node_backpropagation", websocket=websocket, tree_data=tree_data)
325+
else:
326+
print_entire_tree(self.root_node)
311327
if best_node:
312328
# Convert node to serializable trajectory
313-
trajectory = [n.action for n in self.get_path_to_root(best_node) if n.action is not None]
329+
# trajectory = [n.action for n in self.get_path_to_root(best_node) if n.action is not None]
314330
await self.websocket_search_complete("partial_success", best_node.value, best_node.get_trajectory(), websocket=websocket)
315331
return best_node

visual-tree-search-backend/test/test-tree-search-ws-mcts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
# Tree/Path updates
4444
'tree_update': '\033[96m', # Cyan
45+
'tree_update_node_selection': '\033[96m', # Cyan
4546
'tree_update_node_expansion': '\033[96m', # Cyan
4647
'tree_update_node_evaluation': '\033[96m', # Cyan
4748
'tree_update_node_children_evaluation': '\033[96m', # Cyan

0 commit comments

Comments
 (0)