Skip to content

Commit 6e66bf2

Browse files
committed
feat(kernel): 封装 attLen 计算
Signed-off-by: YdrMaster <[email protected]>
1 parent b15bd66 commit 6e66bf2

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

src/04kernel/include/kernel/attributes/attention_info.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ namespace refactor::kernel {
99
DataType dataType;
1010
dim_t batch, nHead, nKVHead, seqLen, headDim, cacheLen;
1111
bool concatCache, resetCache;
12+
13+
dim_t attLen(dim_t pastSeqLen) const noexcept;
14+
size_t attSize(dim_t pastSeqLen) const noexcept;
1215
};
1316

1417
}// namespace refactor::kernel
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include "kernel/attributes/attention_info.h"
2+
3+
namespace refactor::kernel {
4+
5+
dim_t AttentionInfo::attLen(dim_t pastSeqLen) const noexcept {
6+
return pastSeqLen + seqLen;
7+
}
8+
9+
size_t AttentionInfo::attSize(dim_t pastSeqLen) const noexcept {
10+
return batch * nHead * seqLen * attLen(pastSeqLen) * dataType.size();
11+
}
12+
13+
}// namespace refactor::kernel

src/04kernel/src/kernels/attention/cuda_kernel.cu

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ namespace refactor::kernel {
7171
MatMulDescriptor mul;
7272
MatrixDescriptor q, k, v, att;
7373
cublasLtMatmulAlgo_t algoQK, algoAV;
74-
size_t attSize, workspaceSizeQK, workspaceSizeAV;
74+
size_t workspaceSizeQK, workspaceSizeAV;
7575

7676
Descriptors(CublasLtContext const &context,
7777
AttentionInfo info)
@@ -112,8 +112,7 @@ namespace refactor::kernel {
112112
.order = ROW_MAJOR,
113113
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
114114
.batchStride = static_cast<int64_t>(info.seqLen * info.seqLen),
115-
}),
116-
attSize(info.batch * info.nHead * info.seqLen * info.seqLen * info.dataType.size()) {
115+
}) {
117116
auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att);
118117
auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q);
119118
algoQK = algoQK_;
@@ -125,7 +124,7 @@ namespace refactor::kernel {
125124

126125
auto const &context = *res.fetchOrStore<CublasLtContext>();
127126
auto d = std::make_shared<Descriptors>(context, info);
128-
auto workspaceSize = d->attSize;
127+
auto workspaceSize = info.attSize(0);
129128
workspaceSize = hardware::alignBytes(workspaceSize, 256);
130129
workspaceSize += d->workspaceSizeQK;
131130
workspaceSize += d->workspaceSizeAV;
@@ -139,7 +138,7 @@ namespace refactor::kernel {
139138
auto v = inputs[2];
140139
auto o = outputs[0];
141140
auto att = reinterpret_cast<half *>(workspace);
142-
auto workspaceQK = reinterpret_cast<uint8_t *>(workspace) + hardware::alignBytes(d->attSize, 256);
141+
auto workspaceQK = reinterpret_cast<uint8_t *>(workspace) + hardware::alignBytes(info.attSize(0), 256);
143142
auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256);
144143
{
145144
half alpha = rsqrtf(info.headDim), beta = 0;
@@ -155,10 +154,12 @@ namespace refactor::kernel {
155154
workspaceQK, d->workspaceSizeQK,
156155
cudaStreamLegacy);
157156
}
157+
auto attLen = info.attLen(0);
158+
auto bufLen = attLen;
158159
softmax<<<dim3(info.batch * info.nHead, info.seqLen),
159-
info.seqLen,
160-
info.seqLen * sizeof(float)>>>(
161-
att, causualMask, info.seqLen, info.seqLen);
160+
std::min(1024u, attLen),
161+
attLen * sizeof(float)>>>(
162+
att, causualMask, attLen, bufLen);
162163
{
163164
half alpha = 1, beta = 0;
164165
cublasLtMatmul(

0 commit comments

Comments
 (0)