Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8adaf2e
Add Qwen3 model support
nyo16 Oct 5, 2025
0499d71
Add last token pooling support for Qwen3-Embedding models
nyo16 Oct 5, 2025
1d92e9e
Add Qwen3 embedding architecture and instruction prompts support
nyo16 Oct 5, 2025
47c337d
Add .lexical/ to gitignore and IEx usage guide
nyo16 Oct 5, 2025
6f68d8f
mix format and rebuilding lock
nyo16 Oct 5, 2025
5641a4f
Add Qwen3-Reranker support and example
nyo16 Oct 5, 2025
fa592c3
Organize Qwen3 examples into dedicated folder
nyo16 Oct 5, 2025
8208efd
Address PR review feedback for Qwen3 support
Oct 6, 2025
1f24cc6
Fix Qwen3 layer naming for Layers.Transformer.blocks
Oct 6, 2025
cb181f3
Map qwen3 model type to :qwen2 tokenizer type
Oct 6, 2025
1651488
Add comprehensive Qwen3 notebook with examples
Oct 6, 2025
c02c295
Add instruction format to embeddings example in Qwen3 notebook
Oct 6, 2025
bd19c79
Add Qwen3 model tests with reference values
Oct 6, 2025
8d787ee
Fix Qwen3 embedding pooling to use attention mask instead of pad_toke…
nyo16 Oct 10, 2025
a1923e1
Add :for_reranker architecture for Qwen3
nyo16 Oct 10, 2025
0f271b5
Address PR #423 review comments: simple fixes
nyo16 Nov 7, 2025
66e2a1b
Update lib/bumblebee/text/pre_trained_tokenizer.ex
nyo16 Nov 7, 2025
81285e7
Merge branch 'qwen3-dense-support' of github.com:nyo16/bumblebee into…
nyo16 Nov 7, 2025
1e189b8
Merge branch 'main' into qwen3-dense-support
nyo16 Nov 7, 2025
cc92ccc
Rename text_reranking to text_reranking_qwen3
nyo16 Nov 7, 2025
660ef1b
Remove :for_reranker architecture, use :for_causal_language_modeling
nyo16 Nov 7, 2025
9fccfaa
Fix syntax error and document :last_token_pooling option
Nov 16, 2025
b289b75
Make query_norm and key_norm always functions
Nov 16, 2025
7604f42
Fix duplicate rotary_embedding key in transformer blocks
Nov 16, 2025
7a7eb93
Update Qwen3 tests to use bumblebee-testing models
Nov 16, 2025
bd4f915
run formatter
Nov 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ defmodule Bumblebee do
"Phi3ForCausalLM" => {Bumblebee.Text.Phi3, :for_causal_language_modeling},
"Phi3ForSequenceClassification" => {Bumblebee.Text.Phi3, :for_sequence_classification},
"Phi3ForTokenClassification" => {Bumblebee.Text.Phi3, :for_token_classification},
"Qwen3Model" => {Bumblebee.Text.Qwen3, :base},
"Qwen3ForCausalLM" => {Bumblebee.Text.Qwen3, :for_causal_language_modeling},
"Qwen3ForSequenceClassification" => {Bumblebee.Text.Qwen3, :for_sequence_classification},
"Qwen3ForEmbedding" => {Bumblebee.Text.Qwen3, :for_embedding},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no Qwen3ForEmbedding in HF transformers, so we can remove this, and the :for_embedding architecture.

"ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification},
"ResNetModel" => {Bumblebee.Vision.ResNet, :base},
"RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
Expand Down Expand Up @@ -258,6 +262,7 @@ defmodule Bumblebee do
"mbart" => :mbart,
"phi" => :code_gen,
"phi3" => :llama,
"qwen3" => :qwen2,
"roberta" => :roberta,
"smollm3" => :smollm3,
"t5" => :t5,
Expand Down
41 changes: 38 additions & 3 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ defmodule Bumblebee.Layers.Transformer do
:layer_norm,
:block_type,
:attention_window_size,
:scale_attention_weights
:scale_attention_weights,
:query_norm,
:key_norm
]

opts =
Expand Down Expand Up @@ -330,7 +332,9 @@ defmodule Bumblebee.Layers.Transformer do
layer_norm: [],
attention_window_size: nil,
scale_attention_weights: true,
rotary_embedding: nil
rotary_embedding: nil,
query_norm: nil,
key_norm: nil
])

name = opts[:name]
Expand Down Expand Up @@ -360,6 +364,8 @@ defmodule Bumblebee.Layers.Transformer do
attention_window_size = opts[:attention_window_size]
scale_attention_weights = opts[:scale_attention_weights]
rotary_embedding = opts[:rotary_embedding]
query_norm = opts[:query_norm]
key_norm = opts[:key_norm]

ffn_fun =
case ffn do
Expand Down Expand Up @@ -418,6 +424,8 @@ defmodule Bumblebee.Layers.Transformer do
attention_window_size: attention_window_size,
scale_attention_weights: scale_attention_weights,
rotary_embedding: rotary_embedding,
query_norm: query_norm,
key_norm: key_norm,
name: join(name, "self_attention")
)

Expand Down Expand Up @@ -703,6 +711,14 @@ defmodule Bumblebee.Layers.Transformer do

* `:max_positions` - the maximum number of distinct positions

* `:query_norm` - a function that applies normalization to the query
projection before rotary embedding. The function should accept two
arguments: the input and a name for the layer. Defaults to `nil`

* `:key_norm` - a function that applies normalization to the key
projection before rotary embedding. The function should accept two
arguments: the input and a name for the layer. Defaults to `nil`

* `:name` - the prefix for layer names

## References
Expand Down Expand Up @@ -734,7 +750,9 @@ defmodule Bumblebee.Layers.Transformer do
key_use_bias: true,
value_use_bias: true,
output_use_bias: true,
rotary_embedding: nil
rotary_embedding: nil,
query_norm: nil,
key_norm: nil
])

attention_mask = opts[:attention_mask]
Expand All @@ -752,6 +770,8 @@ defmodule Bumblebee.Layers.Transformer do
scale_attention_weights = opts[:scale_attention_weights]
dropout_rate = opts[:dropout_rate]
rotary_embedding = opts[:rotary_embedding]
query_norm = opts[:query_norm]
key_norm = opts[:key_norm]

query_use_bias = opts[:query_use_bias]
key_use_bias = opts[:key_use_bias]
Expand Down Expand Up @@ -791,6 +811,21 @@ defmodule Bumblebee.Layers.Transformer do
)
|> Layers.split_heads(num_key_value_heads)

# Apply query and key normalization if configured (before rotary embedding)
query =
if query_norm do
query_norm.(query, join(name, "query_norm"))
else
query
end

key =
if key_norm do
key_norm.(key, join(name, "key_norm"))
else
key
end

{query, key} =
case rotary_embedding do
opts when is_list(opts) ->
Expand Down
46 changes: 46 additions & 0 deletions lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ defmodule Bumblebee.Text do
Note that we currently assume that the CLS token is the first token
in the sequence

* `:last_token_pooling` - takes the embedding for the last non-padding
token in each sequence

By default no pooling is applied

* `:embedding_processor` - a post-processing step to apply to the
Expand Down Expand Up @@ -444,6 +447,49 @@ defmodule Bumblebee.Text do
defdelegate text_embedding(model_info, tokenizer, opts \\ []),
to: Bumblebee.Text.TextEmbedding

@type text_reranking_qwen3_input :: {String.t(), String.t()} | [{String.t(), String.t()}]
@type text_reranking_qwen3_output :: %{
scores: text_reranking_qwen3_score() | list(text_reranking_qwen3_score())
}
@type text_reranking_qwen3_score :: %{score: number(), query: String.t(), document: String.t()}

@doc """
Builds a serving for text reranking with Qwen3 reranker models.

The serving expects input in one of the following formats:

* `{query, document}` - a tuple with query and document text
* `[{query1, doc1}, {query2, doc2}, ...]` - a list of query-document pairs

## Options

See `Bumblebee.Text.TextRerankingQwen3.text_reranking_qwen3/3` for available options.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TextRerankingQwen3 module is private (@moduledoc false), so we should have all the docs and options here.


## Examples

{:ok, model_info} = Bumblebee.load_model({:hf, "Qwen/Qwen3-Reranker-0.6B"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Qwen/Qwen3-Reranker-0.6B"})

serving = Bumblebee.Text.text_reranking_qwen3(model_info, tokenizer)

query = "What is the capital of France?"
documents = [
"Paris is the capital of France.",
"Berlin is the capital of Germany."
]

pairs = Enum.map(documents, &{query, &1})
Nx.Serving.run(serving, pairs)

"""
@spec text_reranking_qwen3(
Bumblebee.model_info(),
Bumblebee.Tokenizer.t(),
keyword()
) :: Nx.Serving.t()
defdelegate text_reranking_qwen3(model_info, tokenizer, opts \\ []),
to: Bumblebee.Text.TextRerankingQwen3

@type fill_mask_input :: String.t()
@type fill_mask_output :: %{predictions: list(fill_mask_prediction())}
@type fill_mask_prediction :: %{score: number(), token: String.t()}
Expand Down
7 changes: 7 additions & 0 deletions lib/bumblebee/text/pre_trained_tokenizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
},
default_template_options: [language_token: "eng_Latn"]
},
qwen2: %{
special_tokens: %{
unk: "<|endoftext|>",
eos: "<|endoftext|>",
pad: "<|endoftext|>"
}
},
roberta: %{
special_tokens: %{
bos: "<s>",
Expand Down
Loading
Loading