@@ -7,6 +7,7 @@ use std::sync::Arc;
77use std:: time:: { Duration , Instant } ;
88
99use async_trait:: async_trait;
10+ use cortex_common:: DEFAULT_BATCH_TIMEOUT_SECS ;
1011use futures:: future:: join_all;
1112use serde:: { Deserialize , Serialize } ;
1213use serde_json:: { Value , json} ;
@@ -337,14 +338,30 @@ impl ToolHandler for BatchToolHandler {
337338 // Validate calls
338339 self . validate_calls ( & args. calls ) ?;
339340
341+ // Determine overall batch timeout (wraps around entire parallel execution)
342+ let batch_timeout_secs = args. timeout_secs . unwrap_or ( DEFAULT_BATCH_TIMEOUT_SECS ) ;
343+ let batch_timeout = Duration :: from_secs ( batch_timeout_secs) ;
344+
340345 // Determine per-tool timeout (prevents single tool from blocking others)
341346 let tool_timeout_secs = args. tool_timeout_secs . unwrap_or ( DEFAULT_TOOL_TIMEOUT_SECS ) ;
342347 let tool_timeout = Duration :: from_secs ( tool_timeout_secs) ;
343348
344- // Execute all tools in parallel
345- let batch_result = self
346- . execute_parallel ( args. calls , context, tool_timeout)
347- . await ;
349+ // Execute all tools in parallel with overall batch timeout
350+ let batch_result = match timeout (
351+ batch_timeout,
352+ self . execute_parallel ( args. calls , context, tool_timeout) ,
353+ )
354+ . await
355+ {
356+ Ok ( result) => result,
357+ Err ( _) => {
358+ // Batch-level timeout exceeded
359+ return Ok ( ToolResult :: error ( format ! (
360+ "Batch execution timed out after {}s. Consider using a longer timeout_secs or reducing the number of tools." ,
361+ batch_timeout_secs
362+ ) ) ) ;
363+ }
364+ } ;
348365
349366 // Format output
350367 let output = self . format_result ( & batch_result) ;
@@ -668,4 +685,58 @@ mod tests {
668685 elapsed. as_millis( )
669686 ) ;
670687 }
688+
689+ #[ tokio:: test]
690+ async fn test_batch_timeout ( ) {
691+ // Create an executor with a slow tool
692+ struct SlowExecutor ;
693+
694+ #[ async_trait]
695+ impl BatchToolExecutor for SlowExecutor {
696+ async fn execute_tool (
697+ & self ,
698+ _name : & str ,
699+ _arguments : Value ,
700+ _context : & ToolContext ,
701+ ) -> Result < ToolResult > {
702+ // Sleep longer than batch timeout
703+ tokio:: time:: sleep ( Duration :: from_secs ( 5 ) ) . await ;
704+ Ok ( ToolResult :: success ( "Done" ) )
705+ }
706+
707+ fn has_tool ( & self , _name : & str ) -> bool {
708+ true
709+ }
710+ }
711+
712+ let executor: Arc < dyn BatchToolExecutor > = Arc :: new ( SlowExecutor ) ;
713+ let handler = BatchToolHandler :: new ( executor) ;
714+ let context = ToolContext :: new ( PathBuf :: from ( "." ) ) ;
715+
716+ // Use a very short batch timeout (1 second) to test timeout behavior
717+ let args = json ! ( {
718+ "calls" : [
719+ { "tool" : "SlowTool" , "arguments" : { } }
720+ ] ,
721+ "timeout_secs" : 1
722+ } ) ;
723+
724+ let start = Instant :: now ( ) ;
725+ let result = handler. execute ( args, & context) . await ;
726+ let elapsed = start. elapsed ( ) ;
727+
728+ assert ! ( result. is_ok( ) ) ;
729+ let tool_result = result. unwrap ( ) ;
730+
731+ // Should timeout quickly (around 1 second)
732+ assert ! (
733+ elapsed. as_secs( ) < 3 ,
734+ "Batch should have timed out quickly, but took {}s" ,
735+ elapsed. as_secs( )
736+ ) ;
737+
738+ // Should have timed out
739+ assert ! ( !tool_result. success) ;
740+ assert ! ( tool_result. output. contains( "timed out" ) ) ;
741+ }
671742}
0 commit comments