Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the word_emb for cross attention #8

Open
buptxyb666 opened this issue Apr 22, 2024 · 3 comments
Open

About the word_emb for cross attention #8

buptxyb666 opened this issue Apr 22, 2024 · 3 comments

Comments

@buptxyb666
Copy link

Thanks for your great work! I wonder that the length of text usually less than 77. Why not mask the padding tokens in word_emb when performing cross attention?

@exitudio
Copy link
Owner

Hi, we use [MASK] tokens for generation by iterative decoding and [PAD] tokens to fill up the shorter length samples. [PAD] tokens in CLIP model can be in similar manner. Since we only use text tokens as a condition (not for generation), no need [MASK] tokens for text.

@buptxyb666
Copy link
Author

I mean that when perform cross attention between word embed (key and value) and motion token(query), will the [PAD] tokens from CLIP introduce the noise to motion token ?

Compared with global text condition, additionally using the fine-grained word embeds can bring performance gain ?

Look forward your reply.

@exitudio
Copy link
Owner

The model should learn to ignore [PAD] tokens (following CLIP).
For more information, to get global (sentence) text embedding, CLIP simply applies linear layer to the local (word) embedding.
https://github.com/openai/CLIP/blob/main/clip/model.py#L343-L356

We create a wrapper class here:

MMM/train_t2m_trans.py

Lines 76 to 80 in 2f7e3b2

word_emb = self.model.token_embedding(text).type(self.model.dtype)
word_emb = word_emb + self.model.positional_embedding.type(self.model.dtype)
word_emb = word_emb.permute(1, 0, 2) # NLD -> LND
word_emb = self.model.transformer(word_emb)
word_emb = self.model.ln_final(word_emb).permute(1, 0, 2).float()

Applying local text embedding shows the trade-off between R-precision and FID score. Please see table 9 in the supplementary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants