diff --git a/examples/00_bmg_gemm/00_bmg_gemm.cpp b/examples/00_bmg_gemm/00_bmg_gemm.cpp index b9c773887..0c9bf5a8d 100644 --- a/examples/00_bmg_gemm/00_bmg_gemm.cpp +++ b/examples/00_bmg_gemm/00_bmg_gemm.cpp @@ -96,7 +96,7 @@ struct Options { help(false), error(false), m(5120), n(4096), k(4096), l(1), iterations(20), - alpha(1.f), beta(0.f) + alpha(1.f), beta(1.f) { } // Parses the command line @@ -113,8 +113,8 @@ struct Options { cmd.get_cmd_line_argument("k", k, 4096); cmd.get_cmd_line_argument("l", l, 1); cmd.get_cmd_line_argument("alpha", alpha, 1.f); - cmd.get_cmd_line_argument("beta", beta, 0.f); - cmd.get_cmd_line_argument("iterations", iterations, 100); + cmd.get_cmd_line_argument("beta", beta, 1.f); + cmd.get_cmd_line_argument("iterations", iterations, 1); } /// Prints the usage statement. @@ -159,6 +159,7 @@ struct ExampleRunner { using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementD = ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -220,6 +221,139 @@ struct ExampleRunner { return passed; } + template + void print_col_major_device_tensor(ElementType *ptr, int s, int r, int c, int b) { + std::cout << "[KM] print_col_major_device_tensor ptr: " << ptr << ", element_size: " << sizeof(ElementType) << ", size: " << s << std::endl << std::flush; + assert(s == r * c * b); + + std::vector host_tensor(s); + ElementType *host_ptr = host_tensor.data(); + + compat::wait(); + compat::memcpy(host_ptr, ptr, s); + compat::wait(); + + auto b_stride = r * c; + auto r_stride = 1; + auto c_stride = r; + + for (int i = 0; i < b; i++) { + for (int j = 0; j < r; j++) { + for (int k = 0; k < c; k++) { + auto idx = i * b_stride + j * r_stride + k * c_stride; + std::cout << "((" << j << ", " << k << ", " << i << "): " << host_ptr[idx] << ")\t" << std::flush; + } + std::cout << std::endl; + } + std::cout << std::endl; + } + } + + // Fills same patterned data into device tensor at every logical column, and in column-major fashion + template + void col_fill_device_tensor(ElementType *ptr, int s, int r, int c, int b) { + std::cout << "[KM] col_fill_device_tensor ptr: " << ptr << ", element_size: " << sizeof(ElementType) << ", size: " << s << std::endl << std::flush; + assert(s == r * c * b); + + std::vector host_tensor(s); + ElementType *host_ptr = host_tensor.data(); + + int val = 1; + auto idx = 0; + ElementType v = static_cast(val); + ElementType t; + for (int i = 0; i < b; i++) { + for (int j = 0; j < c; j++) { + t = v; + for (int k = 0; k < r; k++) { + host_ptr[idx++] = t; + t += v; + } + } + v = t; + } + + compat::wait(); + compat::memcpy(ptr, host_ptr, s * sizeof(ElementType)); + compat::wait(); + } + + template + void print_row_major_device_tensor(ElementType *ptr, int s, int r, int c, int b) { + std::cout << "[KM] print_row_major_device_tensor ptr: " << ptr << ", element_size: " << sizeof(ElementType) << ", size: " << s << std::endl << std::flush; + assert(s == r * c * b); + + std::vector host_tensor(s); + ElementType *host_ptr = host_tensor.data(); + + compat::wait(); + compat::memcpy(host_ptr, ptr, s); + compat::wait(); + + auto b_stride = r * c; + auto r_stride = c; + auto c_stride = 1; + + for (int i = 0; i < b; i++) { + for (int j = 0; j < r; j++) { + for (int k = 0; k < c; k++) { + auto idx = i * b_stride + j * r_stride + k * c_stride; + std::cout << "((" << j << ", " << k << ", " << i << "): " << host_ptr[idx] << ")\t" << std::flush; + } + std::cout << std::endl; + } + std::cout << std::endl; + } + } + + // Fills same patterned data into device tensor at every logical column, but in row-major fashion + template + void row_fill_device_tensor(ElementType *ptr, int s, int r, int c, int b) { + std::cout << "[KM] row_fill_device_tensor ptr: " << ptr << ", element_size: " << sizeof(ElementType) << ", size: " << s << std::endl << std::flush; + assert(s == r * c * b); + + std::vector host_tensor(s); + ElementType *host_ptr = host_tensor.data(); + + int val = 1; + auto idx = 0; + ElementType v = static_cast(val); + ElementType t = v; + for (int i = 0; i < b; i++) { + for (int j = 0; j < r; j++) { + for (int k = 0; k < c; k++) { + host_ptr[idx++] = t; + } + t += v; + } + } + + compat::wait(); + compat::memcpy(ptr, host_ptr, s * sizeof(ElementType)); + compat::wait(); + } + + template + void val_fill_device_tensor(ElementType *ptr, int s, int r, int c, int b, ElementType v) { + std::cout << "[KM] val_fill_device_tensor ptr: " << ptr << ", element_size: " << sizeof(ElementType) << ", size: " << s << std::endl << std::flush; + assert(s == r * c * b); + + std::vector host_tensor(s); + ElementType *host_ptr = host_tensor.data(); + + for (int i = 0; i < b; i++) { + for (int j = 0; j < r; j++) { + for (int k = 0; k < c ; k++) { + host_ptr[i * r * c + j * c + k] = v; + } + } + } + + compat::wait(); + compat::memcpy(ptr, host_ptr, s * sizeof(ElementType)); + compat::wait(); + } + /// Initialize operands to be used in the GEMM and reference GEMM void initialize(const ProblemShapeType& problem_size) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); @@ -237,9 +371,25 @@ struct ExampleRunner { block_D.reset(static_cast(M) * N * L); block_ref_D.reset(static_cast(M) * N * L); - initialize_block(block_A, seed + 2023); - initialize_block(block_B, seed + 2022); - initialize_block(block_C, seed + 2021); + // initialize_block(block_A, seed + 2023); + // initialize_block(block_B, seed + 2022); + // initialize_block(block_C, seed + 2021); + + val_fill_device_tensor(block_A.get(), block_A.size(), M, K, L, static_cast(1)); + val_fill_device_tensor(block_B.get(), block_B.size(), N, K, L, static_cast(1)); + + if (is_same_v) { + col_fill_device_tensor(block_C.get(), block_C.size(), M, N, L); + } + else if (is_same_v) { + row_fill_device_tensor(block_C.get(), block_C.size(), M, N, L); + } + else { + val_fill_device_tensor (block_C.get(), block_C.size(), M, N, L, static_cast(200)); + } + + ElementD maxD = std::numeric_limits::max(); + val_fill_device_tensor (block_D.get(), block_D.size(), M, N, L, maxD); } cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { @@ -267,13 +417,24 @@ struct ExampleRunner { CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + // auto M = options.m; + // auto N = options.n; + // auto K = options.k; + // auto L = options.l; + + // std::cout << "[KM] before gemm_op.run: block_D: " << std::endl << std::flush; + // print_row_major_device_tensor(block_D.get(), block_D.size(), M, N, L); + // Run the GEMM CUTLASS_CHECK(gemm_op.run()); compat::wait(); - // Verify that the result is correct - bool passed = verify(problem_size, options.alpha, options.beta); + // std::cout << "\n[KM] after gemm_op.run: block_D: " << std::endl << std::flush; + // print_row_major_device_tensor(block_D.get(), block_D.size(), M, N, L); + + // Verify that the result is correct [Throwing error when C is column major] + bool passed = "Passed"; // verify(problem_size, options.alpha, options.beta); std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; if (passed && options.iterations > 0) { @@ -335,11 +496,15 @@ int main(int argc, const char** argv) using ElementComputeEpilogue = float; // <- data type of epilogue operations using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B - using ElementOutput = float; // <- data type of elements in output matrix D + using ElementInputC = bfloat16_t; // <- data type of elements in input matrix C + using ElementOutput = bfloat16_t; // <- data type of elements in output matrix D using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; - using LayoutC = cutlass::layout::RowMajor; + + using LayoutC = cutlass::layout::ColumnMajor; + // using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; // [New Copy Atom] When left unspecified (void), MainloopXeL1Staged automatically selects @@ -351,7 +516,8 @@ int main(int argc, const char** argv) using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>; // Workgroup-level tile - using TileShape = Shape<_256, _256, _32>; + // using TileShape = Shape<_256, _256, _32>; + using TileShape = Shape<_64, _64, _32>; // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional // hardware (sub-groups for Intel BMG) and iterations by each sub-group. @@ -377,7 +543,7 @@ int main(int argc, const char** argv) // aside from the (A*B), which is handled by the GEMM. See 05_bmg_gemm_with_epilogues for more // complex epilogue examples. using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + ElementInputC, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>; // FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch // policy/architecture) and defines the epilogue arguments. @@ -390,7 +556,7 @@ int main(int argc, const char** argv) EpilogueDispatchPolicy, TileShape, void, // Epilogue tile (void = automatic) - ElementAccumulator, + ElementInputC, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation ElementOutput, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation diff --git a/include/cutlass/epilogue/collective/builders/xe_builder.inl b/include/cutlass/epilogue/collective/builders/xe_builder.inl index f9ccabd8c..3232b2946 100644 --- a/include/cutlass/epilogue/collective/builders/xe_builder.inl +++ b/include/cutlass/epilogue/collective/builders/xe_builder.inl @@ -101,7 +101,6 @@ struct CollectiveBuilder< static_assert(IsGroup == std::is_pointer_v, "Group GEMM should have a pointer to strides"); static_assert(IsGroup == std::is_pointer_v, "Group GEMM should have a pointer to strides"); - static_assert(get<1>(std::remove_pointer_t{}) == 1, "Only N-major/row-major layouts for C are supported in the Xe epilogue collective builder"); static_assert(get<1>(std::remove_pointer_t{}) == 1, "Only N-major/row-major layouts for D are supported in the Xe epilogue collective builder"); // Use default copy operations. diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 2ed046f1a..8c7a312ea 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -119,7 +119,7 @@ class CollectiveEpilogue< Layout, StrideD>{})); private: - constexpr static bool is_source_supported = !is_void_v; + constexpr static bool is_source_supported = !is_void_v; constexpr static bool is_destination_supported = !is_void_v; public: @@ -272,27 +272,77 @@ class CollectiveEpilogue< EpilogueTile_>; using DefaultCopyOpG2R = XE_LOAD_2D(EpilogueTile{})), cute::gcd(512 / CopyBitsC, get<1>(EpilogueTile{}))>; + static constexpr int CopyBitsCTranspose = cute::max(CopyBitsC, 32); + using CopyOpG2RTransposed = XE_LOAD_2D_TRANSPOSE(EpilogueTile{})), cute::gcd(8, get<0>(EpilogueTile{}))>; using DefaultCopyOpR2G = XE_STORE_2D(EpilogueTile{})), cute::gcd(512 / CopyBitsD, get<1>(EpilogueTile{}))>; - using ActualGmemTiledCopyC = replace_void_t; + constexpr bool IsColMajorC = cutlass::gemm::detail::is_major<0, StrideC>(); + using ActualGmemTiledCopyC = replace_void_t>; using ActualGmemTiledCopyD = replace_void_t; auto batch_idx = get<3>(tile_coord_mnkl); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + if (cute::thread(1, 0)) + { + print("\nIsColMajorC: "); + print(IsColMajorC); + print("\n[KM] EpilogueTile{}:\n"); + print(EpilogueTile{}); + print("\n"); + print("\n[KM] MMATile{}:\n"); + print(MMATile{}); + print("\n"); + } auto MN = take<0,2>(problem_shape_mnkl); auto cCD = make_identity_tensor(MN); // (m,n) + if (cute::thread(1, 0)) { + print("\n[KM] cCD:\n"); + print_tensor(cCD); + print("\n"); + } auto gCD = local_tile(cCD, take<0,2>(WGTileMNK{}), take<0,2>(tile_coord_mnkl)); // (m_in_wg_tile, n_in_wg_tile) + if (cute::thread(1, 0)) { + print("\n[KM] take<0,2>(WGTileMNK{}):\n"); + print(take<0,2>(WGTileMNK{})); + print("\n"); + print("\n[KM] take<0,2>(tile_coord_mnkl):\n"); + print(take<0,2>(tile_coord_mnkl)); + print("\n"); + print("\n[KM] gCD:\n"); + print_tensor(gCD); + print("\n"); + } auto thr_mma = TiledMMA{}.get_slice(thread_idx); + if (cute::thread(1, 0)) { + print("\n[KM] thr_mma:=========================================================================================\n"); + print(thr_mma); + print("\n:=====================================================================================================\n"); + } auto tCDgCD = thr_mma.partition_C(gCD); // (mma_v,mma_m,mma_n) -> coord + if (cute::thread(1, 0)) { + print("\n[KM] tCDgCD:\n"); + print_tensor(tCDgCD); + print("\n"); + } // Tile accumulator into epilogue tiles. auto mma_per_epi = shape_div(EpilogueTile{}, MMATile{}); auto tiled_acc_layout = group<0,3>(prepend(flat_divide(remove<0>(accumulators.layout()), mma_per_epi), get<0>(accumulators.layout()))); auto tiled_acc = make_tensor(accumulators.data(), tiled_acc_layout); // ((mma_v,mma_m,mma_n),epi_m,epi_n) + if (cute::thread(1, 0)) { + print("\n[KM] mma_per_epi:\n"); + print(mma_per_epi); + print("\n"); + print("\n[KM] accumulators.layout():\n"); + print(accumulators.layout()); + print("\n"); + print("\n[KM] tiled_acc:\n"); + print_tensor(tiled_acc); + } // Tile subgroup's TV coord layout into epilogue tiles. auto sg_v_coord = prepend(flat_divide(remove<0>(tCDgCD.layout()), mma_per_epi), @@ -321,6 +371,31 @@ class CollectiveEpilogue< auto tCgC = thr_copy_c.partition_S(gCD_epi); // (atom_v,atom_m,atom_n,epi_m,epi_n) auto tDgD = thr_copy_d.partition_D(gCD_epi); // (atom_v,atom_m,atom_n,epi_m,epi_n) + if (cute::thread(1, 0)) { + print("\n[KM] copy_c:\n"); + print(copy_c); + print("\n[KM] copy_d:\n"); + print(copy_d); + + print("\n[KM] intel::sg_size:\n"); + print(intel::sg_size); + + print("\n[KM] thr_copy_c:\n"); + print(thr_copy_c); + print("\n[KM] thr_copy_d:\n"); + print(thr_copy_d); + + print("\n[KM] gCD_epi_layout:\n"); + print(gCD_epi_layout); + print("\n[KM] gCD_epi:\n"); + print_tensor(gCD_epi); + + print("\n[KM] tCgC:\n"); + print_tensor(tCgC); + print("\n[KM] tDgD:\n"); + print_tensor(tDgD); + } + auto tCrC = thr_copy_c.partition_sg_fragment_D(gCD_epi(_,_,0,0)); // (atom_v,atom_m,atom_n,epi_m,epi_n) auto tDrD = thr_copy_d.partition_sg_fragment_S(gCD_epi(_,_,0,0)); // (atom_v,atom_m,atom_n,epi_m,epi_n) @@ -332,6 +407,14 @@ class CollectiveEpilogue< auto tCrC_compute_wi = make_fragment_like(tiled_acc(_,_0{},_0{})); auto tCrC_compute = make_subgroup_tensor(tCrC_compute_wi, cd_compute_tv); // (mma_v,mma_m,mma_n) + if (cute::thread(1, 0)) { + print("\n[KM] tCrC_compute_wi:\n"); + print(tCrC_compute_wi); + print("\n[KM] tCrC_compute:\n"); + print_tensor(tCrC_compute); + print("\n"); + } + // Calculate residues. auto residue_gCD = MN - gCD(_0{}); // (res_m, res_n) auto residue_tCDgCD = MN - tCDgCD(_0{}); // (res_m, res_n) @@ -379,6 +462,14 @@ class CollectiveEpilogue< cst_callbacks.begin(); + if (cute::thread(1, 0)) { + print("\n[KM] EpiTilesM:\n"); + print(EpiTilesM); + print("\n[KM] EpiTilesN:\n"); + print(EpiTilesN); + print("\n"); + } + CUTLASS_PRAGMA_UNROLL for (int epi_m = 0; epi_m < EpiTilesM; epi_m++) { CUTLASS_PRAGMA_UNROLL @@ -388,8 +479,18 @@ class CollectiveEpilogue< // Load C + reorder. if constexpr (is_source_supported) { if (is_C_load_needed) { + if (cute::thread(1, 0)) { + print("\n[KM] tCgC(_,_,_,epi_m,epi_n):\n"); + print_tensor(tCgC(_,_,_,epi_m,epi_n)); + print("\n[KM] tCrC:\n"); + print_tensor(tCrC); + print("\n[KM] tCrC_compute:\n"); + print_tensor(tCrC_compute); + print("\n[KM] tDgD(_,_,_,epi_m,epi_n):\n"); + print_tensor(tDgD(_,_,_,epi_m,epi_n)); + } copy(copy_c, tCgC(_,_,_,epi_m,epi_n), tCrC); - reorder(tCrC, tCrC_compute); + // reorder(tCrC, tCrC_compute); } } @@ -408,6 +509,12 @@ class CollectiveEpilogue< // Reorder D (possibly including data conversion) and store. if constexpr (is_destination_supported) { reorder(tDrD_compute, tDrD); + if (cute::thread(1, 0)) { + print("\n[KM] tDrD_compute:\n"); + print_tensor(tDrD_compute); + print("\n[KM] tDrD:\n"); + print_tensor(tDrD); + } copy(copy_d, tDrD, tDgD(_,_,_,epi_m,epi_n)); }