Skip to content

Commit 515ba8b

Browse files
authored
Merge pull request #90 from PathOnAI/improve_mcts
Improve mcts
2 parents bee4f6c + 60bc5cd commit 515ba8b

File tree

8 files changed

+288
-127
lines changed

8 files changed

+288
-127
lines changed

visual-tree-search-app/components/MessageLogPanel.tsx

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ interface MessageLogPanelProps {
4343

4444
interface ParsedMessage {
4545
type?: string;
46+
info?: string;
4647
content?: string;
4748
status?: string;
4849
message?: string;
@@ -95,6 +96,8 @@ const MessageLogPanel: React.FC<MessageLogPanelProps> = ({ messages, messagesEnd
9596
const getCardStyle = (type: string) => {
9697
switch (type) {
9798
// System Status Messages
99+
case 'server_connection':
100+
return "bg-gradient-to-r from-blue-50 to-blue-100 dark:from-blue-900/20 dark:to-blue-800/20 border-blue-200 dark:border-blue-800";
98101
case 'start_search':
99102
return "bg-gradient-to-r from-blue-50 to-blue-100 dark:from-blue-900/20 dark:to-blue-800/20 border-blue-200 dark:border-blue-800";
100103
case 'connection_established':
@@ -170,6 +173,8 @@ const MessageLogPanel: React.FC<MessageLogPanelProps> = ({ messages, messagesEnd
170173

171174
const getIcon = (message: ParsedMessage) => {
172175
switch (message.type) {
176+
case 'server_connection':
177+
return <Globe className="h-4 w-4 text-green-500 animate-pulse" />;
173178
case 'start_search':
174179
return <Target className="h-4 w-4 text-blue-500" />;
175180
case 'connection_established':
@@ -259,6 +264,8 @@ const MessageLogPanel: React.FC<MessageLogPanelProps> = ({ messages, messagesEnd
259264
const getIconBgColor = (type: string) => {
260265
switch (type) {
261266
// System Status Messages
267+
case 'server_connection':
268+
return "bg-gradient-to-r from-blue-50 to-blue-100 dark:from-blue-900/20 dark:to-blue-800/20 border-blue-200 dark:border-blue-800";
262269
case 'start_search':
263270
return "bg-blue-100 dark:bg-blue-800/30 text-blue-600 dark:text-blue-400";
264271
case 'connection_established':
@@ -351,6 +358,16 @@ const MessageLogPanel: React.FC<MessageLogPanelProps> = ({ messages, messagesEnd
351358

352359
const formatMessageContent = (message: ParsedMessage) => {
353360
switch (message.type) {
361+
case 'server_connection':
362+
return (
363+
<div className="flex items-center gap-2 animate-fadeIn">
364+
{getIcon(message)}
365+
<div className="animate-slideIn">
366+
<div className="text-green-600 dark:text-green-400">{message.info}</div>
367+
</div>
368+
</div>
369+
);
370+
354371
case 'start_search':
355372
return (
356373
<div className="flex items-center gap-2 animate-fadeIn">

visual-tree-search-app/components/MessageLogPanelLATS.tsx

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ interface MessageLogPanelProps {
4848

4949
interface ParsedMessage {
5050
type?: string;
51+
info?: string;
5152
content?: string;
5253
status?: string;
5354
message?: string;
@@ -77,6 +78,7 @@ interface ParsedMessage {
7778
step_name?: string;
7879
iteration?: number;
7980
session_id?: string;
81+
node_action?: string;
8082
}
8183

8284
interface PathStep {
@@ -108,6 +110,8 @@ const MessageLogPanelLATS: React.FC<MessageLogPanelProps> = ({ messages, message
108110
const getCardStyle = (type: string) => {
109111
switch (type) {
110112
// System Status Messages
113+
case 'server_connection':
114+
return "bg-gradient-to-r from-blue-50 to-blue-100 dark:from-blue-900/20 dark:to-blue-800/20 border-blue-200 dark:border-blue-800";
111115
case 'start_search':
112116
return "bg-gradient-to-r from-blue-50 to-blue-100 dark:from-blue-900/20 dark:to-blue-800/20 border-blue-200 dark:border-blue-800";
113117
case 'connection_established':
@@ -205,6 +209,8 @@ const MessageLogPanelLATS: React.FC<MessageLogPanelProps> = ({ messages, message
205209

206210
const getIcon = (message: ParsedMessage) => {
207211
switch (message.type) {
212+
case 'server_connection':
213+
return <Globe className="h-4 w-4 text-green-500 animate-pulse" />;
208214
case 'start_search':
209215
return <Target className="h-4 w-4 text-blue-500" />;
210216
case 'connection_established':
@@ -318,6 +324,8 @@ const MessageLogPanelLATS: React.FC<MessageLogPanelProps> = ({ messages, message
318324
const getIconBgColor = (type: string) => {
319325
switch (type) {
320326
// System Status Messages
327+
case 'server_connection':
328+
return "bg-gradient-to-r from-blue-50 to-blue-100 dark:from-blue-900/20 dark:to-blue-800/20 border-blue-200 dark:border-blue-800";
321329
case 'start_search':
322330
return "bg-blue-100 dark:bg-blue-800/30 text-blue-600 dark:text-blue-400";
323331
case 'connection_established':
@@ -433,6 +441,15 @@ const MessageLogPanelLATS: React.FC<MessageLogPanelProps> = ({ messages, message
433441

434442
const formatMessageContent = (message: ParsedMessage) => {
435443
switch (message.type) {
444+
case 'server_connection':
445+
return (
446+
<div className="flex items-center gap-2 animate-fadeIn">
447+
{getIcon(message)}
448+
<div className="animate-slideIn">
449+
<div className="text-green-600 dark:text-green-400">{message.info}</div>
450+
</div>
451+
</div>
452+
);
436453
case 'start_search':
437454
return (
438455
<div className="flex items-center gap-2 animate-fadeIn">

visual-tree-search-app/components/MessageLogPanelMCTS.tsx

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ interface MessageLogPanelProps {
4848

4949
interface ParsedMessage {
5050
type?: string;
51+
info?: string;
5152
content?: string;
5253
status?: string;
5354
message?: string;
@@ -108,6 +109,10 @@ const MessageLogPanelMCTS: React.FC<MessageLogPanelProps> = ({ messages, message
108109
const getCardStyle = (type: string) => {
109110
switch (type) {
110111
// System Status Messages
112+
case 'reflection_backtracking':
113+
return "bg-gradient-to-r from-blue-50 to-blue-100 dark:from-blue-900/20 dark:to-blue-800/20 border-blue-200 dark:border-blue-800";
114+
case 'server_connection':
115+
return "bg-gradient-to-r from-blue-50 to-blue-100 dark:from-blue-900/20 dark:to-blue-800/20 border-blue-200 dark:border-blue-800";
111116
case 'start_search':
112117
return "bg-gradient-to-r from-blue-50 to-blue-100 dark:from-blue-900/20 dark:to-blue-800/20 border-blue-200 dark:border-blue-800";
113118
case 'connection_established':
@@ -205,6 +210,10 @@ const MessageLogPanelMCTS: React.FC<MessageLogPanelProps> = ({ messages, message
205210

206211
const getIcon = (message: ParsedMessage) => {
207212
switch (message.type) {
213+
case 'reflection_backtracking':
214+
return <Brain className="h-4 w-4 text-blue-500" />;
215+
case 'server_connection':
216+
return <Globe className="h-4 w-4 text-green-500 animate-pulse" />;
208217
case 'start_search':
209218
return <Target className="h-4 w-4 text-blue-500" />;
210219
case 'connection_established':
@@ -318,6 +327,8 @@ const MessageLogPanelMCTS: React.FC<MessageLogPanelProps> = ({ messages, message
318327
const getIconBgColor = (type: string) => {
319328
switch (type) {
320329
// System Status Messages
330+
case 'reflection_backtracking':
331+
return "bg-gradient-to-r from-blue-50 to-blue-100 dark:from-blue-900/20 dark:to-blue-800/20 border-blue-200 dark:border-blue-800";
321332
case 'start_search':
322333
return "bg-blue-100 dark:bg-blue-800/30 text-blue-600 dark:text-blue-400";
323334
case 'connection_established':
@@ -433,6 +444,15 @@ const MessageLogPanelMCTS: React.FC<MessageLogPanelProps> = ({ messages, message
433444

434445
const formatMessageContent = (message: ParsedMessage) => {
435446
switch (message.type) {
447+
case 'server_connection':
448+
return (
449+
<div className="flex items-center gap-2 animate-fadeIn">
450+
{getIcon(message)}
451+
<div className="animate-slideIn">
452+
<div className="text-green-600 dark:text-green-400">{message.info}</div>
453+
</div>
454+
</div>
455+
);
436456
case 'start_search':
437457
return (
438458
<div className="flex items-center gap-2 animate-fadeIn">
@@ -538,6 +558,32 @@ const MessageLogPanelMCTS: React.FC<MessageLogPanelProps> = ({ messages, message
538558
</div>
539559
);
540560

561+
case 'reflection_backtracking':
562+
return (
563+
<div className="flex items-center gap-2 animate-fadeIn">
564+
{getIcon(message)}
565+
<div className="animate-slideIn">
566+
<div className="text-emerald-600 dark:text-emerald-400">
567+
Reflecting & backtracking | Node: {message.description}
568+
</div>
569+
{message.path && message.path.length > 0 && (
570+
<div className="mt-1">
571+
{message.path.map((step: PathStep, index: number) => (
572+
<div
573+
key={index}
574+
className="flex items-start gap-1 text-xs text-slate-500 dark:text-slate-400 animate-fadeIn"
575+
style={{ animationDelay: `${index * 100}ms` }}
576+
>
577+
<ArrowRight className="h-3 w-3 mt-0.5" />
578+
{step.natural_language_description}
579+
</div>
580+
))}
581+
</div>
582+
)}
583+
</div>
584+
</div>
585+
);
586+
541587
case 'search_complete':
542588
return (
543589
<div className="flex items-center gap-2 animate-fadeIn">
@@ -654,6 +700,7 @@ const MessageLogPanelMCTS: React.FC<MessageLogPanelProps> = ({ messages, message
654700
</div>
655701
);
656702

703+
657704
default:
658705
return (
659706
<div className="flex items-center gap-2 animate-fadeIn">

visual-tree-search-app/pages/LATSAgent.tsx

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ const LATSAgent = () => {
3333
goal: 'search running shoes, click on the first result',
3434
maxDepth: 3,
3535
num_simulations: 1,
36-
iterations: 1
36+
iterations: 2
3737
});
3838

3939
const [sessionId, setSessionId] = useState<string | null>(null);
@@ -87,10 +87,17 @@ const LATSAgent = () => {
8787
wsRef.current = new WebSocket(wsUrl);
8888

8989
wsRef.current.onopen = () => {
90-
logMessage('Connected to LATS WebSocket server');
90+
const connectionMessage = {
91+
type: "server_connection",
92+
info: 'Connecting to LATS WebSocket server'
93+
};
94+
95+
wsRef.current?.send(JSON.stringify(connectionMessage));
96+
logMessage(connectionMessage, 'incoming');
97+
// logMessage('Connected to LATS WebSocket server');
9198
setConnected(true);
9299

93-
const request = {
100+
const searchRequest = {
94101
type: "start_search",
95102
agent_type: "LATSAgent",
96103
starting_url: searchParams.startingUrl,
@@ -101,8 +108,8 @@ const LATSAgent = () => {
101108
iterations: searchParams.iterations
102109
};
103110

104-
wsRef.current?.send(JSON.stringify(request));
105-
logMessage(request, 'outgoing');
111+
wsRef.current?.send(JSON.stringify(searchRequest));
112+
logMessage(searchRequest, 'outgoing');
106113
};
107114

108115
wsRef.current.onmessage = (event) => {
@@ -119,19 +126,31 @@ const LATSAgent = () => {
119126
};
120127

121128
wsRef.current.onclose = () => {
122-
logMessage('Disconnected from WebSocket server');
129+
const closeMessage = {
130+
type: "server_connection",
131+
info: 'Disconnected from WebSocket server'
132+
};
133+
logMessage(closeMessage, 'incoming');
123134
setConnected(false);
124135
setIsSearching(false);
125136
wsRef.current = null;
126137
};
127138

128139
wsRef.current.onerror = (error) => {
129-
logMessage(`WebSocket error: ${error instanceof Error ? error.message : String(error)}`);
140+
const errorMessage = {
141+
type: "server_connection",
142+
info: `WebSocket error: ${error instanceof Error ? error.message : String(error)}`
143+
};
144+
logMessage(errorMessage, 'incoming');
130145
setConnected(false);
131146
setIsSearching(false);
132147
};
133148
} catch (error) {
134-
logMessage(`Failed to connect: ${error instanceof Error ? error.message : String(error)}`);
149+
const errorMessage = {
150+
type: "server_connection",
151+
info: `Failed to connect: ${error instanceof Error ? error.message : String(error)}`
152+
};
153+
logMessage(errorMessage, 'incoming');
135154
setConnected(false);
136155
setIsSearching(false);
137156
}
@@ -164,9 +183,17 @@ const LATSAgent = () => {
164183
if (!response.ok) {
165184
throw new Error(`Failed to terminate session: ${response.statusText}`);
166185
}
167-
logMessage(`Session ${sessionId} terminated successfully`);
186+
const terminateMessage = {
187+
type: "server_connection",
188+
info: `Session ${sessionId} terminated successfully`
189+
};
190+
logMessage(terminateMessage, 'incoming');
168191
} catch (error) {
169-
logMessage(`Failed to terminate session: ${error instanceof Error ? error.message : String(error)}`);
192+
const errorMessage = {
193+
type: "server_connection",
194+
info: `Failed to terminate session: ${error instanceof Error ? error.message : String(error)}`
195+
};
196+
logMessage(errorMessage, 'incoming');
170197
}
171198
}
172199

visual-tree-search-app/pages/MCTSAgent.tsx

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ const MCTSAgent = () => {
3333
goal: 'search running shoes, click on the first result',
3434
maxDepth: 3,
3535
set_prior_value: false,
36-
iterations: 1
36+
iterations: 2
3737
});
3838

3939
const [sessionId, setSessionId] = useState<string | null>(null);
@@ -87,7 +87,11 @@ const MCTSAgent = () => {
8787
wsRef.current = new WebSocket(wsUrl);
8888

8989
wsRef.current.onopen = () => {
90-
logMessage('Connected to MCTS WebSocket server');
90+
const connectionMessage = {
91+
type: "server_connection",
92+
info: 'Connecting to MCTS WebSocket server'
93+
};
94+
logMessage(connectionMessage, 'incoming');
9195
setConnected(true);
9296

9397
const request = {
@@ -120,19 +124,31 @@ const MCTSAgent = () => {
120124
};
121125

122126
wsRef.current.onclose = () => {
123-
logMessage('Disconnected from WebSocket server');
127+
const closeMessage = {
128+
type: "server_connection",
129+
info: 'Disconnected from WebSocket server'
130+
};
131+
logMessage(closeMessage, 'incoming');
124132
setConnected(false);
125133
setIsSearching(false);
126134
wsRef.current = null;
127135
};
128136

129137
wsRef.current.onerror = (error) => {
130-
logMessage(`WebSocket error: ${error instanceof Error ? error.message : String(error)}`);
138+
const errorMessage = {
139+
type: "server_connection",
140+
info: `WebSocket error: ${error instanceof Error ? error.message : String(error)}`
141+
};
142+
logMessage(errorMessage, 'incoming');
131143
setConnected(false);
132144
setIsSearching(false);
133145
};
134146
} catch (error) {
135-
logMessage(`Failed to connect: ${error instanceof Error ? error.message : String(error)}`);
147+
const errorMessage = {
148+
type: "server_connection",
149+
info: `Failed to connect: ${error instanceof Error ? error.message : String(error)}`
150+
};
151+
logMessage(errorMessage, 'incoming');
136152
setConnected(false);
137153
setIsSearching(false);
138154
}
@@ -165,9 +181,17 @@ const MCTSAgent = () => {
165181
if (!response.ok) {
166182
throw new Error(`Failed to terminate session: ${response.statusText}`);
167183
}
168-
logMessage(`Session ${sessionId} terminated successfully`);
184+
const terminateMessage = {
185+
type: "server_connection",
186+
info: `Session ${sessionId} terminated successfully`
187+
};
188+
logMessage(terminateMessage, 'incoming');
169189
} catch (error) {
170-
logMessage(`Failed to terminate session: ${error instanceof Error ? error.message : String(error)}`);
190+
const errorMessage = {
191+
type: "server_connection",
192+
info: `Failed to terminate session: ${error instanceof Error ? error.message : String(error)}`
193+
};
194+
logMessage(errorMessage, 'incoming');
171195
}
172196
}
173197

0 commit comments

Comments
 (0)