-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
max token length for finetune and completion endpoints on Lllama-2? #208
Comments
BUMP! |
Hi @urimerhav, thanks for reaching out. Here are some answers to your questions:
Seems like 4096 as you may have found searching elsewhere.
Do you have some more detailed repro steps? We just tried a long input and got an exception.
We haven't yet open sourced our fine-tuning code, although we fully intend to! The issue is that unlike the rest of LLM Engine, our fine-tuning scripts still have some internal dependencies that need to be ripped out. In the meantime, I can share a code snippet that might give some visibility into the tokenization process - we currently truncate to 1024 tokens for fine-tuning because we're currently just using A10's and wanted to avoid OOM: class SFTCollator(object):
"""Collate examples for supervised fine-tuning.
We intentionally mask out the prompt tokens to avoid training on them.
"""
def __init__(
self,
tokenizer: PreTrainedTokenizer,
max_sequence_length: int = None,
):
self.tokenizer = tokenizer
self.max_sequence_length = max_sequence_length
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
prompt_input_ids, completion_input_ids = tuple(
[instance[key] for instance in instances] for key in (PROMPT_KEY, COMPLETION_KEY)
)
max_input_length = max(
[len(p) + len(c) for p, c in zip(prompt_input_ids, completion_input_ids)]
)
input_ids = (
torch.ones((len(prompt_input_ids), max_input_length), dtype=prompt_input_ids[0].dtype)
* self.tokenizer.pad_token_id
)
labels = torch.ones(input_ids.shape, dtype=prompt_input_ids[0].dtype) * IGNORE_INDEX
attention_mask = torch.zeros(input_ids.shape, dtype=torch.bool)
for i, (prompt_ids, completion_ids) in enumerate(
zip(prompt_input_ids, completion_input_ids)
):
sequence_ids = torch.concatenate([prompt_ids, completion_ids])
if self.tokenizer.padding_side == "right":
input_ids[i][: len(sequence_ids)] = sequence_ids
attention_mask[i][: len(sequence_ids)] = True
labels[i][len(prompt_ids) : len(prompt_ids) + len(completion_ids)] = completion_ids
else:
input_ids[i][-len(sequence_ids) :] = sequence_ids
attention_mask[i][-len(sequence_ids) :] = True
labels[i][-len(completion_ids) :] = completion_ids
return dict(
input_ids=input_ids[:, : self.max_sequence_length],
labels=labels[:, : self.max_sequence_length],
attention_mask=attention_mask[:, : self.max_sequence_length],
) |
Great job with this repo. I was able to finetune Llama-2 and it certainly seems to have an effect.
Unfortunately the finetune silently accepts all inputs and the documentation states that you simply truncate inputs to max length. But it's not specified anywhere what's LLama-2's max length. Originally Meta released it with a bug that caused max length to be 2048 while the native max length seems to be 4096. So which is it?
Also, I tested my finetune model's completion code with inputs as big as 12,000 tokens and it still makes a completion. So I assume you truncate there as well? Only taking the tail of the prompt, presumably?
tldr:
The text was updated successfully, but these errors were encountered: