Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fp8]Support deep gemm #9955

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/third_party/cutlass
Submodule cutlass updated 2225 files
34 changes: 34 additions & 0 deletions ops/csrc/fp8/deep_gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# MIT License
#
# Copyright (c) 2025 DeepSeek-Ai/DeepGEMM
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import paddle

from . import jit
from .jit_kernels import (
cell_div,
gemm_fp8_fp8_bf16_nt,
get_col_major_tma_aligned_tensor,
get_m_alignment_for_contiguous_layout,
get_num_sms,
set_num_sms,
)
from .utils import bench, calc_diff
466 changes: 466 additions & 0 deletions ops/csrc/fp8/deep_gemm/include/deep_gemm/fp8_gemm.cuh

Large diffs are not rendered by default.

907 changes: 907 additions & 0 deletions ops/csrc/fp8/deep_gemm/include/deep_gemm/mma_utils.cuh

Large diffs are not rendered by default.

126 changes: 126 additions & 0 deletions ops/csrc/fp8/deep_gemm/include/deep_gemm/scheduler.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// MIT License
//
// Copyright (c) 2025 DeepSeek-Ai/DeepGEMM
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

#include "utils.cuh"

namespace deep_gemm {

enum class GemmType {
Normal,
GroupedContiguous,
GroupedMasked
};

#pragma clang diagnostic push
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
template <GemmType kGemmType,
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumNBlocks = cell_div(SHAPE_N, BLOCK_N),
uint32_t kNumNBlocksPerGroup = 16>
struct Scheduler {
int current_iter = -1;
uint32_t num_aligned_m_blocks;

// For normal GEMM
// Maybe not used in the masked grouped GEMM
uint32_t num_blocks;

// For grouped GEMM
int* grouped_layout;
// Only used for masked layout
uint32_t curr_group_idx, curr_cumsum;

__device__ __forceinline__ explicit Scheduler(const uint32_t shape_m,
int* grouped_layout = nullptr) {
num_aligned_m_blocks = cell_div(shape_m, BLOCK_M);
if constexpr (kGemmType == GemmType::Normal) {
num_blocks = num_aligned_m_blocks * kNumNBlocks;
} else if (kGemmType == GemmType::GroupedContiguous) {
num_blocks = num_aligned_m_blocks * kNumNBlocks;
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::GroupedMasked) {
curr_group_idx = curr_cumsum = 0;
this->grouped_layout = grouped_layout;
}
}

__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");

// Swizzle for better L2 usages
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = in_group_idx / num_n_blocks_in_group;
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
}

template <bool kIgnoreGroupedForGroupedContiguous=true>
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
if constexpr (kGemmType == GemmType::Normal) {
return block_idx * block_size;
} else if (kGemmType == GemmType::GroupedContiguous) {
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
return offset * shape_dim + block_idx * block_size;
} else if (kGemmType == GemmType::GroupedMasked) {
return curr_group_idx * shape_dim + block_idx * block_size;
}
}

__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;

if constexpr (kGemmType == GemmType::GroupedMasked) {
uint32_t num_m_blocks;
while (true) {
// End of the task
if (curr_group_idx == kNumGroups)
return false;

// Within current group
num_m_blocks = cell_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
break;

// Move to check the next group
curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
}

get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
} else {
if (next_block_idx >= num_blocks)
return false;

get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
}
return true;
}
};
#pragma clang diagnostic pop

} // namespace deep_gemm

118 changes: 118 additions & 0 deletions ops/csrc/fp8/deep_gemm/include/deep_gemm/tma_utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// MIT License
//
// Copyright (c) 2025 DeepSeek-Ai/DeepGEMM
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

#pragma once

#include <cassert>
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda/barrier>

#include "utils.cuh"

namespace deep_gemm {

template <class T>
constexpr CUtensorMapDataType get_CUtensorMapDataType() {
if constexpr (std::is_same<T, uint8_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if constexpr (std::is_same<T, uint16_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
} else if constexpr (std::is_same<T, uint32_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
} else if constexpr (std::is_same<T, uint64_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
} else if constexpr (std::is_same<T, int32_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_INT32;
} else if constexpr (std::is_same<T, int64_t>::value) {
return CU_TENSOR_MAP_DATA_TYPE_INT64;
} else if constexpr (std::is_same<T, __half>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if constexpr (std::is_same<T, float>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if constexpr (std::is_same<T, double>::value) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
}
}

PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
// Get pointer to `cuTensorMapEncodeTiled`
cudaDriverEntryPointQueryResult driver_status;
void* cuTensorMapEncodeTiled_ptr = nullptr;

#if CUDA_VERSION >= 12050
cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000,
cudaEnableDefault, &driver_status);
#else
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr,
cudaEnableDefault, &driver_status);
#endif

if (driver_status != cudaDriverEntryPointSuccess)
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
}

template <typename T>
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
uint64_t stride_in_bytes, uint32_t smem_dim[2],
CUtensorMapSwizzle swizzle_type,
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
CUtensorMap tensor_map{};
constexpr uint32_t rank = 2;
uint64_t global_stride[rank - 1] = {stride_in_bytes};
uint32_t elem_strides[rank] = {1, 1};

if (encode_func == nullptr)
encode_func = get_cuTensorMapEncodeTiled();

auto result = encode_func(
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
DG_HOST_ASSERT(result == CUDA_SUCCESS);
return tensor_map;
}

template <uint32_t kNumTMAMulticast = 1>
__device__ __forceinline__ void
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
int32_t const& crd_0, int32_t const& crd_1) {
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
if constexpr (kNumTMAMulticast == 1) {
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
} else if (cute::block_rank_in_cluster() == 0) {
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
}
}

} // namespace deep_gemm
70 changes: 70 additions & 0 deletions ops/csrc/fp8/deep_gemm/include/deep_gemm/utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// MIT License
//
// Copyright (c) 2025 DeepSeek-Ai/DeepGEMM
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

#pragma once

#include <exception>

#ifdef __CLION_IDE__
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
#define printf host_device_printf
#endif

class AssertionException : public std::exception {
private:
std::string message{};

public:
explicit AssertionException(const std::string& message) : message(message) {}

const char *what() const noexcept override { return message.c_str(); }
};

#ifndef DG_HOST_ASSERT
#define DG_HOST_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", \
__FILE__, __LINE__, #cond); \
throw AssertionException("Assertion failed: " #cond); \
} \
} while (0)
#endif

#ifndef DG_DEVICE_ASSERT
#define DG_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while (0)
#endif

#ifndef DG_STATIC_ASSERT
#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
#endif

template <typename T>
__device__ __host__ constexpr T cell_div(T a, T b) {
return (a + b - 1) / b;
}
25 changes: 25 additions & 0 deletions ops/csrc/fp8/deep_gemm/jit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# MIT License
#
# Copyright (c) 2025 DeepSeek-Ai/DeepGEMM
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from .compiler import build, get_nvcc_compiler
from .runtime import Runtime
from .template import cpp_format, generate
Loading
Loading