Skip to content

Layerwise offload is not compatible with torch.compile #1012

@DefTruth

Description

@DefTruth
 python3 -m cache_dit.generate qwen_image_edit_2511_lightning \
  --layerwise-offload \
  --layerwise-async-transfer \
  --layerwise-transfer-buckets 4 \
  --layerwise-persistent-buckets 32 \
  --layerwise-persistent-bins 4 \
  --layerwise-max-inflight-prefetch-bytes 8gib \
  --layerwise-text-transfer-buckets 1 \
  --layerwise-text-persistent-buckets 8 \
  --layerwise-text-persistent-bins 1 \
  --layerwise-text-max-inflight-prefetch-bytes 4gib --compile

(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 1669, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 810, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 489, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1408, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 5241, in inline_call
    return tracer.inline_call_()
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 5462, in inline_call_
    self.run()
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1813, in run
    while self.step():
          ^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1480, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1017, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 4171, in CALL
    self._call(inst)
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 4162, in _call
    self.call_function(fn, args, kwargs)
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1381, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 294, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 810, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 489, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1408, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 5241, in inline_call
    return tracer.inline_call_()
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 5462, in inline_call_
    self.run()
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1813, in run
    while self.step():
          ^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1480, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1017, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2929, in STORE_ATTR
    VariableTracker.build(self, setattr).call_function(
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/variables/builtin.py", line 1497, in call_function
    return handler(tx, args, kwargs)  # type: ignore[return-value]
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/variables/builtin.py", line 1261, in builtin_dispatch
    rv = handler(tx, args, kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/variables/builtin.py", line 1135, in call_self_handler
    return self_handler(tx, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/variables/builtin.py", line 2688, in call_setattr
    out = wrap_fx_proxy(
          ^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 3090, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 3165, in wrap_fx_proxy_cls
    out: VTTypeAlias = _wrap_fx_proxy(
                       ^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 3289, in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 3751, in get_fake_value
    return _get_fake_value_impl(node, tx, allow_non_graph_fake)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 3942, in _get_fake_value_impl
    _wrap_graph_break_with_torch_runtime_err(
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 3740, in _wrap_graph_break_with_torch_runtime_err
    raise exc.with_traceback(e.__traceback__) from None
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 3737, in _wrap_graph_break_with_torch_runtime_err
    gb_fn()
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 3943, in <lambda>
    lambda: unimplemented(
            ^^^^^^^^^^^^^^
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/_dynamo/exc.py", line 653, in unimplemented
    raise Unsupported(
torch._dynamo.exc.TorchRuntimeError: RuntimeError when making fake tensor call
  Explanation: Dynamo failed to run FX node with fake tensors: call_function <method 'set_' of 'torch._C.TensorBase' objects>(*(Parameter(FakeTensor(..., size=(12288, 3072), dtype=torch.bfloat16)), FakeTensor(..., device='cuda:0', size=(12288, 3072), dtype=torch.bfloat16)), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.set_.source_Tensor, found two different devices cpu, cuda:0')
  Hint: Your code may result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled. You can do this by removing the `torch.compile` call, or by using `torch.compiler.set_stance("force_eager")`.

  Developer debug context:

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb4315.html

from user code:
   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 676, in forward
    txt_mod_params = self.txt_mod(temb)  # [B, 6*dim]
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/nn/modules/container.py", line 253, in forward
    input = module(input)
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1881, in _call_impl
    return inner()
  File "/workspace/dev/miniconda3/envs/cdit/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1811, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/workspace/dev/vipshop/cache-dit/src/cache_dit/offload/layerwise.py", line 1515, in pre_hook
    handle._prefetch_bucket_targets(target)
  File "/workspace/dev/vipshop/cache-dit/src/cache_dit/offload/layerwise.py", line 1083, in _prefetch_bucket_targets
    self._schedule_target_onload(next_target, allow_wait=False)
  File "/workspace/dev/vipshop/cache-dit/src/cache_dit/offload/layerwise.py", line 955, in _schedule_target_onload
    _assign_direct_state_tensor(
  File "/workspace/dev/vipshop/cache-dit/src/cache_dit/offload/layerwise.py", line 407, in _assign_direct_state_tensor
    parameter.data = tensor

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions