1
1
use std:: collections:: {
2
2
HashMap ,
3
+ HashSet ,
3
4
VecDeque ,
4
5
} ;
5
6
use std:: sync:: Arc ;
@@ -32,7 +33,6 @@ use super::hooks::{
32
33
} ;
33
34
use super :: message:: {
34
35
AssistantMessage ,
35
- AssistantToolUse ,
36
36
ToolUseResult ,
37
37
ToolUseResultBlock ,
38
38
UserMessage ,
@@ -60,7 +60,6 @@ use crate::api_client::model::{
60
60
ToolInputSchema ,
61
61
ToolResult ,
62
62
ToolResultContentBlock ,
63
- ToolResultStatus ,
64
63
ToolSpecification ,
65
64
ToolUse ,
66
65
UserInputMessage ,
@@ -347,7 +346,7 @@ impl ConversationState {
347
346
}
348
347
}
349
348
350
- self . enforce_tool_use_history_invariants ( true ) ;
349
+ self . enforce_tool_use_history_invariants ( ) ;
351
350
}
352
351
353
352
/// Here we also need to make sure that the tool result corresponds to one of the tools
@@ -362,105 +361,52 @@ impl ConversationState {
362
361
/// intervention here is to substitute the ambiguous, partial name with a dummy.
363
362
/// 3. The model had decided to call a tool that does not exist. The intervention here is to
364
363
/// substitute the non-existent tool name with a dummy.
365
- pub fn enforce_tool_use_history_invariants ( & mut self , last_only : bool ) {
366
- let tool_name_list = self . tool_manager . tn_map . keys ( ) . map ( String :: as_str) . collect :: < Vec < _ > > ( ) ;
367
- // We need to first determine what the range of interest is. There are two places where we
368
- // would call this function:
369
- // 1. When there are changes to the list of available tools, in which case we comb through the
370
- // entire conversation
371
- // 2. When we send a message, in which case we only examine the most recent entry
372
- let ( tool_use_results, mut tool_uses) = if last_only {
373
- if let ( Some ( ( _, AssistantMessage :: ToolUse { ref mut tool_uses, .. } ) ) , Some ( user_msg) ) = (
374
- self . history
375
- . range_mut ( self . valid_history_range . 0 ..self . valid_history_range . 1 )
376
- . last ( ) ,
377
- & mut self . next_message ,
378
- ) {
379
- let tool_use_results = user_msg
380
- . tool_use_results ( )
381
- . map_or ( Vec :: new ( ) , |results| results. iter ( ) . collect :: < Vec < _ > > ( ) ) ;
382
- let tool_uses = tool_uses. iter_mut ( ) . collect :: < Vec < _ > > ( ) ;
383
- ( tool_use_results, tool_uses)
384
- } else {
385
- ( Vec :: new ( ) , Vec :: new ( ) )
386
- }
387
- } else {
388
- let tool_use_results = self . next_message . as_ref ( ) . map_or ( Vec :: new ( ) , |user_msg| {
389
- user_msg
390
- . tool_use_results ( )
391
- . map_or ( Vec :: new ( ) , |results| results. iter ( ) . collect :: < Vec < _ > > ( ) )
392
- } ) ;
393
- self . history
394
- . iter_mut ( )
395
- . filter_map ( |( user_msg, asst_msg) | {
396
- if let ( Some ( tool_use_results) , AssistantMessage :: ToolUse { ref mut tool_uses, .. } ) =
397
- ( user_msg. tool_use_results ( ) , asst_msg)
398
- {
399
- Some ( ( tool_use_results, tool_uses) )
400
- } else {
401
- None
402
- }
364
+ pub fn enforce_tool_use_history_invariants ( & mut self ) {
365
+ let tool_names: HashSet < _ > = self
366
+ . tools
367
+ . values ( )
368
+ . flat_map ( |tools| {
369
+ tools. iter ( ) . map ( |tool| match tool {
370
+ Tool :: ToolSpecification ( tool_specification) => tool_specification. name . as_str ( ) ,
403
371
} )
404
- . fold (
405
- ( tool_use_results, Vec :: < & mut AssistantToolUse > :: new ( ) ) ,
406
- |( mut tool_use_results, mut tool_uses) , ( results, uses) | {
407
- let mut results = results. iter ( ) . collect :: < Vec < _ > > ( ) ;
408
- let mut uses = uses. iter_mut ( ) . collect :: < Vec < _ > > ( ) ;
409
- tool_use_results. append ( & mut results) ;
410
- tool_uses. append ( & mut uses) ;
411
- ( tool_use_results, tool_uses)
412
- } ,
413
- )
414
- } ;
372
+ } )
373
+ . filter ( |name| * name != DUMMY_TOOL_NAME )
374
+ . collect ( ) ;
375
+
376
+ for ( _, assistant) in & mut self . history {
377
+ if let AssistantMessage :: ToolUse { ref mut tool_uses, .. } = assistant {
378
+ for tool_use in tool_uses {
379
+ if tool_names. contains ( tool_use. name . as_str ( ) ) {
380
+ continue ;
381
+ }
415
382
416
- // Replace tool uses associated with tools that does not exist / no longer exists with
417
- // dummy (i.e. put them to sleep / dormant)
418
- for result in tool_use_results {
419
- let tool_use_id = result. tool_use_id . as_str ( ) ;
420
- let corresponding_tool_use = tool_uses. iter_mut ( ) . find ( |tool_use| tool_use_id == tool_use. id ) ;
421
- if let Some ( tool_use) = corresponding_tool_use {
422
- if tool_name_list. contains ( & tool_use. name . as_str ( ) ) {
423
- // If this tool matches of the tools in our list, this is not our
424
- // concern, error or not.
425
- continue ;
426
- }
427
- if let ToolResultStatus :: Error = result. status {
428
- // case 2 and 3
429
- tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
430
- tool_use. args = serde_json:: json!( { } ) ;
431
- } else {
432
- // case 1
433
- let full_name = tool_name_list. iter ( ) . find ( |name| name. ends_with ( & tool_use. name ) ) ;
434
- // We should be able to find a match but if not we'll just treat it as
435
- // a dummy and move on
436
- if let Some ( full_name) = full_name {
437
- tool_use. name = ( * full_name) . to_string ( ) ;
438
- } else {
439
- tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
440
- tool_use. args = serde_json:: json!( { } ) ;
383
+ if tool_names. contains ( tool_use. orig_name . as_str ( ) ) {
384
+ tool_use. name = tool_use. orig_name . clone ( ) ;
385
+ tool_use. args = tool_use. orig_args . clone ( ) ;
386
+ continue ;
441
387
}
442
- }
443
- }
444
- }
445
388
446
- // Revive tools that were previously dormant if they now corresponds to one of the tools in
447
- // our list of available tools. Note that this check only works because tn_map does NOT
448
- // contain names of native tools.
449
- for tool_use in tool_uses {
450
- if tool_use. name == DUMMY_TOOL_NAME
451
- && tool_use
452
- . orig_name
453
- . as_ref ( )
454
- . is_some_and ( |name| tool_name_list. contains ( & ( * name) . as_str ( ) ) )
455
- {
456
- tool_use. name = tool_use
457
- . orig_name
458
- . as_ref ( )
459
- . map_or ( DUMMY_TOOL_NAME . to_string ( ) , |name| name. clone ( ) ) ;
460
- tool_use. args = tool_use
461
- . orig_args
462
- . as_ref ( )
463
- . map_or ( serde_json:: json!( { } ) , |args| args. clone ( ) ) ;
389
+ let names: Vec < & str > = tool_names
390
+ . iter ( )
391
+ . filter_map ( |name| {
392
+ if name. ends_with ( & tool_use. name ) {
393
+ Some ( * name)
394
+ } else {
395
+ None
396
+ }
397
+ } )
398
+ . collect ( ) ;
399
+
400
+ // There's only one tool use matching, so we can just replace it with the
401
+ // found name.
402
+ if names. len ( ) == 1 {
403
+ tool_use. name = ( * names. first ( ) . unwrap ( ) ) . to_string ( ) ;
404
+ continue ;
405
+ }
406
+
407
+ // Otherwise, we have to replace it with a dummy.
408
+ tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
409
+ }
464
410
}
465
411
}
466
412
}
@@ -514,8 +460,8 @@ impl ConversationState {
514
460
. expect ( "unable to construct conversation state" )
515
461
}
516
462
517
- pub async fn update_state ( & mut self ) {
518
- let needs_update = self . tool_manager . has_new_stuff . load ( Ordering :: Acquire ) ;
463
+ pub async fn update_state ( & mut self , force_update : bool ) {
464
+ let needs_update = self . tool_manager . has_new_stuff . load ( Ordering :: Acquire ) || force_update ;
519
465
if !needs_update {
520
466
return ;
521
467
}
@@ -540,12 +486,13 @@ impl ConversationState {
540
486
// We call this in [Self::enforce_conversation_invariants] as well. But we need to call it
541
487
// here as well because when it's being called in [Self::enforce_conversation_invariants]
542
488
// it is only checking the last entry.
543
- self . enforce_tool_use_history_invariants ( false ) ;
489
+ self . enforce_tool_use_history_invariants ( ) ;
544
490
}
545
491
546
492
/// Returns a conversation state representation which reflects the exact conversation to send
547
493
/// back to the model.
548
494
pub async fn backend_conversation_state ( & mut self , run_hooks : bool , quiet : bool ) -> BackendConversationState < ' _ > {
495
+ self . update_state ( false ) . await ;
549
496
self . enforce_conversation_invariants ( ) ;
550
497
551
498
// Run hooks and add to conversation start and next user message.
0 commit comments