Skip to content

Fix Qwen3-Embedding batch vs single inference inconsistency #648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

lance-miles
Copy link
Contributor

@lance-miles lance-miles commented Jun 18, 2025

What does this PR do?

Fix Qwen3-Embedding batch vs single inference inconsistency

Fixes #642 and PR #646

Problem

PR #646 introduced a regression where Qwen3-Embedding models produced inconsistent embeddings between batch and single sequence inference for identical inputs. The test backends/candle/tests/test_qwen3.rs was failing with assertion errors on the line:
assert_eq!(embeddings_batch[0], embeddings_single[0]);

Root Cause

The issue stemmed from inconsistent attention bias handling between batch and single sequence processing:

  1. Batch processing: Applied padding and attention bias correctly with causal masking
  2. Single processing: Did not create attention bias tensors, causing SDPA vs eager attention behavioral differences
  3. Padding inconsistency: The original implementation used right padding, but Qwen3-Embedding requires left padding for proper causal attention

Investigation

Further analysis revealed that Qwen3-Embedding models use causal attention by design (not bidirectional like BERT), requiring:

  • Left padding for batched sequences to align EOS tokens
  • Consistent causal attention masking for both single and batch inference
  • Proper last token indexing accounting for padding position

Fixes # (issue)

This fix ensures consistent behavior by:

  1. Left Padding Implementation:
    - Pad sequences at the beginning (left) rather than end (right)
    - Aligns with Qwen3-Embedding's causal attention requirements
  2. Consistent Attention Bias Creation:
    - Create causal attention bias for both single and batch processing
    - Apply identical upper triangular masking in both code paths
  3. Correct Last Token Indexing:
    - Account for left padding when extracting last token embeddings
    - Ensure EOS token is correctly identified for pooling
  4. Causal Attention Masking:
    - Maintain proper causal mask generation preventing future token attention
    - Consistent with Qwen3-Embedding's architectural requirements

Changes

  • backends/candle/src/models/qwen3.rs: Updated batch processing logic, attention bias handling, and last token extraction
  • Test snapshots: Updated to reflect correct implementation outputs

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@lance-miles lance-miles marked this pull request as ready for review June 18, 2025 17:16
Copy link
Contributor

@kozistr kozistr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for catching this quickly :)

and i've also verified that the output is identical with Transformer!

looks good to me!

@lance-miles
Copy link
Contributor Author

@kozistr Thank you so much for the initial PR and initial fix!!

Tagging @Narsil or @alvarobartt for review

@lance-miles
Copy link
Contributor Author

@kozistr Thank you so much for the initial PR and initial fix!!

Tagging @Narsil or @alvarobartt for review

I've run the pre-commit hooks and fixed the formatting issues.

Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kozistr kozistr mentioned this pull request Jun 21, 2025
6 tasks
@lance-miles
Copy link
Contributor Author

@alvarobartt and @Narsil,
Is there anything else you need me to do before merging this one in? Please let me know how I can help!

@alvarobartt
Copy link
Member

alvarobartt commented Jun 24, 2025

Hey @lance-miles, apologies for the delay!

FYI float16 inference is now broken on CPU and MPS, but I already have a fix for that in a follow up PR, so no need to worry about that!

But when comparing the outputs with the Sentence Transformers counterpart, I realized that those are still not matching, since both you and @kozistr got matching results, could you please share the snippet you used? For reference I deployed Text Embeddings Inference (TEI) from this branch + the causal attention mask boolean flag from #650, and ran it as cargo run --release --features candle-cuda,dynamic-linking,http --no-default-features -- --model-id Qwen/Qwen3-Embedding-0.6B --dtype float16 in an instance with a single NVIDIA L40 48GB and then compared it against the Sentence Transformers' output in Python with numpy as it follows:

import numpy as np
import requests
from sentence_transformers import SentenceTransformer

model = SentenceTransformer(
    "Qwen/Qwen3-Embedding-0.6B",
    model_kwargs={
        "attn_implementation": "flash_attention_2",
        "torch_dtype": "float16",
        "device_map": "cuda",
    },
    tokenizer_kwargs={"padding_side": "left"},
)

out_py = model.encode(
    ["What is Deep Learning?", "Who is Walt Disney?"], normalize_embeddings=True
)
print(out_py)

response = requests.post(
    "http://localhost:3000/embed",
    json={
        "inputs": ["What is Deep Learning?", "Who is Walt Disney?"],
        "normalize": True,
    },
)

response.raise_for_status()
out = response.json()
out_http = np.array(out, dtype=np.float16)

np.testing.assert_allclose(out_py, out_http)

And that will fail with:

Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 1883 / 2048 (91.9%)
Max absolute difference among violations: 0.0005035
Max relative difference among violations: 5.
 ACTUAL: array([[-0.01566 , -0.03204 , -0.010994, ..., -0.004463,  0.03345 ,
         0.001066],
       [ 0.00918 ,  0.03442 , -0.004593, ...,  0.03952 , -0.01244 ,
        -0.01973 ]], shape=(2, 1024), dtype=float16)
 DESIRED: array([[-0.01567 , -0.0318  , -0.010956, ..., -0.00437 ,  0.0334  ,
         0.001115],
       [ 0.00939 ,  0.0343  , -0.004593, ...,  0.03952 , -0.01247 ,
        -0.01971 ]], shape=(2, 1024), dtype=float16)

P.S. There's a mismatch on both single and batched inference, but the cosine similarity value seems to be 1.0, whereas the allclose check fails when it should generate the same embedding AFAIK

Thanks again for the PR @lance-miles 🤗 And we can merge as soon as we clarify that, then I'll patch the float16 on both CPU and MPS.

@kozistr
Copy link
Contributor

kozistr commented Jun 24, 2025

Hey @lance-miles, apologies for the delay!

FYI float16 inference is now broken on CPU and MPS, but I already have a fix for that in a follow up PR, so no need to worry about that!

But when comparing the outputs with the Sentence Transformers counterpart, I realized that those are still not matching, since both you and @kozistr got matching results, could you please share the snippet you used? For reference I deployed Text Embeddings Inference (TEI) from this branch + the causal attention mask boolean flag from #650, and ran it as cargo run --release --features candle-cuda,dynamic-linking,http --no-default-features -- --model-id Qwen/Qwen3-Embedding-0.6B --dtype float16 in an instance with a single NVIDIA L40 48GB and then compared it against the Sentence Transformers' output in Python with numpy as it follows:

import numpy as np
import requests
from sentence_transformers import SentenceTransformer

model = SentenceTransformer(
    "Qwen/Qwen3-Embedding-0.6B",
    model_kwargs={
        "attn_implementation": "flash_attention_2",
        "torch_dtype": "float16",
        "device_map": "cuda",
    },
    tokenizer_kwargs={"padding_side": "left"},
)

out_py = model.encode(
    ["What is Deep Learning?", "Who is Walt Disney?"], normalize_embeddings=True
)
print(out_py)

response = requests.post(
    "http://localhost:3000/embed",
    json={
        "inputs": ["What is Deep Learning?", "Who is Walt Disney?"],
        "normalize": True,
    },
)

response.raise_for_status()
out = response.json()
out_http = np.array(out, dtype=np.float16)

np.testing.assert_allclose(out_py, out_http)

And that will fail with:

Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 1883 / 2048 (91.9%)
Max absolute difference among violations: 0.0005035
Max relative difference among violations: 5.
 ACTUAL: array([[-0.01566 , -0.03204 , -0.010994, ..., -0.004463,  0.03345 ,
         0.001066],
       [ 0.00918 ,  0.03442 , -0.004593, ...,  0.03952 , -0.01244 ,
        -0.01973 ]], shape=(2, 1024), dtype=float16)
 DESIRED: array([[-0.01567 , -0.0318  , -0.010956, ..., -0.00437 ,  0.0334  ,
         0.001115],
       [ 0.00939 ,  0.0343  , -0.004593, ...,  0.03952 , -0.01247 ,
        -0.01971 ]], shape=(2, 1024), dtype=float16)

P.S. Fails on both single and batch inference on my end!

Thanks again for the PR @lance-miles 🤗 And we can merge as soon as we clarify that, then I'll patch the float16 on both CPU and MPS.

Hi! Personally, I use this script to validate the outputs! code (need to be modified to work with Qwen3 Embedding)

I might be wrong, IMHO, the difference between Sentence Transformer and TEI seems marginal, primarily due to the precision and difference at some layers (e.g. activation, ...).

@alvarobartt alvarobartt merged commit f7aa35b into huggingface:main Jun 24, 2025
3 of 13 checks passed
@pocman
Copy link

pocman commented Jun 27, 2025

Amazing work, could we schedule a version release 1.7.3 to ship this fix ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Qwen3-Embedding models: embeddings from TEI differ sharply from Sentence-Transformers reference
5 participants