-
Notifications
You must be signed in to change notification settings - Fork 287
Fix Qwen3-Embedding batch vs single inference inconsistency #648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Qwen3-Embedding batch vs single inference inconsistency #648
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for catching this quickly :)
and i've also verified that the output is identical with Transformer!
looks good to me!
@kozistr Thank you so much for the initial PR and initial fix!! Tagging @Narsil or @alvarobartt for review |
Co-authored-by: Hyeongchan Kim <[email protected]>
I've run the pre-commit hooks and fixed the formatting issues. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@alvarobartt and @Narsil, |
Hey @lance-miles, apologies for the delay! FYI But when comparing the outputs with the Sentence Transformers counterpart, I realized that those are still not matching, since both you and @kozistr got matching results, could you please share the snippet you used? For reference I deployed Text Embeddings Inference (TEI) from this branch + the causal attention mask boolean flag from #650, and ran it as import numpy as np
import requests
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(
"Qwen/Qwen3-Embedding-0.6B",
model_kwargs={
"attn_implementation": "flash_attention_2",
"torch_dtype": "float16",
"device_map": "cuda",
},
tokenizer_kwargs={"padding_side": "left"},
)
out_py = model.encode(
["What is Deep Learning?", "Who is Walt Disney?"], normalize_embeddings=True
)
print(out_py)
response = requests.post(
"http://localhost:3000/embed",
json={
"inputs": ["What is Deep Learning?", "Who is Walt Disney?"],
"normalize": True,
},
)
response.raise_for_status()
out = response.json()
out_http = np.array(out, dtype=np.float16)
np.testing.assert_allclose(out_py, out_http) And that will fail with: Not equal to tolerance rtol=1e-07, atol=0
Mismatched elements: 1883 / 2048 (91.9%)
Max absolute difference among violations: 0.0005035
Max relative difference among violations: 5.
ACTUAL: array([[-0.01566 , -0.03204 , -0.010994, ..., -0.004463, 0.03345 ,
0.001066],
[ 0.00918 , 0.03442 , -0.004593, ..., 0.03952 , -0.01244 ,
-0.01973 ]], shape=(2, 1024), dtype=float16)
DESIRED: array([[-0.01567 , -0.0318 , -0.010956, ..., -0.00437 , 0.0334 ,
0.001115],
[ 0.00939 , 0.0343 , -0.004593, ..., 0.03952 , -0.01247 ,
-0.01971 ]], shape=(2, 1024), dtype=float16) P.S. There's a mismatch on both single and batched inference, but the cosine similarity value seems to be 1.0, whereas the Thanks again for the PR @lance-miles 🤗 And we can merge as soon as we clarify that, then I'll patch the |
Hi! Personally, I use this script to validate the outputs! code (need to be modified to work with Qwen3 Embedding) I might be wrong, IMHO, the difference between Sentence Transformer and TEI seems marginal, primarily due to the precision and difference at some layers (e.g. activation, ...). |
Amazing work, could we schedule a version release 1.7.3 to ship this fix ? |
What does this PR do?
Fix Qwen3-Embedding batch vs single inference inconsistency
Fixes #642 and PR #646
Problem
PR #646 introduced a regression where Qwen3-Embedding models produced inconsistent embeddings between batch and single sequence inference for identical inputs. The test backends/candle/tests/test_qwen3.rs was failing with assertion errors on the line:
assert_eq!(embeddings_batch[0], embeddings_single[0]);
Root Cause
The issue stemmed from inconsistent attention bias handling between batch and single sequence processing:
Investigation
Further analysis revealed that Qwen3-Embedding models use causal attention by design (not bidirectional like BERT), requiring:
Fixes # (issue)
This fix ensures consistent behavior by:
- Pad sequences at the beginning (left) rather than end (right)
- Aligns with Qwen3-Embedding's causal attention requirements
- Create causal attention bias for both single and batch processing
- Apply identical upper triangular masking in both code paths
- Account for left padding when extracting last token embeddings
- Ensure EOS token is correctly identified for pooling
- Maintain proper causal mask generation preventing future token attention
- Consistent with Qwen3-Embedding's architectural requirements
Changes
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.