Skip to content

Commit 78ab4dc

Browse files
authored
Merge pull request #95 from PathOnAI/fix_is_terminal
fix is_terminal
2 parents bbae3fe + 73d5ec0 commit 78ab4dc

File tree

5 files changed

+59
-105
lines changed

5 files changed

+59
-105
lines changed

visual-tree-search-app/components/TreeVisual.tsx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ const TreeVisual: React.FC<TreeVisualProps> = ({
490490
tooltipContent += `<div class="mt-2">${nodeInfo.join(' | ')}</div>`;
491491
}
492492

493+
493494
// Add value info if available
494495
if (typeof d.data.value === 'number') {
495496
tooltipContent += `<div>Value: <span class="font-bold">${d.data.value.toFixed(2)}</span></div>`;
@@ -504,6 +505,11 @@ const TreeVisual: React.FC<TreeVisualProps> = ({
504505
if (typeof d.data.depth === 'number') {
505506
tooltipContent += `<div>Depth: <span class="font-bold">${d.data.depth}</span></div>`;
506507
}
508+
509+
// Add is_terminal info if available
510+
if (typeof d.data.is_terminal === 'boolean') {
511+
tooltipContent += `<div>Is Terminal: <span class="font-bold">${d.data.is_terminal ? 'Yes' : 'No'}</span></div>`;
512+
}
507513

508514
const tooltip = d3.select(tooltipRef.current);
509515
tooltip.transition()

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/base_agent.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -398,19 +398,10 @@ async def node_expansion(self, node: LATSNode, websocket = None) -> None:
398398
goal=node.goal,
399399
parent=node
400400
)
401+
if child.depth == self.config.max_depth:
402+
child.is_terminal = True
401403
node.children.append(child)
402404
await self.websocket_node_created(child, node, websocket=websocket)
403-
404-
# Send child creation update if websocket is provided
405-
# if websocket:
406-
# await websocket.send_json({
407-
# "type": "node_created",
408-
# "node_id": id(child),
409-
# "parent_id": id(node),
410-
# "action": child.action,
411-
# "description": child.natural_language_description,
412-
# "timestamp": datetime.utcnow().isoformat()
413-
# })
414405

415406

416407
# node evaluation
@@ -420,17 +411,17 @@ async def node_children_evaluation(self, node: LATSNode) -> None:
420411
print(f"{GREEN}-- total {len(node.children)} children to evaluate:{RESET}")
421412
for i, child in enumerate(node.children):
422413
print(f"{GREEN}--- evaluating child {i+1}...{RESET}")
423-
if child.is_terminal:
414+
# if child.is_terminal:
415+
# score = 0
416+
# else:
417+
trajectory = child.get_trajectory()
418+
if len(trajectory) == 0:
424419
score = 0
425420
else:
426-
trajectory = child.get_trajectory()
427-
if len(trajectory) == 0:
428-
score = 0
429-
else:
430-
prompt = create_llm_prompt(trajectory, self.goal)
431-
# , child.observation.image
432-
result = score_trajectory_with_openai(prompt, openai_client, self.config.evaluation_model)
433-
score = result["overall_score"]
421+
prompt = create_llm_prompt(trajectory, self.goal)
422+
# , child.observation.image
423+
result = score_trajectory_with_openai(prompt, openai_client, self.config.evaluation_model)
424+
score = result["overall_score"]
434425
scores.append(score)
435426

436427
for child, score in zip(node.children, scores):
@@ -454,19 +445,19 @@ async def node_evaluation(self, node: LATSNode) -> None:
454445

455446
try:
456447
# Score the trajectory
457-
if node.is_terminal:
448+
# if node.is_terminal:
449+
# score = 0
450+
# else:
451+
if len(trajectory) == 0:
458452
score = 0
459453
else:
460-
if len(trajectory) == 0:
461-
score = 0
462-
else:
463-
prompt = create_llm_prompt(trajectory, self.goal)
464-
result = score_trajectory_with_openai(
465-
prompt,
466-
openai_client,
467-
model=self.config.evaluation_model
468-
)
469-
score = result["overall_score"]
454+
prompt = create_llm_prompt(trajectory, self.goal)
455+
result = score_trajectory_with_openai(
456+
prompt,
457+
openai_client,
458+
model=self.config.evaluation_model
459+
)
460+
score = result["overall_score"]
470461

471462
except Exception as e:
472463
error_msg = f"Error scoring node {id(node)}: {str(e)}"

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_node.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def __init__(
8888
self.value = 0.0
8989
self.depth = 0 if parent is None else parent.depth + 1
9090
self.is_terminal = False
91-
# self.reward = 0.0
91+
# The goal has been achieved;
92+
# The maximum depth has been reached;
93+
# A failure condition has been triggered.
9294
self.exhausted = False # If all children are terminal
9395
self.em = 0.0 # Exact match, evaluation metric
9496
self.observation: Optional[Observation] = None
@@ -106,6 +108,16 @@ def uct(self) -> float:
106108
return self.value / self.visits + np.sqrt(2 * np.log(self.parent.visits) / self.visits)
107109

108110
def get_best_leaf(self) -> 'LATSNode':
111+
"""
112+
Recursively get the best leaf node from the current node.
113+
114+
The method searches through unfinished (non-terminal) children,
115+
selects the one with the highest UCT score, and continues the
116+
search recursively until a leaf node (with no unfinished children) is reached.
117+
118+
Returns:
119+
LATSNode: The best leaf node for expansion based on UCT.
120+
"""
109121
unfinished_children = [c for c in self.children if not c.is_terminal]
110122
if not unfinished_children:
111123
return self

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -361,19 +361,7 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
361361
print(f"Node {node.action}:")
362362
print(f" Visits: {node.visits}")
363363
print(f" Value: {old_value:.3f} -> {node.value:.3f}")
364-
# add websocket information, just use websocket here
365-
# if websocket:
366-
# await websocket.send_json({
367-
# "type": "backpropagation",
368-
# "node_id": id(node),
369-
# "node_parent_id": id(node.parent),
370-
# "node_action": node.action,
371-
# "node_value": node.value,
372-
# "node_visits": node.visits,
373-
# "node_old_value": old_value,
374-
# "node_description": node.natural_language_description,
375-
# })
376-
364+
377365
tree_data = self._get_tree_data()
378366
print_entire_tree(self.root_node)
379367
print(tree_data)

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py

Lines changed: 17 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,8 @@ async def bfs(self, websocket=None):
4747
for _ in range(level_size):
4848
current_node = queue.popleft()
4949
queue_set.remove(current_node) # Remove from queue tracking
50-
51-
# Skip if we've already visited this node
52-
if current_node in visited:
53-
continue
54-
5550
visited.add(current_node)
5651

57-
# Skip terminal nodes
58-
if current_node.is_terminal:
59-
continue
60-
6152
# Expand current node if it hasn't been expanded yet and hasn't reached max_depth
6253
# node expansion for the next level
6354
if not current_node.children and current_node.depth < self.config.max_depth:
@@ -78,7 +69,7 @@ async def bfs(self, websocket=None):
7869

7970
# Add non-terminal children to queue for next level if they haven't reached max_depth
8071
for child in current_node.children:
81-
if not child.is_terminal and child not in visited and child not in queue_set and child.depth < self.config.max_depth:
72+
if child not in visited and child not in queue_set and child.depth <= self.config.max_depth:
8273
queue.append(child)
8374
queue_set.add(child) # Add to queue tracking
8475

@@ -134,76 +125,50 @@ async def bfs(self, websocket=None):
134125

135126
return None
136127

137-
# TODO: first evaluate, then expansion
138-
async def dfs(self, websocket=None) -> List[Dict[str, Any]]:
139-
stack = [self.root_node]
128+
async def dfs(self, websocket=None):
129+
stack = [self.root_node] # Use a list as a stack
140130
stack_set = {self.root_node} # Track nodes in stack
141131
best_score = float('-inf')
142132
best_path = None
143133
best_node = None
144134
visited = set() # Track visited nodes to avoid cycles
145-
current_path = [] # Track current path for DFS
146-
147-
# # Get the live browser URL during initial setup
148-
# live_browser_url, session_id = await self._reset_browser(websocket)
149-
150135

151136
while stack:
152-
current_node = stack[-1] # Peek at the top node without removing it
153-
154-
# Skip if we've already visited this node
155-
if current_node in visited:
156-
stack.pop()
157-
stack_set.remove(current_node)
158-
if current_path:
159-
current_path.pop() # Remove from current path
160-
continue
161-
137+
# Get the top node from the stack
138+
current_node = stack.pop()
139+
stack_set.remove(current_node) # Remove from stack tracking
162140
visited.add(current_node)
163-
current_path.append(current_node) # Add to current path
164141

165-
# Skip terminal nodes
166-
if current_node.is_terminal:
167-
print(f"Node {id(current_node)} is terminal")
168-
stack.pop()
169-
stack_set.remove(current_node)
170-
current_path.pop() # Remove from current path
171-
continue
172-
173142
# Expand current node if it hasn't been expanded yet and hasn't reached max_depth
174-
# stage 1: node expansion
175143
if not current_node.children and current_node.depth < self.config.max_depth:
176-
## during the node expansion process, reset browser for each node
144+
# Reset browser for each node expansion
177145
live_browser_url, session_id = await self._reset_browser(websocket)
178-
# await self.websocket_step_start(step=1, step_name="node_expansion", websocket=websocket)
179146
await self.websocket_node_selection(current_node, websocket=websocket)
180147
await self.node_expansion(current_node, websocket)
181148
tree_data = self._get_tree_data()
149+
182150
if websocket:
183151
await self.websocket_tree_update(type="tree_update_node_expansion", websocket=websocket, tree_data=tree_data)
184152
else:
185153
print_entire_tree(self.root_node)
186154

187-
# Get the path from root to this node
188-
path = self.get_path_to_root(current_node)
155+
# Node evaluation
189156
await self.node_evaluation(current_node)
190157
tree_data = self._get_tree_data()
191158
if websocket:
192159
await self.websocket_tree_update(type="tree_update_node_evaluation", websocket=websocket, tree_data=tree_data)
193160
else:
194161
print("after evaluation")
195162
print_entire_tree(self.root_node)
196-
path = self.get_path_to_root(current_node)
197163

198-
164+
path = self.get_path_to_root(current_node)
199165
score = current_node.value
200166

201167
# Update best path if this score is better
202168
if score > best_score:
203169
best_score = score
204170
best_path = path
205171
best_node = current_node
206-
207172

208173
print(f"Node {id(current_node)} score: {score}")
209174

@@ -213,23 +178,16 @@ async def dfs(self, websocket=None) -> List[Dict[str, Any]]:
213178

214179
# Send completion update if websocket is provided
215180
await self.websocket_search_complete("success", score, current_node.get_trajectory(), websocket=websocket)
216-
await self.playwright_manager.close()
181+
await self.playwright_manager.close()
182+
217183
return current_node
218-
219-
# Add non-terminal children to stack in reverse order
220-
has_unvisited_children = False
184+
185+
# Add children to stack in reverse order so that the first child is processed first
186+
# This maintains the left-to-right exploration order similar to BFS
221187
for child in reversed(current_node.children):
222-
if not child.is_terminal and child not in visited and child not in stack_set:
188+
if child not in visited and child not in stack_set and child.depth <= self.config.max_depth:
223189
stack.append(child)
224190
stack_set.add(child) # Add to stack tracking
225-
has_unvisited_children = True
226-
break # Only add one child at a time for DFS
227-
228-
# If no unvisited children, remove current node from stack
229-
if not has_unvisited_children:
230-
stack.pop()
231-
stack_set.remove(current_node)
232-
current_path.pop() # Remove from current path
233191

234192
# If we've exhausted all nodes and haven't found a perfect solution,
235193
# return the best path we found
@@ -249,5 +207,4 @@ async def dfs(self, websocket=None) -> List[Dict[str, Any]]:
249207
await self.websocket_search_complete("failure", 0, None, websocket=websocket)
250208
await self.playwright_manager.close()
251209

252-
return None
253-
210+
return None

0 commit comments

Comments
 (0)