diff --git a/src/encoding.rs b/src/encoding.rs index 60257e7..ac40fac 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -192,28 +192,20 @@ impl HarmonyEncoding { let render_options = RenderOptions { conversation_has_function_tools: has_function_tools, }; - let last_assistant_is_final = messages - .iter() - .rev() - .find_map(|msg| { - (msg.author.role == Role::Assistant) - .then(|| msg.channel.as_deref() == Some("final")) - }) - .unwrap_or(false); let should_drop_analysis = - config.is_some_and(|c| c.auto_drop_analysis && last_assistant_is_final); + config.is_some_and(|c| c.auto_drop_analysis); - let first_final_idx = messages + let last_final_idx = messages .iter() - .position(|msg| msg.channel.as_deref() == Some("final")); + .rposition(|msg| msg.channel.as_deref() == Some("final")); let result = messages .iter() .enumerate() .filter(|(idx, msg)| { !(should_drop_analysis - && first_final_idx.is_some_and(|first| *idx < first) + && last_final_idx.is_some_and(|last| *idx < last) && msg.channel.as_deref() == Some("analysis")) }) .try_for_each(|(_, msg)| self.render_into(msg, into, Some(&render_options))); diff --git a/src/tests.rs b/src/tests.rs index 7aba934..8115d5b 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -886,3 +886,74 @@ fn test_parse_completion_with_invalid_content_token_errors_on_eos() { .with_channel("analysis"); assert_eq!(parsed_message, &expected_message); } + +#[test] +fn test_multi_turn_auto_drop_analysis() { + let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap(); + let expected_output = load_test_data("../test-data/test_multi_turn_auto_drop_analysis.txt"); + + let convo = Conversation::from_messages([ + Message::from_role_and_content( + Role::Developer, + DeveloperContent::new().with_instructions( + "You are a helpful assistant that analyzes code and provides detailed feedback.", + ), + ), + Message::from_role_and_content( + Role::User, + "Can you help me optimize this Python function?\n\n\ + def fibonacci(n):\n\ + if n <= 1:\n\ + return n\n\ + return fibonacci(n-1) + fibonacci(n-2)", + ), + // Turn 1: analysis + final + Message::from_role_and_content( + Role::Assistant, + "The user provided a recursive Fibonacci implementation. O(2^n) complexity.", + ) + .with_channel("analysis"), + Message::from_role_and_content( + Role::Assistant, + "This recursive Fibonacci has exponential time complexity. \ + Would you like me to show optimized versions?", + ) + .with_channel("final"), + Message::from_role_and_content(Role::User, "Yes, and benchmark them"), + // Turn 2: analysis + final + Message::from_role_and_content( + Role::Assistant, + "User wants benchmarks. I should run some Python code to compare performance.", + ) + .with_channel("analysis"), + Message::from_role_and_content(Role::Assistant, "I'll benchmark both versions for you.") + .with_channel("final"), + Message::from_role_and_content(Role::User, "Run the benchmark for n=30"), + // Turn 3: analysis + tool call (no final after) + Message::from_role_and_content( + Role::Assistant, + "I need to execute Python code to run the benchmark for n=30.", + ) + .with_channel("analysis"), + Message::from_role_and_content( + Role::Assistant, + r#"{"code": "import timeit\n\ndef fib_recursive(n):\n if n <= 1: return n\n return fib_recursive(n-1) + fib_recursive(n-2)\n\ndef fib_iter(n):\n if n <= 1: return n\n a, b = 0, 1\n for _ in range(2, n+1): a, b = b, a+b\n return b\n\nprint(timeit.timeit(lambda: fib_recursive(30), number=1))\nprint(timeit.timeit(lambda: fib_iter(30), number=1000))"}"#, + ) + .with_channel("commentary") + .with_recipient("functions.python") + .with_content_type("json"), + ]); + + let tokens = encoding + .render_conversation_for_completion( + &convo, + Role::Assistant, + Some(&crate::encoding::RenderConversationConfig { + auto_drop_analysis: true, + }), + ) + .unwrap(); + + let decoded = encoding.tokenizer.decode_utf8(&tokens).unwrap(); + assert_eq!(decoded, expected_output); +} \ No newline at end of file diff --git a/test-data/test_multi_turn_auto_drop_analysis.txt b/test-data/test_multi_turn_auto_drop_analysis.txt new file mode 100644 index 0000000..79b62c2 --- /dev/null +++ b/test-data/test_multi_turn_auto_drop_analysis.txt @@ -0,0 +1,6 @@ +<|start|>developer<|message|>You are a helpful assistant that analyzes code and provides detailed feedback.<|end|><|start|>user<|message|>Can you help me optimize this Python function? + +def fibonacci(n): + if n <= 1: + return n + return fibonacci(n-1) + fibonacci(n-2)<|end|><|start|>assistant<|channel|>final<|message|>This recursive Fibonacci has exponential time complexity. Would you like me to show optimized versions?<|end|><|start|>user<|message|>Yes, and benchmark them<|end|><|start|>assistant<|channel|>final<|message|>I'll benchmark both versions for you.<|end|><|start|>user<|message|>Run the benchmark for n=30<|end|><|start|>assistant<|channel|>analysis<|message|>I need to execute Python code to run the benchmark for n=30.<|end|><|start|>assistant to=functions.python<|channel|>commentary<|message|>{"code": "import timeit\n\ndef fib_recursive(n):\n if n <= 1: return n\n return fib_recursive(n-1) + fib_recursive(n-2)\n\ndef fib_iter(n):\n if n <= 1: return n\n a, b = 0, 1\n for _ in range(2, n+1): a, b = b, a+b\n return b\n\nprint(timeit.timeit(lambda: fib_recursive(30), number=1))\nprint(timeit.timeit(lambda: fib_iter(30), number=1000))"}<|call|><|start|>assistant diff --git a/tests/test_harmony.py b/tests/test_harmony.py index dbb9925..84ddebb 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -1244,3 +1244,77 @@ def test_streamable_parser_tricky_utf8_decoding(): # Ensure if we're accumulating content deltas we still get the full utf-8 text assert "".join(content_deltas) == tricky_utf8_text + + +def test_multi_turn_auto_drop_analysis(): + """ + In multi-turn conversations with auto_drop_analysis=True, + all analysis messages before the last final message should be dropped. + + This test ensures that we use last_final_idx instead of first_final_idx + when determining which analysis messages to drop. + """ + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + expected_output = ( + (ROOT_DIR / "test-data" / "test_multi_turn_auto_drop_analysis.txt") + .read_text(encoding="utf-8") + .rstrip() + ) + + convo = Conversation.from_messages( + [ + Message.from_role_and_content( + Role.DEVELOPER, + DeveloperContent.new().with_instructions( + "You are a helpful assistant that analyzes code and provides detailed feedback." + ), + ), + Message.from_role_and_content( + Role.USER, + "Can you help me optimize this Python function?\n\n" + "def fibonacci(n):\n" + " if n <= 1:\n" + " return n\n" + " return fibonacci(n-1) + fibonacci(n-2)", + ), + # Turn 1: analysis + final + Message.from_role_and_content( + Role.ASSISTANT, + "The user provided a recursive Fibonacci implementation. O(2^n) complexity.", + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, + "This recursive Fibonacci has exponential time complexity. " + "Would you like me to show optimized versions?", + ).with_channel("final"), + Message.from_role_and_content(Role.USER, "Yes, and benchmark them"), + # Turn 2: analysis + final + Message.from_role_and_content( + Role.ASSISTANT, + "User wants benchmarks. I should run some Python code to compare performance.", + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, "I'll benchmark both versions for you." + ).with_channel("final"), + Message.from_role_and_content(Role.USER, "Run the benchmark for n=30"), + # Turn 3: analysis + tool call (no final after) + Message.from_role_and_content( + Role.ASSISTANT, + "I need to execute Python code to run the benchmark for n=30.", + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, + '{"code": "import timeit\\n\\ndef fib_recursive(n):\\n if n <= 1: return n\\n return fib_recursive(n-1) + fib_recursive(n-2)\\n\\ndef fib_iter(n):\\n if n <= 1: return n\\n a, b = 0, 1\\n for _ in range(2, n+1): a, b = b, a+b\\n return b\\n\\nprint(timeit.timeit(lambda: fib_recursive(30), number=1))\\nprint(timeit.timeit(lambda: fib_iter(30), number=1000))"}', + ) + .with_channel("commentary") + .with_recipient("functions.python") + .with_content_type("json"), + ] + ) + + tokens = encoding.render_conversation_for_completion( + convo, Role.ASSISTANT, RenderConversationConfig(auto_drop_analysis=True) + ) + + assert encoding.decode_utf8(tokens) == expected_output