Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems on generating with llama model #921

Open
wiio12 opened this issue May 4, 2023 · 21 comments
Open

Problems on generating with llama model #921

wiio12 opened this issue May 4, 2023 · 21 comments

Comments

@wiio12
Copy link

wiio12 commented May 4, 2023

Hi, I tried loading the llama model for inference and encountered some problems. I use 4 v100 GPUs with the model parallel size of 4 to load the llama 7b checkpoints.

  1. Error in loading the llama checkpoint. The converted checkpoint generated by the script tools/convert_raw_llama_weights_to_neox.py provides no optimizer states. but the deepspeed keep trying to load the optimizer state even if I set the finetune flag to true. The cause seems to be in here. Regardless of whether the file exists, deepspeed will return the file list of optimizer states. I fix this by adding an additional line to check if the file exists and returning None if not.

  2. Tensor shape mismatch occurred during inference. This is fixed by changing the line here, where

attention_mask = attention_mask[
                    ..., : attention_scores.size(3), : attention_scores.size(3)
                ]

is change to

attention_mask = attention_mask[
                    ..., : attention_scores.size(2), : attention_scores.size(3)
                ]

I wonder if my fixes are correct, or if there are better ways to fix this. I think I just tackling the phenomenon of the problem but not the causes of it.

@wiio12
Copy link
Author

wiio12 commented May 4, 2023

The error mentioned here also occurred if I set the self.pipe_parallel_size >= 1, so I have the stay with self.pipe_parallel_size >= 2

@DaoD
Copy link

DaoD commented May 4, 2023

Hi, I tried loading the llama model for inference and encountered some problems. I use 4 v100 GPUs with the model parallel size of 4 to load the llama 7b checkpoints.

  1. Error in loading the llama checkpoint. The converted checkpoint generated by the script tools/convert_raw_llama_weights_to_neox.py provides no optimizer states. but the deepspeed keep trying to load the optimizer state even if I set the finetune flag to true. The cause seems to be in here. Regardless of whether the file exists, deepspeed will return the file list of optimizer states. I fix this by adding an additional line to check if the file exists and returning None if not.
  2. Tensor shape mismatch occurred during inference. This is fixed by changing the line here, where
attention_mask = attention_mask[
                    ..., : attention_scores.size(3), : attention_scores.size(3)
                ]

is change to

attention_mask = attention_mask[
                    ..., : attention_scores.size(2), : attention_scores.size(3)
                ]

I wonder if my fixes are correct, or if there are better ways to fix this. I think I just tackling the phenomenon of the problem but not the causes of it.

hi, I set the default value of load_optimizer_states as False in the load_checkpoint function of engine.py.
Then, I get the error IndexError: list index out of range in current_rank_sd = state_dict_list[dp_rank]
Have you met this problem?
How to fix the problem you mentioned in 1?
Can you fine-tune it now?

@wiio12
Copy link
Author

wiio12 commented May 4, 2023 via email

@DaoD
Copy link

DaoD commented May 4, 2023

Thx. I'm trying to continue training the model to see if the loss is correct. I will update my results here if it runs successfully.

@wiio12
Copy link
Author

wiio12 commented May 5, 2023

Hi @DaoD, specifically, I make the following changes in problem 1:

    def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode):
        mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
        zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(
            load_dir=load_dir,
            tag=tag,
            mp_rank=mp_rank,
            dp_world_size=self.loaded_checkpoint_dp_world_size,
            bf16_mode=bf16_mode)
        for i, ckpt_name in enumerate(zero_ckpt_names):
            if not os.path.exists(ckpt_name):
                # transparently handle the old file pattern for optim_states
                if "optim_states.pt" in ckpt_name:
                    ckpt_name_try = ckpt_name.replace("_optim_states.pt",
                                                      "optim_states.pt")
                    if os.path.exists(ckpt_name_try):
                        zero_ckpt_names[i] = ckpt_name_try
                        continue

        for ckpt_name in zero_ckpt_names:
            if not os.path.exists(ckpt_name):
                return None

        return zero_ckpt_names

The function is defined in here

@DaoD
Copy link

DaoD commented May 5, 2023

Hi @DaoD, specifically, I make the following changes in problem 1:

    def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode):
        mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
        zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(
            load_dir=load_dir,
            tag=tag,
            mp_rank=mp_rank,
            dp_world_size=self.loaded_checkpoint_dp_world_size,
            bf16_mode=bf16_mode)
        for i, ckpt_name in enumerate(zero_ckpt_names):
            if not os.path.exists(ckpt_name):
                # transparently handle the old file pattern for optim_states
                if "optim_states.pt" in ckpt_name:
                    ckpt_name_try = ckpt_name.replace("_optim_states.pt",
                                                      "optim_states.pt")
                    if os.path.exists(ckpt_name_try):
                        zero_ckpt_names[i] = ckpt_name_try
                        continue

        for ckpt_name in zero_ckpt_names:
            if not os.path.exists(ckpt_name):
                return None

        return zero_ckpt_names

The function is defined in here

Thx! I have continued training the model. I find the loss can be reduced successfully, but I'm not sure if it can really work. I will test it soon.

@StellaAthena
Copy link
Member

Apologies for the issues y'all're having. We tested this at pretraining before merging but not for finetuning or inference which in retrospect is somewhat silly. Thank you for your help in debugging this.

@wiio12
Copy link
Author

wiio12 commented May 5, 2023

Hi @DaoD, specifically, I make the following changes in problem 1:

    def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode):
        mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
        zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(
            load_dir=load_dir,
            tag=tag,
            mp_rank=mp_rank,
            dp_world_size=self.loaded_checkpoint_dp_world_size,
            bf16_mode=bf16_mode)
        for i, ckpt_name in enumerate(zero_ckpt_names):
            if not os.path.exists(ckpt_name):
                # transparently handle the old file pattern for optim_states
                if "optim_states.pt" in ckpt_name:
                    ckpt_name_try = ckpt_name.replace("_optim_states.pt",
                                                      "optim_states.pt")
                    if os.path.exists(ckpt_name_try):
                        zero_ckpt_names[i] = ckpt_name_try
                        continue

        for ckpt_name in zero_ckpt_names:
            if not os.path.exists(ckpt_name):
                return None

        return zero_ckpt_names

The function is defined in here

Thx! I have continued training the model. I find the loss can be reduced successfully, but I'm not sure if it can really work. I will test it soon.

Have you tried to generate with the loaded checkpoint? I found out that the generated text with greedy decoding tends to repeat itself.
For example:
This is the output of greedy decoding:

Input: Tell me about alpacas.
Output: \n the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the

This is the output of sampling with temperature=1:

Input: Tell me about alpacas.
Output: Whatersung\n Data you\n how youib them us how the you it I the To my In youral And youdown ifing your Youinas your them to mety it story that the the itoray U the k your D him youstein me whenisonaw you me me much Startor if youcom You You\n You this Your Breas at You it\nha what the your them me you\n a iten us Your You Hol to\nly K\nur me What your

@DaoD Could you kindly provide your config file and the running command for fine-tuning? Thx!

@DaoD
Copy link

DaoD commented May 5, 2023

For config, you can just copy the configs/llama/7B.yml into the model settings part of configs/6-7B.yml. The running command is the same as training other models.

@yoosan
Copy link

yoosan commented May 8, 2023

Hi @wiio12 I don't see a LlamaTokenizer in the code. How do you perform inference and verify the results?

@wiio12
Copy link
Author

wiio12 commented May 8, 2023

Hi @yoosan, I use the SMPTokenizer, with the vocab file tokenizer.model. But a problem exists in this tokenizer reported here

@DaoD
Copy link

DaoD commented May 8, 2023

Hi @wiio12 I don't see a LlamaTokenizer in the code. How do you perform inference and verify the results?

I replace the SMPTokenizer by LlamaTokenizer by myself.

@yoosan
Copy link

yoosan commented May 8, 2023

That's great. Have you now aligned the results of the HF version and GPT-Neo inference to be consistent?

PS. LLaMA itself doesn't have the ability to engage in conversations, so it's better to verify the results using a continuation task, such as providing a prefix and letting it generate the rest. e.g.
Input: Charles was born in Buckingham Palace during the reign of

@DaoD
Copy link

DaoD commented May 8, 2023

I do not use the inference code of GPT-NeoX, and it seems that there are some minor problems. I have just tested the HF version. It works well.

@yoosan
Copy link

yoosan commented May 8, 2023

Thanks DaoD. This project has already converted the ckpt into Megatron/GPT-NeoX format. I'm curious about how you used HF for validation.

I do not use the inference code of GPT-NeoX, and it seems that there are some minor problems. I have just tested the HF version. It works well.

@DaoD
Copy link

DaoD commented May 8, 2023

Thanks DaoD. This project has already converted the ckpt into Megatron/GPT-NeoX format. I'm curious about how you used HF for validation.

I do not use the inference code of GPT-NeoX, and it seems that there are some minor problems. I have just tested the HF version. It works well.

I just use the neox format for training/fine-tuning. After training, I convert it into HF version for inference/testing.

@StellaAthena
Copy link
Member

So it sounds like this issue is a combination of two other issues:

  1. Our recurring issues with generation in GPT-NeoX
  2. The fact that we don’t currently support the SPM Tokenizer.

If that’s the case, I think it probably makes sense to close this issue as both of those are known problems we are currently working on.

@Quentin-Anthony
Copy link
Member

Quentin-Anthony commented May 15, 2023

So it sounds like this issue is a combination of two other issues:

  1. Our recurring issues with generation in GPT-NeoX
  2. The fact that we don’t currently support the SPM Tokenizer.

If that’s the case, I think it probably makes sense to close this issue as both of those are known problems we are currently working on.

Keep this issue open until these issues are resolved. We'll add a "Fixes xxx" clause that auto-closes this issue to whatever PR fixes things.

@yoosan
Copy link

yoosan commented May 22, 2023

Hi @StellaAthena Recently, has there been any progress? I have updated to the latest commit and found that there are no relevant code updates yet. I am really looking forward to a usable version.

@sxthunder
Copy link

Thanks DaoD. This project has already converted the ckpt into Megatron/GPT-NeoX format. I'm curious about how you used HF for validation.

I do not use the inference code of GPT-NeoX, and it seems that there are some minor problems. I have just tested the HF version. It works well.

I just use the neox format for training/fine-tuning. After training, I convert it into HF version for inference/testing.

Hello,can you share your script to convert neox-ckpt to hf-llama model?

@ghost
Copy link

ghost commented Mar 4, 2024

Thanks DaoD. This project has already converted the ckpt into Megatron/GPT-NeoX format. I'm curious about how you used HF for validation.

I do not use the inference code of GPT-NeoX, and it seems that there are some minor problems. I have just tested the HF version. It works well.

I just use the neox format for training/fine-tuning. After training, I convert it into HF version for inference/testing.

Hello,can you share your script to convert neox-ckpt to hf-llama model?

Hello, there is a new conversion script, you can go to /tools/ckpts to check it out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants