-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More robust Online DPO changes for RL update #1664
base: main
Are you sure you want to change the base?
Conversation
…se autoSequenceClassification
Yeah I only added the lora changes because I get lora downcasting errors otherwise, we can remove it if its not correct and probably should not hard code bfloat16 in case the gpu uses float16 |
Thank you @danielhanchen will get on reviewing it :) |
Oh I think I found why there's the LoRA downcasting issue - it looks like accelerate needs the env variable Also |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some questions - otherwise great work!
@@ -1596,8 +1596,37 @@ def _fast_generate(*args, **kwargs): | |||
pass | |||
|
|||
|
|||
original_attention_forward = LlamaAttention.forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually that's very smart on keeping the old function!
lm_backbone = getattr(model, model.base_model_prefix) | ||
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) | ||
FastLanguageModel.reset_functions() | ||
output = lm_backbone( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So my main question is I didn't understand the lm_backbone
part - does this remove the lm_head
or something? Or is it because we need output.hidden_states[-1]
? Is it necessary to reset the functions?
Also after calling reset_functions
are we sure the new functions actually patch over the current model's functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok so, this function presumes you are using AutoSequenceForClassifcation
and that module inherently changes the LM head to now produce a 1x2 tensor with logits or scores of chosen and rejected responses respectively. And so if we do not reset the functions, unfortunately, the Unsloth forward functions are not compatible with AutoSequenceForClassifcation
just yet. This is why I store the old function pointers, and what reset_functions
does is that, when we first load unsloth and import it, we change the Attention forward implementations to the ones that Unsloth has. However, since they are not exactly compatible with AutoSequenceForClassifcation
you have to reset the functions to the old ones in order to initialize the correct layers in the model AND also reset them when doing a forward pass on the reward model as it is still a hugging face model. The speed is generally fast enough and if reset_functions
did not work correctly, we would get an error on forward pass or the rewards would not increase, but I sent you the wandb at some point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually tried out your code a while back (from the issue you created). If I remember correctly, I had to set model = model.eval() for the HF's outputs to match Unsloth's. I couldn't dig deeper into why though
Have you noticed something like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks so much for this comment! I'm trying to get RLOO trainer to get decent generations but it for some reason was not. This might help me actually fix that. Oddly enough though, I did not have that particular issue in this code. Hm, I am not exactly sure why that would do the trick there but I do change Unsloth from training to inference mode often in the code and this may just enable the right settings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what I observed with HF (on your us-open-instruct) with the arguments you had in your code.
Adding model.eval to HF's code stablised the output (its weird that it keeps generating same token)
I think your implementation was fine (the unsloth one). But with model=model.eval somewhere around here.
Also you might want to check if generation_config is appropriate. I remember there being something with that but don't fully recollect
Off topic but on your repo, the batch_generation_vllm
might have a bug. Might want to expand the attention mask. With these changes I observed unsloth and HF's losses on the tldr dataset only differ by 1e-2 magnitude
Which is fine ig...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah so my repo does not actually call batch_generation_vllm
I am trying to get that integrated for RLOO and PPO at the moment, but not the VLLM version. I am going to try the model.eval()
thing to see if it works, otherwise I will step through the Online DPO code here, that does work and figure out how to patch batch generation properly as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, lmk if you need help with anything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I currently am trying to recover from being sick but I will try to get back to this as soon as possible and as soon as I catch up with my school work.
|
||
from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat | ||
else: | ||
IS_SAGEMAKER_MP_POST_1_10 = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these primarily for Amazon Sagemaker?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just for Amazon Sagemaker, I wanted to keep the old functionality of the function itself, sadly it is very deep into transformers and transformers.trainer
this is exclusively in the saving steps. This function had to be overwritten since it does evaluation outside of the trainer's code directly. I think we could theoretically remove this if you want to, I just did not want to remove this from the function if we did not have to. I think there is a lot of SMP forward passes in the function hence why I decided to keep it and tried to make it compatible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/unslothai/unsloth/blob/main/unsloth/models/llama.py#L1956-L1961
Wasn't unsloth already setting this to False? Or were we missing something so far?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So actually yes, they do that, its more of to patch the trainer function itself right, so when trainer in huggingface wants to perform saving at like 500 steps right, it will call this function and if we do not write the Amazon sagemaker stuff here it will error out due to dependency issues and I wanted to keep most of the features of the function itself intact.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah cool. So no more force disabling sagemaker after these changes I presume
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So sagemaker by default will be disabled, its just that this function needs to check if sagemaker is there or not in order to also bring some of the other dependencies with it that are needed in the function. I did not really want to touch too much of the internal transformers stuff except for what unsloth needed to patch to get working.
I wanted to get this reviewed so I think atleast the pelimniary framework for Online DPO with the LLama model examples I have actually work officially with the RL update. I will work towards the other RL trainers and getting full model support.