Skip to content

Commit bee4f6c

Browse files
authored
Merge pull request #89 from PathOnAI/improve_mcts
Improve mcts
2 parents f2f35fe + 6978bd5 commit bee4f6c

File tree

6 files changed

+78
-17
lines changed

6 files changed

+78
-17
lines changed

visual-tree-search-app/components/ControlPanelMCTS.tsx

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@ import { Button } from "@/components/ui/button";
33
import { Input } from "@/components/ui/input";
44
import { Label } from "@/components/ui/label";
55
import { Info, ChevronDown, ChevronUp } from "lucide-react";
6+
import { Checkbox } from "@/components/ui/checkbox";
67

78
interface SearchParams {
89
startingUrl: string;
910
goal: string;
11+
maxDepth: number;
1012
iterations: number;
13+
set_prior_value: boolean;
1114
}
1215

1316
interface ControlPanelProps {
@@ -105,6 +108,19 @@ const ControlPanelMCTS: React.FC<ControlPanelProps> = ({
105108
/>
106109
</div>
107110

111+
<div className="space-y-2">
112+
<Label htmlFor="maxDepth" className="text-slate-700 dark:text-slate-300 font-medium">Max Depth</Label>
113+
<Input
114+
id="maxDepth"
115+
type="number"
116+
min={1}
117+
max={10}
118+
value={searchParams.maxDepth}
119+
onChange={(e) => handleParamChange('maxDepth', parseInt(e.target.value))}
120+
className="border-slate-300 dark:border-slate-600 focus:ring-cyan-500 focus:border-cyan-500"
121+
/>
122+
</div>
123+
108124
<div className="space-y-2">
109125
<Label htmlFor="iterations" className="text-slate-700 dark:text-slate-300 font-medium">Iterations</Label>
110126
<Input
@@ -118,6 +134,26 @@ const ControlPanelMCTS: React.FC<ControlPanelProps> = ({
118134
/>
119135
</div>
120136
</div>
137+
138+
{/* Add prior_value checkbox */}
139+
<div className="mt-4">
140+
<div className="flex items-center space-x-2">
141+
<Checkbox
142+
id="set_prior_value"
143+
checked={searchParams.set_prior_value}
144+
onCheckedChange={(checked) => handleParamChange('set_prior_value', checked === true)}
145+
/>
146+
<Label
147+
htmlFor="set_prior_value"
148+
className="text-slate-700 dark:text-slate-300 font-medium cursor-pointer"
149+
>
150+
Use Prior Value
151+
</Label>
152+
</div>
153+
<p className="mt-1 ml-6 text-xs text-slate-500 dark:text-slate-400">
154+
When enabled, RMCTS will use an LLM to generate an initial value as a prior value for each newly generated node.
155+
</p>
156+
</div>
121157
</div>
122158
)}
123159
</div>

visual-tree-search-app/pages/MCTSAgent.tsx

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ interface SearchParams {
99
startingUrl: string;
1010
goal: string;
1111
maxDepth: number;
12-
num_simulations: number;
1312
iterations: number;
13+
set_prior_value: boolean;
1414
}
1515

1616
interface Message {
@@ -32,7 +32,7 @@ const MCTSAgent = () => {
3232
startingUrl: 'http://xwebarena.pathonai.org:7770/',
3333
goal: 'search running shoes, click on the first result',
3434
maxDepth: 3,
35-
num_simulations: 1,
35+
set_prior_value: false,
3636
iterations: 1
3737
});
3838

@@ -96,8 +96,11 @@ const MCTSAgent = () => {
9696
starting_url: searchParams.startingUrl,
9797
goal: searchParams.goal,
9898
search_algorithm: "mcts",
99-
iterations: searchParams.iterations
99+
iterations: searchParams.iterations,
100+
set_prior_value: searchParams.set_prior_value,
101+
max_depth: searchParams.maxDepth
100102
};
103+
console.log(request);
101104

102105
wsRef.current?.send(JSON.stringify(request));
103106
logMessage(request, 'outgoing');
@@ -141,7 +144,9 @@ const MCTSAgent = () => {
141144
starting_url: searchParams.startingUrl,
142145
goal: searchParams.goal,
143146
search_algorithm: "mcts",
144-
iterations: searchParams.iterations
147+
iterations: searchParams.iterations,
148+
set_prior_value: searchParams.set_prior_value,
149+
max_depth: searchParams.maxDepth
145150
};
146151

147152
wsRef.current?.send(JSON.stringify(request));

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/base_agent.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -486,10 +486,11 @@ async def node_evaluation(self, node: LATSNode) -> None:
486486
## TODO: check the logic of updating value/ reward, is the input value?
487487
def backpropagate(self, node: LATSNode, value: float) -> None:
488488
while node:
489-
node.visits += 1
490-
# Calculate running average: newAvg = oldAvg + (value - oldAvg) / newCount
491-
node.value += (value - node.value) / node.visits
492-
node = node.parent
489+
if node.depth != 0:
490+
node.visits += 1
491+
# Calculate running average: newAvg = oldAvg + (value - oldAvg) / newCount
492+
node.value += (value - node.value) / node.visits
493+
node = node.parent
493494

494495
# shared
495496
async def simulation(self, node: LATSNode, websocket=None) -> tuple[float, LATSNode]:

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,17 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
290290
print_entire_tree(self.root_node)
291291

292292

293+
# optional: prior value
294+
if self.config.set_prior_value:
295+
await self.websocket_step_start(step=2, step_name="node_children_evaluation", websocket=websocket)
296+
await self.node_children_evaluation(selected_node)
297+
tree_data = self._get_tree_data()
298+
if websocket:
299+
await self.websocket_tree_update(type="tree_update_node_children_evaluation", websocket=websocket, tree_data=tree_data)
300+
else:
301+
print("after evaluation")
302+
print_entire_tree(self.root_node)
303+
293304
# Step 3: simulation using the current node, (generate a path using the current node, and score the path)
294305
# TODO: implement simulation using openai
295306
print(f"{GREEN}Step 3: Simulation{RESET}")
@@ -337,14 +348,15 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
337348
print(f"{GREEN}Step 5: Backpropagation{RESET}")
338349
await self.websocket_step_start(step=5, step_name="backpropagation", websocket=websocket)
339350
for node in path:
340-
old_value = node.value
341-
node.visits += 1
342-
node.value += (score - node.value) / node.visits
343-
# consiste with lats backpropagation
344-
#node.value = (node.value * (node.visits - 1) + score) / node.visits
345-
print(f"Node {node.action}:")
346-
print(f" Visits: {node.visits}")
347-
print(f" Value: {old_value:.3f} -> {node.value:.3f}")
351+
if node != self.root_node:
352+
old_value = node.value
353+
node.visits += 1
354+
node.value += (score - node.value) / node.visits
355+
# consiste with lats backpropagation
356+
#node.value = (node.value * (node.visits - 1) + score) / node.visits
357+
print(f"Node {node.action}:")
358+
print(f" Visits: {node.visits}")
359+
print(f" Value: {old_value:.3f} -> {node.value:.3f}")
348360
# add websocket information, just use websocket here
349361
# if websocket:
350362
# await websocket.send_json({

visual-tree-search-backend/app/api/lwats/core_async/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@ class AgentConfig:
2525
num_simulations: int = 1
2626
account_reset: bool = True
2727

28+
# for LATS
2829
simulation_score: float = 0.75
30+
31+
# for MCTS
2932
reflection_score: float = 0.75
33+
set_prior_value: bool = False
3034

3135
# Features
3236
features: List[str] = field(default_factory=lambda: ['axtree'])

visual-tree-search-backend/app/api/routes/tree_search_websocket.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]):
7979
storage_state = message.get("storage_state", "app/api/shopping.json")
8080
iterations = message.get("iterations", 1) # Extract iterations parameter
8181
num_simulations=message.get("num_simulations", 1)
82+
set_prior_value = message.get("set_prior_value", False)
8283

8384
# Send status update
8485
await websocket.send_json({
@@ -95,8 +96,10 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]):
9596
storage_state=storage_state,
9697
headless=False,
9798
iterations=iterations,
98-
num_simulations=num_simulations
99+
num_simulations=num_simulations,
100+
set_prior_value=set_prior_value
99101
)
102+
print(config)
100103

101104
# Send status update
102105
await websocket.send_json({

0 commit comments

Comments
 (0)