Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f613a2d

Browse files
committedFeb 7, 2025··
migrate some updates to compile() from numba-cuda
1 parent 82bd416 commit f613a2d

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed
 

‎numbast/src/numbast/numba_patch.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
# added.
3737

3838

39-
def nvrtc_compile(src, name, cc):
39+
def nvrtc_compile(src, name, cc, ltoir=False):
4040
"""
4141
Compile a CUDA C/C++ source to PTX for a given compute capability.
4242
@@ -68,6 +68,9 @@ def nvrtc_compile(src, name, cc):
6868
options = [arch, *extra_include_paths, include, numba_include, "-rdc", "true"]
6969
options += extra_options
7070

71+
if ltoir:
72+
options += ["-dlto"]
73+
7174
# Compile the program
7275
compile_error = nvrtc.compile_program(program, options)
7376

@@ -84,8 +87,12 @@ def nvrtc_compile(src, name, cc):
8487
msg = f"NVRTC log messages whilst compiling {name}:\n\n{log}"
8588
warnings.warn(msg)
8689

87-
ptx = nvrtc.get_ptx(program)
88-
return ptx, log
90+
if ltoir:
91+
lto = nvrtc.get_lto(program)
92+
return lto, log
93+
else:
94+
ptx = nvrtc.get_ptx(program)
95+
return ptx, log
8996

9097

9198
# Monkey-patch the existing implementation

0 commit comments

Comments
 (0)
Please sign in to comment.