Skip to content

Gemma with softprompt raises error #2452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
4 tasks
krishnakanthnakkav2 opened this issue Mar 26, 2025 · 3 comments · May be fixed by #2458
Open
4 tasks

Gemma with softprompt raises error #2452

krishnakanthnakkav2 opened this issue Mar 26, 2025 · 3 comments · May be fixed by #2458

Comments

@krishnakanthnakkav2
Copy link

System Info

I am doing soft-prompt tuning on gemma2b. THere is issue during generation

File "/home/krishna/PII/fs-llm/libs/peft/src/peft/peft_model.py", line 1920, in prepare_inputs_for_generation
model_kwargs["attention_mask"] = torch.cat(
RuntimeError: Tensors must have same number of dimensions: got 2 and 4

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

softprompt with gemma

Expected behavior

No error

@krishnakanthnakkav2
Copy link
Author

krishnakanthnakkav2 commented Mar 26, 2025

Sample code

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PromptTuningConfig
import torch
from peft import get_peft_model, TaskType, PromptTuningConfig, PromptTuningInit

# Load the pre-trained gemma2b-it model and tokenizer
model_name = "google/gemma-2-2b"
cache_dir = "/assets/hub"




model = AutoModelForCausalLM.from_pretrained(
    model_name, cache_dir=cache_dir,
    attn_implementation="eager" if "gemma" in model_name else None,
    device_map="auto",
    torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir,)

config = PromptTuningConfig(
    peft_type="PROMPT_TUNING",
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.RANDOM,
    prompt_tuning_init_text="email",  # "phone" # "address"
    num_virtual_tokens=20,
    tokenizer_name_or_path=model_name)

model = get_peft_model(model, config)



# Define a batch of text for generation
input_texts = [
    "In the world of artificial intelligence,",
   
]

# Tokenize the batch of input texts
inputs = tokenizer(input_texts, return_tensors="pt",
                    padding=True, truncation=True
                   )

print(f"Input attention mask shape: {inputs['attention_mask'].shape}")

# Generate text with the model for the batch
generated_output = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    max_length=50
)

# Decode and print the generated text for each example in the batch
generated_texts = tokenizer.batch_decode(
    generated_output, skip_special_tokens=True
)

for idx, generated_text in enumerate(generated_texts):
    print(f"Generated text for input {idx + 1}: {generated_text}")

The error happens in peft_model.py during concatenation at

model_kwargs["attention_mask"] = torch.cat(

where the shapes of model_kwargs["attention_mask"] and
prefix_attention_mask are [1,1,8,49] and [1, 20] before the concatenation step.

Here 49 corresponds to number of max_length minus 1. I changed the max_length and this value changes accordingly, and 8 refers to the number of tokens in the input.

I tested with other models like EleutherAI/pythia-6.9b, the code works and it print these variables as 2-D tensors like [1,8] and [1,20]

Edit: The issue is possibly from transformers function self.base_model_prepare_inputs_for_generation() which returns the 4-D attention mask.

Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
_openmp_mutex             5.1                       1_gnu
accelerate                1.5.2                    pypi_0    pypi
blobfile                  3.0.0                    pypi_0    pypi
bzip2                     1.0.8                h5eee18b_6
ca-certificates           2025.2.25            h06a4308_0
certifi                   2025.1.31                pypi_0    pypi
charset-normalizer        3.4.1                    pypi_0    pypi
expat                     2.6.4                h6a678d5_0
filelock                  3.18.0                   pypi_0    pypi
fsspec                    2025.3.0                 pypi_0    pypi
huggingface-hub           0.29.3                   pypi_0    pypi
idna                      3.10                     pypi_0    pypi
jinja2                    3.1.6                    pypi_0    pypi
ld_impl_linux-64          2.40                 h12ee557_0
libffi                    3.4.4                h6a678d5_1
libgcc-ng                 11.2.0               h1234567_1
libgomp                   11.2.0               h1234567_1
libmpdec                  4.0.0                h5eee18b_0
libstdcxx-ng              11.2.0               h1234567_1
libuuid                   1.41.5               h5eee18b_0
lxml                      5.3.1                    pypi_0    pypi
markupsafe                3.0.2                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0
networkx                  3.4.2                    pypi_0    pypi
numpy                     2.2.4                    pypi_0    pypi
nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
openssl                   3.0.16               h5eee18b_0
packaging                 24.2                     pypi_0    pypi
peft                      0.14.0                   pypi_0    pypi
pip                       25.0            py313h06a4308_0
protobuf                  6.30.1                   pypi_0    pypi
psutil                    7.0.0                    pypi_0    pypi
pycryptodomex             3.22.0                   pypi_0    pypi
python                    3.13.2          hf623796_100_cp313
python_abi                3.13                    0_cp313
pyyaml                    6.0.2                    pypi_0    pypi
readline                  8.2                  h5eee18b_0
regex                     2024.11.6                pypi_0    pypi
requests                  2.32.3                   pypi_0    pypi
safetensors               0.5.3                    pypi_0    pypi
setuptools                75.8.0          py313h06a4308_0
sqlite                    3.45.3               h5eee18b_0
sympy                     1.13.1                   pypi_0    pypi
tiktoken                  0.8.0                    pypi_0    pypi
tk                        8.6.14               h39e8969_0
tokenizers                0.21.1                   pypi_0    pypi
torch                     2.6.0                    pypi_0    pypi
tqdm                      4.67.1                   pypi_0    pypi
transformers              4.49.0                   pypi_0    pypi
triton                    3.2.0                    pypi_0    pypi
typing-extensions         4.13.0                   pypi_0    pypi
tzdata                    2025a                h04d1e81_0
urllib3                   2.3.0                    pypi_0    pypi
wheel                     0.45.1          py313h06a4308_0
xz                        5.6.4                h5eee18b_1
zlib                      1.2.13               h5eee18b_1

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Mar 27, 2025
Resolves huggingface#2452

Some causal language models in transformers have 4d attention masks at
the input preparation stage. So far, we have assumed 2d attention masks,
which results in an error in that case. This PR fixes the situation.

My first attempt was to transform the 2d prefix attention mask (from the
virtual tokens) into a 4d attention mask before concatenating them.
However, this was error prone and I was unsure if my approach would
generalize to other model architectures than the one tested (gemma), as
it involved using private transformers methods. The simpler approach was
thus to just create a 2d attention mask and let the model handle it.

The test suite has been extended to include a tiny gemma model. To
prevent the test suite from ballooning, I removed another model.
Specifically, this was GPT neox, which from HF download stats seems to
be one of the least popular architectures from our test suite.
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

not stale, waiting for #2458 to be merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants