diff --git a/src/encoding.rs b/src/encoding.rs index 60257e7..6af13f9 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -1329,7 +1329,7 @@ impl StreamableParser { /// If `parse_recipient_and_type` is true, tries to parse recipient and content_type from /// whitespace-separated tokens (normal header parsing). If false, treats all remaining /// text after extracting channel as content (for malformed messages). - fn parse_header_from_string( + pub(crate) fn parse_header_from_string( &self, mut header_string: String, role: Option, @@ -1352,6 +1352,19 @@ impl StreamableParser { new_header.push_str(&header_string[..idx]); new_header.push_str(&after_marker[channel_end..]); header_string = new_header; + + // Trim extraneous channel markers, which are sometimes emittted + // with smaller models in multi-turn conversations + while let Some(extra_idx) = header_string.find(channel_marker) { + let after = &header_string[extra_idx + channel_marker.len()..]; + let end = after + .find(|c: char| c.is_whitespace() || c == '<') + .unwrap_or(after.len()); + let mut cleaned = String::new(); + cleaned.push_str(&header_string[..extra_idx]); + cleaned.push_str(&after[end..]); + header_string = cleaned; + } } } diff --git a/src/tests.rs b/src/tests.rs index 7aba934..a518598 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -652,6 +652,10 @@ fn assert_tokens_eq(tokenizer: &CoreBPE, expected: &[Rank], actual: &[Rank]) { } } +fn parsed_header_to_json(header: &crate::encoding::ParsedHeader) -> serde_json::Value { + serde_json::to_value(header).expect("ParsedHeader should be JSON-serializable") +} + #[test] fn test_streamable_parser_tool_call_with_constrain_adjacent() { let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); @@ -674,6 +678,131 @@ fn test_streamable_parser_tool_call_with_constrain_adjacent() { ); } +#[test] +fn test_parse_header_from_string_extracts_channel_recipient_content_type() { + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + let parser = StreamableParser::new(encoding, None).unwrap(); + let header_string = + "assistant<|channel|>commentary to=functions.get_weather <|constrain|>json".to_string(); + + let (header, remaining) = parser + .parse_header_from_string(header_string, None, true) + .unwrap(); + + assert_eq!( + parsed_header_to_json(&header), + json!({ + "author": { "role": "assistant" }, + "recipient": "functions.get_weather", + "channel": "commentary", + "content_type": "<|constrain|>json", + }) + ); + assert_eq!(remaining, None); +} + +#[test] +fn test_parse_header_from_string_extracts_channel_recipient_content_type_extra_channel() { + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + let parser = StreamableParser::new(encoding, None).unwrap(); + let header_string = + "assistant<|channel|>commentary to=functions.get_weather<|channel|>commentary <|constrain|>json".to_string(); + + let (header, remaining) = parser + .parse_header_from_string(header_string, None, true) + .unwrap(); + + assert_eq!( + parsed_header_to_json(&header), + json!({ + "author": { "role": "assistant" }, + "recipient": "functions.get_weather", + "channel": "commentary", + "content_type": "<|constrain|>json", + }) + ); + assert_eq!(remaining, None); +} + +#[test] +fn test_parse_header_from_string_extracts_channel_recipient_content_type_extra_channels() { + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + let parser = StreamableParser::new(encoding, None).unwrap(); + let header_string = + "assistant<|channel|>commentary to=functions.get_weather<|channel|>analysis <|channel|>commentary <|channel|>final".to_string(); + + let (header, remaining) = parser + .parse_header_from_string(header_string, None, true) + .unwrap(); + + assert_eq!( + parsed_header_to_json(&header), + json!({ + "author": { "role": "assistant" }, + "recipient": "functions.get_weather", + "channel": "commentary", + "content_type": null, + }) + ); + assert_eq!(remaining, None); +} + +#[test] +fn test_parse_header_from_string_channel_marker_without_value_errors() { + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + let parser = StreamableParser::new(encoding, None).unwrap(); + let header_string = "assistant<|channel|> to=foo".to_string(); + + let result = parser.parse_header_from_string(header_string, None, true); + + assert!(result.is_err()); +} + +#[test] +fn test_parse_header_from_string_extra_channel_no_recipient() { + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + let parser = StreamableParser::new(encoding, None).unwrap(); + let header_string = "assistant<|channel|>commentary<|channel|>analysis".to_string(); + + let (header, remaining) = parser + .parse_header_from_string(header_string, None, true) + .unwrap(); + + assert_eq!( + parsed_header_to_json(&header), + json!({ + "author": { "role": "assistant" }, + "recipient": null, + "channel": "commentary", + "content_type": null, + }) + ); + assert_eq!(remaining, None); +} + +#[test] +fn test_parse_header_from_string_extra_channel_adjacent_to_constrain() { + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + let parser = StreamableParser::new(encoding, None).unwrap(); + let header_string = + "assistant<|channel|>commentary to=foo<|channel|>analysis<|constrain|>json".to_string(); + + let (header, remaining) = parser + .parse_header_from_string(header_string, None, true) + .unwrap(); + + assert_eq!( + parsed_header_to_json(&header), + json!({ + "author": { "role": "assistant" }, + "recipient": "foo", + "channel": "commentary", + "content_type": "<|constrain|>json", + }) + ); + assert_eq!(remaining, None); +} + #[test] fn test_missing_message_token_requires_non_strict_mode() { let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();