@@ -71,7 +71,7 @@ namespace refactor::kernel {
7171 MatMulDescriptor mul;
7272 MatrixDescriptor q, k, v, att;
7373 cublasLtMatmulAlgo_t algoQK, algoAV;
74- size_t attSize, workspaceSizeQK, workspaceSizeAV;
74+ size_t workspaceSizeQK, workspaceSizeAV;
7575
7676 Descriptors (CublasLtContext const &context,
7777 AttentionInfo info)
@@ -112,8 +112,7 @@ namespace refactor::kernel {
112112 .order = ROW_MAJOR,
113113 .batchCount = static_cast <int32_t >(info.batch * info.nHead ),
114114 .batchStride = static_cast <int64_t >(info.seqLen * info.seqLen ),
115- }),
116- attSize(info.batch * info.nHead * info.seqLen * info.seqLen * info.dataType.size()) {
115+ }) {
117116 auto [algoQK_, workspaceSizeQK_] = tune (context.handle , mul, q, k, att);
118117 auto [algoAV_, workspaceSizeAV_] = tune (context.handle , mul, att, v, q);
119118 algoQK = algoQK_;
@@ -125,7 +124,7 @@ namespace refactor::kernel {
125124
126125 auto const &context = *res.fetchOrStore<CublasLtContext>();
127126 auto d = std::make_shared<Descriptors>(context, info);
128- auto workspaceSize = d-> attSize;
127+ auto workspaceSize = info. attSize( 0 ) ;
129128 workspaceSize = hardware::alignBytes(workspaceSize, 256 );
130129 workspaceSize += d->workspaceSizeQK;
131130 workspaceSize += d->workspaceSizeAV;
@@ -139,7 +138,7 @@ namespace refactor::kernel {
139138 auto v = inputs[2 ];
140139 auto o = outputs[0 ];
141140 auto att = reinterpret_cast <half *>(workspace);
142- auto workspaceQK = reinterpret_cast <uint8_t *>(workspace) + hardware::alignBytes (d-> attSize , 256 );
141+ auto workspaceQK = reinterpret_cast <uint8_t *>(workspace) + hardware::alignBytes (info. attSize ( 0 ) , 256 );
143142 auto workspaceAV = workspaceQK + hardware::alignBytes (d->workspaceSizeQK , 256 );
144143 {
145144 half alpha = rsqrtf (info.headDim ), beta = 0 ;
@@ -155,10 +154,12 @@ namespace refactor::kernel {
155154 workspaceQK, d->workspaceSizeQK ,
156155 cudaStreamLegacy);
157156 }
157+ auto attLen = info.attLen (0 );
158+ auto bufLen = attLen;
158159 softmax<<<dim3 (info.batch * info.nHead, info.seqLen),
159- info.seqLen ,
160- info.seqLen * sizeof (float )>>> (
161- att, causualMask, info. seqLen , info. seqLen );
160+ std::min ( 1024u , attLen) ,
161+ attLen * sizeof(float )>>>(
162+ att, causualMask, attLen, bufLen );
162163 {
163164 half alpha = 1 , beta = 0 ;
164165 cublasLtMatmul (
0 commit comments