diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp index 2b8eee71a9ec..fda222d34b24 100644 --- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp +++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp @@ -452,6 +452,29 @@ enumerateMatmulTileRiscv64(TypeRange elementTypes, DictionaryAttr config) { }; } } + // This adds support for our s8*sx`8->s32 kernel. + if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(8) && + out.isSignlessInteger(32)) { + // This logic follows the f32 case, as both use 32-bit accumulators. + // For your +zvl512b target, vlen = 512. + // N0 = (512 bits / 32 bits_per_element) * 4_LMUL = 64 elements. + + // N0 for LMUL=8 path (M0=16) + int N0_lmul8 = vlen / 4; + // N0 for LMUL=4 path (M0=8, 4, 2, 1) + int N0_lmul4 = vlen / 8; + + return { + // --- LMUL=8 Path --- + TileMxNxK{16, N0_lmul8, 1}, // Target tile for s8s8s32 (LMUL=8) + + // --- LMUL=4 Paths --- + TileMxNxK{8, N0_lmul4, 1}, // Truncation (LMUL=4) + TileMxNxK{4, N0_lmul4, 1}, // Truncation (LMUL=4) + TileMxNxK{2, N0_lmul4, 1}, // Truncation (LMUL=4) + TileMxNxK{1, N0_lmul4, 1}, // Truncation (vecmat) (LMUL=4) + }; + } // Fallback - no architecture-optimized tile size for this case. return {}; } diff --git a/runtime/src/iree/builtins/ukernel/arch/riscv_64/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/arch/riscv_64/CMakeLists.txt index 9469b86b277e..ab6c80b44c22 100644 --- a/runtime/src/iree/builtins/ukernel/arch/riscv_64/CMakeLists.txt +++ b/runtime/src/iree/builtins/ukernel/arch/riscv_64/CMakeLists.txt @@ -42,6 +42,7 @@ iree_bitcode_library( "${PROJECT_BINARY_DIR}/runtime/src/iree/schemas/cpu_data_headers_filegroup.stamp" "common_riscv_64.h" "mmt4d_riscv_64_internal.h" + "bme.h" "mmt4d_riscv_64_tiles.inl" "pack_riscv_64_internal.h" "unpack_riscv_64_internal.h" @@ -177,6 +178,8 @@ iree_cc_library( riscv_64_v SRCS "mmt4d_riscv_64_v.c" + HDRS + "bme.h" COPTS "${IREE_UK_COPTS_RISCV_64_V}" DEPS diff --git a/runtime/src/iree/builtins/ukernel/arch/riscv_64/bme.h b/runtime/src/iree/builtins/ukernel/arch/riscv_64/bme.h new file mode 100644 index 000000000000..6ee763e309f1 --- /dev/null +++ b/runtime/src/iree/builtins/ukernel/arch/riscv_64/bme.h @@ -0,0 +1,64 @@ +// HACK reuse the scalar registers to avoid assembler hacking for now +#define m0 "x0" +#define m1 "x1" +#define m2 "x2" +#define m3 "x3" +#define m4 "x4" +#define m5 "x5" +#define m6 "x6" +#define m7 "x7" + +#define v0 "x0" +#define v1 "x1" +#define v2 "x2" +#define v3 "x3" +#define v4 "x4" +#define v5 "x5" +#define v6 "x6" +#define v7 "x7" +#define v8 "x8" +#define v9 "x9" +#define v10 "x10" +#define v11 "x11" +#define v12 "x12" +#define v13 "x13" +#define v14 "x14" +#define v15 "x15" +#define v16 "x16" +#define v17 "x17" +#define v18 "x18" +#define v19 "x19" +#define v20 "x20" +#define v21 "x21" +#define v22 "x22" +#define v23 "x23" +#define v24 "x24" +#define v25 "x25" +#define v26 "x26" +#define v27 "x27" +#define v28 "x28" +#define v29 "x29" +#define v30 "x30" +#define v31 "x31" + +// opmvx. f6=b101010, f7=b1010101 +#define VMV_RV(md, rs1, vs2) \ + asm volatile(".insn r 0x57, 0x6, 0x55, " md ", %0, " vs2 : : "r"(rs1)); + +// opmvx. f6=b101110, f7=b1011101 +#define VMV_VR(vd, rs1, ms2) \ + asm volatile(".insn r 0x57, 0x6, 0x5d, " vd ", %0, " ms2 : : "r"(rs1)); + +// opmvx. f6=b101100, f7=b1011001 +#define OPMVINBCAST(md, vs2) \ + asm volatile(".insn r 0x57, 0x6, 0x59, " md ", x0, " vs2); + +// opmvv. f6=b101000, f7=b1010001 +#define VOPACC(md, vs2, vs1) \ + asm volatile(".insn r 0x57, 0x2, 0x51, " md ", " vs1 ", " vs2); + +#include // For size_t +#include // For int8_t and int32_t + +//void i8_mm_bme_1x2(int32_t* c_bias, int32_t* c_out, int8_t* at, int8_t* b, +// size_t M, size_t N, size_t K); \ No newline at end of file diff --git a/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_tiles.inl b/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_tiles.inl index 4f4d0413a907..48247dc98dc9 100644 --- a/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_tiles.inl +++ b/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_tiles.inl @@ -32,3 +32,10 @@ IREE_UK_MMT4D_TILE(riscv_64, f16, f16, f16, 1, 1, _zvfh) IREE_UK_MMT4D_TILE(riscv_64, f16, f16, f16, 2, 1, _zvfh) IREE_UK_MMT4D_TILE(riscv_64, f16, f16, f16, 4, 1, _zvfh) IREE_UK_MMT4D_TILE(riscv_64, f16, f16, f16, 7, 1, _zvfh) + +// s8s8s32 tiles using the 'v' extension +IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 1, 1, _v) +IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 2, 1, _v) +IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 4, 1, _v) +IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 8, 1, _v) +IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 16, 1, _v) \ No newline at end of file diff --git a/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_v.c b/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_v.c index 2517080da24a..7280e3aef0af 100644 --- a/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_v.c +++ b/runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_v.c @@ -8,6 +8,7 @@ #include "iree/builtins/ukernel/arch/riscv_64/common_riscv_64.h" #include "iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_internal.h" +#include "bme.h" IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void iree_uk_mmt4d_tile_f32f32f32_1xXXx1_to_7xXXx1_riscv_64_v( @@ -121,6 +122,87 @@ iree_uk_mmt4d_tile_f32f32f32_1xXXx1_to_7xXXx1_riscv_64_v( } } +IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void +iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v( + void* IREE_UK_RESTRICT out_tile, + const void* IREE_UK_RESTRICT lhs_panel, + const void* IREE_UK_RESTRICT rhs_panel, + const iree_uk_mmt4d_params_t* params, int M0) { + IREE_UK_ASSERT(M0 >= 1 && M0 <= 16); + iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile; + const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel; + const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel; + + const int N0 = params->N0; + const int K = params->K; + size_t ml = M0; + size_t vl = N0; + + // Performance case for M0=16 (LMUL=8) + if (M0 == 16) { + // init m0 to zero (LMUL=8) + asm volatile("vsetvli zero, %0, e32, m8, ta, ma" : : "r"(vl)); + asm volatile("vmv.v.i v0, 0"); + OPMVINBCAST(m0, v0); + + // K-loop unrolled by 2 + size_t k = 0; + while (k + 2 <= K) { + asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(ml)); // ml=16 + asm volatile("vle8.v v16, (%0)" : : "r"(&lhs_ptr[k * M0])); + asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(vl)); // vl=N0 + asm volatile("vle8.v v18, (%0)" : : "r"(&rhs_ptr[k * N0])); + VOPACC(m0, v18, v16); + k++; + asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(ml)); // ml=16 + asm volatile("vle8.v v20, (%0)" : : "r"(&lhs_ptr[k * M0])); + asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(vl)); // vl=N0 + asm volatile("vle8.v v22, (%0)" : : "r"(&rhs_ptr[k * N0])); + VOPACC(m0, v22, v20); + k++; + } + if (k < K) { + asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(ml)); // ml=16 + asm volatile("vle8.v v16, (%0)" : : "r"(&lhs_ptr[k * M0])); + asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(vl)); // vl=N0 + asm volatile("vle8.v v18, (%0)" : : "r"(&rhs_ptr[k * N0])); + VOPACC(m0, v18, v16); + } + + // store results + asm volatile("vsetvli zero, %0, e32, m8, ta, ma" : : "r"(vl)); + for (size_t r = 0; r < ml; r++) { // ml=16 + VMV_VR(v0, r, m0); + asm volatile("vse32.v v0, (%0)" : : "r"(&out_ptr[r * N0])); + } + } + // Tail case for M0 < 16 (using LMUL=4) + else { + // 1. Initialize accumulators to ZERO (LMUL=4) + asm volatile("vsetvli zero, %0, e32, m4, ta, ma" : : "r"(vl)); + asm volatile("vmv.v.i v0, 0"); + OPMVINBCAST(m3, v0); // Initialize m3 to zero + + // 2. Main K-loop + for (int k = 0; k < K; ++k) { + asm volatile("vsetvli zero, %0, e8, m1, ta, ma" : : "r"(ml)); + asm volatile("vle8.v v5, (%0)" : : "r"(&lhs_ptr[k * M0])); + + asm volatile("vsetvli zero, %0, e8, m1, ta, ma" : : "r"(vl)); + asm volatile("vle8.v v4, (%0)" : : "r"(&rhs_ptr[k * N0])); + + VOPACC(m3, v4, v5); + } + + // 3. Store results + asm volatile("vsetvli zero, %0, e32, m4, ta, ma" : : "r"(vl)); + for (size_t r = 0; r < ml; r++) { + VMV_VR(v0, r, m3); + asm volatile("vse32.v v0, (%0)" : : "r"(&out_ptr[r * N0])); + } + } +} + IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( iree_uk_mmt4d_tile_f32f32f32_1xXXx1_to_7xXXx1_riscv_64_v, iree_uk_mmt4d_tile_f32f32f32_1xXXx1_riscv_64_v, 1) @@ -133,3 +215,22 @@ IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( iree_uk_mmt4d_tile_f32f32f32_1xXXx1_to_7xXXx1_riscv_64_v, iree_uk_mmt4d_tile_f32f32f32_7xXXx1_riscv_64_v, 7) + +// *** UPDATED SECTION *** +// Point all s8s8s32 tiles to the new generic function +IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( + iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v, + iree_uk_mmt4d_tile_s8s8s32_1xXXx1_riscv_64_v, 1) +IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( + iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v, + iree_uk_mmt4d_tile_s8s8s32_2xXXx1_riscv_64_v, 2) +IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( + iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v, + iree_uk_mmt4d_tile_s8s8s32_4xXXx1_riscv_64_v, 4) +IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( + iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v, + iree_uk_mmt4d_tile_s8s8s32_8xXXx1_riscv_64_v, 8) +// Add the new M0=16 tile +IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0( + iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v, + iree_uk_mmt4d_tile_s8s8s32_16xXXx1_riscv_64_v, 16) \ No newline at end of file diff --git a/runtime/src/iree/builtins/ukernel/arch/riscv_64/query_tile_sizes_riscv_64_entry_point.c b/runtime/src/iree/builtins/ukernel/arch/riscv_64/query_tile_sizes_riscv_64_entry_point.c index 352a0725a7ad..0f9a6d0b926e 100644 --- a/runtime/src/iree/builtins/ukernel/arch/riscv_64/query_tile_sizes_riscv_64_entry_point.c +++ b/runtime/src/iree/builtins/ukernel/arch/riscv_64/query_tile_sizes_riscv_64_entry_point.c @@ -19,6 +19,21 @@ iree_uk_query_matmul_tile_sizes_riscv_64_f32f32f32( return (iree_uk_matmul_tile_sizes_t){.M = 8, .K = 1, .N = 8}; } +static iree_uk_matmul_tile_sizes_t +iree_uk_query_matmul_tile_sizes_riscv_64_s8s8s32( + const iree_uk_query_tile_sizes_2d_params_t* params) { +#if defined(IREE_UK_BUILD_RISCV_64_V) + if (iree_uk_cpu_riscv_64_v(params->cpu_data)) { + // Corresponds to the new target M0=16. + // N=32 is based on a minimum-VLEN (128-bit) and the new LMUL=8. + // N0 = (128 bits / 32 bits_per_element) * 8_LMUL = 32. + return (iree_uk_matmul_tile_sizes_t){.M = 16, .K = 1, .N = 32}; + } +#endif + // generic fallback + return (iree_uk_matmul_tile_sizes_t){.M = 8, .K = 1, .N = 8}; +} + bool iree_uk_query_matmul_tile_sizes_arch( const iree_uk_query_tile_sizes_2d_params_t* params, iree_uk_matmul_tile_sizes_t* out_matmul_tile_sizes) { @@ -27,8 +42,12 @@ bool iree_uk_query_matmul_tile_sizes_arch( *out_matmul_tile_sizes = iree_uk_query_matmul_tile_sizes_riscv_64_f32f32f32(params); return true; + } else if (op == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32) { + *out_matmul_tile_sizes = + iree_uk_query_matmul_tile_sizes_riscv_64_s8s8s32(params); + return true; } else { // Shouldn't happen, validated earlier. return false; } -} +} \ No newline at end of file