-
Notifications
You must be signed in to change notification settings - Fork 59
Add Support for Frequency Penalties in On Device Sampling #523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: quic-sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: quic-sanising <[email protected]>
Signed-off-by: quic-sanising <[email protected]>
Signed-off-by: quic-sanising <[email protected]>
Depends on PR #463. |
Signed-off-by: quic-sanising <[email protected]>
scatter_values, | ||
torch.ones(last_accepted_output_tokens.shape, dtype=torch.bool), | ||
) | ||
gather_values = past_presence_penalty_buffer[batch_index, last_accepted_output_tokens] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For improved performance, I intended to use CtxGatherFuncCB3D
but it doesn't work as last_accepted_output_tokens
is a tensor of shape (batch_size, seq_len)
whereas the function expects it to be of shape (batch_size, 1)
. Please let me know if there is a workaround.
@quic-sanising is it ready for review? Can you rebase the PR? |
✨ Add Frequency Penalty Support to On Device Sampling
This PR adds support for the
frequency_penalty
parameter in On Device Sampling forQEffForCausalLM
models. This parameter adjusts token selection based on how often tokens have already appeared in the generated output:The implementation tracks token frequencies directly on the QAIC device using optimized scratch buffers, ensuring minimal overhead and maintaining high throughput. This feature integrates seamlessly with the existing
include_sampler=True
workflow and complements other supported strategies like repetition and presence penalties.