Skip to content

Commit 0d5fb09

Browse files
henrylhtsangpytorchmergebot
authored andcommitted
[cutlass backend] check against arch >= 100 (pytorch#145812)
Summary: Want to add a guard against silent fallback to SM90. GenerateSM100 was just added 3 days ago. https://github.com/NVIDIA/cutlass/blame/main/python/cutlass_library/generator.py#L8896 It should show up in CUTLASS 3.8 (not pinned yet). Test Plan: ci Differential Revision: D68748705 Pull Request resolved: pytorch#145812 Approved by: https://github.com/chenyang78, https://github.com/ColinPeppler, https://github.com/Aidyn-A
1 parent bab35eb commit 0d5fb09

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

torch/_inductor/codegen/cuda/cutlass_utils.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,19 @@ def try_import_cutlass() -> bool:
114114
return False
115115

116116

117+
@functools.lru_cache(8)
117118
def _normalize_cuda_arch(arch: str) -> str:
118-
if int(arch) >= 90:
119+
if int(arch) >= 100:
120+
log.warning(
121+
"Detected CUDA architecture >= 100: %s. We will generate operations with "
122+
"GenerateSM100 (if available) and GenerateSM90. Please file an "
123+
"issue for any problems and feedback. ",
124+
arch,
125+
)
126+
127+
if int(arch) >= 100:
128+
return "100"
129+
elif int(arch) >= 90:
119130
return "90"
120131
elif int(arch) >= 80:
121132
return "80"
@@ -186,7 +197,15 @@ def _gen_ops_cached(arch, version) -> list[Any]:
186197
)
187198
manifest = cutlass_manifest.Manifest(args)
188199

189-
if arch == "90":
200+
if arch == "100":
201+
try:
202+
from cutlass_generator import GenerateSM100 # type: ignore[import]
203+
204+
GenerateSM100(manifest, args.cuda_version)
205+
except ImportError:
206+
log.warning("Cannot find GenerateSM100. Only GenerateSM90 will be used. ")
207+
cutlass_generator.GenerateSM90(manifest, args.cuda_version)
208+
elif arch == "90":
190209
cutlass_generator.GenerateSM90(manifest, args.cuda_version)
191210
cutlass_generator.GenerateSM80(manifest, args.cuda_version)
192211
else:

0 commit comments

Comments
 (0)