@@ -88,8 +88,8 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::gemm(__nv_fp8
8888
8989template <typename ElementA, typename ElementB, typename ElementD>
9090void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void * mat_d, void const * mat_a,
91- void const * mat_b, int64_t const * problem_m_offsets, size_t num_problems, size_t shape_n , size_t shape_k ,
92- cudaStream_t stream, float const * scales_a, float const * scales_b)
91+ void const * mat_b, int64_t const * problem_m_offsets, size_t num_problems, size_t expected_m , size_t shape_n ,
92+ size_t shape_k, cudaStream_t stream, float const * scales_a, float const * scales_b)
9393{
9494 constexpr bool internal_quantize_a = !std::is_same_v<ElementA, __nv_fp8_e4m3>;
9595 constexpr bool internal_quantize_b = !std::is_same_v<ElementB, __nv_fp8_e4m3>;
@@ -138,21 +138,21 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void*
138138 {
139139 fp8_grouped_gemm_run (reinterpret_cast <__nv_bfloat16 const *>(mat_a), fp8_mat_a, per_token_per_128c_scales,
140140 reinterpret_cast <__nv_bfloat16 const *>(mat_b), fp8_mat_b, per_block_scales,
141- reinterpret_cast <__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m_ , max_shape_m_4_align_,
141+ reinterpret_cast <__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m , max_shape_m_4_align_,
142142 max_shape_m_32_align_padded_, shape_n, shape_k, stream, internal_quantize_a, internal_quantize_b);
143143 }
144144 else if constexpr (std::is_same_v<ElementA, __nv_bfloat16> && std::is_same_v<ElementB, __nv_fp8_e4m3>)
145145 {
146146 fp8_grouped_gemm_run (reinterpret_cast <__nv_bfloat16 const *>(mat_a), fp8_mat_a, per_token_per_128c_scales,
147147 nullptr , fp8_mat_b, per_block_scales, reinterpret_cast <__nv_bfloat16*>(mat_d), problem_m_offsets,
148- num_problems, expected_m_ , max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, shape_k, stream,
148+ num_problems, expected_m , max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, shape_k, stream,
149149 internal_quantize_a, internal_quantize_b);
150150 }
151151 else if constexpr (std::is_same_v<ElementA, __nv_fp8_e4m3> && std::is_same_v<ElementB, __nv_fp8_e4m3>)
152152 {
153153 fp8_grouped_gemm_run (nullptr , fp8_mat_a, per_token_per_128c_scales,
154154 reinterpret_cast <__nv_bfloat16 const *>(mat_b), fp8_mat_b, per_block_scales,
155- reinterpret_cast <__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m_ , max_shape_m_4_align_,
155+ reinterpret_cast <__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m , max_shape_m_4_align_,
156156 max_shape_m_32_align_padded_, shape_n, shape_k, stream, internal_quantize_a, internal_quantize_b);
157157 }
158158 else
@@ -164,6 +164,15 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void*
164164#endif
165165}
166166
167+ template <typename ElementA, typename ElementB, typename ElementD>
168+ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void * mat_d, void const * mat_a,
169+ void const * mat_b, int64_t const * problem_m_offsets, size_t num_problems, size_t shape_n, size_t shape_k,
170+ cudaStream_t stream, float const * scales_a, float const * scales_b)
171+ {
172+ moeGemm (mat_d, mat_a, mat_b, problem_m_offsets, num_problems, expected_m_, shape_n, shape_k, stream, scales_a,
173+ scales_b);
174+ }
175+
167176template <typename ElementA, typename ElementB, typename ElementD>
168177void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::strideBatchGemm(__nv_bfloat16* mat_d, int ld_d,
169178 int stride_d, __nv_fp8_e4m3* mat_a, int ld_a, int stride_a, __nv_fp8_e4m3* mat_b, int ld_b, int stride_b,
0 commit comments