You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Dim 1 contains the index (i) over the query sequence entries, and dim 2 contains the index (j) over the key sequency entries.
If my understanding is correct, for any query (dim 1), we would like the sum of associated keys to be 1, so the copied information to that query position will remain the same scale, and so it may ignore information from many keys.
However the current implementation has it such that for any key, the attention from all query positions to it sums up to 1 after the softmax. Then some query positions may be close to 0 for all keys, while this forces EVERY key to be used by at least one query position.
We should not take the softmax over the key dimension (2) but over the query dimension (1).
PLEASE CORRECT ME IF IM WRONG.
I believe the line
attn = attn.softmax(dim=2)
is incorrect.annotated_deep_learning_paper_implementations/labml_nn/diffusion/ddpm/unet.py
Line 188 in 05321d6
Dim 1 contains the index (i) over the query sequence entries, and dim 2 contains the index (j) over the key sequency entries.
If my understanding is correct, for any query (dim 1), we would like the sum of associated keys to be 1, so the copied information to that query position will remain the same scale, and so it may ignore information from many keys.
However the current implementation has it such that for any key, the attention from all query positions to it sums up to 1 after the softmax. Then some query positions may be close to 0 for all keys, while this forces EVERY key to be used by at least one query position.
We should not take the softmax over the key dimension (2) but over the query dimension (1).
This implementation, based on the current, uses dim 1. https://github.com/pdearena/pdearena/blob/db7664bb8ba1fe6ec3217e4079979a5e4f800151/pdearena/modules/conditioned/twod_unet.py#L223
Or am I mistaken in the output of the softmax?
The text was updated successfully, but these errors were encountered: