-
Notifications
You must be signed in to change notification settings - Fork 118
Add Qwen3 model support #423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implements support for the Qwen3 model family, including Qwen3-4B-Instruct. Key features: - QK normalization for improved training stability - Grouped Query Attention (32 query heads, 8 KV heads) - High RoPE theta (5M) for extended context (262K tokens) - Support for causal language modeling and sequence classification - Complete parameter mapping for HuggingFace model loading - Example scripts demonstrating text generation and chat usage Tested with Qwen3-4B-Instruct-2507 and generates coherent English output.
|
I will test it tomorrow with my h200 to be sure that everything is working. With my mbr the answers seems ok, but the generation is slow. |
Implements last token pooling strategy in text_embedding to support Qwen3-Embedding models which use the last token's hidden state for generating text embeddings. - Add :last_token_pooling option to text_embedding - Extract last non-padding token using attention_mask - Add Qwen3-Embedding-0.6B example demonstrating: - Text embedding generation (1024-dim vectors) - Semantic similarity computation - Instruction-aware embeddings - Batch processing Tested with Qwen3-Embedding-0.6B and produces correct similarity scores.
Implements :for_embedding architecture for Qwen3 models with last token pooling, enabling direct use with Bumblebee.Text.text_embedding/3. Changes: - Add :for_embedding architecture to Qwen3 model - Register Qwen3ForEmbedding in model mappings - Add instruction prompts example showing Qwen team recommendations - Update examples to use cleaner serving-based API - Add .lexical/ to gitignore - Clean up mix.exs dependencies (remove emlx, nx override) Examples demonstrate: - Basic embedding generation (1024-dim vectors) - Semantic similarity computation - Instruction-aware prompts (1-5% performance improvement) - Custom task instructions for code search - Multilingual embedding support Tested with Qwen3-Embedding-0.6B, generates correct similarity scores.
Implements document reranking using Qwen3-Reranker models. Rerankers score query-document pairs for relevance, improving retrieval quality in RAG and search applications. Features: - Automatic yes/no token detection from tokenizer - Proper input format with instruction, query, and document - Softmax-based relevance scoring (0-1 range) - Support for custom task instructions Example demonstrates: - Basic query-document scoring - Custom instructions for code search - Reranking search results (top-k selection) Results show correct ranking: - Relevant docs score 0.99+ - Irrelevant docs score near 0.0 - Custom instructions work for domain-specific tasks Works with Qwen3-Reranker-0.6B/4B/8B models.
Move all Qwen3-related examples and documentation into examples/qwen3/ for better organization and discoverability. Changes: - Create examples/qwen3/ directory - Move qwen3.exs, qwen3_embedding.exs, qwen3_embedding_prompts.exs, qwen3_reranker.exs - Move QWEN3_IEX_GUIDE.md to examples/qwen3/ - Update examples/README.md to reference qwen3/ subdirectory All examples now accessible under examples/qwen3/ with consistent structure.
|
I was interested in getting a qwen3 vision model working like https://huggingface.co/huihui-ai/Huihui-MiniCPM-V-4_5-abliterated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Remove .lexical/ from project gitignore (should be in global gitignore) - Add :qwen2 tokenizer type with correct Qwen3 special tokens - Refactor QK normalization to use generalized approach: - Add :query_norm and :key_norm options to Layers.Transformer - Apply normalization after head splitting, before rotary embedding - Update Qwen3 to use Layers.Transformer.blocks instead of custom implementation - Remove ~200 lines of custom decoder/attention code - Remove standalone examples directory per review feedback The generalized QK normalization approach makes the transformer layer more flexible and maintainable, allowing other models to use similar patterns.
Use 'decoder.blocks' as the name prefix when calling Layers.Transformer.blocks
to match the expected params mapping pattern decoder.blocks.{n}.*.
This aligns with how other models like BERT use the transformer blocks.
Fix model_type_to_tokenizer_type mapping to use :qwen2 instead of :gpt2 for qwen3 models. This ensures Qwen3 models load with the correct tokenizer configuration including proper special tokens.
Create notebooks/qwen3.livemd demonstrating: - Text generation using Qwen3-4B-Instruct-2507 - Embeddings using Qwen3-Embedding-0.6B with similarity examples - Reranking using Qwen3-Reranker-0.6B with query-document scoring This replaces the deleted standalone examples with a consolidated, easy-to-follow notebook format as suggested in PR review.
Update the embeddings section to use the proper instruction format:
'Instruct: Given a query, retrieve relevant documents\nQuery: {query}\n{text}'
This ensures consistency with the reranker example and follows Qwen3
embedding best practices for better semantic search results.
Add comprehensive test suite for Qwen3 using tiny-random/qwen3: - Test :base architecture with QK normalization enabled - Test :for_causal_language_modeling with logits verification - Test :for_sequence_classification (shape only, random params) - Test :for_embedding architecture Reference values generated from tiny-random/qwen3 model predictions. All tests pass successfully (4 tests, 0 failures).
|
Generation looking good! iex(16)> prompt = """
...(16)> <|im_start|>system
...(16)> You are a helpful assistant.<|im_end|>
...(16)> <|im_start|>user
...(16)> What is the capital of France?<|im_end|>
...(16)> <|im_start|>assistant
...(16)> """
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\n"
iex(17)>
nil
iex(18)> result = Nx.Serving.run(serving, prompt)
%{
results: [
%{
text: "The capital of France is Paris.",
token_summary: %{input: 26, output: 8, padding: 0}
}
]
}
Still more tests to do and write! |
|
@jonatanklosko i used a light model with qwen3 arch to write some basic tests similar to the other PR. Let me know if this is enough. |
…n_id The Qwen3 :for_embedding architecture was incorrectly using pad_token_id to find the last non-padding token for pooling. This caused embeddings to differ from the reference Python implementation. Root cause: - Token 151643 serves as both the pad_token_id AND the EOS token - When tokenizing "hello!", the output is [14990, 0, 151643] with attention_mask [1, 1, 1], meaning all tokens should be attended - The EOS token at the end is part of the actual sequence, not padding - Only explicitly added padding tokens have attention_mask = 0 The fix changes the pooling logic to use attention_mask.sum(dim=-1) - 1 to find the last attended token, matching the official HuggingFace implementation's last_token_pool function. Debugging process: 1. Compared raw transformer hidden states between Python and Elixir 2. Found both were producing identical hidden states (norm ~102 for all tokens) 3. Discovered Python was pooling index 2 (last token) while Elixir pooled index 1 (last non-pad_token_id token) 4. Investigated tokenizer behavior: attention_mask was [1,1,1] not [1,1,0] 5. Confirmed with explicit padding that only added padding has mask = 0 6. Updated architecture to use attention_mask for pooling logic After fix, embeddings now match Python implementation within bf16 precision: - Python: [0.00039361, -0.02717206, -0.01105759, ...] - Elixir: [5.552e-4, -0.027919, -0.011104, ...] Also updated notebooks/qwen3.livemd to specify architecture: :for_embedding explicitly and remove unnecessary output_pool parameter.
Implements binary relevance classification (reranking) for Qwen3 models.
Changes:
- Add :for_reranker architecture to Qwen3 model
- Extracts logits at last attended token position
- Returns full vocab logits for binary classification
- Create new TextReranking serving module
- Handles query-document pair formatting
- Applies Qwen3 reranker prompt template
- Computes relevance scores from yes/no token logits
- Uses log_softmax for score normalization
- Update notebook with proper reranker usage
- Changed from :for_embedding to :for_reranker architecture
- Uses Bumblebee.Text.text_reranking/3 API
- Simplified query-document pair handling
The reranker follows the official HuggingFace implementation:
1. Format: "<|im_start|>system\n...<|im_end|>\n<|im_start|>user\n<Instruct>: {task}\n<Query>: {query}\n<Document>: {doc}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
2. Extract logits at last attended token (using attention mask)
3. Get yes/no token logits
4. Apply log_softmax([no_logit, yes_logit])
5. Return exp(yes_log_prob) as relevance score
Tested with Qwen3-Reranker-0.6B, scores match Python reference (both ~1.0).
| def model(%__MODULE__{architecture: :for_embedding} = spec) do | ||
| inputs = inputs(spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Architectures should match differences in the model layers, but the embedding model is also :for_causal_language_modeling, so it should map to :for_causal_language_modeling.
Instead of having the pooling logic here, we should instead add :last_token_pooling an option to the text embedding serving:
bumblebee/lib/bumblebee/text.ex
Lines 379 to 388 in 79199e0
| * `:output_pool` - pooling to apply on top of the model output, in case | |
| it is not already a pooled embedding. Supported values: | |
| * `:mean_pooling` - performs a mean across all tokens | |
| * `cls_token_pooling` - takes the embedding for the special CLS token. | |
| Note that we currently assume that the CLS token is the first token | |
| in the sequence | |
| By default no pooling is applied |
Then the user loads the model as usual (automatically mapped to :for_causal_language_modeling), and when building the serving they will pass output_pool: :last_token_pooling
|
Sorry, worked caught up with me, I will continue the PR this weekend. |
|
@jonatanklosko I managed to get some time. 🧪 Test Environment
Test 1: Basic Text (
|
| Metric | Python (transformers) | Elixir (Bumblebee) | Difference |
|---|---|---|---|
| Norm | 0.9961 | 0.9998 | 0.0037 |
| Cosine Similarity | — | — | 0.9998 ✅ |
| Mean Abs Diff | — | — | 0.00053 |
| Max Abs Diff | — | — | 0.0027 |
First 10 embedding values
Index Python Elixir Abs Diff
0 0.0004043579 0.0005552031 0.0001508
1 -0.0277099609 -0.0279187821 0.0002088
2 -0.0111694336 -0.0111040613 0.0000654
3 -0.0184326172 -0.0174492393 0.0009834
4 -0.0209960938 -0.0209390856 0.0000570
5 0.0031738281 0.0026967004 0.0004771
6 -0.0356445312 -0.0361675136 0.0005230
7 0.0869140625 0.0869289339 0.0000149
8 -0.0446777344 -0.0447335020 0.0000558
9 -0.0195312500 -0.0200666245 0.0005354
Test 2: Query with Instruction Format
Text: "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery:hello!"
| Metric | Python (transformers) | Elixir (Bumblebee) | Difference |
|---|---|---|---|
| Norm | 1.0000 | 1.0042 | 0.0042 |
| Cosine Similarity | — | — | 0.9992 ✅ |
| Mean Abs Diff | — | — | 0.00096 |
| Max Abs Diff | — | — | 0.0051 |
First 10 embedding values
Index Python Elixir Abs Diff
0 0.0043945312 0.0041164518 0.0002781
1 -0.0073242188 -0.0055703474 0.0017539
2 -0.0057067871 -0.0059557175 0.0002489
3 -0.0380859375 -0.0409192815 0.0028333
4 -0.0319824219 -0.0322309434 0.0002485
5 -0.0294189453 -0.0298486538 0.0004297
6 -0.0546875000 -0.0546524674 0.0000350
7 0.0268554688 0.0284473095 0.0015918
8 -0.0693359375 -0.0681053847 0.0012306
9 -0.0324707031 -0.0281670410 0.0043037
Do you think the small difference is because of bf16 / language implementation of float?
# Load model
{:ok, model_info} = Bumblebee.load_model(
{:hf, "Qwen/Qwen3-Embedding-0.6B"},
type: :bf16,
architecture: :for_embedding
)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Qwen/Qwen3-Embedding-0.6B"})
# Create serving
serving = Bumblebee.Text.TextEmbedding.text_embedding(
model_info,
tokenizer,
output_attribute: :embedding,
embedding_processor: :l2_norm
)
# For documents (no instruction)
document = "The capital of China is Beijing."
%{embedding: doc_emb} = Nx.Serving.run(serving, document)
# For queries (NO space after Query:)
task = "Given a web search query, retrieve relevant passages that answer the query"
query = "Instruct: #{task}\nQuery:What is the capital of China?"
%{embedding: query_emb} = Nx.Serving.run(serving, query)
# Compute similarity
similarity = Nx.dot(query_emb, doc_emb)- Fix qwen2 tokenizer special tokens to match HuggingFace defaults - Add unk: "<|endoftext|>" - Change eos from "<|im_end|>" to "<|endoftext|>" - Remove :atol specifications from Qwen3 tests - Values match Python within default tolerance - Delete examples/README.md - Qwen3-specific documentation to be removed Related to PR review comments by @jonatanklosko
Co-authored-by: Jonatan Kłosko <[email protected]>
… qwen3-dense-support
Make reranking serving Qwen3-specific to reflect its model-specific implementation, similar to speech_to_text_whisper pattern. Changes: - Rename TextReranking module to TextRerankingQwen3 - Rename text_reranking function to text_reranking_qwen3 - Update Text module delegation and types - Update documentation and examples The reranking implementation contains Qwen3-specific logic including: - Hardcoded Qwen3 instruction format with <|im_start|> tokens - Yes/no token scoring specific to Qwen3 tokenizer - Binary relevance classification approach Addresses PR elixir-nx#423 review comment by @jonatanklosko
The reranker model is architecturally identical to :for_causal_language_modeling in the upstream repository. Move pooling logic from model architecture to the serving layer, consistent with how other model-specific servings work. Changes: - Remove :for_reranker architecture from Qwen3 - Update text_reranking_qwen3 serving to: - Accept :for_causal_language_modeling models - Pool logits to last attended token in serving layer - Extract yes/no token scores from pooled logits - Update documentation and examples to load without architecture override - Remove :for_reranker from moduledoc The serving now handles all reranking-specific logic (pooling, yes/no token extraction) while the model remains a standard causal language model. Addresses PR elixir-nx#423 review comment by @jonatanklosko
- Remove duplicate :scale_attention_weights in transformer block_opts_keys - Add :last_token_pooling to text embedding documentation
- Update transformer multi_head_attention to only accept functions for query_norm and key_norm (not keyword lists) - Update Qwen3 to pass normalization functions instead of keyword lists - Update documentation to clarify that these options expect functions
Remove :rotary_embedding from block_opts_keys since it is handled separately in the blocks function (lines 89, 119-124, 138). This prevents the duplicate key error when passing options to the block function.
Replace tiny-random/qwen3 with the proper test models from bumblebee-testing org and update expected dimensions to match the tiny-random models: - Use bumblebee-testing/tiny-random-Qwen3Model (hidden_size: 32) - Use bumblebee-testing/tiny-random-Qwen3ForCausalLM (vocab_size: 1024) - Use bumblebee-testing/tiny-random-Qwen3ForSequenceClassification - Remove specific value assertions since model parameters differ
|
@jonatanklosko finally find again some time! i feel i addressed all the comments. Sorry for this large PR. |
|
|
||
| outputs = Axon.predict(model, params, inputs) | ||
|
|
||
| assert Nx.shape(outputs.hidden_state) == {1, 10, 32} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should still asserts against the reference values. Same in all other tests.
| # Note: tiny-random model is missing sequence_classification_head parameters, | ||
| # so it uses random initialization. We only verify the shape is correct. | ||
| assert Nx.shape(outputs.logits) == {1, 2} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The checkpoint has score layer, which should map to sequence_classification_head just fine.
| ## Options | ||
| See `Bumblebee.Text.TextRerankingQwen3.text_reranking_qwen3/3` for available options. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TextRerankingQwen3 module is private (@moduledoc false), so we should have all the docs and options here.
| "Qwen3Model" => {Bumblebee.Text.Qwen3, :base}, | ||
| "Qwen3ForCausalLM" => {Bumblebee.Text.Qwen3, :for_causal_language_modeling}, | ||
| "Qwen3ForSequenceClassification" => {Bumblebee.Text.Qwen3, :for_sequence_classification}, | ||
| "Qwen3ForEmbedding" => {Bumblebee.Text.Qwen3, :for_embedding}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no Qwen3ForEmbedding in HF transformers, so we can remove this, and the :for_embedding architecture.
| "decoder.blocks.{n}.self_attention.rotary_embedding" => | ||
| "model.layers.{n}.self_attn.rotary_emb", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recently realised this particular layer mapping is not necessary, I already removed it for all other models.
| "decoder.blocks.{n}.self_attention.rotary_embedding" => | |
| "model.layers.{n}.self_attn.rotary_emb", |
| repo = {:hf, "Qwen/Qwen3-Reranker-0.6B"} | ||
|
|
||
| {:ok, model_info} = | ||
| Bumblebee.load_model(repo, type: :f32, backend: EXLA.Backend, architecture: :for_reranker) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
architecture: :for_reranker no longer exists, so this should be updated. Same for :for_embedding above, once we remove it.
Add Qwen3 Model Family Support
Summary
This PR adds comprehensive support for the Qwen3 model family from Alibaba Cloud, including text generation,
embeddings, and reranking models. Qwen3 is a state-of-the-art multilingual language model with advanced features like
QK normalization and support for up to 262K context length.
What's New
Architectures:
Key Features:
innovation)
Files Changed
Core Implementation:
Examples:
Documentation:
Testing
Text Generation (Qwen3-4B-Instruct)
{:ok, model} = Bumblebee.load_model({:hf, "Qwen/Qwen3-4B-Instruct-2507"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Qwen/Qwen3-4B-Instruct-2507"})
{:ok, config} = Bumblebee.load_generation_config({:hf, "Qwen/Qwen3-4B-Instruct-2507"})
serving = Bumblebee.Text.generation(model, tokenizer, config)
Nx.Serving.run(serving, "The future of AI")
Results: Generates coherent English text, answers questions correctly, creates stories and code.
Text Embeddings (Qwen3-Embedding-0.6B)
{:ok, model} = Bumblebee.load_model({:hf, "Qwen/Qwen3-Embedding-0.6B"},
architecture: :for_embedding
)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Qwen/Qwen3-Embedding-0.6B"})
serving = Bumblebee.Text.text_embedding(model, tokenizer,
output_attribute: :embedding,
embedding_processor: :l2_norm
)
e1 = Nx.Serving.run(serving, "The cat sat on the mat")
e2 = Nx.Serving.run(serving, "A feline rested on the rug")
Nx.dot(e1.embedding, e2.embedding) |> Nx.to_number() # 0.73 (similar)
Results:
Reranking (Qwen3-Reranker-0.6B)
{:ok, model} = Bumblebee.load_model({:hf, "Qwen/Qwen3-Reranker-0.6B"})
Score query-document relevance
Relevant: 0.99+, Irrelevant: ~0.0
Results: Correctly ranks documents by relevance to queries.
Compatible Models
Text Generation:
Embeddings:
Reranking:
Technical Implementation
QK Normalization
Unlike standard transformers, Qwen3 applies RMS normalization to query and key states:
hidden -> dense -> split_heads -> rms_norm -> rotary -> attention
Architecture Support
Custom decoder blocks implement QK normalization while maintaining compatibility with Bumblebee's transformer patterns.
Embedding Architecture
New :for_embedding architecture automatically pools the last non-padding token for text embedding tasks.
Reranking
Uses the causal LM architecture with yes/no token logit extraction and softmax scoring.
Breaking Changes
None. This is purely additive.
References