Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_interpret_call could reuse names in a trace when it's called in the lookaside of torch.autograd.Function #1776

Open
crcrpar opened this issue Feb 18, 2025 · 1 comment · May be fixed by #1777
Assignees
Labels

Comments

@crcrpar
Copy link
Collaborator

crcrpar commented Feb 18, 2025

🐛 Bug

As per title,

def _interpret_call(fn: Callable, /, *args, **kwargs) -> Any | INTERPRETER_SIGNALS:
compilectx: InterpreterCompileCtx = get_interpretercompilectx()
runtimectx: InterpreterRuntimeCtx = get_interpreterruntimectx()
# TODO: Implement generics and fix WrappedValue[T] everywhere.
runtimectx.record_interpreter_call(fn)
rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs) # type: ignore
if compilectx._with_provenance_tracking:
assert isinstance(rval, (INTERPRETER_SIGNALS, WrappedValue)), f"return {rval} unexpected calling {unwrap(fn)}"
runtimectx.record_interpreter_return(fn, rval) # type: ignore
return rval
could reuse proxy names in some cases if it's called inside
def _convert_pytorchfunc_to_thundertrace(
func: Callable[[Any], Any],
shallow_copy_output: bool,
*args,
**kwargs,
) -> tuple[TraceCtx | INTERPRETER_SIGNALS, ProvenanceRecord | None]:
"""Converts pytorch function to thunder trace.
Note that the generated trace would not have _siginfo and args set.
Args:
func: A callable composed of pytorch functions.
shallow_copy_output: Needs to be :obj:`True` only if func is `torch.autograd.Function.apply` as
it produces views of the tensor to attach the autograd node to.
*args:
**kwargs
"""
from thunder.core.baseutils import sequencify
active_jit_ctx: JitCtx = get_jit_ctx()
active_jit_ctx.computation_trace.push_scope([])
wrapped_func_result = _interpret_call(func, *args, **kwargs)
if wrapped_func_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return wrapped_func_result, None
trace = TraceCtx()
bsyms = active_jit_ctx.computation_trace.pop_scope()
trace.bound_symbols.extend(bsyms)
func_result = unwrap(wrapped_func_result)
if shallow_copy_output and not bsyms:
out_to_shallow_copy: dict[Variable, TensorProxy] = {}
for a in sequencify(func_result):
shallow_copy_of_a = prims.shallow_copy.meta(a)
bsym = prims.shallow_copy.bind(a, output=shallow_copy_of_a)
trace.add_bound_symbol(bsym)
out_to_shallow_copy[variableify(a)] = shallow_copy_of_a
func_result = tree_map(lambda t: out_to_shallow_copy.get(variableify(t), t), func_result)
with tracectx(trace):
prims.python_return(func_result)
return trace, sequencify(wrapped_func_result)[0].provenance
.

I use lightning-thunder @ 7c16a1a

To Reproduce

Code sample

import torch
import thunder


class Func(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        tensor: torch.Tensor,
        scale: torch.Tensor,
    ):
        tensor_scaled = tensor.to(torch.float32) * scale
        return tensor_scaled

    @staticmethod
    def backward(ctx, g):
        return g, None


# not working
@thunder.jit
def f(tensor: torch.Tensor, scale: torch.Tensor):
    return Func.apply(tensor, scale)


@thunder.jit
def g(tensor: torch.Tensor, scale: torch.Tensor):
    tensor_scaled = tensor.to(torch.float32) * scale
    return tensor_scaled


if __name__ == "__main__":
    with torch.device("cuda"):
        t = torch.randn((4, 4))
        s = torch.tensor(0.1)
    print("Call `g` which is free from a custom `torch.autograd.Function`")
    g(t, s)
    print("Call `f` which is dependent on a custom `torch.autograd.Function`")
    f(t, s)

Error

NOTE: If t is created by torch.randn((4, 4), dtype=torch.bfloat16), then this error does not happen. Thus it seems to be related to that ltorch.to returning the input as is with certain conditions met.

Call `f` which is dependent on a custom `torch.autograd.Function`
Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 7222, in fn_
    interpretation_result: Any = _interpret_call(wrapped_fn_2, args, kwargs)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6465, in _interpret_call
    rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6683, in _call_dispatch
    return _setup_frame_and_run_python_function(compilectx, runtimectx, wrapped_fn, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 7208, in fn_2
    return fn(*args, **kwargs)

  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6886, in _setup_frame_and_run_python_function
    res, status = _run_frame(frame, compilectx, runtimectx)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6936, in _run_frame
    interpretation_result: None | int | INTERPRETER_SIGNALS = compilectx.interpret(
                                                              ^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 411, in interpret
    return self._opcode_interpreter(inst, **interpreter_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 1252, in default_opcode_interpreter
    return handler(inst, **interpreter_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 3781, in _call_function_ex_handler
    return check_and_append(stack, _interpret_call(func, *args, **kwargs))
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6465, in _interpret_call
    rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6683, in _call_dispatch
    return _setup_frame_and_run_python_function(compilectx, runtimectx, wrapped_fn, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/a.py", line 24, in f
    return Func.apply(tensor, scale)
^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6886, in _setup_frame_and_run_python_function
    res, status = _run_frame(frame, compilectx, runtimectx)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6936, in _run_frame
    interpretation_result: None | int | INTERPRETER_SIGNALS = compilectx.interpret(
                                                              ^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 411, in interpret
    return self._opcode_interpreter(inst, **interpreter_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 1252, in default_opcode_interpreter
    return handler(inst, **interpreter_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 3741, in _call_handler
    res = _interpret_call(func, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6465, in _interpret_call
    rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6626, in _call_dispatch
    res = lookaside_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/jit_ext.py", line 776, in _general_jit_torch_autograd_function_apply_lookaside
    unwrapped_forward_result = custom_fwd_sym(*unwrapped_custom_forward_args)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/symbol.py", line 323, in __call__
    result = self.meta(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/jit_ext.py", line 770, in core_of_forward
    return thunder.core.trace_interpreter.interpret_trace(trace_of_fwd, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/trace_interpreter.py", line 71, in interpret_trace
    safe_map_flat(write, list(sequencify(symbol.output)), list(sequencify(result)))
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/utils.py", line 884, in safe_map_flat
    out_flat = list(map(f, *[a for a, _ in args_flat_spec]))
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/trace_interpreter.py", line 55, in write
    raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
ValueError: Variable t_0 is being overwritten this is not allowed

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/a.py", line 40, in <module>
    f(t, s)
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/__init__.py", line 743, in wrapped
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/__init__.py", line 779, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/__init__.py", line 725, in wrapped
    cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/__init__.py", line 237, in cache_info_wrapper
    res = fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/__init__.py", line 528, in get_computation_and_inputs
    jit_results: TraceResults = thunder_general_jit(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/jit_ext.py", line 2055, in thunder_general_jit
    result = jfn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 7235, in fn_
    raise InterpreterError(msg) from e
thunder.core.interpreter.InterpreterError: Encountered exception ValueError: Variable t_0 is being overwritten this is not allowed while tracing <function f at 0x744e8ebd1f80>:

Trace of g

def computation(tensor, scale):
  # tensor: "cuda:0 f32[4, 4]"
  # scale: "cuda:0 f32[]"

  # /home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/a.py:29:         tensor_scaled = tensor.to(torch.float32) * scale
  t3 = torch.mul(tensor, scale)  # t3: "cuda:0 f32[4, 4]"
    # t3 = ltorch.mul(tensor, scale)  # t3: "cuda:0 f32[4, 4]"
      # t2 = prims.broadcast_in_dim(scale, (4, 4), ())  # t2: "cuda:0 f32[4, 4]"
      # t3 = prims.mul(tensor, t2)  # t3: "cuda:0 f32[4, 4]"
  return (t3,)

Expected behavior

It shouldn't reuse proxy names.

Initial Attempts to debug

1. check proxy names

With the following diff, I got a bit friendlier message:

diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index 42b932df..8384e604 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -700,7 +700,7 @@ def _convert_pytorchfunc_to_thundertrace(
         *args:
         **kwargs
     """
-    from thunder.core.baseutils import sequencify
+    from thunder.core.baseutils import check, sequencify

     active_jit_ctx: JitCtx = get_jit_ctx()
     active_jit_ctx.computation_trace.push_scope([])
@@ -722,6 +722,12 @@ def _convert_pytorchfunc_to_thundertrace(
         func_result = tree_map(lambda t: out_to_shallow_copy.get(variableify(t), t), func_result)
     with tracectx(trace):
         prims.python_return(func_result)
+
+    for bsym in bsyms:
+        output_name = set(variableify(a) for a in bsym.flat_proxy_outs)
+        args_name = set(variableify(a) for a in bsym.flat_proxy_args)
+        name_dup = output_name & args_name
+        check(not name_dup, lambda: f"{output_name = } reuses {name_dup} of arg names of {args_name}, seen in the following trace\n{trace}\n")
     return trace, sequencify(wrapped_func_result)[0].provenance

The message I get with the diff above:

Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 7222, in fn_
    interpretation_result: Any = _interpret_call(wrapped_fn_2, args, kwargs)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6465, in _interpret_call
    rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6683, in _call_dispatch
    return _setup_frame_and_run_python_function(compilectx, runtimectx, wrapped_fn, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 7208, in fn_2
    return fn(*args, **kwargs)

  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6886, in _setup_frame_and_run_python_function
    res, status = _run_frame(frame, compilectx, runtimectx)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6936, in _run_frame
    interpretation_result: None | int | INTERPRETER_SIGNALS = compilectx.interpret(
                                                              ^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 411, in interpret
    return self._opcode_interpreter(inst, **interpreter_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 1252, in default_opcode_interpreter
    return handler(inst, **interpreter_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 3781, in _call_function_ex_handler
    return check_and_append(stack, _interpret_call(func, *args, **kwargs))
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6465, in _interpret_call
    rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6683, in _call_dispatch
    return _setup_frame_and_run_python_function(compilectx, runtimectx, wrapped_fn, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/a.py", line 24, in f
    return Func.apply(tensor, scale)
^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6886, in _setup_frame_and_run_python_function
    res, status = _run_frame(frame, compilectx, runtimectx)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6936, in _run_frame
    interpretation_result: None | int | INTERPRETER_SIGNALS = compilectx.interpret(
                                                              ^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 411, in interpret
    return self._opcode_interpreter(inst, **interpreter_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 1252, in default_opcode_interpreter
    return handler(inst, **interpreter_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 3741, in _call_handler
    res = _interpret_call(func, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6465, in _interpret_call
    rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 6626, in _call_dispatch
    res = lookaside_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/jit_ext.py", line 755, in _general_jit_torch_autograd_function_apply_lookaside
    trace_of_fwd, fwd_output_provenance = _convert_pytorchfunc_to_thundertrace(
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/jit_ext.py", line 730, in _convert_pytorchfunc_to_thundertrace
    check(not name_dup, lambda: f"{output_name = } reuses {name_dup} of arg names of {args_name}, seen in the following trace\n{trace}\n")
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/baseutils.py", line 146, in check
    raise exception_type(s())
RuntimeError: output_name = {<TensorProxy(name="t_0", dtype=thunder.dtypes.float32, shape=(4, 4))>} reuses {<TensorProxy(name="t_0", dtype=thunder.dtypes.float32, shape=(4, 4))>} of arg names of {<TensorProxy(name="t_0", dtype=thunder.dtypes.float32, shape=(4, 4))>}, seen in the following trace
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
# No signature available
  # /home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/a.py:13:             tensor_scaled = tensor.to(torch.float32) * scale
  t_0 = ltorch.to(t_0, torch.float32, None, device=None, dtype=None, copy=False, memory_format=None)  # t_0: "cuda:0 f32[4, 4]"
  t2 = ltorch.mul(t_0, t_1)  # t2: "cuda:0 f32[4, 4]"
    # t1 = prims.broadcast_in_dim(t_1, (4, 4), ())  # t1: "cuda:0 f32[4, 4]"
    # t2 = prims.mul(t_0, t1)  # t2: "cuda:0 f32[4, 4]"
  return t2


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/a.py", line 40, in <module>
    f(t, s)
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/__init__.py", line 743, in wrapped
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/__init__.py", line 779, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/__init__.py", line 725, in wrapped
    cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/__init__.py", line 237, in cache_info_wrapper
    res = fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/__init__.py", line 528, in get_computation_and_inputs
    jit_results: TraceResults = thunder_general_jit(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/jit_ext.py", line 2061, in thunder_general_jit
    result = jfn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/core/interpreter.py", line 7235, in fn_
    raise InterpreterError(msg) from e
thunder.core.interpreter.InterpreterError: Encountered exception RuntimeError: output_name = {<TensorProxy(name="t_0", dtype=thunder.dtypes.float32, shape=(4, 4))>} reuses {<TensorProxy(name="t_0", dtype=thunder.dtypes.float32, shape=(4, 4))>} of arg names of {<TensorProxy(name="t_0", dtype=thunder.dtypes.float32, shape=(4, 4))>}, seen in the following trace
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
# No signature available
  # /home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/a.py:13:             tensor_scaled = tensor.to(torch.float32) * scale
  t_0 = ltorch.to(t_0, torch.float32, None, device=None, dtype=None, copy=False, memory_format=None)  # t_0: "cuda:0 f32[4, 4]"
  t2 = ltorch.mul(t_0, t_1)  # t2: "cuda:0 f32[4, 4]"
    # t1 = prims.broadcast_in_dim(t_1, (4, 4), ())  # t1: "cuda:0 f32[4, 4]"
    # t2 = prims.mul(t_0, t1)  # t2: "cuda:0 f32[4, 4]"
  return t2
 while tracing <function f at 0x766a2ebd1f80>:

Assign a variable to return of tensor.to(...)

I tried the following Func.forward only to see the same error.

class Func(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        tensor: torch.Tensor,
        scale: torch.Tensor,
    ):
        cast_tensor = tensor.to(torch.float32)
        scaled_tensor = cast_tensor * scale
        return tensor_scaled

Environment

  • PyTorch Version (e.g., 1.0): 2.7.0a0+git71855a1
  • OS (e.g., Linux): Ubuntu 22.04
  • How you installed PyTorch (conda, pip, source): source
  • Build command you used (if compiling from source): MAX_JOBS=16 BUILD_TEST=0 USE_FLASH_ATTENTION=0 USE_MKLDNN=0 USE_SYSTEM_NCCL=1 NCCL_ROOT=/usr/local python setup.py develop --cmake
  • Python version: 3.11.7
  • CUDA/cuDNN version: 12.8 / 9.7.0
  • GPU models and configuration: RTX 6000 Ada
  • Any other relevant information:

Additional context

@crcrpar crcrpar changed the title _interpret_call called in a lookaside of torch.autograd.Function could reuse names in a trace _interpret_call could reuse names in a trace when it's called in the lookaside of torch.autograd.Function Feb 18, 2025
@crcrpar
Copy link
Collaborator Author

crcrpar commented Feb 19, 2025

Naive alternative of 1777 is

diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index 42b932df..9ebe98bb 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -1,6 +1,6 @@
 from __future__ import annotations
 import math
-from typing import Any
+from typing import Any, TYPE_CHECKING
 import collections
 from collections.abc import Callable, Sequence
 import dataclasses
@@ -75,6 +75,9 @@ from thunder.torch import _torch_to_thunder_function_map
 from thunder.clang import _clang_fn_set
 from thunder.core.pytree import tree_map, tree_iter

+if TYPE_CHECKING:
+    from thunder.core.symbol import BoundSymbol
+
 #
 # jit_ext.py implements extensions of thunder's interpreter
 #
@@ -722,6 +725,19 @@ def _convert_pytorchfunc_to_thundertrace(
         func_result = tree_map(lambda t: out_to_shallow_copy.get(variableify(t), t), func_result)
     with tracectx(trace):
         prims.python_return(func_result)
+
+    bsym: BoundSymbol
+    new_bsyms: list[BoundSymbol] = []
+    for bsym in trace.bound_symbols:
+        should_skip = False
+        for o in filter(lambda o: isinstance(o, TensorProxy), bsym.flat_proxy_outs):
+            var_o = variableify(o)
+            for var_a in [variableify(a) for a in bsym.flat_proxy_args if isinstance(a, TensorProxy)]:
+                if var_o == var_a:
+                    should_skip = True
+        if not should_skip:
+            new_bsyms.append(bsym)
+    trace.bound_symbols = new_bsyms
     return trace, sequencify(wrapped_func_result)[0].provenance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants