Skip to content

Commit e576fbe

Browse files
authored
Merge pull request #63 from PathOnAI/refactoring-cleanup
Refactoring cleanup
2 parents 8399efd + 644d74f commit e576fbe

File tree

8 files changed

+38
-96
lines changed

8 files changed

+38
-96
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
visual-tree-search-backend/app/api/test_logs/*
12
visual-tree-search-backend/log/*
23
log/*
34
shopping.json

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/base_agent.py

Lines changed: 12 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@
3434
openai_client = OpenAI()
3535

3636

37-
## TODO: remove account reset websocket message
38-
## browser setup message, ok to leave there in the _reset_browser method
39-
40-
4137
class BaseAgent:
4238
# no need to pass an initial playwright_manager to the agent class
4339
def __init__(
@@ -381,6 +377,10 @@ async def websocket_search_complete(self, status, score, path, websocket=None):
381377
"path": path,
382378
"timestamp": datetime.utcnow().isoformat()
383379
})
380+
else:
381+
print(f"Search complete: {GREEN}{status}{RESET}")
382+
print(f"Search score: {GREEN}{score}{RESET}")
383+
print(f"Search path: {GREEN}{path}{RESET}")
384384

385385
# shared, not implemented, BFS, DFS and LATS has its own node selection logic
386386
async def node_selection(self, node, websocket = None):
@@ -485,31 +485,19 @@ def backpropagate(self, node: LATSNode, value: float) -> None:
485485
node = node.parent
486486

487487
# shared
488-
async def simulation(self, node: LATSNode, max_depth: int = 2, num_simulations=1, websocket=None) -> tuple[float, LATSNode]:
488+
async def simulation(self, node: LATSNode, websocket=None) -> tuple[float, LATSNode]:
489489
depth = node.depth
490+
num_simulations = self.config.num_simulations
491+
max_depth = self.config.max_depth
490492
print("print the trajectory")
491493
print_trajectory(node)
492494
print("print the entire tree")
493495
print_entire_tree(self.root_node)
494-
# if websocket:
495-
# tree_data = self._get_tree_data()
496-
# await self.websocket_tree_update(type="tree_update_simulation", tree_data=tree_data, websocket=websocket)
497-
# await websocket.send_json({
498-
# "type": "tree_update",
499-
# "tree": tree_data,
500-
# "timestamp": datetime.utcnow().isoformat()
501-
# })
502-
# trajectory_data = self._get_trajectory_data(node)
503-
# await websocket.send_json({
504-
# "type": "trajectory_update",
505-
# "trajectory": trajectory_data,
506-
# "timestamp": datetime.utcnow().isoformat()
507-
# })
508-
return await self.rollout(node, max_depth=max_depth, websocket=websocket)
496+
return await self.rollout(node, websocket=websocket)
509497

510498
# refactor simulation, rollout, send_completion_request methods
511499
# TODO: check, score as reward and then update value of the starting node?
512-
async def rollout(self, node: LATSNode, max_depth: int = 2, websocket=None)-> tuple[float, LATSNode]:
500+
async def rollout(self, node: LATSNode, websocket=None)-> tuple[float, LATSNode]:
513501
# Reset browser state
514502
await self._reset_browser()
515503
path = self.get_path_to_root(node)
@@ -540,23 +528,14 @@ async def rollout(self, node: LATSNode, max_depth: int = 2, websocket=None)-> tu
540528
"action": n.action,
541529
"feedback": n.feedback
542530
})
543-
## call the prompt agent
544531
print("current depth: ", len(path) - 1)
545532
print("max depth: ", self.config.max_depth)
546533

547-
## find a better name for this
548534
trajectory, terminal_node = await self.send_completion_request(self.goal, len(path) - 1, node=n, trajectory=trajectory, websocket=websocket)
549535
print("print the trajectory")
550536
print_trajectory(terminal_node)
551537
print("print the entire tree")
552538
print_entire_tree(self.root_node)
553-
# if websocket:
554-
# trajectory_data = self._get_trajectory_data(node)
555-
# await websocket.send_json({
556-
# "type": "trajectory_update",
557-
# "trajectory": trajectory_data,
558-
# "timestamp": datetime.utcnow().isoformat()
559-
# })
560539

561540
page = await self.playwright_manager.get_page()
562541
page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder)
@@ -583,12 +562,6 @@ async def send_completion_request(self, plan, depth, node, trajectory=[], websoc
583562
print("print the entire tree")
584563
print_entire_tree(self.root_node)
585564
if websocket:
586-
# tree_data = self._get_tree_data()
587-
# await websocket.send_json({
588-
# "type": "tree_update",
589-
# "tree": tree_data,
590-
# "timestamp": datetime.utcnow().isoformat()
591-
# })
592565
trajectory_data = self._get_trajectory_data(node)
593566
await websocket.send_json({
594567
"type": "trajectory_update",
@@ -684,15 +657,7 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
684657
path = self.get_path_to_root(node)
685658

686659
# Execute path
687-
for n in path[1:]: # Skip root node
688-
# if websocket:
689-
# await websocket.send_json({
690-
# "type": "replaying_action",
691-
# "node_id": id(n),
692-
# "action": n.action,
693-
# "timestamp": datetime.utcnow().isoformat()
694-
# })
695-
660+
for n in path[1:]: # Skip root node
696661
success = await playwright_step_execution(
697662
n,
698663
self.goal,
@@ -702,12 +667,6 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
702667
)
703668
if not success:
704669
n.is_terminal = True
705-
# if websocket:
706-
# await websocket.send_json({
707-
# "type": "replay_failed",
708-
# "node_id": id(n),
709-
# "timestamp": datetime.utcnow().isoformat()
710-
# })
711670
return []
712671

713672
if not n.feedback:
@@ -716,26 +675,13 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
716675
n.natural_language_description,
717676
self.playwright_manager,
718677
)
719-
# if websocket:
720-
# await websocket.send_json({
721-
# "type": "feedback_generated",
722-
# "node_id": id(n),
723-
# "feedback": n.feedback,
724-
# "timestamp": datetime.utcnow().isoformat()
725-
# })
726678

727679
time.sleep(3)
728680
page = await self.playwright_manager.get_page()
729681
page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder)
730682

731683
messages = [{"role": "user", "content": f"Action is: {n.action}"} for n in path[1:]]
732-
733-
# if websocket:
734-
# await websocket.send_json({
735-
# "type": "generating_actions",
736-
# "node_id": id(node),
737-
# "timestamp": datetime.utcnow().isoformat()
738-
# })
684+
739685

740686
next_actions = await extract_top_actions(
741687
[{"natural_language_description": n.natural_language_description, "action": n.action, "feedback": n.feedback} for n in path[1:]],
@@ -779,23 +725,8 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
779725
action["element"] = element
780726
except Exception as e:
781727
action["element"] = None
782-
# if websocket:
783-
# await websocket.send_json({
784-
# "type": "element_location_failed",
785-
# "action": action["action"],
786-
# "error": str(e),
787-
# "timestamp": datetime.utcnow().isoformat()
788-
# })
789728
children.append(action)
790729

791730
if not children:
792-
node.is_terminal = True
793-
# if websocket:
794-
# await websocket.send_json({
795-
# "type": "node_terminal",
796-
# "node_id": id(node),
797-
# "reason": "no_valid_actions",
798-
# "timestamp": datetime.utcnow().isoformat()
799-
# })
800-
731+
node.is_terminal = True
801732
return children

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ async def run(self, websocket=None) -> list[LATSNode]:
1919

2020
best_node = await self.lats_search(websocket)
2121
print_trajectory(best_node)
22+
return best_node
2223

2324
async def lats_search(self, websocket=None):
2425
terminal_nodes = []
@@ -71,11 +72,12 @@ async def lats_search(self, websocket=None):
7172
await self.websocket_step_start(step=4, step_name="simulation", websocket=websocket)
7273
selected_node = max(node.children, key=lambda child: child.value)
7374
await self.websocket_node_selection(selected_node, websocket=websocket, type="node_selected_for_simulation")
74-
reward, terminal_node = await self.simulation(selected_node, max_depth=self.config.max_depth, num_simulations=1, websocket=websocket)
75+
reward, terminal_node = await self.simulation(selected_node, websocket=websocket)
7576
terminal_nodes.append(terminal_node)
7677
await self.websocket_simulation_result(reward, terminal_node, websocket=websocket)
7778

7879
if reward == 1:
80+
await self.websocket_search_complete("success", reward, terminal_node.get_trajectory(), websocket=websocket)
7981
return terminal_node
8082

8183
# Step 5: Backpropagation
@@ -96,10 +98,12 @@ async def lats_search(self, websocket=None):
9698
## temp change: if reward is the same, choose the deeper node
9799
best_child = max(all_nodes_list, key=lambda x: (x.reward, x.depth))
98100

99-
if best_child.reward == 1:
101+
if best_child.value >= 0.75:
100102
print("Successful trajectory found")
103+
await self.websocket_search_complete("success", best_child.value, best_child.get_trajectory(), websocket=websocket)
101104
else:
102105
print("Unsuccessful trajectory found")
106+
await self.websocket_search_complete("partial_success", best_child.value, best_child.get_trajectory(), websocket=websocket)
103107
await self.playwright_manager.close()
104108

105109
return best_child if best_child is not None else self.root_node

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,27 +109,27 @@ async def bfs(self, websocket=None):
109109
print(f"Found satisfactory solution with score {score}")
110110

111111
# Send completion update if websocket is provided
112-
await self.websocket_search_complete("success", score, current_node.get_trajectory(), websocket=None)
112+
await self.websocket_search_complete("success", score, current_node.get_trajectory(), websocket=websocket)
113113

114-
return [{"action": node.action} for node in path[1:]]
114+
return current_node
115115

116116
# If we've exhausted all nodes and haven't found a perfect solution,
117117
# return the best path we found
118118
if best_path:
119119
print(f"Returning best path found with score {best_score}")
120120

121121
# Send completion update if websocket is provided
122-
await self.websocket_search_complete("partial_success", best_score, best_node.get_trajectory(), websocket=None)
122+
await self.websocket_search_complete("partial_success", best_score, best_node.get_trajectory(), websocket=websocket)
123123

124-
return [{"action": node.action} for node in best_path[1:]]
124+
return best_node
125125

126126
# If no path was found at all
127127
print("No valid path found")
128128

129129
# Send failure update if websocket is provided
130-
await self.websocket_search_complete("failure", 0, None, websocket=None)
130+
await self.websocket_search_complete("failure", 0, None, websocket=websocket)
131131

132-
return []
132+
return None
133133

134134
# TODO: first evaluate, then expansion
135135
async def dfs(self, websocket=None) -> List[Dict[str, Any]]:
@@ -209,8 +209,8 @@ async def dfs(self, websocket=None) -> List[Dict[str, Any]]:
209209
print(f"Found satisfactory solution with score {score}")
210210

211211
# Send completion update if websocket is provided
212-
await self.websocket_search_complete("success", score, current_node.get_trajectory(), websocket=None)
213-
return [{"action": node.action} for node in path[1:]]
212+
await self.websocket_search_complete("success", score, current_node.get_trajectory(), websocket=websocket)
213+
return current_node
214214

215215
# Add non-terminal children to stack in reverse order
216216
has_unvisited_children = False
@@ -233,15 +233,15 @@ async def dfs(self, websocket=None) -> List[Dict[str, Any]]:
233233
print(f"Returning best path found with score {best_score}")
234234

235235
# Send completion update if websocket is provided
236-
await self.websocket_search_complete("partial_success", best_score, best_node.get_trajectory(), websocket=None)
236+
await self.websocket_search_complete("partial_success", best_score, best_node.get_trajectory(), websocket=websocket)
237237

238-
return [{"action": node.action} for node in best_path[1:]]
238+
return best_node
239239

240240
# If no path was found at all
241241
print("No valid path found")
242242

243243
# Send failure update if websocket is provided
244-
await self.websocket_search_complete("failure", 0, None, websocket=None)
244+
await self.websocket_search_complete("failure", 0, None, websocket=websocket)
245245

246-
return []
246+
return None
247247

visual-tree-search-backend/app/api/lwats/core_async/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class AgentConfig:
2222
branching_factor: int = 5
2323
iterations: int = 1
2424
max_depth: int = 3
25-
num_simulations: int = 100
25+
num_simulations: int = 1
2626
account_reset: bool = True
2727

2828
# Features
Binary file not shown.

visual-tree-search-backend/test/test-tree-search-ws-lats.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ async def connect_and_test_search(
124124
color = COLORS.get(msg_type, COLORS['reset'])
125125
print(f"\nWebSocket message - Type: {color}{msg_type}{COLORS['reset']}")
126126
print(f"Raw message: {json.dumps(data, indent=2)}")
127+
128+
if msg_type == "search_complete":
129+
break
127130

128131
except websockets.exceptions.ConnectionClosed:
129132
logger.warning("WebSocket connection closed")

visual-tree-search-backend/test/test-tree-search-ws-simple.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ async def connect_and_test_search(
116116
color = COLORS.get(msg_type, COLORS['reset'])
117117
print(f"\nWebSocket message - Type: {color}{msg_type}{COLORS['reset']}")
118118
print(f"Raw message: {json.dumps(data, indent=2)}")
119+
120+
if msg_type == "search_complete":
121+
break
119122

120123
except websockets.exceptions.ConnectionClosed:
121124
logger.warning("WebSocket connection closed")

0 commit comments

Comments
 (0)