Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

max token length for finetune and completion endpoints on Lllama-2? #208

Open
urimerhav opened this issue Aug 12, 2023 · 2 comments
Open
Assignees

Comments

@urimerhav
Copy link

Great job with this repo. I was able to finetune Llama-2 and it certainly seems to have an effect.

Unfortunately the finetune silently accepts all inputs and the documentation states that you simply truncate inputs to max length. But it's not specified anywhere what's LLama-2's max length. Originally Meta released it with a bug that caused max length to be 2048 while the native max length seems to be 4096. So which is it?

Also, I tested my finetune model's completion code with inputs as big as 12,000 tokens and it still makes a completion. So I assume you truncate there as well? Only taking the tail of the prompt, presumably?

tldr:

  1. What is llama-2's max token length?
  2. Is there anything we can do to effect this or get better visibility into how the input got tokenized, etc?
@urimerhav
Copy link
Author

BUMP!

@yixu34 yixu34 self-assigned this Aug 15, 2023
@yixu34
Copy link
Member

yixu34 commented Aug 15, 2023

Hi @urimerhav, thanks for reaching out. Here are some answers to your questions:

Originally Meta released it with a bug that caused max length to be 2048 while the native max length seems to be 4096. So which is it?

Seems like 4096 as you may have found searching elsewhere.

Also, I tested my finetune model's completion code with inputs as big as 12,000 tokens and it still makes a completion. So I assume you truncate there as well? Only taking the tail of the prompt, presumably?

Do you have some more detailed repro steps? We just tried a long input and got an exception.

Is there anything we can do to effect this or get better visibility into how the input got tokenized, etc?

We haven't yet open sourced our fine-tuning code, although we fully intend to! The issue is that unlike the rest of LLM Engine, our fine-tuning scripts still have some internal dependencies that need to be ripped out. In the meantime, I can share a code snippet that might give some visibility into the tokenization process - we currently truncate to 1024 tokens for fine-tuning because we're currently just using A10's and wanted to avoid OOM:

class SFTCollator(object):
    """Collate examples for supervised fine-tuning.
    We intentionally mask out the prompt tokens to avoid training on them.
    """

    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        max_sequence_length: int = None,
    ):
        self.tokenizer = tokenizer
        self.max_sequence_length = max_sequence_length

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        prompt_input_ids, completion_input_ids = tuple(
            [instance[key] for instance in instances] for key in (PROMPT_KEY, COMPLETION_KEY)
        )

        max_input_length = max(
            [len(p) + len(c) for p, c in zip(prompt_input_ids, completion_input_ids)]
        )
        input_ids = (
            torch.ones((len(prompt_input_ids), max_input_length), dtype=prompt_input_ids[0].dtype)
            * self.tokenizer.pad_token_id
        )
        labels = torch.ones(input_ids.shape, dtype=prompt_input_ids[0].dtype) * IGNORE_INDEX
        attention_mask = torch.zeros(input_ids.shape, dtype=torch.bool)

        for i, (prompt_ids, completion_ids) in enumerate(
            zip(prompt_input_ids, completion_input_ids)
        ):
            sequence_ids = torch.concatenate([prompt_ids, completion_ids])
            if self.tokenizer.padding_side == "right":
                input_ids[i][: len(sequence_ids)] = sequence_ids
                attention_mask[i][: len(sequence_ids)] = True
                labels[i][len(prompt_ids) : len(prompt_ids) + len(completion_ids)] = completion_ids
            else:
                input_ids[i][-len(sequence_ids) :] = sequence_ids
                attention_mask[i][-len(sequence_ids) :] = True
                labels[i][-len(completion_ids) :] = completion_ids

        return dict(
            input_ids=input_ids[:, : self.max_sequence_length],
            labels=labels[:, : self.max_sequence_length],
            attention_mask=attention_mask[:, : self.max_sequence_length],
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants