Skip to content

Commit a7509fc

Browse files
jinyangyuan-nvidiadominicshanshan
authored andcommitted
[None][fix] Fix the performance issue of FP8 blockwise grouped GEMM when using attention DP (NVIDIA#8501)
Signed-off-by: Jinyang Yuan <[email protected]>
1 parent a746650 commit a7509fc

File tree

12 files changed

+192
-151
lines changed

12 files changed

+192
-151
lines changed

cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -994,8 +994,8 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
994994
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
995995
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
996996
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
997-
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
998-
mHiddenSize, mInterSize, mNumExperts, mK,
997+
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mTotalTokens,
998+
mHiddenSize, mHiddenSize, mInterSize, mNumExperts, mK,
999999
mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers),
10001000
mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers),
10011001
mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config,
@@ -1007,8 +1007,8 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
10071007
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
10081008
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
10091009
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
1010-
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
1011-
mHiddenSize, mInterSize, mNumExperts, mK,
1010+
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mTotalTokens,
1011+
mHiddenSize, mHiddenSize, mInterSize, mNumExperts, mK,
10121012
mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers),
10131013
mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers),
10141014
mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config,

cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::gemm(__nv_fp8
8888

8989
template <typename ElementA, typename ElementB, typename ElementD>
9090
void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void* mat_d, void const* mat_a,
91-
void const* mat_b, int64_t const* problem_m_offsets, size_t num_problems, size_t shape_n, size_t shape_k,
92-
cudaStream_t stream, float const* scales_a, float const* scales_b)
91+
void const* mat_b, int64_t const* problem_m_offsets, size_t num_problems, size_t expected_m, size_t shape_n,
92+
size_t shape_k, cudaStream_t stream, float const* scales_a, float const* scales_b)
9393
{
9494
constexpr bool internal_quantize_a = !std::is_same_v<ElementA, __nv_fp8_e4m3>;
9595
constexpr bool internal_quantize_b = !std::is_same_v<ElementB, __nv_fp8_e4m3>;
@@ -138,21 +138,21 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void*
138138
{
139139
fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, per_token_per_128c_scales,
140140
reinterpret_cast<__nv_bfloat16 const*>(mat_b), fp8_mat_b, per_block_scales,
141-
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m_, max_shape_m_4_align_,
141+
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m, max_shape_m_4_align_,
142142
max_shape_m_32_align_padded_, shape_n, shape_k, stream, internal_quantize_a, internal_quantize_b);
143143
}
144144
else if constexpr (std::is_same_v<ElementA, __nv_bfloat16> && std::is_same_v<ElementB, __nv_fp8_e4m3>)
145145
{
146146
fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, per_token_per_128c_scales,
147147
nullptr, fp8_mat_b, per_block_scales, reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets,
148-
num_problems, expected_m_, max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, shape_k, stream,
148+
num_problems, expected_m, max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, shape_k, stream,
149149
internal_quantize_a, internal_quantize_b);
150150
}
151151
else if constexpr (std::is_same_v<ElementA, __nv_fp8_e4m3> && std::is_same_v<ElementB, __nv_fp8_e4m3>)
152152
{
153153
fp8_grouped_gemm_run(nullptr, fp8_mat_a, per_token_per_128c_scales,
154154
reinterpret_cast<__nv_bfloat16 const*>(mat_b), fp8_mat_b, per_block_scales,
155-
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m_, max_shape_m_4_align_,
155+
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m, max_shape_m_4_align_,
156156
max_shape_m_32_align_padded_, shape_n, shape_k, stream, internal_quantize_a, internal_quantize_b);
157157
}
158158
else
@@ -164,6 +164,15 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void*
164164
#endif
165165
}
166166

167+
template <typename ElementA, typename ElementB, typename ElementD>
168+
void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void* mat_d, void const* mat_a,
169+
void const* mat_b, int64_t const* problem_m_offsets, size_t num_problems, size_t shape_n, size_t shape_k,
170+
cudaStream_t stream, float const* scales_a, float const* scales_b)
171+
{
172+
moeGemm(mat_d, mat_a, mat_b, problem_m_offsets, num_problems, expected_m_, shape_n, shape_k, stream, scales_a,
173+
scales_b);
174+
}
175+
167176
template <typename ElementA, typename ElementB, typename ElementD>
168177
void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::strideBatchGemm(__nv_bfloat16* mat_d, int ld_d,
169178
int stride_d, __nv_fp8_e4m3* mat_a, int ld_a, int stride_a, __nv_fp8_e4m3* mat_b, int ld_b, int stride_b,

cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ class CutlassFp8BlockScaleGemmRunnerInterface
4040
cudaStream_t stream)
4141
= 0;
4242

43+
virtual void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
44+
size_t num_problems, size_t expected_m, size_t shape_n, size_t shape_k, cudaStream_t stream,
45+
float const* scales_a = nullptr, float const* scales_b = nullptr)
46+
= 0;
47+
4348
virtual void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
4449
size_t num_problems, size_t shape_n, size_t shape_k, cudaStream_t stream, float const* scales_a = nullptr,
4550
float const* scales_b = nullptr)
@@ -95,6 +100,10 @@ class CutlassFp8BlockScaleGemmRunner : public CutlassFp8BlockScaleGemmRunnerInte
95100
int ld_d, int shape_m, int shape_n, int shape_k, float const* scales_a, float const* scales_b,
96101
cudaStream_t stream) override;
97102

103+
void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
104+
size_t num_problems, size_t expected_m, size_t shape_n, size_t shape_k, cudaStream_t stream,
105+
float const* scales_a = nullptr, float const* scales_b = nullptr) override;
106+
98107
void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
99108
size_t num_problems, size_t shape_n, size_t shape_k, cudaStream_t stream, float const* scales_a = nullptr,
100109
float const* scales_b = nullptr) override;

0 commit comments

Comments
 (0)