-
Notifications
You must be signed in to change notification settings - Fork 631
Fix torch.compile recompilation issue with HF modeling + TP
#2130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…e issue when combined with TP
|
|
||
|
|
||
| flavors = { | ||
| "debugperf": HFTransformerModelArgs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference between debugperf / debugperf_large and debugmodel? Can we just keep one of them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to ship this folder to fix the issue? It's about 2k LoC complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 , I think we could remove these test scripts to keep code simple
wwwjn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for finding this! to check my understanding, the bug is:
the function call with kwargs will return new object id for the hook -> causing recompile
Is this correct?
|
|
||
|
|
||
| flavors = { | ||
| "debugperf": HFTransformerModelArgs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove these 2 test models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 , I think we could remove these test scripts to keep code simple
|
|
||
|
|
||
| llama3_args = { | ||
| "debugperf": TransformerModelArgs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, could we remove these 2 models here?
| class HFTransformers: | ||
| model: str = "" | ||
| """HuggingFace model ID (e.g., 'Qwen/Qwen3-4B-Instruct-2507')""" | ||
| tie_word_embeddings: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Putting tie_word_embeddings into job config is a little bit confusing, and seems not related to this error?
IIUC this is a field is decided by model architecture, and not decided by each training run. So previously we put Qwen3's weight tying config into model_args:
| enable_weight_tying: bool = False |
Fixing the bug huggingface#6
TODO: need to apply change in
transformers V5. That requires to wait for V5 to be a bit stable before switch torchtitan transformers modeling backend to v5 (as for now, it relies on 4.57.1)Issue
Fix
transformersatmodeling_llama.py, change./tooling_dev/debug_local.sh debugperf_large --compileExplanation
torch.compiletraces your model, it creates a compiled graph along with guards. Guards are conditions that must be true for that graph to be reused. If guard fails,torch.compilewill recompiles.modeling_llama.py, theself.attn(hidden_states=hidden_states)is called withkwargsregister_forward_pre_hook. However, depending on if you usekwargsor not, it will call different function (cf https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L576).module.register_forward_pre_hook(lambda _, inputs, kwargs: some_fn(inputs, kwargs), with_kwargs=Trueif hook_id in self._forward_pre_hooks_with_kwargs:(cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1808)kwargswill results in differenthook_id, hence the error___dict_contains(148, self._modules['_checkpoint_wrapped_module']._modules['self_attn']._forward_pre_hooks_with_kwargs)kwargs,self._forward_pre_hooks_with_kwargswill always be empty (cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1679C13-L1679C48) so the if check is not triggered, so each attention layer has samehook_id, thus no recompile