Skip to content

[Feature] Add support for Qwen3 Reranker with Sequence Classifier head #698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions backends/candle/src/models/flash_qwen3.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm};
use crate::models::{Model, Qwen3Config};
use crate::models::{Model, Qwen3Config, Qwen3ClassificationHead};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_rotary::apply_rotary_inplace;
Expand Down Expand Up @@ -288,6 +288,7 @@ pub struct FlashQwen3Model {
cos_cache: Tensor,
sin_cache: Tensor,
pool: Pool,
classification_head: Option<Qwen3ClassificationHead>,
pub device: Device,

span: tracing::Span,
Expand All @@ -304,11 +305,13 @@ impl FlashQwen3Model {
candle::bail!("FlashQwen3 requires DType::F16")
}

let pool = match model_type {
let (pool, classification_head) = match model_type {
ModelType::Classifier => {
candle::bail!("`classifier` model type is not supported for Qwen3")
// Load classification head before the vb is modified
let classification_head = Some(Qwen3ClassificationHead::load(vb.clone(), config)?);
(Pool::Cls, classification_head) // Use CLS pooling for classification
}
ModelType::Embedding(pool) => pool,
ModelType::Embedding(pool) => (pool, None),
};

// The Qwen3-Reranker models contain the `model` key
Expand Down Expand Up @@ -351,6 +354,7 @@ impl FlashQwen3Model {
cos_cache,
sin_cache,
pool,
classification_head,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
Expand Down Expand Up @@ -512,4 +516,16 @@ impl Model for FlashQwen3Model {
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}

fn predict(&self, batch: Batch) -> Result<Tensor> {
match &self.classification_head {
None => candle::bail!("`predict` is not implemented for this model"),
Some(classification_head) => {
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
let pooled_embeddings =
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
classification_head.forward(&pooled_embeddings)
}
}
}
}
71 changes: 68 additions & 3 deletions backends/candle/src/models/qwen3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::models::Model;
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

#[derive(Debug, Clone, PartialEq, Deserialize)]
Expand All @@ -24,6 +25,7 @@ pub struct Qwen3Config {
pub sliding_window: Option<usize>,
pub use_sliding_window: bool,
pub eos_token_id: usize,
pub id2label: Option<HashMap<String, String>>,
}

struct Qwen3Attention {
Expand Down Expand Up @@ -375,13 +377,62 @@ impl Qwen3Layer {
}
}

pub struct Qwen3ClassificationHead {
classifier: Linear,
span: tracing::Span,
}

impl Qwen3ClassificationHead {
pub fn load(vb: VarBuilder, config: &Qwen3Config) -> Result<Self> {
let n_classes = match &config.id2label {
None => candle::bail!("`id2label` must be set for classifier models"),
Some(id2label) => id2label.len(),
};

// Try different common classification head layer names
// The tomaarsen/Qwen3-Reranker models have score.weight at the top level with no bias
let classifier = if let Ok(weight) = vb.get((n_classes, config.hidden_size), "score.weight") {
// No bias for score layer in converted Qwen3 rerankers
Linear::new(weight, None, None)
} else if let (Ok(weight), Ok(bias)) = (
vb.pp("classifier").get((n_classes, config.hidden_size), "weight"),
vb.pp("classifier").get(n_classes, "bias")
) {
Linear::new(weight, Some(bias), None)
} else if let (Ok(weight), Ok(bias)) = (
vb.pp("score").get((n_classes, config.hidden_size), "weight"),
vb.pp("score").get(n_classes, "bias")
) {
Linear::new(weight, Some(bias), None)
} else if let (Ok(weight), Ok(bias)) = (
vb.get((n_classes, config.hidden_size), "classifier.weight"),
vb.get(n_classes, "classifier.bias")
) {
Linear::new(weight, Some(bias), None)
} else {
candle::bail!("Could not find classification head weights. Tried: score.weight, classifier.weight");
};

Ok(Self {
classifier,
span: tracing::span!(tracing::Level::TRACE, "classification_head"),
})
}

pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.classifier.forward(hidden_states)
}
}

pub struct Qwen3Model {
embeddings: Embedding,
layers: Vec<Qwen3Layer>,
norm: RMSNorm,
rotary_cache: (Tensor, Tensor),
rotary_dim: usize,
pool: Pool,
classification_head: Option<Qwen3ClassificationHead>,
num_attention_heads: usize,
pad_token_id: u32,

Expand All @@ -393,11 +444,12 @@ pub struct Qwen3Model {

impl Qwen3Model {
pub fn load(vb: VarBuilder, config: &Qwen3Config, model_type: ModelType) -> Result<Self> {
let pool = match model_type {
let (pool, classification_head) = match model_type {
ModelType::Classifier => {
candle::bail!("`classifier` model type is not supported for Qwen3")
let classification_head = Some(Qwen3ClassificationHead::load(vb.clone(), config)?);
(Pool::LastToken, classification_head)
}
ModelType::Embedding(pool) => pool,
ModelType::Embedding(pool) => (pool, None),
};

// The Qwen3-Reranker models contain the `model` key
Expand Down Expand Up @@ -436,6 +488,7 @@ impl Qwen3Model {
rotary_cache,
rotary_dim,
pool,
classification_head,
pad_token_id: config.eos_token_id as u32,
num_attention_heads: config.num_attention_heads,
dtype: vb.dtype(),
Expand Down Expand Up @@ -700,4 +753,16 @@ impl Model for Qwen3Model {
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}

fn predict(&self, batch: Batch) -> Result<Tensor> {
match &self.classification_head {
None => candle::bail!("`predict` is not implemented for this model"),
Some(classification_head) => {
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
let pooled_embeddings =
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
classification_head.forward(&pooled_embeddings)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
source: backends/candle/tests/test_qwen3.rs
assertion_line: 86
expression: predictions_single
---
- - 2.0719934

41 changes: 39 additions & 2 deletions backends/candle/tests/test_qwen3.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
mod common;

use crate::common::{sort_embeddings, SnapshotEmbeddings};
use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores};
use anyhow::Result;
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer};
use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher};
use text_embeddings_backend_candle::CandleBackend;
use text_embeddings_backend_core::{Backend, ModelType, Pool};

Expand Down Expand Up @@ -50,3 +50,40 @@ fn test_qwen3() -> Result<()> {

Ok(())
}

#[test]
#[serial_test::serial]
fn test_qwen3_reranker() -> Result<()> {
let model_root = download_artifacts("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", None, None)?;
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
&model_root,
"float32".to_string(),
ModelType::Classifier,
None,
)?;

let input_single = batch(
vec![tokenizer
.encode(
"What is Deep Learning?",
true,
)
.unwrap()],
[0].to_vec(),
vec![],
);

let predictions: Vec<Vec<f32>> = backend
.predict(input_single)?
.into_iter()
.map(|(_, v)| v)
.collect();
let predictions_single = SnapshotScores::from(predictions);

let matcher = relative_matcher();
insta::assert_yaml_snapshot!("qwen3_reranker_single", predictions_single, &matcher);

Ok(())
}
1 change: 1 addition & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod download;
pub mod infer;
pub mod queue;
pub mod templates;
pub mod tokenization;

use text_embeddings_backend::BackendError;
Expand Down
110 changes: 110 additions & 0 deletions core/src/templates.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
use std::fmt::Write;

/// Template formatter for models that require structured prompts
pub trait TemplateFormatter {
/// Format a query-document pair for reranking
fn format_rerank(
&self,
query: &str,
document: &str,
instruction: Option<&str>,
) -> String;
}

/// Qwen3 reranker template formatter
pub struct Qwen3RerankerTemplate {
default_instruction: String,
}

impl Qwen3RerankerTemplate {
pub fn new() -> Self {
Self {
default_instruction: "Select only the Documents that are semantically similar to the Query.".to_string(),
}
}
}

impl TemplateFormatter for Qwen3RerankerTemplate {
fn format_rerank(
&self,
query: &str,
document: &str,
instruction: Option<&str>,
) -> String {
let instruction = instruction.unwrap_or(&self.default_instruction);

let mut result = String::with_capacity(512);

// System prompt
result.push_str("<|im_start|>system\n");
result.push_str("Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n");

// User prompt with instruction, query, and document
result.push_str("<|im_start|>user\n");
write!(&mut result, "<Instruct>: {}\n", instruction).unwrap();
write!(&mut result, "<Query>: {}\n", query).unwrap();
write!(&mut result, "<Document>: {}", document).unwrap();
result.push_str("<|im_end|>\n");

// Assistant prompt to trigger reasoning
result.push_str("<|im_start|>assistant\n");
result.push_str("<think>\n\n</think>\n\n");

result
}
}

/// Check if a model requires template formatting
pub fn requires_template(model_name: &str) -> bool {
// Check if this is a Qwen3 sequence classification model
model_name.contains("Qwen3") && model_name.contains("seq-cls")
}

/// Get the appropriate template formatter for a model
pub fn get_template_formatter(model_name: &str) -> Option<Box<dyn TemplateFormatter + Send + Sync>> {
if requires_template(model_name) {
Some(Box::new(Qwen3RerankerTemplate::new()))
} else {
None
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_qwen3_template() {
let template = Qwen3RerankerTemplate::new();
let formatted = template.format_rerank(
"What is Deep Learning?",
"Deep Learning is a branch of machine learning",
None,
);

assert!(formatted.contains("<|im_start|>system"));
assert!(formatted.contains("<Query>: What is Deep Learning?"));
assert!(formatted.contains("<Document>: Deep Learning is a branch of machine learning"));
assert!(formatted.contains("<think>"));
}

#[test]
fn test_custom_instruction() {
let template = Qwen3RerankerTemplate::new();
let formatted = template.format_rerank(
"test query",
"test doc",
Some("Custom instruction"),
);

assert!(formatted.contains("<Instruct>: Custom instruction"));
}

#[test]
fn test_requires_template() {
assert!(requires_template("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"));
assert!(requires_template("Qwen3-Something-seq-cls"));
assert!(!requires_template("BAAI/bge-reranker"));
assert!(!requires_template("Qwen3-Embed"));
}
}
4 changes: 4 additions & 0 deletions proto/tei.proto
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ message RerankRequest {
bool raw_scores = 4;
bool return_text = 5;
TruncationDirection truncation_direction = 6;
optional string instruction = 7;
optional bool use_template = 8;
}

message RerankStreamRequest{
Expand All @@ -163,6 +165,8 @@ message RerankStreamRequest{
// The server will only consider the first value
bool return_text = 5;
TruncationDirection truncation_direction = 6;
optional string instruction = 7;
optional bool use_template = 8;
}

message Rank {
Expand Down
Loading