Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,29 @@ enumerateMatmulTileRiscv64(TypeRange elementTypes, DictionaryAttr config) {
};
}
}
// This adds support for our s8*sx`8->s32 kernel.
if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(8) &&
out.isSignlessInteger(32)) {
// This logic follows the f32 case, as both use 32-bit accumulators.
// For your +zvl512b target, vlen = 512.
// N0 = (512 bits / 32 bits_per_element) * 4_LMUL = 64 elements.

// N0 for LMUL=8 path (M0=16)
int N0_lmul8 = vlen / 4;
// N0 for LMUL=4 path (M0=8, 4, 2, 1)
int N0_lmul4 = vlen / 8;

return {
// --- LMUL=8 Path ---
TileMxNxK{16, N0_lmul8, 1}, // Target tile for s8s8s32 (LMUL=8)

// --- LMUL=4 Paths ---
TileMxNxK{8, N0_lmul4, 1}, // Truncation (LMUL=4)
TileMxNxK{4, N0_lmul4, 1}, // Truncation (LMUL=4)
TileMxNxK{2, N0_lmul4, 1}, // Truncation (LMUL=4)
TileMxNxK{1, N0_lmul4, 1}, // Truncation (vecmat) (LMUL=4)
};
}
// Fallback - no architecture-optimized tile size for this case.
return {};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ iree_bitcode_library(
"${PROJECT_BINARY_DIR}/runtime/src/iree/schemas/cpu_data_headers_filegroup.stamp"
"common_riscv_64.h"
"mmt4d_riscv_64_internal.h"
"bme.h"
"mmt4d_riscv_64_tiles.inl"
"pack_riscv_64_internal.h"
"unpack_riscv_64_internal.h"
Expand Down Expand Up @@ -177,6 +178,8 @@ iree_cc_library(
riscv_64_v
SRCS
"mmt4d_riscv_64_v.c"
HDRS
"bme.h"
COPTS
"${IREE_UK_COPTS_RISCV_64_V}"
DEPS
Expand Down
64 changes: 64 additions & 0 deletions runtime/src/iree/builtins/ukernel/arch/riscv_64/bme.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// HACK reuse the scalar registers to avoid assembler hacking for now
#define m0 "x0"
#define m1 "x1"
#define m2 "x2"
#define m3 "x3"
#define m4 "x4"
#define m5 "x5"
#define m6 "x6"
#define m7 "x7"

#define v0 "x0"
#define v1 "x1"
#define v2 "x2"
#define v3 "x3"
#define v4 "x4"
#define v5 "x5"
#define v6 "x6"
#define v7 "x7"
#define v8 "x8"
#define v9 "x9"
#define v10 "x10"
#define v11 "x11"
#define v12 "x12"
#define v13 "x13"
#define v14 "x14"
#define v15 "x15"
#define v16 "x16"
#define v17 "x17"
#define v18 "x18"
#define v19 "x19"
#define v20 "x20"
#define v21 "x21"
#define v22 "x22"
#define v23 "x23"
#define v24 "x24"
#define v25 "x25"
#define v26 "x26"
#define v27 "x27"
#define v28 "x28"
#define v29 "x29"
#define v30 "x30"
#define v31 "x31"

// opmvx. f6=b101010, f7=b1010101
#define VMV_RV(md, rs1, vs2) \
asm volatile(".insn r 0x57, 0x6, 0x55, " md ", %0, " vs2 : : "r"(rs1));

// opmvx. f6=b101110, f7=b1011101
#define VMV_VR(vd, rs1, ms2) \
asm volatile(".insn r 0x57, 0x6, 0x5d, " vd ", %0, " ms2 : : "r"(rs1));

// opmvx. f6=b101100, f7=b1011001
#define OPMVINBCAST(md, vs2) \
asm volatile(".insn r 0x57, 0x6, 0x59, " md ", x0, " vs2);

// opmvv. f6=b101000, f7=b1010001
#define VOPACC(md, vs2, vs1) \
asm volatile(".insn r 0x57, 0x2, 0x51, " md ", " vs1 ", " vs2);

#include <stddef.h> // For size_t
#include <stdint.h> // For int8_t and int32_t

//void i8_mm_bme_1x2(int32_t* c_bias, int32_t* c_out, int8_t* at, int8_t* b,
// size_t M, size_t N, size_t K);
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,10 @@ IREE_UK_MMT4D_TILE(riscv_64, f16, f16, f16, 1, 1, _zvfh)
IREE_UK_MMT4D_TILE(riscv_64, f16, f16, f16, 2, 1, _zvfh)
IREE_UK_MMT4D_TILE(riscv_64, f16, f16, f16, 4, 1, _zvfh)
IREE_UK_MMT4D_TILE(riscv_64, f16, f16, f16, 7, 1, _zvfh)

// s8s8s32 tiles using the 'v' extension
IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 1, 1, _v)
IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 2, 1, _v)
IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 4, 1, _v)
IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 8, 1, _v)
IREE_UK_MMT4D_TILE(riscv_64, s8, s8, s32, 16, 1, _v)
101 changes: 101 additions & 0 deletions runtime/src/iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "iree/builtins/ukernel/arch/riscv_64/common_riscv_64.h"
#include "iree/builtins/ukernel/arch/riscv_64/mmt4d_riscv_64_internal.h"
#include "bme.h"

IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void
iree_uk_mmt4d_tile_f32f32f32_1xXXx1_to_7xXXx1_riscv_64_v(
Expand Down Expand Up @@ -121,6 +122,87 @@ iree_uk_mmt4d_tile_f32f32f32_1xXXx1_to_7xXXx1_riscv_64_v(
}
}

IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void
iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v(
void* IREE_UK_RESTRICT out_tile,
const void* IREE_UK_RESTRICT lhs_panel,
const void* IREE_UK_RESTRICT rhs_panel,
const iree_uk_mmt4d_params_t* params, int M0) {
IREE_UK_ASSERT(M0 >= 1 && M0 <= 16);
iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;

const int N0 = params->N0;
const int K = params->K;
size_t ml = M0;
size_t vl = N0;

// Performance case for M0=16 (LMUL=8)
if (M0 == 16) {
// init m0 to zero (LMUL=8)
asm volatile("vsetvli zero, %0, e32, m8, ta, ma" : : "r"(vl));
asm volatile("vmv.v.i v0, 0");
OPMVINBCAST(m0, v0);

// K-loop unrolled by 2
size_t k = 0;
while (k + 2 <= K) {
asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(ml)); // ml=16
asm volatile("vle8.v v16, (%0)" : : "r"(&lhs_ptr[k * M0]));
asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(vl)); // vl=N0
asm volatile("vle8.v v18, (%0)" : : "r"(&rhs_ptr[k * N0]));
VOPACC(m0, v18, v16);
k++;
asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(ml)); // ml=16
asm volatile("vle8.v v20, (%0)" : : "r"(&lhs_ptr[k * M0]));
asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(vl)); // vl=N0
asm volatile("vle8.v v22, (%0)" : : "r"(&rhs_ptr[k * N0]));
VOPACC(m0, v22, v20);
k++;
}
if (k < K) {
asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(ml)); // ml=16
asm volatile("vle8.v v16, (%0)" : : "r"(&lhs_ptr[k * M0]));
asm volatile("vsetvli zero, %0, e8, m2, ta, ma" : : "r"(vl)); // vl=N0
asm volatile("vle8.v v18, (%0)" : : "r"(&rhs_ptr[k * N0]));
VOPACC(m0, v18, v16);
}

// store results
asm volatile("vsetvli zero, %0, e32, m8, ta, ma" : : "r"(vl));
for (size_t r = 0; r < ml; r++) { // ml=16
VMV_VR(v0, r, m0);
asm volatile("vse32.v v0, (%0)" : : "r"(&out_ptr[r * N0]));
}
}
// Tail case for M0 < 16 (using LMUL=4)
else {
// 1. Initialize accumulators to ZERO (LMUL=4)
asm volatile("vsetvli zero, %0, e32, m4, ta, ma" : : "r"(vl));
asm volatile("vmv.v.i v0, 0");
OPMVINBCAST(m3, v0); // Initialize m3 to zero

// 2. Main K-loop
for (int k = 0; k < K; ++k) {
asm volatile("vsetvli zero, %0, e8, m1, ta, ma" : : "r"(ml));
asm volatile("vle8.v v5, (%0)" : : "r"(&lhs_ptr[k * M0]));

asm volatile("vsetvli zero, %0, e8, m1, ta, ma" : : "r"(vl));
asm volatile("vle8.v v4, (%0)" : : "r"(&rhs_ptr[k * N0]));

VOPACC(m3, v4, v5);
}

// 3. Store results
asm volatile("vsetvli zero, %0, e32, m4, ta, ma" : : "r"(vl));
for (size_t r = 0; r < ml; r++) {
VMV_VR(v0, r, m3);
asm volatile("vse32.v v0, (%0)" : : "r"(&out_ptr[r * N0]));
}
}
}

IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f32f32f32_1xXXx1_to_7xXXx1_riscv_64_v,
iree_uk_mmt4d_tile_f32f32f32_1xXXx1_riscv_64_v, 1)
Expand All @@ -133,3 +215,22 @@ IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f32f32f32_1xXXx1_to_7xXXx1_riscv_64_v,
iree_uk_mmt4d_tile_f32f32f32_7xXXx1_riscv_64_v, 7)

// *** UPDATED SECTION ***
// Point all s8s8s32 tiles to the new generic function
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v,
iree_uk_mmt4d_tile_s8s8s32_1xXXx1_riscv_64_v, 1)
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v,
iree_uk_mmt4d_tile_s8s8s32_2xXXx1_riscv_64_v, 2)
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v,
iree_uk_mmt4d_tile_s8s8s32_4xXXx1_riscv_64_v, 4)
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v,
iree_uk_mmt4d_tile_s8s8s32_8xXXx1_riscv_64_v, 8)
// Add the new M0=16 tile
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1xXXx1_to_16xXXx1_riscv_64_v,
iree_uk_mmt4d_tile_s8s8s32_16xXXx1_riscv_64_v, 16)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ iree_uk_query_matmul_tile_sizes_riscv_64_f32f32f32(
return (iree_uk_matmul_tile_sizes_t){.M = 8, .K = 1, .N = 8};
}

static iree_uk_matmul_tile_sizes_t
iree_uk_query_matmul_tile_sizes_riscv_64_s8s8s32(
const iree_uk_query_tile_sizes_2d_params_t* params) {
#if defined(IREE_UK_BUILD_RISCV_64_V)
if (iree_uk_cpu_riscv_64_v(params->cpu_data)) {
// Corresponds to the new target M0=16.
// N=32 is based on a minimum-VLEN (128-bit) and the new LMUL=8.
// N0 = (128 bits / 32 bits_per_element) * 8_LMUL = 32.
return (iree_uk_matmul_tile_sizes_t){.M = 16, .K = 1, .N = 32};
}
#endif
// generic fallback
return (iree_uk_matmul_tile_sizes_t){.M = 8, .K = 1, .N = 8};
}

bool iree_uk_query_matmul_tile_sizes_arch(
const iree_uk_query_tile_sizes_2d_params_t* params,
iree_uk_matmul_tile_sizes_t* out_matmul_tile_sizes) {
Expand All @@ -27,8 +42,12 @@ bool iree_uk_query_matmul_tile_sizes_arch(
*out_matmul_tile_sizes =
iree_uk_query_matmul_tile_sizes_riscv_64_f32f32f32(params);
return true;
} else if (op == IREE_UK_FLAG_QUERY_TILE_SIZES_OPERATION_MATMUL_I8I8I32) {
*out_matmul_tile_sizes =
iree_uk_query_matmul_tile_sizes_riscv_64_s8s8s32(params);
return true;
} else {
// Shouldn't happen, validated earlier.
return false;
}
}
}
Loading