Skip to content

Draft: Adding GPT2 support to Adaption prompts #2440

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

efraimdahl
Copy link

Llama-Adapters unlike the name suggests are model agnostic. This contribution seeks to add an adjustment to the llama-adapter implementation to support GPT2 Models. Currently this is achieved through an additional class AdaptedAttentionGPT that wraps the attention layer of GPT2 - type models that handles the difference in attention calculation and the different input formats of the forward function between LLama and GPT transformers.

Currently I am testing that the learning behavior of this implementation is as expected, comparing similar LLama and GPT configurations on the same datasets. It passes initial tests for saving/loading/passing data.

Llama adapter require that the the initialized adapter should not change the generation of the base model.
I am having trouble testing for the non-evasiveness of the model as mentioned here, the following test would fail, with or without the adapter. I am looking for alternative ways to test this.

config=GPT2Config(
            vocab_size=16,
            hidden_size=8, #mapped to n_embd
            n_layers=8, #mapped to n_layers
            num_attention_heads=4, #mapped to n_head
            use_cache=True,
            attn_implementation="eager"
        )
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(device)
target_ids = torch.LongTensor([[0, 0, 0], [0, 0, 0]]).to(device)
attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(device)
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Create and compare gpt2 model outputs .
model_gpt2 = GPT2Model(create_test_gpt2_config())
model_gpt2 = model_gpt2.to(device)
a= model_gpt2(input_ids=input_ids, attention_mask=attention_mask)
b= model_gpt2(input_ids=input_ids, attention_mask=attention_mask)
assert_close(a.last_hidden_state, b.last_hidden_state, rtol=0, atol=0)

@BenjaminBossan
Copy link
Member

Thanks a lot for the PR. I haven't checked the details yet, but regarding your testing question, the missing piece was that you need to set the seed for each generate call. Here is code that passes:

import torch
from transformers import GPT2Config, GPT2Model

device = 0
config=GPT2Config(
    vocab_size=16,
    hidden_size=8, #mapped to n_embd
    n_layers=8, #mapped to n_layers
    num_attention_heads=4, #mapped to n_head
    use_cache=True,
    attn_implementation="eager"
)
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(device)
target_ids = torch.LongTensor([[0, 0, 0], [0, 0, 0]]).to(device)
attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(device)

# Create and compare gpt2 model outputs
model_gpt2 = GPT2Model(config)
model_gpt2 = model_gpt2.to(device)
torch.manual_seed(42)
a = model_gpt2(input_ids=input_ids, attention_mask=attention_mask)
torch.manual_seed(42)  # <================= important
b = model_gpt2(input_ids=input_ids, attention_mask=attention_mask)
torch.testing.assert_close(a.last_hidden_state, b.last_hidden_state, rtol=1e-6, atol=1e-6)

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

@efraimdahl are you still working on this?

@efraimdahl
Copy link
Author

@efraimdahl are you still working on this?
Cheers @BenjaminBossan. Thank you for checking in. I am still planning on completing this. I am currently experimenting with projecting outside conditioning (to finetune unimodal LLM's for multi-modal reasoning) and trying to verify that its behaving similar to the paper description.
It will take a bit until ill finish this. Let me know if its easier to close this PR for now and reopen at a later point in time, and whether i should separate the GPT2 implementation from the conditioning projection?

@BenjaminBossan
Copy link
Member

@efraimdahl No hurry, I was just checking, as sometimes people just forget about their PRs. No need to close this one. As to separating PRs, yes, it's always a good idea to keep them small.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions bot closed this May 25, 2025
@githubnemo githubnemo reopened this May 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants