Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions common_level3.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ void sgemm_direct_alpha_beta(BLASLONG M, BLASLONG N, BLASLONG K,
float beta,
float * R, BLASLONG strideR);

void strmm_direct_LNUN(BLASLONG M, BLASLONG N, BLASLONG K,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB);
void strmm_direct_LNLN(BLASLONG M, BLASLONG N, BLASLONG K,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB);
void strmm_direct_LTUN(BLASLONG M, BLASLONG N, BLASLONG K,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB);
void strmm_direct_LTLN(BLASLONG M, BLASLONG N, BLASLONG K,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB);

int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);

int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
Expand Down
6 changes: 6 additions & 0 deletions common_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
int (*strsm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
#endif
#if (BUILD_SINGLE==1)
#ifdef ARCH_ARM64
void (*strmm_direct_LNUN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
#endif
int (*strmm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*strmm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
int (*strmm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
Expand Down
12 changes: 12 additions & 0 deletions common_s.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@
#define SGEMM_ITCOPY sgemm_itcopy
#endif

#define STRMM_DIRECT_LNUN strmm_direct_LNUN
#define STRMM_DIRECT_LNLN strmm_direct_LNLN
#define STRMM_DIRECT_LTUN strmm_direct_LTUN
#define STRMM_DIRECT_LTLN strmm_direct_LTLN

#define STRMM_OUNUCOPY strmm_ounucopy
#define STRMM_OUNNCOPY strmm_ounncopy
#define STRMM_OUTUCOPY strmm_outucopy
Expand Down Expand Up @@ -227,6 +232,13 @@
#define SGEMM_INCOPY gotoblas -> sgemm_incopy
#define SGEMM_ITCOPY gotoblas -> sgemm_itcopy

#ifdef ARCH_ARM64
#define STRMM_DIRECT_LNUN gotoblas -> strmm_direct_LNUN
#define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN
#define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN
#define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN
#endif

#define STRMM_OUNUCOPY gotoblas -> strmm_ounucopy
#define STRMM_OUTUCOPY gotoblas -> strmm_outucopy
#define STRMM_OLNUCOPY gotoblas -> strmm_olnucopy
Expand Down
15 changes: 15 additions & 0 deletions interface/trsm.c
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,21 @@ void CNAME(enum CBLAS_ORDER order,
#endif

PRINT_DEBUG_CNAME;
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
#if defined(ARCH_ARM64) && (defined(USE_STRMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
#if defined(DYNAMIC_ARCH)
if (support_sme1())
#endif
if (order == CblasRowMajor && Diag == CblasNonUnit && Side == CblasLeft) {
if (Trans == CblasNoTrans) {
(Uplo == CblasUpper ? STRMM_DIRECT_LNUN : STRMM_DIRECT_LNLN)(m, n, m, alpha, a, lda, b, ldb);
} else if (Trans == CblasTrans) {
(Uplo == CblasUpper ? STRMM_DIRECT_LTUN : STRMM_DIRECT_LTLN)(m, n, m, alpha, a, lda, b, ldb);
}
return;
}
#endif
#endif

args.a = (void *)a;
args.b = (void *)b;
Expand Down
15 changes: 15 additions & 0 deletions kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
if (ZARCH OR (UC_TARGET_CORE MATCHES POWER8) OR (UC_TARGET_CORE MATCHES POWER9) OR (UC_TARGET_CORE MATCHES POWER10))
set(USE_TRMM true)
endif ()
set(USE_DIRECT_STRMM false)
if (ARM64)
set(USE_DIRECT_STRMM true)
endif()
set(USE_DIRECT_SGEMM false)
if (X86_64 OR ARM64)
set(USE_DIRECT_SGEMM true)
Expand Down Expand Up @@ -442,6 +446,17 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
set(TRMM_KERNEL "${${float_char}GEMMKERNEL}")
endif ()

if (USE_DIRECT_STRMM)
set (STRMMDIRECTKERNEL strmm_direct_arm64_sme1.c)
set (STRMMDIRECTPREKERNEL strmm_direct_arm64_sme1_preprocess.c)
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNUN" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNLN" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTUN" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTLN" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTPREKERNEL}" "" "trmm_direct_sme1_preprocess_UN" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTPREKERNEL}" "" "trmm_direct_sme1_preprocess_LN" false "" "" false SINGLE)
endif ()

if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX")

# just enumerate all these. there is an extra define for these indicating which side is a conjugate (e.g. CN NC NN) that I don't really want to work into GenerateCombinationObjects
Expand Down
35 changes: 35 additions & 0 deletions kernel/Makefile.L3
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ endif
ifeq ($(ARCH), arm64)
USE_TRMM = 1
USE_DIRECT_SGEMM = 1
USE_DIRECT_STRMM = 1
endif

ifeq ($(ARCH), riscv64)
Expand Down Expand Up @@ -137,6 +138,10 @@ endif
endif
endif

ifdef USE_DIRECT_STRMM
STRMMDIRECTKERNEL = strmm_direct_arm64_sme1.c
endif

ifeq ($(BUILD_BFLOAT16), 1)
ifndef BGEMMKERNEL
BGEMM_BETA = ../generic/gemm_beta.c
Expand Down Expand Up @@ -286,6 +291,15 @@ SBLASOBJS += \
strsm_kernel_RN$(TSUFFIX).$(SUFFIX) strsm_kernel_RT$(TSUFFIX).$(SUFFIX)
endif

ifdef USE_DIRECT_STRMM
SBLASOBJS += \
strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) \
strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) strmm_direct_LTLN$(TSUFFIX).$(SUFFIX)
SBLASOBJS += \
strmm_direct_sme1_preprocess_UN$(TSUFFIX).$(SUFFIX) \
strmm_direct_sme1_preprocess_LN$(TSUFFIX).$(SUFFIX)
endif

ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
DBLASOBJS += \
dgemm_beta$(TSUFFIX).$(SUFFIX) \
Expand Down Expand Up @@ -1150,6 +1164,27 @@ else
$(CC) $(CFLAGS) -c -DTRMMKERNEL -UDOUBLE -UCOMPLEX -ULEFT -DTRANSA $< -o $@
endif


ifdef USE_DIRECT_STRMM
$(KDIR)strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -DUPPER $< -o $@

$(KDIR)strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -UUPPER $< -o $@

$(KDIR)strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -DUPPER $< -o $@

$(KDIR)strmm_direct_LTLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -UUPPER $< -o $@

$(KDIR)strmm_direct_sme1_preprocess_UN$(TSUFFIX).$(SUFFIX) :
$(CC) $(CFLAGS) -c $(KERNELDIR)/strmm_direct_arm64_sme1_preprocess.c -UDOUBLE -UCOMPLEX -DUPPER $< -o $@

$(KDIR)strmm_direct_sme1_preprocess_LN$(TSUFFIX).$(SUFFIX) :
$(CC) $(CFLAGS) -c $(KERNELDIR)/strmm_direct_arm64_sme1_preprocess.c -UDOUBLE -UCOMPLEX -UUPPER $< -o $@
endif

$(KDIR)dtrmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DTRMMKERNEL)
ifeq ($(OS), AIX)
$(CC) $(CFLAGS) -S -DTRMMKERNEL -DDOUBLE -UCOMPLEX -DLEFT -UTRANSA $< -o - > dtrmm_kernel_ln.s
Expand Down
191 changes: 191 additions & 0 deletions kernel/arm64/strmm_direct_arm64_sme1.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
SPDX-License-Identifier: BSD-3-Clause-Clear
*/

#include "common.h"
#include <stdlib.h>
#include <inttypes.h>
#include <math.h>
#if defined(HAVE_SME)

#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16
#include "sme_abi.h"
#include <arm_sme.h>
#endif

/* Function prototypes */
// For Upper and NonUnit Triangular preprocess
extern void strmm_direct_sme1_preprocess_UN(uint64_t nbr, uint64_t nbc, const float * restrict a, float * a_mod);
// For Lower and NonUnit Triangular preprocess
extern void strmm_direct_sme1_preprocess_LN(uint64_t nbr, uint64_t nbc, const float * restrict a, float * a_mod);

/* Function Definitions */
static uint64_t sve_cntw() {
uint64_t cnt;
asm volatile(
"rdsvl %[res], #1\n"
"lsr %[res], %[res], #2\n"
: [res] "=r" (cnt) ::
);
return cnt;
}

#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16
// Outer product kernel.
// Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA.
__attribute__((always_inline)) inline void
kernel_2x2(const float *A, const float *B, float *C, size_t shared_dim,
size_t ldc, size_t block_rows, size_t block_cols, float alpha, uint64_t row_idx)
__arm_out("za") __arm_streaming {
const uint64_t svl = svcntw();
size_t ldb = ldc;
// Predicate set-up
svbool_t pg = svptrue_b32();
svbool_t pg_a_0 = svwhilelt_b32_u64(0, block_rows);
svbool_t pg_a_1 = svwhilelt_b32_u64(svl, block_rows);

#if (!defined(TRANSA) && !defined(UPPER)) || (defined(TRANSA) && defined(UPPER))
#define pg_a_0_full pg_a_0
#define pg_a_1_full pg_a_1
#endif
svbool_t pg_b_0 = svwhilelt_b32_u64(0, block_cols);
svbool_t pg_b_1 = svwhilelt_b32_u64(svl, block_cols);

#define pg_c_0 pg_b_0
#define pg_c_1 pg_b_1

svzero_za();
svfloat32_t alpha_vec = svdup_f32(alpha);
// Iterate through shared dimension (K)
#if (!defined(TRANSA) && defined(UPPER)) || (defined(TRANSA) && !defined(UPPER))
for (size_t k = row_idx, valid_index = 1; k < shared_dim; k++,valid_index++) {
pg_a_0 = svwhilelt_b32_u64(0, MIN(valid_index, block_rows));
pg_a_1 = svwhilelt_b32_u64(svl, MIN(valid_index, block_rows));
#else
for (size_t k = 0; k < MIN(row_idx + block_rows, shared_dim); k++) {
// If k exceeds row_idx, mask out rows before (k - row_idx)
// This ensures only valid rows are included for lower triangular logic.
if (k > row_idx) {
pg_a_0 = svnot_b_z(pg_a_0_full, svwhilelt_b32_u64(0, k - row_idx));
pg_a_1 = svnot_b_z(pg_a_1_full, svwhilelt_b32_u64(svl, k - row_idx));
}
#endif

#if !defined(TRANSA)
// Load column of A
svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * svl]);
svfloat32_t col_a_1 = svld1(pg_a_1, &A[(k + shared_dim) * svl]);
#else
svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * shared_dim]);
svfloat32_t col_a_1 = svld1(pg_a_1, &A[k * shared_dim + svl]);
#endif
col_a_0 = svmul_x(pg_a_0, alpha_vec, col_a_0);
col_a_1 = svmul_x(pg_a_1, alpha_vec, col_a_1);
// Load row of B
svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * ldb]);
svfloat32_t row_b_1 = svld1(pg_b_1, &B[k * ldb + svl]);
// Perform outer product
svmopa_za32_m(/*tile*/0, pg_a_0, pg, col_a_0, row_b_0);
svmopa_za32_m(/*tile*/1, pg_a_0, pg, col_a_0, row_b_1);
svmopa_za32_m(/*tile*/2, pg_a_1, pg, col_a_1, row_b_0);
svmopa_za32_m(/*tile*/3, pg_a_1, pg, col_a_1, row_b_1);
}

// Store to C from ZA
for (size_t i = 0; i < MIN(svl, block_rows); i++) {
svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]);
svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
}
for (size_t i = svl; i < block_rows; i++) {
svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]);
svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
}

}

__arm_new("za") __arm_locally_streaming
static inline void strmm_direct_alpha_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
const float *ba, float *restrict bb) {
const uint64_t num_rows = m;
const uint64_t num_cols = n;

const float *restrict a_ptr = ba;
const float *restrict b_ptr = bb;
float *restrict c_ptr = bb;

const uint64_t svl = svcntw();
const uint64_t svl_x2 = 2*svl;
const uint64_t ldc = n;


uint64_t row_idx = 0;
#if (!defined(TRANSA) && defined(UPPER)) || (defined(TRANSA) && !defined(UPPER))
// 2x2 loop
uint64_t row_batch = svl_x2;
// Block over rows of C (panels of A)
for (; row_idx < num_rows; row_idx += row_batch) {
row_batch = MIN(row_batch, num_rows - row_idx);
#else
// Calculate the remainder of num_rows divided by 2VL to determine tail tile size
uint64_t row_batch = num_rows % svl_x2;
// If there's no remainder, use full tile size (2VL) for initial batch
if (row_batch == 0) row_batch = svl_x2;
// Loop from bottom to top, processing rows in batches
for (uint64_t index = num_rows; index > 0; index -= row_batch, row_batch = svl_x2) {
// Compute the starting row index for the current batch
row_idx = index - row_batch;
#endif
uint64_t col_idx = 0;
uint64_t col_batch = svl_x2;
// Block over column dimension of C
for (; col_idx < num_cols; col_idx += col_batch) {
col_batch = MIN(col_batch, num_cols - col_idx);
#if !defined(TRANSA)
kernel_2x2(&a_ptr[row_idx * k], &b_ptr[col_idx],
&c_ptr[row_idx * ldc + col_idx], k,
ldc, row_batch, col_batch, *alpha, row_idx);
#else
kernel_2x2(&a_ptr[row_idx], &b_ptr[col_idx],
&c_ptr[row_idx * ldc + col_idx], k,
ldc, row_batch, col_batch, *alpha, row_idx);
#endif
}
}

return;
}

#else
void strmm_direct_alpha_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
const float *ba, float *restrict bb){}
#endif

void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\
BLASLONG strideA, float * __restrict B, BLASLONG strideB){
#if !defined(TRANSA)
uint64_t m_mod, vl_elms;

vl_elms = sve_cntw();

m_mod = ceil((double)M/(double)vl_elms) * vl_elms;

float *A_mod = (float *) malloc(m_mod*K*sizeof(float));
#if defined(UPPER)
strmm_direct_sme1_preprocess_UN(M, K, A, A_mod);
#else
strmm_direct_sme1_preprocess_LN(M, K, A, A_mod);
#endif
/* Calculate B = alpha*A*B*/
strmm_direct_alpha_sme1_2VLx2VL(M, K, N, &alpha, A_mod, B);
free(A_mod);
#else
strmm_direct_alpha_sme1_2VLx2VL(M, K, N, &alpha, A, B);
#endif
}

#else
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\
BLASLONG strideA, float * __restrict B, BLASLONG strideB){}

#endif
Loading
Loading