Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,10 @@ def get_extensions():
mxfp8_src_files_exist = all(os.path.exists(f) for f in mxfp8_sources)
if mxfp8_src_files_exist and build_for_sm100a:
print("Building mxfp8_cuda extension")
arch_flags = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In your compile logs from before this change in the PR description, I already see the proper gencodes:

...
-DTORCH_EXTENSION_NAME=mxfp8_cuda -gencode=arch=compute_100,code=sm_100 -gencode=arch=compute_100a,code=sm_100a -gencode=arch=compute_120,code=compute_120 -gencode=arch=compute_120,
code=sm_120 -gencode=arch=compute_120a,code=sm_120a -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_90,
code=sm_90 -gencode=arch=compute_90a,code=sm_90a 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proper gencodes are there but the kernel compilation was still getting skipped... If you remove CUDA_ARCH, you'll see that we get other errors like instructions not available in SM90 etc, since all these gencodes are being passed. So that's why needed to conditionally add the gencodes through setup.py to specific source files, just like it's been done for other files in setup.py.

"-gencode=arch=compute_100,code=sm_100",
"-gencode=arch=compute_120,code=sm_120"
]
ext_modules.append(
CUDAExtension(
name="torchao.prototype.mxfp8_cuda",
Expand All @@ -647,7 +651,7 @@ def get_extensions():
],
extra_compile_args={
"cxx": ["-std=c++17", "-O3"],
"nvcc": nvcc_args,
"nvcc": nvcc_args + arch_flags,
},
extra_link_args=["-lcuda", "-lcudart"],
),
Expand Down
21 changes: 0 additions & 21 deletions torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,6 @@
#include <cuda/barrier>
#include <cuda/ptx>

#define MIN_CUDA_SM 1000 // SM90 = 900, SM100 = 1000

// Check if we're compiling for supported architecture
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < MIN_CUDA_SM)
#warning \
"MXFP8 quantization requires SM90+ (Hopper) or SM100+ (Blackwell) architecture. Kernel will be disabled for this architecture."
#endif

// Architecture detection for native FP8 support
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
#define HAS_NATIVE_FP8_CONVERSION 1
#else
#define HAS_NATIVE_FP8_CONVERSION 0
#endif

enum class DType {
kByte,
kFloat32,
Expand Down Expand Up @@ -975,11 +960,6 @@ public:
output_bits_per_elem); // bits per elem in output fp8e4m3
}

// Launch kernel based on input/output types and scaling dimensions
// Only compile kernel launches for SM90+
#if defined(__CUDACC__) && \
(!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= MIN_CUDA_SM)

// Use TMA and mbarrier instructions
#define LAUNCH_KERNEL(IType, OType, SCALE_Y, SCALE_X, ScalingMode) \
mxfp8_quantize_kernel<IType, OType, SCALE_Y, SCALE_X, ScalingMode> \
Expand Down Expand Up @@ -1044,6 +1024,5 @@ public:

#undef LAUNCH_KERNEL

#endif
}
};
Loading