Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions csrc/apis/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ static void m_grouped_fp8_fp4_gemm_nn_contiguous(const std::pair<torch::Tensor,
d, grouped_layout, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast, use_psum_layout, std::nullopt);
}

static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
static std::tuple<c10::optional<int64_t>, c10::optional<int64_t>> m_grouped_fp8_fp4_gemm_nt_masked(
const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& masked_m,
Expand All @@ -228,7 +229,10 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torc
std::optional<std::tuple<int, int>> recipe_a,
std::optional<std::tuple<int, int>> recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
const bool& disable_ue8m0_cast,
const int& max_block_n,
const bool& enable_overlap,
const c10::optional<torch::Tensor>& signal) {
// Shape must be `[G, M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
Expand All @@ -255,17 +259,28 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torc
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, num_groups, num_groups, disable_ue8m0_cast);

// Dispatch implementation
std::optional<std::pair<int, int>> overlap_result = std::nullopt;
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
const auto& major_sfb = get_major_type_ab(sfb);
sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims);
overlap_result = sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims,
max_block_n, enable_overlap, signal);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
DG_HOST_ASSERT(not enable_overlap);
sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, gran_k_a, gran_k_b,
major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}

if (!overlap_result) {
return std::make_tuple(c10::nullopt, c10::nullopt);
}
return std::make_tuple(
c10::optional<int64_t>(overlap_result->first),
c10::optional<int64_t>(overlap_result->second)
);
}

static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
Expand Down Expand Up @@ -640,7 +655,8 @@ static void register_apis(pybind11::module_& m) {
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false,
py::arg("max_block_n") = 256, py::arg("enable_overlap") = false, py::arg("signal") = std::nullopt);
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
Expand Down
6 changes: 6 additions & 0 deletions csrc/apis/runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ static void register_apis(pybind11::module_& m) {
m.def("get_num_sms", [&]() {
return device_runtime->get_num_sms();
});
m.def("set_compile_mode", [&](const int& new_compile_mode) {
device_runtime->set_compile_mode(new_compile_mode);
});
m.def("get_compile_mode", [&]() {
return device_runtime->get_compile_mode();
});
m.def("set_tc_util", [&](const int& new_tc_util) {
device_runtime->set_tc_util(new_tc_util);
});
Expand Down
10 changes: 10 additions & 0 deletions csrc/jit/device_runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace deep_gemm {
class DeviceRuntime {
int num_sms = 0, tc_util = 0;
std::shared_ptr<cudaDeviceProp> cached_prop;
int compile_mode = 0;

// cuBLASLt utils
static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024;
Expand Down Expand Up @@ -80,6 +81,15 @@ class DeviceRuntime {
return num_sms;
}

void set_compile_mode(const int& new_compile_mode) {
DG_HOST_ASSERT(0 <= new_compile_mode and new_compile_mode <= 1);
compile_mode = new_compile_mode;
}

int get_compile_mode() {
return compile_mode;
}

int get_l2_cache_size() {
return get_prop()->l2CacheSize;
}
Expand Down
9 changes: 7 additions & 2 deletions csrc/jit_kernels/heuristics/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,13 @@ struct GemmConfig {
cute::UMMA::Major major_b;
bool with_accumulation;
int block_m, block_n, block_k;
int signal_threshold;
int num_stages, num_last_stages;

// Templated device configs
int num_sms;
int tc_util;
bool enable_overlap;

// Structured configs
MulticastConfig multicast_config;
Expand Down Expand Up @@ -154,7 +156,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& a_dtype, const at::ScalarType& b_dtype,
const at::ScalarType& cd_dtype,
const bool& with_accumulation, const int& num_sms) {
const bool& with_accumulation, const int& num_sms,
const int& max_block_n = 256, const bool& enable_overlap = false) {
const auto mma_kind = (a_dtype == torch::kBFloat16 ? MmaKind::BF16 : MmaKind::MXFP8FP4);
if (mma_kind == MmaKind::BF16) {
DG_HOST_ASSERT(a_dtype == torch::kBFloat16 and b_dtype == torch::kBFloat16);
Expand All @@ -170,7 +173,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
if (gemm_type == GemmType::MGroupedMasked or gemm_type == GemmType::MGroupedContiguousWithPsumLayout)
block_ms = std::vector{64, 128}; // Exclude 256 for performance
auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype);
auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype, max_block_n);

// NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B
// TODO: Optimize it
Expand Down Expand Up @@ -297,10 +300,12 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
.block_m = best_block_m,
.block_n = best_block_n,
.block_k = block_k,
.signal_threshold = ceil_div(n, best_block_n),
.num_stages = best_num_stages,
.num_last_stages = ceil_div(k, block_k) % best_num_stages,
.num_sms = num_min_sms,
.tc_util = device_runtime->get_tc_util(),
.enable_overlap = enable_overlap,
.multicast_config = best_multicast_config,
// ReSharper disable once CppLocalVariableMightNotBeInitialized
.smem_config = best_smem_config,
Expand Down
4 changes: 2 additions & 2 deletions csrc/jit_kernels/heuristics/sm100.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ struct SM100ArchSpec {
return candidates;
}

static std::vector<int> get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype) {
static std::vector<int> get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype, const int& max_block_n) {
// 16 is for better SM usage
// Stride 32 is due to low-performance swizzle-16/32B
std::vector<int> candidates = {16};
for (int i = 32; i <= 256; i += 32)
for (int i = 32; i <= max_block_n; i += 32)
candidates.push_back(i);
return candidates;
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/jit_kernels/heuristics/sm90.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct SM90ArchSpec {
return candidates;
}

static std::vector<int> get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype) {
static std::vector<int> get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype, const int& max_block_n) {
int start = 16;

// Avoid bank conflicts for 1D1D kernel FP32 output
Expand All @@ -32,7 +32,7 @@ struct SM90ArchSpec {
}

// Push the strided options
for (int i = start; i <= 256; i += 16)
for (int i = start; i <= max_block_n; i += 16)
candidates.push_back(i);
return candidates;
}
Expand Down
9 changes: 9 additions & 0 deletions csrc/jit_kernels/impls/runtime_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "../heuristics/sm90.hpp"
#include "../../jit/handle.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../utils/math.hpp"
#include "../../utils/system.hpp"
#include "../../utils/exception.hpp"
Expand Down Expand Up @@ -234,4 +235,12 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
allow_tf32);
}

#ifndef MAYBE_LAUNCH
#define MAYBE_LAUNCH(EXPR) do { \
if (device_runtime->get_compile_mode() == 0) { \
(EXPR); \
} \
} while (0)
#endif

} // namespace deep_gemm
12 changes: 6 additions & 6 deletions csrc/jit_kernels/impls/sm100_bf16_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
};
const auto& code = SM100BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bf16_gemm", code);
SM100BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args));
}

static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
Expand Down Expand Up @@ -177,7 +177,7 @@ static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
};
const auto& code = SM100BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_contiguous", code);
SM100BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args));
}

static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a,
Expand Down Expand Up @@ -227,7 +227,7 @@ static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a,
};
const auto& code = SM100BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_masked", code);
SM100BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args));
}

static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a,
Expand Down Expand Up @@ -289,7 +289,7 @@ static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a,
};
const auto& code = SM100BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bf16_k_grouped_gemm", code);
SM100BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args));
}

static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a,
Expand Down Expand Up @@ -337,7 +337,7 @@ static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a,
};
const auto& code = SM100BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bf16_bhr_hdr_bhd", code);
SM100BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args));
}

static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a,
Expand Down Expand Up @@ -385,7 +385,7 @@ static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a,
};
const auto& code = SM100BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bf16_bhd_hdr_bhr", code);
SM100BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args));
}

} // namespace deep_gemm
10 changes: 5 additions & 5 deletions csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args));
}

static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
Expand Down Expand Up @@ -210,7 +210,7 @@ static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a,
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args));
}

static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
Expand Down Expand Up @@ -271,7 +271,7 @@ static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, con
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args));
}

static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
Expand Down Expand Up @@ -342,7 +342,7 @@ static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::T
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args));
}

static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
Expand Down Expand Up @@ -410,7 +410,7 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args));
}

} // namespace deep_gemm
12 changes: 6 additions & 6 deletions csrc/jit_kernels/impls/sm90_bf16_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bf16_gemm", code);
SM90BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args));
}

static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
Expand Down Expand Up @@ -173,7 +173,7 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code);
SM90BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args));
}

static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
Expand Down Expand Up @@ -228,7 +228,7 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code);
SM90BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args));
}

static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a,
Expand Down Expand Up @@ -290,7 +290,7 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bf16_k_grouped_gemm", code);
SM90BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args));
}

static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a,
Expand Down Expand Up @@ -337,7 +337,7 @@ static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bf16_bhr_hdr_bhd", code);
SM90BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args));
}

static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a,
Expand Down Expand Up @@ -384,7 +384,7 @@ static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bf16_bhd_hdr_bhr", code);
SM90BF16GemmRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args));
}

} // namespace deep_gemm
4 changes: 2 additions & 2 deletions csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);

SM90FP8Gemm1D1DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90FP8Gemm1D1DRuntime::launch(runtime, args));
}

static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
Expand Down Expand Up @@ -212,7 +212,7 @@ static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Te
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);

SM90FP8Gemm1D1DRuntime::launch(runtime, args);
MAYBE_LAUNCH(SM90FP8Gemm1D1DRuntime::launch(runtime, args));
}

} // namespace deep_gemm
Loading
Loading