Skip to content

Commit 94a4a5a

Browse files
authored
Add PTX helpers (#686)
* add PTX helpers * avoid clobbering __init__.py + address review comments * add simple test + fix linter errors * add docs
1 parent 248edca commit 94a4a5a

File tree

7 files changed

+201
-1
lines changed

7 files changed

+201
-1
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33

4+
from cuda.bindings import utils
45
from cuda.bindings._version import __version__
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
from ._ptx_utils import get_minimal_required_cuda_ver_from_ptx_ver, get_ptx_ver
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
import re
5+
6+
# Mapping based on the official PTX ISA <-> CUDA Release table
7+
# https://docs.nvidia.com/cuda/parallel-thread-execution/#release-notes-ptx-release-history
8+
_ptx_to_cuda = {
9+
"1.0": (1, 0),
10+
"1.1": (1, 1),
11+
"1.2": (2, 0),
12+
"1.3": (2, 1),
13+
"1.4": (2, 2),
14+
"2.0": (3, 0),
15+
"2.1": (3, 1),
16+
"2.2": (3, 2),
17+
"2.3": (4, 0),
18+
"3.0": (4, 1),
19+
"3.1": (5, 0),
20+
"3.2": (5, 5),
21+
"4.0": (6, 0),
22+
"4.1": (6, 5),
23+
"4.2": (7, 0),
24+
"4.3": (7, 5),
25+
"5.0": (8, 0),
26+
"6.0": (9, 0),
27+
"6.1": (9, 1),
28+
"6.2": (9, 2),
29+
"6.3": (10, 0),
30+
"6.4": (10, 1),
31+
"6.5": (10, 2),
32+
"7.0": (11, 0),
33+
"7.1": (11, 1),
34+
"7.2": (11, 2),
35+
"7.3": (11, 3),
36+
"7.4": (11, 4),
37+
"7.5": (11, 5),
38+
"7.6": (11, 6),
39+
"7.7": (11, 7),
40+
"7.8": (11, 8),
41+
"8.0": (12, 0),
42+
"8.1": (12, 1),
43+
"8.2": (12, 2),
44+
"8.3": (12, 3),
45+
"8.4": (12, 4),
46+
"8.5": (12, 5),
47+
"8.6": (12, 7),
48+
"8.7": (12, 8),
49+
"8.8": (12, 9),
50+
}
51+
52+
53+
def get_minimal_required_cuda_ver_from_ptx_ver(ptx_version: str) -> int:
54+
"""
55+
Maps the PTX ISA version to the minimal CUDA driver, nvPTXCompiler, or nvJitLink version
56+
that is needed to load a PTX of the given ISA version.
57+
58+
Parameters
59+
----------
60+
ptx_version : str
61+
PTX ISA version as a string, e.g. "8.8" for PTX ISA 8.8. This is the ``.version``
62+
directive in the PTX header.
63+
64+
Returns
65+
-------
66+
int
67+
Minimal CUDA version as 1000 * major + 10 * minor, e.g. 12090 for CUDA 12.9.
68+
69+
Raises
70+
------
71+
ValueError
72+
If the PTX version is unknown.
73+
74+
Examples
75+
--------
76+
>>> get_minimal_required_driver_ver_from_ptx_ver("8.8")
77+
12090
78+
>>> get_minimal_required_driver_ver_from_ptx_ver("7.0")
79+
11000
80+
"""
81+
try:
82+
major, minor = _ptx_to_cuda[ptx_version]
83+
return 1000 * major + 10 * minor
84+
except KeyError:
85+
raise ValueError(f"Unknown or unsupported PTX ISA version: {ptx_version}") from None
86+
87+
88+
# Regex pattern to match .version directive and capture the version number
89+
# TODO: if import speed is a concern, consider lazy-initializing it.
90+
_ptx_ver_pattern = re.compile(r"\.version\s+([0-9]+\.[0-9]+)")
91+
92+
93+
def get_ptx_ver(ptx: str) -> str:
94+
"""
95+
Extract the PTX ISA version string from PTX source code.
96+
97+
Parameters
98+
----------
99+
ptx : str
100+
The PTX assembly source code as a string.
101+
102+
Returns
103+
-------
104+
str
105+
The PTX ISA version string, e.g., "8.8".
106+
107+
Raises
108+
------
109+
ValueError
110+
If the .version directive is not found in the PTX source.
111+
112+
Examples
113+
--------
114+
>>> ptx = r'''
115+
... .version 8.8
116+
... .target sm_86
117+
... .address_size 64
118+
...
119+
... .visible .entry test_kernel()
120+
... {
121+
... ret;
122+
... }
123+
... '''
124+
>>> get_ptx_ver(ptx)
125+
'8.8'
126+
"""
127+
m = _ptx_ver_pattern.search(ptx)
128+
if m:
129+
return m.group(1)
130+
else:
131+
raise ValueError("No .version directive found in PTX source. Is it a valid PTX?")

cuda_bindings/docs/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ CUDA Python API Reference
1515
module/nvjitlink
1616
module/nvvm
1717
module/cufile
18+
module/utils
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
.. SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
.. SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
.. module:: cuda.bindings.utils
5+
6+
Utils module
7+
============
8+
9+
Functions
10+
---------
11+
12+
.. autosummary::
13+
:toctree: generated/
14+
15+
get_minimal_required_cuda_ver_from_ptx_ver
16+
get_ptx_ver

cuda_bindings/docs/source/release/12.X.Y-notes.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
.. SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
.. SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33
4+
.. module:: cuda.bindings
5+
46
``cuda-bindings`` 12.X.Y Release notes
57
======================================
68

@@ -24,6 +26,8 @@ Bug fixes
2426
Miscellaneous
2527
-------------
2628

29+
* Added PTX utilities including :func:`~utils.get_minimal_required_cuda_ver_from_ptx_ver` and :func:`~utils.get_ptx_ver`.
30+
2731

2832
Known issues
2933
------------

cuda_bindings/tests/test_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
import pytest
5+
6+
from cuda.bindings.utils import get_minimal_required_cuda_ver_from_ptx_ver, get_ptx_ver
7+
8+
ptx_88_kernel = r"""
9+
.version 8.8
10+
.target sm_75
11+
.address_size 64
12+
13+
// .globl empty_kernel
14+
15+
.visible .entry empty_kernel()
16+
{
17+
ret;
18+
}
19+
"""
20+
21+
22+
ptx_72_kernel = r"""
23+
.version 7.2
24+
.target sm_75
25+
.address_size 64
26+
27+
// .globl empty_kernel
28+
29+
.visible .entry empty_kernel()
30+
{
31+
ret;
32+
}
33+
"""
34+
35+
36+
@pytest.mark.parametrize(
37+
"kernel,actual_ptx_ver,min_cuda_ver", ((ptx_88_kernel, "8.8", 12090), (ptx_72_kernel, "7.2", 11020))
38+
)
39+
def test_ptx_utils(kernel, actual_ptx_ver, min_cuda_ver):
40+
ptx_ver = get_ptx_ver(kernel)
41+
assert ptx_ver == actual_ptx_ver
42+
cuda_ver = get_minimal_required_cuda_ver_from_ptx_ver(ptx_ver)
43+
assert cuda_ver == min_cuda_ver

0 commit comments

Comments
 (0)