Skip to content

Commit 60bc5cd

Browse files
committed
log message: reflection & backtracking
1 parent 9a90faa commit 60bc5cd

File tree

4 files changed

+108
-72
lines changed

4 files changed

+108
-72
lines changed

visual-tree-search-app/components/MessageLogPanelMCTS.tsx

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ const MessageLogPanelMCTS: React.FC<MessageLogPanelProps> = ({ messages, message
109109
const getCardStyle = (type: string) => {
110110
switch (type) {
111111
// System Status Messages
112+
case 'reflection_backtracking':
113+
return "bg-gradient-to-r from-blue-50 to-blue-100 dark:from-blue-900/20 dark:to-blue-800/20 border-blue-200 dark:border-blue-800";
112114
case 'server_connection':
113115
return "bg-gradient-to-r from-blue-50 to-blue-100 dark:from-blue-900/20 dark:to-blue-800/20 border-blue-200 dark:border-blue-800";
114116
case 'start_search':
@@ -208,6 +210,8 @@ const MessageLogPanelMCTS: React.FC<MessageLogPanelProps> = ({ messages, message
208210

209211
const getIcon = (message: ParsedMessage) => {
210212
switch (message.type) {
213+
case 'reflection_backtracking':
214+
return <Brain className="h-4 w-4 text-blue-500" />;
211215
case 'server_connection':
212216
return <Globe className="h-4 w-4 text-green-500 animate-pulse" />;
213217
case 'start_search':
@@ -323,6 +327,8 @@ const MessageLogPanelMCTS: React.FC<MessageLogPanelProps> = ({ messages, message
323327
const getIconBgColor = (type: string) => {
324328
switch (type) {
325329
// System Status Messages
330+
case 'reflection_backtracking':
331+
return "bg-gradient-to-r from-blue-50 to-blue-100 dark:from-blue-900/20 dark:to-blue-800/20 border-blue-200 dark:border-blue-800";
326332
case 'start_search':
327333
return "bg-blue-100 dark:bg-blue-800/30 text-blue-600 dark:text-blue-400";
328334
case 'connection_established':
@@ -552,6 +558,32 @@ const MessageLogPanelMCTS: React.FC<MessageLogPanelProps> = ({ messages, message
552558
</div>
553559
);
554560

561+
case 'reflection_backtracking':
562+
return (
563+
<div className="flex items-center gap-2 animate-fadeIn">
564+
{getIcon(message)}
565+
<div className="animate-slideIn">
566+
<div className="text-emerald-600 dark:text-emerald-400">
567+
Reflecting & backtracking | Node: {message.description}
568+
</div>
569+
{message.path && message.path.length > 0 && (
570+
<div className="mt-1">
571+
{message.path.map((step: PathStep, index: number) => (
572+
<div
573+
key={index}
574+
className="flex items-start gap-1 text-xs text-slate-500 dark:text-slate-400 animate-fadeIn"
575+
style={{ animationDelay: `${index * 100}ms` }}
576+
>
577+
<ArrowRight className="h-3 w-3 mt-0.5" />
578+
{step.natural_language_description}
579+
</div>
580+
))}
581+
</div>
582+
)}
583+
</div>
584+
</div>
585+
);
586+
555587
case 'search_complete':
556588
return (
557589
<div className="flex items-center gap-2 animate-fadeIn">
@@ -668,6 +700,7 @@ const MessageLogPanelMCTS: React.FC<MessageLogPanelProps> = ({ messages, message
668700
</div>
669701
);
670702

703+
671704
default:
672705
return (
673706
<div className="flex items-center gap-2 animate-fadeIn">

visual-tree-search-app/pages/LATSAgent.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ const LATSAgent = () => {
3333
goal: 'search running shoes, click on the first result',
3434
maxDepth: 3,
3535
num_simulations: 1,
36-
iterations: 1
36+
iterations: 2
3737
});
3838

3939
const [sessionId, setSessionId] = useState<string | null>(null);

visual-tree-search-app/pages/MCTSAgent.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ const MCTSAgent = () => {
3333
goal: 'search running shoes, click on the first result',
3434
maxDepth: 3,
3535
set_prior_value: false,
36-
iterations: 1
36+
iterations: 2
3737
});
3838

3939
const [sessionId, setSessionId] = useState<string | null>(null);

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py

Lines changed: 73 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,13 @@ async def websocket_reflection_backtracking(self, path, selected_node, websocket
165165
if websocket:
166166
await websocket.send_json({
167167
"type": "reflection_backtracking",
168-
"path": [node.action for node in path if node.action is not None],
168+
"path": [{
169+
"natural_language_description": node.natural_language_description,
170+
"action": node.action} for node in path if node.action is not None],
169171
"node_id": id(selected_node),
170-
"node_parent_id": id(selected_node.parent),
171-
"node_action": selected_node.action,
172-
"node_description": selected_node.natural_language_description,
172+
"parent_id": id(selected_node.parent),
173+
"action": selected_node.action,
174+
"description": selected_node.natural_language_description,
173175
"trajectory": selected_node.get_trajectory()
174176
})
175177

@@ -304,80 +306,81 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
304306

305307
# Step 3: simulation using the current node, (generate a path using the current node, and score the path)
306308
# TODO: implement simulation using openai
307-
print(f"{GREEN}Step 3: Simulation{RESET}")
308-
await self.websocket_step_start(step=3, step_name="simulation", websocket=websocket)
309-
path = self.get_path_to_root(selected_node)
310-
# here score is the reward
311-
score = await self.evaluate_selected_path(path)
312-
# change to reward later?
313-
if score > best_score:
314-
best_score = score
315-
best_path = path
316-
best_node = selected_node
317-
print(f"\nNew best path found!")
318-
print(f"best score: {best_score:.3f}")
319-
print(f"best node: {best_node.action}")
320-
print(f"best node: {best_node.natural_language_description}")
321-
print(f"best path: {best_path}")
322-
323-
# add websocket information, just use websocket here
324-
if websocket:
325-
await self.websocket_simulation_result(score, selected_node, websocket=websocket)
309+
if selected_node != self.root_node:
310+
print(f"{GREEN}Step 3: Simulation{RESET}")
311+
await self.websocket_step_start(step=3, step_name="simulation", websocket=websocket)
312+
path = self.get_path_to_root(selected_node)
313+
# here score is the reward
314+
score = await self.evaluate_selected_path(path)
315+
# change to reward later?
316+
if score > best_score:
317+
best_score = score
318+
best_path = path
319+
best_node = selected_node
320+
print(f"\nNew best path found!")
321+
print(f"best score: {best_score:.3f}")
322+
print(f"best node: {best_node.action}")
323+
print(f"best node: {best_node.natural_language_description}")
324+
print(f"best path: {best_path}")
326325

326+
# add websocket information, just use websocket here
327+
if websocket:
328+
await self.websocket_simulation_result(score, selected_node, websocket=websocket)
327329

328-
## Step 4: reflection backtracking
329-
print(f"{GREEN}Step 4: Reflection Backtracking{RESET}")
330-
await self.websocket_step_start(step=4, step_name="reflection_backtracking", websocket=websocket)
331-
if score >= self.config.reflection_score:
332-
# Convert path to serializable trajectory
333-
# trajectory = [node.action for node in path if node.action is not None]
334-
await self.websocket_search_complete("success", score, selected_node.get_trajectory(), websocket=websocket)
335-
await self.playwright_manager.close()
336-
return selected_node
337330

338-
print(f"path: {path}")
339-
path, current_node = await self.reflection_backtracking(path)
340-
print(f"path: {path}")
341-
print(f"current_node: {current_node.action}")
342-
print(f"current_node: {current_node.natural_language_description}")
331+
## Step 4: reflection backtracking
332+
print(f"{GREEN}Step 4: Reflection Backtracking{RESET}")
333+
await self.websocket_step_start(step=4, step_name="reflection_backtracking", websocket=websocket)
334+
if score >= self.config.reflection_score:
335+
# Convert path to serializable trajectory
336+
# trajectory = [node.action for node in path if node.action is not None]
337+
await self.websocket_search_complete("success", score, selected_node.get_trajectory(), websocket=websocket)
338+
await self.playwright_manager.close()
339+
return selected_node
343340

344-
# add websocket information, just use websocket here
345-
if websocket:
346-
await self.websocket_reflection_backtracking(path, current_node, websocket=websocket)
341+
print(f"path: {path}")
342+
path, current_node = await self.reflection_backtracking(path)
343+
print(f"path: {path}")
344+
print(f"current_node: {current_node.action}")
345+
print(f"current_node: {current_node.natural_language_description}")
347346

348-
# Step 5: backpropagation
349-
print(f"{GREEN}Step 5: Backpropagation{RESET}")
350-
await self.websocket_step_start(step=5, step_name="backpropagation", websocket=websocket)
351-
for node in path:
352-
if node != self.root_node:
353-
old_value = node.value
354-
node.visits += 1
355-
node.value += (score - node.value) / node.visits
356-
# consiste with lats backpropagation
357-
#node.value = (node.value * (node.visits - 1) + score) / node.visits
358-
print(f"Node {node.action}:")
359-
print(f" Visits: {node.visits}")
360-
print(f" Value: {old_value:.3f} -> {node.value:.3f}")
361347
# add websocket information, just use websocket here
362-
# if websocket:
363-
# await websocket.send_json({
364-
# "type": "backpropagation",
365-
# "node_id": id(node),
366-
# "node_parent_id": id(node.parent),
367-
# "node_action": node.action,
368-
# "node_value": node.value,
369-
# "node_visits": node.visits,
370-
# "node_old_value": old_value,
371-
# "node_description": node.natural_language_description,
372-
# })
348+
if websocket:
349+
await self.websocket_reflection_backtracking(path, current_node, websocket=websocket)
373350

374-
tree_data = self._get_tree_data()
375-
print_entire_tree(self.root_node)
376-
print(tree_data)
377-
if websocket:
378-
await self.websocket_tree_update(type="tree_update_node_backpropagation", websocket=websocket, tree_data=tree_data)
379-
else:
351+
# Step 5: backpropagation
352+
print(f"{GREEN}Step 5: Backpropagation{RESET}")
353+
await self.websocket_step_start(step=5, step_name="backpropagation", websocket=websocket)
354+
for node in path:
355+
if node != self.root_node:
356+
old_value = node.value
357+
node.visits += 1
358+
node.value += (score - node.value) / node.visits
359+
# consiste with lats backpropagation
360+
#node.value = (node.value * (node.visits - 1) + score) / node.visits
361+
print(f"Node {node.action}:")
362+
print(f" Visits: {node.visits}")
363+
print(f" Value: {old_value:.3f} -> {node.value:.3f}")
364+
# add websocket information, just use websocket here
365+
# if websocket:
366+
# await websocket.send_json({
367+
# "type": "backpropagation",
368+
# "node_id": id(node),
369+
# "node_parent_id": id(node.parent),
370+
# "node_action": node.action,
371+
# "node_value": node.value,
372+
# "node_visits": node.visits,
373+
# "node_old_value": old_value,
374+
# "node_description": node.natural_language_description,
375+
# })
376+
377+
tree_data = self._get_tree_data()
380378
print_entire_tree(self.root_node)
379+
print(tree_data)
380+
if websocket:
381+
await self.websocket_tree_update(type="tree_update_node_backpropagation", websocket=websocket, tree_data=tree_data)
382+
else:
383+
print_entire_tree(self.root_node)
381384
if best_node:
382385
# Convert node to serializable trajectory
383386
# trajectory = [n.action for n in self.get_path_to_root(best_node) if n.action is not None]

0 commit comments

Comments
 (0)