Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More robust Online DPO changes for RL update #1664

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

pluesclues
Copy link

I wanted to get this reviewed so I think atleast the pelimniary framework for Online DPO with the LLama model examples I have actually work officially with the RL update. I will work towards the other RL trainers and getting full model support.

@pluesclues
Copy link
Author

Yeah I only added the lora changes because I get lora downcasting errors otherwise, we can remove it if its not correct and probably should not hard code bfloat16 in case the gpu uses float16

@shimmyshimmer
Copy link
Collaborator

Thank you @danielhanchen will get on reviewing it :)

@danielhanchen
Copy link
Contributor

Oh I think I found why there's the LoRA downcasting issue - it looks like accelerate needs the env variable ACCELERATE_MIXED_PRECISION set to either fp16 or bf16 to make it use mixed precision.

Also unwrap_model from accelerate will remove all mixed precision hooks - hence the issue

Copy link
Contributor

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some questions - otherwise great work!

@@ -1596,8 +1596,37 @@ def _fast_generate(*args, **kwargs):
pass


original_attention_forward = LlamaAttention.forward
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually that's very smart on keeping the old function!

lm_backbone = getattr(model, model.base_model_prefix)
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
FastLanguageModel.reset_functions()
output = lm_backbone(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my main question is I didn't understand the lm_backbone part - does this remove the lm_head or something? Or is it because we need output.hidden_states[-1]? Is it necessary to reset the functions?

Also after calling reset_functions are we sure the new functions actually patch over the current model's functions?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so, this function presumes you are using AutoSequenceForClassifcation and that module inherently changes the LM head to now produce a 1x2 tensor with logits or scores of chosen and rejected responses respectively. And so if we do not reset the functions, unfortunately, the Unsloth forward functions are not compatible with AutoSequenceForClassifcation just yet. This is why I store the old function pointers, and what reset_functions does is that, when we first load unsloth and import it, we change the Attention forward implementations to the ones that Unsloth has. However, since they are not exactly compatible with AutoSequenceForClassifcation you have to reset the functions to the old ones in order to initialize the correct layers in the model AND also reset them when doing a forward pass on the reward model as it is still a hugging face model. The speed is generally fast enough and if reset_functions did not work correctly, we would get an error on forward pass or the rewards would not increase, but I sent you the wandb at some point.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried out your code a while back (from the issue you created). If I remember correctly, I had to set model = model.eval() for the HF's outputs to match Unsloth's. I couldn't dig deeper into why though
Have you noticed something like that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks so much for this comment! I'm trying to get RLOO trainer to get decent generations but it for some reason was not. This might help me actually fix that. Oddly enough though, I did not have that particular issue in this code. Hm, I am not exactly sure why that would do the trick there but I do change Unsloth from training to inference mode often in the code and this may just enable the right settings.

Copy link
Contributor

@Datta0 Datta0 Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
This is what I observed with HF (on your us-open-instruct) with the arguments you had in your code.
Adding model.eval to HF's code stablised the output (its weird that it keeps generating same token)
I think your implementation was fine (the unsloth one). But with model=model.eval somewhere around here.
Also you might want to check if generation_config is appropriate. I remember there being something with that but don't fully recollect

Off topic but on your repo, the batch_generation_vllm might have a bug. Might want to expand the attention mask. With these changes I observed unsloth and HF's losses on the tldr dataset only differ by 1e-2 magnitude
Which is fine ig...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah so my repo does not actually call batch_generation_vllm I am trying to get that integrated for RLOO and PPO at the moment, but not the VLLM version. I am going to try the model.eval() thing to see if it works, otherwise I will step through the Online DPO code here, that does work and figure out how to patch batch generation properly as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, lmk if you need help with anything

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I currently am trying to recover from being sick but I will try to get back to this as soon as possible and as soon as I catch up with my school work.


from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
else:
IS_SAGEMAKER_MP_POST_1_10 = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these primarily for Amazon Sagemaker?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for Amazon Sagemaker, I wanted to keep the old functionality of the function itself, sadly it is very deep into transformers and transformers.trainer this is exclusively in the saving steps. This function had to be overwritten since it does evaluation outside of the trainer's code directly. I think we could theoretically remove this if you want to, I just did not want to remove this from the function if we did not have to. I think there is a lot of SMP forward passes in the function hence why I decided to keep it and tried to make it compatible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/unslothai/unsloth/blob/main/unsloth/models/llama.py#L1956-L1961
Wasn't unsloth already setting this to False? Or were we missing something so far?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So actually yes, they do that, its more of to patch the trainer function itself right, so when trainer in huggingface wants to perform saving at like 500 steps right, it will call this function and if we do not write the Amazon sagemaker stuff here it will error out due to dependency issues and I wanted to keep most of the features of the function itself intact.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah cool. So no more force disabling sagemaker after these changes I presume

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So sagemaker by default will be disabled, its just that this function needs to check if sagemaker is there or not in order to also bring some of the other dependencies with it that are needed in the function. I did not really want to touch too much of the internal transformers stuff except for what unsloth needed to patch to get working.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants