Skip to content

Conversation

@3outeille
Copy link
Contributor

Fixing the bug huggingface#6

TODO: need to apply change in transformers V5. That requires to wait for V5 to be a bit stable before switch torchtitan transformers modeling backend to v5 (as for now, it relies on 4.57.1)

Issue

[rank3]:/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/env_torchtitan_official/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:321: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
[rank3]:  warnings.warn(
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] torch._dynamo hit config.recompile_limit (8)
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8]    function: 'forward' (/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/env_torchtitan_official/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py:145)
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8]    last reason: 0/7: ___dict_contains(148, self._modules['_checkpoint_wrapped_module']._modules['self_attn']._forward_pre_hooks_with_kwargs)  # if hook_id in self._forward_pre_hooks_with_kwargs:  # nn/modules/module.py:1815 in inner
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html
[rank3]:[rank3]: Traceback (most recent call last):

Fix

  • Apply + current PR changes + transformers at modeling_llama.py, change
       hidden_states, _ = self.self_attn(
-		   hidden_states=hidden_states,
+          hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
  • ./tooling_dev/debug_local.sh debugperf_large --compile
image

Explanation

  • When torch.compile traces your model, it creates a compiled graph along with guards. Guards are conditions that must be true for that graph to be reused. If guard fails, torch.compile will recompiles.
  • in modeling_llama.py, the self.attn(hidden_states=hidden_states) is called with kwargs
  • In torchtitan, if you apply TP, it will apply register_forward_pre_hook . However, depending on if you use kwargs or not, it will call different function (cf https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L576).
    • In our case, it will call module.register_forward_pre_hook(lambda _, inputs, kwargs: some_fn(inputs, kwargs), with_kwargs=True
  • but calling this function is problematic as it will trigger if hook_id in self._forward_pre_hooks_with_kwargs: (cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1808)
    • This means that using kwargs will results in different hook_id , hence the error ___dict_contains(148, self._modules['_checkpoint_wrapped_module']._modules['self_attn']._forward_pre_hooks_with_kwargs)
  • When we don't usekwargs, self._forward_pre_hooks_with_kwargs will always be empty (cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1679C13-L1679C48) so the if check is not triggered, so each attention layer has same hook_id, thus no recompile

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 9, 2025


flavors = {
"debugperf": HFTransformerModelArgs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between debugperf / debugperf_large and debugmodel? Can we just keep one of them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to ship this folder to fix the issue? It's about 2k LoC complexity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 , I think we could remove these test scripts to keep code simple

Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for finding this! to check my understanding, the bug is:

the function call with kwargs will return new object id for the hook -> causing recompile

Is this correct?



flavors = {
"debugperf": HFTransformerModelArgs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove these 2 test models?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 , I think we could remove these test scripts to keep code simple



llama3_args = {
"debugperf": TransformerModelArgs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, could we remove these 2 models here?

class HFTransformers:
model: str = ""
"""HuggingFace model ID (e.g., 'Qwen/Qwen3-4B-Instruct-2507')"""
tie_word_embeddings: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting tie_word_embeddings into job config is a little bit confusing, and seems not related to this error?

IIUC this is a field is decided by model architecture, and not decided by each training run. So previously we put Qwen3's weight tying config into model_args:

enable_weight_tying: bool = False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants