Skip to content

Commit cc6be5b

Browse files
committed
fix(provider): apply MergeSystemMessages for vLLM provider
vLLM is a built-in OpenAI-response provider in provider.json, but the `MergeSystemMessages` predicate in the openai pipeline only matched NVIDIA. vLLM rejects requests where the system message is not first (per #3128: 'System message must be at the beginning.'), so add vLLM to the same predicate using the existing `from_str` pattern that handles providers without a `ProviderId` constant. Closes #3128
1 parent 3964f1b commit cc6be5b

2 files changed

Lines changed: 62 additions & 4 deletions

File tree

crates/forge_app/src/dto/openai/transformers/ensure_system_first.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ use crate::dto::openai::{Message, MessageContent, Request, Role};
55
/// Merges all system messages into a single system message at the beginning of
66
/// the messages array.
77
///
8-
/// Some providers (e.g. NVIDIA) reject requests with multiple system messages
9-
/// or system messages that are not positioned at the start of the conversation.
8+
/// Some providers (e.g. NVIDIA, vLLM) reject requests with multiple system
9+
/// messages or system messages that are not positioned at the start of the
10+
/// conversation.
1011
pub struct MergeSystemMessages;
1112

1213
impl Transformer for MergeSystemMessages {

crates/forge_app/src/dto/openai/transformers/pipeline.rs

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ impl Transformer for ProviderPipeline<'_> {
8282

8383
let xai_compat = MakeXaiCompat.when(move |_| provider.id == ProviderId::XAI);
8484

85-
let ensure_system_first =
86-
MergeSystemMessages.when(move |_| provider.id == ProviderId::NVIDIA);
85+
let vllm = ProviderId::from_str("vllm").unwrap();
86+
let ensure_system_first = MergeSystemMessages
87+
.when(move |_| provider.id == ProviderId::NVIDIA || provider.id == vllm);
8788

8889
let trim_tool_call_ids = TrimToolCallIds.when(move |_| provider.id == ProviderId::OPENAI);
8990

@@ -171,6 +172,7 @@ mod tests {
171172

172173
use super::*;
173174
use crate::domain::{ModelSource, ProviderResponse};
175+
use crate::dto::openai::{Message, MessageContent, Role};
174176

175177
// Test helper functions
176178
fn make_credential(provider_id: ProviderId, key: &str) -> Option<forge_domain::AuthCredential> {
@@ -247,6 +249,21 @@ mod tests {
247249
}
248250
}
249251

252+
fn vllm(key: &str) -> Provider<Url> {
253+
let id = ProviderId::from_str("vllm").unwrap();
254+
Provider {
255+
id: id.clone(),
256+
provider_type: Default::default(),
257+
response: Some(ProviderResponse::OpenAI),
258+
url: Url::parse("http://localhost:8000/v1/chat/completions").unwrap(),
259+
auth_methods: vec![forge_domain::AuthMethod::ApiKey],
260+
url_params: vec![],
261+
credential: make_credential(id, key),
262+
custom_headers: None,
263+
models: Some(ModelSource::Hardcoded(vec![])),
264+
}
265+
}
266+
250267
fn xai(key: &str) -> Provider<Url> {
251268
Provider {
252269
id: ProviderId::XAI,
@@ -369,6 +386,21 @@ mod tests {
369386
}
370387
}
371388

389+
fn message(role: Role, content: &str) -> Message {
390+
Message {
391+
role,
392+
content: Some(MessageContent::Text(content.to_string())),
393+
name: None,
394+
tool_call_id: None,
395+
tool_calls: None,
396+
reasoning_details: None,
397+
reasoning_text: None,
398+
reasoning_opaque: None,
399+
reasoning_content: None,
400+
extra_content: None,
401+
}
402+
}
403+
372404
#[test]
373405
fn test_supports_open_router_params() {
374406
assert!(supports_open_router_params(&forge("forge")));
@@ -450,6 +482,31 @@ mod tests {
450482
assert_eq!(actual.reasoning, None);
451483
}
452484

485+
#[test]
486+
fn test_vllm_provider_merges_system_messages() {
487+
let provider = vllm("vllm");
488+
let fixture = Request::default().messages(vec![
489+
message(Role::User, "hello"),
490+
message(Role::System, "you are helpful"),
491+
message(Role::Assistant, "hi"),
492+
message(Role::System, "be concise"),
493+
]);
494+
495+
let mut pipeline = ProviderPipeline::new(&provider);
496+
let actual = pipeline.transform(fixture).messages.unwrap();
497+
498+
assert_eq!(actual.len(), 3);
499+
assert_eq!(actual[0].role, Role::System);
500+
assert_eq!(
501+
actual[0].content.as_ref(),
502+
Some(&MessageContent::Text(
503+
"you are helpful\n\nbe concise".to_string()
504+
))
505+
);
506+
assert_eq!(actual[1].role, Role::User);
507+
assert_eq!(actual[2].role, Role::Assistant);
508+
}
509+
453510
#[test]
454511
fn test_openai_provider_trims_tool_call_ids() {
455512
let provider = openai("openai");

0 commit comments

Comments
 (0)