Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention takes softmax over wrong dimension #282

Open
Trezorro opened this issue Feb 4, 2025 · 0 comments
Open

Attention takes softmax over wrong dimension #282

Trezorro opened this issue Feb 4, 2025 · 0 comments

Comments

@Trezorro
Copy link

Trezorro commented Feb 4, 2025

PLEASE CORRECT ME IF IM WRONG.

I believe the line attn = attn.softmax(dim=2) is incorrect.

Dim 1 contains the index (i) over the query sequence entries, and dim 2 contains the index (j) over the key sequency entries.
If my understanding is correct, for any query (dim 1), we would like the sum of associated keys to be 1, so the copied information to that query position will remain the same scale, and so it may ignore information from many keys.

However the current implementation has it such that for any key, the attention from all query positions to it sums up to 1 after the softmax. Then some query positions may be close to 0 for all keys, while this forces EVERY key to be used by at least one query position.

We should not take the softmax over the key dimension (2) but over the query dimension (1).

This implementation, based on the current, uses dim 1. https://github.com/pdearena/pdearena/blob/db7664bb8ba1fe6ec3217e4079979a5e4f800151/pdearena/modules/conditioned/twod_unet.py#L223

Or am I mistaken in the output of the softmax?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant