Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightGlue-full-compile won't work with torch '2.2.1+cu121' #123

Closed
mnauf opened this issue Apr 16, 2024 · 3 comments · Fixed by #124
Closed

LightGlue-full-compile won't work with torch '2.2.1+cu121' #123

mnauf opened this issue Apr 16, 2024 · 3 comments · Fixed by #124

Comments

@mnauf
Copy link

mnauf commented Apr 16, 2024

I have torch >=2, yet I can't run lightglue_full_compile. Logs are attached. Please help.

Environment:

ubuntu 20.04.1
python 3.9.18
torch '2.2.1+cu121'
gpu NVIDIA GeForce RTX 4060 Ti

Logs:

/home/mnauf/anaconda3/envs/lightglue2/bin/python /home/mnauf/Desktop/lightglue/benchmark.py --device cuda --num_keypoints 512 1024 2048 4096 --compile 
Running benchmark on device: cuda
Run benchmark for: LightGlue-full
Run benchmark for: LightGlue-adaptive
Run benchmark for: LightGlue-full-compile
/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_inductor/compile_fx.py:140: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
Traceback (most recent call last):
  File "/home/mnauf/Desktop/lightglue/benchmark.py", line 195, in <module>
    runtime = measure(
  File "/home/mnauf/Desktop/lightglue/benchmark.py", line 25, in measure
    _ = matcher(data)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mnauf/Desktop/lightglue/lightglue/lightglue.py", line 473, in forward
    return self._forward(data)
  File "/home/mnauf/Desktop/lightglue/lightglue/lightglue.py", line 533, in _forward
    desc0, desc1 = self.transformers[i](
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mnauf/Desktop/lightglue/lightglue/lightglue.py", line 242, in forward
    return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/mnauf/Desktop/lightglue/lightglue/lightglue.py", line 249, in masked_forward
    def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 901, in forward
    return compiled_fn(full_args)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81, in g
    return f(*args)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 94, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 864, in __call__
    return self.get_current_callable()(inputs)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 665, in run
    return compiled_fn(new_inputs)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 380, in deferred_cudagraphify
    fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 408, in cudagraphify
    return manager.add_function(
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 1941, in add_function
    return fn, fn(inputs)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 1755, in run
    out = self._run(new_inputs, function_id)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 1796, in _run
    return self.run_eager(new_inputs, function_id)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 1911, in run_eager
    return node.run(new_inputs)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 600, in run
    non_cudagraph_inps = get_non_cudagraph_inps()
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 595, in get_non_cudagraph_inps
    and t.untyped_storage().data_ptr() not in existing_path_data_ptrs
RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/home/mnauf/Desktop/lightglue/lightglue/lightglue.py", line 255, in masked_forward
    return self.cross_attn(desc0, desc1, mask)
  File "/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mnauf/Desktop/lightglue/lightglue/lightglue.py", line 221, in forward
    x0 = x0 + self.ffn(torch.cat([x0, m0], -1)). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.
@Phil26AT
Copy link
Collaborator

Hi @mnauf, thank you for pointing out this issue. I pushed a fix to PR #124. Please let us know if this solves your problem.

@mnauf
Copy link
Author

mnauf commented Apr 16, 2024

Thanks @Phil26AT, it works now! Are the following warnings expected?

/home/mnauf/anaconda3/envs/lightglue2/lib/python3.9/site-packages/torch/_inductor/compile_fx.py:140: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
Run benchmark for: LightGlue-adaptive-compile
/home/mnauf/Desktop/lightglue/benchmark.py:187: UserWarning: Point pruning is partially disabled for compiled forward.
  matcher.compile()

Phil26AT added a commit that referenced this issue Apr 16, 2024
Mark cudagraph begin in compilation.
Fixes #123
@Phil26AT
Copy link
Collaborator

Great! Yes these warning are expected. The first one is from torch. You could change the matmul precision to "high", but this sometimes results in wrong estimates. For safety reasons we do not enable it by default.

The second warning is about point pruning, which is not compatible with the padding and masking used in compiled forward. Newer torch versions support dynamic shapes, which would be compatible with point pruning, but we have not tested this with LightGlue yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants