diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index 6770cf92..35f2815d 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -219,7 +219,8 @@ static void m_grouped_fp8_fp4_gemm_nn_contiguous(const std::pair& a, +static std::tuple, c10::optional> m_grouped_fp8_fp4_gemm_nt_masked( + const std::pair& a, const std::pair& b, const torch::Tensor& d, const torch::Tensor& masked_m, @@ -228,7 +229,10 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair> recipe_a, std::optional> recipe_b, const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { + const bool& disable_ue8m0_cast, + const int& max_block_n, + const bool& enable_overlap, + const c10::optional& signal) { // Shape must be `[G, M, K] @ [G, N, K].mT` const auto& major_a = get_major_type_ab(a.first); const auto& major_b = get_major_type_ab(b.first); @@ -255,17 +259,28 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair> overlap_result = std::nullopt; if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { const auto& major_sfb = get_major_type_ab(sfb); - sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, - num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims); + overlap_result = sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims, + max_block_n, enable_overlap, signal); } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + DG_HOST_ASSERT(not enable_overlap); sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, num_groups, m, n, k, expected_m, gran_k_a, gran_k_b, major_a, major_b, compiled_dims); } else { DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } + + if (!overlap_result) { + return std::make_tuple(c10::nullopt, c10::nullopt); + } + return std::make_tuple( + c10::optional(overlap_result->first), + c10::optional(overlap_result->second) + ); } static void k_grouped_fp8_gemm_tn_contiguous(const std::pair& a, @@ -640,7 +655,8 @@ static void register_apis(pybind11::module_& m) { py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), py::arg("expected_m"), py::arg("recipe") = std::nullopt, py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, - py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); + py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false, + py::arg("max_block_n") = 256, py::arg("enable_overlap") = false, py::arg("signal") = std::nullopt); m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), py::arg("ks_tensor"), py::arg("c") = std::nullopt, diff --git a/csrc/apis/runtime.hpp b/csrc/apis/runtime.hpp index a5d313e2..44e4b107 100644 --- a/csrc/apis/runtime.hpp +++ b/csrc/apis/runtime.hpp @@ -14,6 +14,12 @@ static void register_apis(pybind11::module_& m) { m.def("get_num_sms", [&]() { return device_runtime->get_num_sms(); }); + m.def("set_compile_mode", [&](const int& new_compile_mode) { + device_runtime->set_compile_mode(new_compile_mode); + }); + m.def("get_compile_mode", [&]() { + return device_runtime->get_compile_mode(); + }); m.def("set_tc_util", [&](const int& new_tc_util) { device_runtime->set_tc_util(new_tc_util); }); diff --git a/csrc/jit/device_runtime.hpp b/csrc/jit/device_runtime.hpp index d33743ef..ae881a03 100644 --- a/csrc/jit/device_runtime.hpp +++ b/csrc/jit/device_runtime.hpp @@ -12,6 +12,7 @@ namespace deep_gemm { class DeviceRuntime { int num_sms = 0, tc_util = 0; std::shared_ptr cached_prop; + int compile_mode = 0; // cuBLASLt utils static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024; @@ -80,6 +81,15 @@ class DeviceRuntime { return num_sms; } + void set_compile_mode(const int& new_compile_mode) { + DG_HOST_ASSERT(0 <= new_compile_mode and new_compile_mode <= 1); + compile_mode = new_compile_mode; + } + + int get_compile_mode() { + return compile_mode; + } + int get_l2_cache_size() { return get_prop()->l2CacheSize; } diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index a49584f4..b1571ae9 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -65,11 +65,13 @@ struct GemmConfig { cute::UMMA::Major major_b; bool with_accumulation; int block_m, block_n, block_k; + int signal_threshold; int num_stages, num_last_stages; // Templated device configs int num_sms; int tc_util; + bool enable_overlap; // Structured configs MulticastConfig multicast_config; @@ -154,7 +156,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const at::ScalarType& a_dtype, const at::ScalarType& b_dtype, const at::ScalarType& cd_dtype, - const bool& with_accumulation, const int& num_sms) { + const bool& with_accumulation, const int& num_sms, + const int& max_block_n = 256, const bool& enable_overlap = false) { const auto mma_kind = (a_dtype == torch::kBFloat16 ? MmaKind::BF16 : MmaKind::MXFP8FP4); if (mma_kind == MmaKind::BF16) { DG_HOST_ASSERT(a_dtype == torch::kBFloat16 and b_dtype == torch::kBFloat16); @@ -170,7 +173,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k block_ms = std::vector{get_mk_alignment_for_contiguous_layout()}; if (gemm_type == GemmType::MGroupedMasked or gemm_type == GemmType::MGroupedContiguousWithPsumLayout) block_ms = std::vector{64, 128}; // Exclude 256 for performance - auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype); + auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype, max_block_n); // NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B // TODO: Optimize it @@ -297,10 +300,12 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k .block_m = best_block_m, .block_n = best_block_n, .block_k = block_k, + .signal_threshold = ceil_div(n, best_block_n), .num_stages = best_num_stages, .num_last_stages = ceil_div(k, block_k) % best_num_stages, .num_sms = num_min_sms, .tc_util = device_runtime->get_tc_util(), + .enable_overlap = enable_overlap, .multicast_config = best_multicast_config, // ReSharper disable once CppLocalVariableMightNotBeInitialized .smem_config = best_smem_config, diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp index dd1e6024..71aa52f0 100644 --- a/csrc/jit_kernels/heuristics/sm100.hpp +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -22,11 +22,11 @@ struct SM100ArchSpec { return candidates; } - static std::vector get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype) { + static std::vector get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype, const int& max_block_n) { // 16 is for better SM usage // Stride 32 is due to low-performance swizzle-16/32B std::vector candidates = {16}; - for (int i = 32; i <= 256; i += 32) + for (int i = 32; i <= max_block_n; i += 32) candidates.push_back(i); return candidates; } diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp index 2fd2e9ec..40a1235f 100644 --- a/csrc/jit_kernels/heuristics/sm90.hpp +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -21,7 +21,7 @@ struct SM90ArchSpec { return candidates; } - static std::vector get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype) { + static std::vector get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype, const int& max_block_n) { int start = 16; // Avoid bank conflicts for 1D1D kernel FP32 output @@ -32,7 +32,7 @@ struct SM90ArchSpec { } // Push the strided options - for (int i = start; i <= 256; i += 16) + for (int i = start; i <= max_block_n; i += 16) candidates.push_back(i); return candidates; } diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index b245b94a..dce0581c 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -5,6 +5,7 @@ #include "../heuristics/sm90.hpp" #include "../../jit/handle.hpp" +#include "../../jit/device_runtime.hpp" #include "../../utils/math.hpp" #include "../../utils/system.hpp" #include "../../utils/exception.hpp" @@ -234,4 +235,12 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, allow_tf32); } +#ifndef MAYBE_LAUNCH +#define MAYBE_LAUNCH(EXPR) do { \ + if (device_runtime->get_compile_mode() == 0) { \ + (EXPR); \ + } \ +} while (0) +#endif + } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp index 95f72729..2db83eea 100644 --- a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp @@ -118,7 +118,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a, }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_gemm", code); - SM100BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args)); } static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, @@ -177,7 +177,7 @@ static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_contiguous", code); - SM100BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args)); } static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a, @@ -227,7 +227,7 @@ static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a, }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_masked", code); - SM100BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args)); } static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a, @@ -289,7 +289,7 @@ static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a, }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_k_grouped_gemm", code); - SM100BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args)); } static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, @@ -337,7 +337,7 @@ static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_bhr_hdr_bhd", code); - SM100BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args)); } static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, @@ -385,7 +385,7 @@ static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_bhd_hdr_bhr", code); - SM100BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args)); } } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index 07a977d7..201455c3 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -141,7 +141,7 @@ static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code); - SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args)); } static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, @@ -210,7 +210,7 @@ static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code); - SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args)); } static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, @@ -271,7 +271,7 @@ static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, con }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code); - SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args)); } static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, @@ -342,7 +342,7 @@ static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::T }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code); - SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args)); } static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, @@ -410,7 +410,7 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); - SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args)); } } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp index 32003f88..71721550 100644 --- a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -119,7 +119,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a, }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_gemm", code); - SM90BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args)); } static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, @@ -173,7 +173,7 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code); - SM90BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args)); } static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, @@ -228,7 +228,7 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code); - SM90BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args)); } static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a, @@ -290,7 +290,7 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a, }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_k_grouped_gemm", code); - SM90BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args)); } static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, @@ -337,7 +337,7 @@ static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_bhr_hdr_bhd", code); - SM90BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args)); } static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, @@ -384,7 +384,7 @@ static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_bhd_hdr_bhr", code); - SM90BF16GemmRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args)); } } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp index e61841b3..d47e7980 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp @@ -136,7 +136,7 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); - SM90FP8Gemm1D1DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90FP8Gemm1D1DRuntime::launch(runtime, args)); } static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, @@ -212,7 +212,7 @@ static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Te const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); - SM90FP8Gemm1D1DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90FP8Gemm1D1DRuntime::launch(runtime, args)); } } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index 2696b5a0..f1992c73 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -25,7 +25,7 @@ class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime GemmConfig gemm_config; LaunchArgs launch_args; - void *sfb, *grouped_layout; + void *sfb, *grouped_layout, *signal; CUtensorMap tensor_map_a; CUtensorMap tensor_map_b; CUtensorMap tensor_map_d; @@ -49,7 +49,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, {}, - {} + {}, {} >); }}; )", @@ -63,13 +63,14 @@ static void __instantiate_kernel() {{ args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), - get_default_epilogue_type(args.epilogue_type)); + get_default_epilogue_type(args.epilogue_type), + args.gemm_config.enable_overlap); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { // TODO: optimize `args` copy DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, - args.sfb, args.grouped_layout, + args.sfb, args.grouped_layout, args.signal, args.m, args.n, args.k, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.tensor_map_sfa)); @@ -128,6 +129,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, config.multicast_config.num_multicast), .sfb = sfb.data_ptr(), .grouped_layout = nullptr, + .signal = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, @@ -135,7 +137,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, }; const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code); - SM90FP8Gemm1D2DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args)); } static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, @@ -189,6 +191,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons config.multicast_config.num_multicast), .sfb = sfb.data_ptr(), .grouped_layout = m_indices.data_ptr(), + .signal = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, @@ -196,26 +199,34 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons }; const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code); - SM90FP8Gemm1D2DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args)); } -static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, +static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& b, const torch::Tensor& sfb, const torch::Tensor& d, const torch::Tensor& masked_m, const int& num_groups, const int& m, const int& n, const int& k, const int& expected_m, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, - const std::string& compiled_dims) { + const std::string& compiled_dims, + const int& max_block_n, + const bool& enable_overlap, + const c10::optional& signal) { DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + if (enable_overlap) { + DG_HOST_ASSERT(signal.has_value()); + DG_HOST_ASSERT(signal.value().is_contiguous()); + DG_HOST_ASSERT(signal.value().scalar_type() == torch::kInt32); + } const auto& config = get_best_config( GemmType::MGroupedMasked, KernelType::Kernel1D2D, expected_m, n, k, num_groups, major_a, major_b, a.scalar_type(), b.scalar_type(), d.scalar_type(), false, - device_runtime->get_num_sms()); + device_runtime->get_num_sms(), max_block_n, enable_overlap); // Requires no TMA splits DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); @@ -251,6 +262,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to config.multicast_config.num_multicast), .sfb = sfb.data_ptr(), .grouped_layout = masked_m.data_ptr(), + .signal = enable_overlap ? signal.value().data_ptr() : nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, @@ -258,7 +270,8 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to }; const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code); - SM90FP8Gemm1D2DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args)); + return enable_overlap ? std::optional(std::make_pair(config.block_m, config.signal_threshold)) : std::nullopt; } static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, @@ -318,6 +331,7 @@ static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, config.multicast_config.num_multicast), .sfb = sfb.data_ptr(), .grouped_layout = nullptr, + .signal = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, @@ -325,7 +339,7 @@ static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, }; const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code); - SM90FP8Gemm1D2DRuntime::launch(runtime, args); + MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args)); } } // namespace deep_gemm diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 1c07f5d9..f8bc6d52 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -17,6 +17,8 @@ from ._C import ( set_num_sms, get_num_sms, + set_compile_mode, + get_compile_mode, set_tc_util, get_tc_util, ) diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh index 8fb6c2fc..ef098b31 100644 --- a/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -162,6 +162,16 @@ __device__ __forceinline__ void prefetch_l1(void *ptr) { asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); } +__device__ __forceinline__ void store_wait() { + asm volatile("cp.async.bulk.wait_group 0;\n" ::: "memory"); +} + +__device__ __forceinline__ int atomic_add_release_global(int* addr, int value) { + int ret; + asm volatile ("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(value)); + return ret; +} + template struct Vectorized { static auto zeros() { diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 9247304c..ce0220c5 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -39,9 +39,9 @@ template + typename epilogue_type_t, bool kEnableOverlap> __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void -sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, +sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -427,6 +427,18 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, cute::tma_store_arrive(); } __syncwarp(); + + if constexpr (kEnableOverlap) { + DG_TRAP_ONLY_DEVICE_ASSERT(signal != nullptr); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + store_wait(); + } + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 2); + if (threadIdx.x == 0) { + atomic_add_release_global( + signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); + } + } } } #else diff --git a/deep_gemm/testing/numeric.py b/deep_gemm/testing/numeric.py index a42c4318..f98c3f12 100644 --- a/deep_gemm/testing/numeric.py +++ b/deep_gemm/testing/numeric.py @@ -19,3 +19,19 @@ def count_bytes(*tensors): elif t is not None: total += t.numel() * t.element_size() return total + + +def check_signal(num_local_expert, max_m, block_m, threshold, signal, masked_m): + ceil_div = lambda a, b: (a + b - 1) // b + + expert_len = max_m // block_m + for expert in range(num_local_expert): + mask = masked_m[expert] + start = expert * expert_len + end = expert * expert_len + expert_len + valid_len = ceil_div(mask, block_m) + for i in range(start, end): + if i < start + valid_len: + assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}' + else: + assert signal[i] == 0, f'{i=}, {signal[i]=}' diff --git a/tests/generators.py b/tests/generators.py index ee22e515..7f102715 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -178,7 +178,8 @@ def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator: reset_seed() for num_groups, m in m_group_list: for n, k in n_k_list: - yield kernel_type, quant_config, num_groups, max_m, m, n, k, use_psum_layout + for enable_overlap in (False, True): + yield kernel_type, enable_overlap, quant_config, num_groups, max_m, m, n, k, use_psum_layout def enumerate_k_grouped_contiguous(dtype: torch.dtype): @@ -332,7 +333,8 @@ def layout_masked_to_psum(x: torch.Tensor, psum_m: torch.Tensor): def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, use_ue8m0: bool = False, use_bf16: bool = False, use_psum_layout: bool = False, - quant_config: Optional[QuantConfig] = None): + quant_config: Optional[QuantConfig] = None, + enable_overlap: bool = False): a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) @@ -346,13 +348,21 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: assert masked_m.amax().item() <= max_m if use_bf16: - return a, b, masked_m, psum_m, d, ref_d + signal = None + if enable_overlap and (not use_psum_layout): + max_signal_size = num_groups * ceil_div(max_m, 64) + signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') + return a, b, masked_m, psum_m, d, ref_d, signal quant_config = QuantConfig() if quant_config is None else quant_config a = grouped_cast_fp8_fp4_with_major(a, MajorTypeAB.KMajor, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) b = grouped_cast_fp8_fp4_with_major(b, MajorTypeAB.KMajor, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, use_block_cast_for_fp8=True) - return a, b, masked_m, psum_m, d, ref_d + signal = None + if enable_overlap and (not use_psum_layout): + max_signal_size = num_groups * ceil_div(max_m, 64) + signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') + return a, b, masked_m, psum_m, d, ref_d, signal def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int], diff --git a/tests/test_bf16.py b/tests/test_bf16.py index 1a3b0467..8c8df94e 100644 --- a/tests/test_bf16.py +++ b/tests/test_bf16.py @@ -86,14 +86,17 @@ def test_m_grouped_gemm_masked() -> None: print('Testing m-grouped masked GEMM:') # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for _, _, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.bfloat16): + for _, enable_overlap, _, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.bfloat16): + if enable_overlap: + continue num_tests = 8 sum_t, max_t = 0, 0 sum_ops, sum_bytes = 0, 0 for i in range(num_tests): - a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, - use_bf16=True, use_psum_layout=use_psum_layout) + a, b, masked_m, psum_m, d, ref_d, _ = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, + use_bf16=True, use_psum_layout=use_psum_layout, + enable_overlap=False) if use_psum_layout: a_psum = layout_masked_to_psum(a, psum_m) d_psum = layout_masked_to_psum(d, psum_m) diff --git a/tests/test_fp8_fp4.py b/tests/test_fp8_fp4.py index f7e3e1c4..54c97692 100644 --- a/tests/test_fp8_fp4.py +++ b/tests/test_fp8_fp4.py @@ -7,6 +7,7 @@ from deep_gemm.testing import ( bench_kineto, calc_diff, count_bytes, + check_signal, ignore_env, get_arch_major ) @@ -102,20 +103,23 @@ def test_m_grouped_gemm_masked() -> None: print('Testing m-grouped masked GEMM:') # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for kernel_type, quant_config, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.float8_e4m3fn): + for kernel_type, enable_overlap, quant_config, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.float8_e4m3fn): kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' use_ue8m0 = get_ue8m0_usage(kernel_type) disable_ue8m0_cast = not use_ue8m0 recipe, recipe_a, recipe_b = quant_config.get_recipes() + if enable_overlap and (use_psum_layout or get_arch_major() != 9): + continue + num_tests = 8 sum_t, max_t = 0, 0 sum_ops, sum_bytes = 0, 0 for i in range(num_tests): - a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, - use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, - quant_config=quant_config) + a, b, masked_m, psum_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, + use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, + quant_config=quant_config, enable_overlap=enable_overlap) if use_psum_layout: a_psum = (layout_masked_to_psum(a[0], psum_m), layout_masked_to_psum(a[1], psum_m)) d_psum = layout_masked_to_psum(d, psum_m) @@ -127,8 +131,15 @@ def test_func(): use_psum_layout=True, expected_m_for_psum_layout=expected_m_per_group, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) else: - deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, - recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + result = deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked( + a, b, d, masked_m, expected_m_per_group, + disable_ue8m0_cast=disable_ue8m0_cast, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b, + enable_overlap=enable_overlap, signal=signal, + ) + if enable_overlap: + block_m, threshold = result + check_signal(num_groups, max_m, block_m, threshold, signal, masked_m) test_func() for j in range(num_groups): @@ -151,7 +162,7 @@ def test_func(): sum_bytes += count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b) print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, ' - f'{kernel_opt}, psum={1 if use_psum_layout else 0}): ' + f'{kernel_opt}, psum={1 if use_psum_layout else 0}, enable_overlap={enable_overlap}): ' f'{sum_t / num_tests * 1e6:4.0f} us (max: {max_t * 1e6:3.0f} us) | ' f'{sum_ops / sum_t / 1e12:4.0f} TFLOPS | ' f'{sum_bytes / sum_t / 1e9:4.0f} GB/s')