@@ -70,6 +70,20 @@ namespace refactor::kernel {
7070 }
7171 }
7272
73+ static __global__ void concatCache (
74+ void *__restrict__ cache,
75+ void const *__restrict__ value,
76+ dim_t pageStrideI,
77+ dim_t pageStrideO,
78+ dim_t lineStride,
79+ dim_t pastOffset) {
80+
81+ auto tid = blockIdx .x * blockDim .x + threadIdx .x ,
82+ dst = tid / pageStrideO * pageStrideI + pastOffset + tid % pageStrideO;
83+ reinterpret_cast <float4 *>(cache)[dst] = reinterpret_cast <float4 const *>(value)[tid];
84+ }
85+ constexpr uint64_t DYNAMIC_WORKSPACE_SIZE = 40 << 20 ;// 试出来 40MiB 是够用的
86+
7387 RoutineWorkspace K::lower (Resources &res) const {
7488 auto handle = res.fetchOrStore <CublasLtContext>()->handle ;
7589
@@ -125,8 +139,8 @@ namespace refactor::kernel {
125139 .batchCount = static_cast <int32_t >(info.batch * info.nHead ),
126140 .batchStride = static_cast <int64_t >(info.seqLen * info.seqLen ),
127141 }) {
128- auto [algoQK_, workspaceSizeQK_] = tune (context.handle , mul, q, k, att);
129- auto [algoAV_, workspaceSizeAV_] = tune (context.handle , mul, att, v, q);
142+ auto [algoQK_, workspaceSizeQK_] = tune (context.handle , mul, q, k, att, DYNAMIC_WORKSPACE_SIZE );
143+ auto [algoAV_, workspaceSizeAV_] = tune (context.handle , mul, att, v, q, DYNAMIC_WORKSPACE_SIZE );
130144 algoQK = algoQK_;
131145 algoAV = algoAV_;
132146 workspaceSizeQK = workspaceSizeQK_;
@@ -187,12 +201,146 @@ namespace refactor::kernel {
187201 &d->algoAV ,
188202 workspaceAV, d->workspaceSizeAV ,
189203 stream);
190- };
204+ }
191205 };
192206
193207 return {std::move (routine), workspaceSize};
194208 }
209+ TODO (" " );
195210 }
211+ if (info.concatCache && !info.resetCache) {
212+ if (info.nHead == info.nKVHead ) {
213+
214+ // RAII for closure
215+ struct Descriptors {
216+ MatMulDescriptor mul;
217+
218+ Descriptors (AttentionInfo info)
219+ : mul(computeTypeConvert(info.dataType),
220+ dataTypeConvert (info.dataType)) {}
221+ };
222+
223+ auto const &context = *res.fetchOrStore<CublasLtContext>();
224+ auto d = std::make_shared<Descriptors>(info);
225+ auto attentionSize = info.maxAttSize();
226+ auto workspaceSize = DYNAMIC_WORKSPACE_SIZE + attentionSize;
227+
228+ auto routine = [d = std::move(d), info = this ->info]//
229+ (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
230+ auto handle = res.fetchOrStore <CublasLtContext>()->handle ;
231+ auto q = inputs[0 ];
232+ auto k = inputs[1 ];
233+ auto v = inputs[2 ];
234+ auto past = *reinterpret_cast <int64_t const *>(inputs[3 ]);
235+ auto attLen = info.attLen (past);
236+ auto o = reinterpret_cast <half *>(outputs[0 ]);
237+ auto kCache = reinterpret_cast <half *>(outputs[1 ]);
238+ auto vCache = reinterpret_cast <half *>(outputs[2 ]);
239+ auto att = reinterpret_cast <half *>(reinterpret_cast <uint8_t *>(workspace) + DYNAMIC_WORKSPACE_SIZE);
240+ auto stream = cudaStreamLegacy;
241+ {
242+ auto itemsPerLine = info.headDim * sizeof (half) / sizeof (float4 );
243+ auto threads = info.batch * info.nHead * info.seqLen * itemsPerLine;
244+ auto blocks = (threads + 1023 ) / 1024 ;
245+
246+ concatCache<<<blocks, 1024 , 0 , stream>>> (
247+ kCache , k,
248+ info.seqLen * itemsPerLine,
249+ info.cacheLen * itemsPerLine,
250+ itemsPerLine,
251+ past * itemsPerLine);
252+ concatCache<<<blocks, 1024 , 0 , stream>>> (
253+ vCache, v,
254+ info.seqLen * itemsPerLine,
255+ info.cacheLen * itemsPerLine,
256+ itemsPerLine,
257+ past * itemsPerLine);
258+ }
259+ MatrixDescriptor
260+ q_ (MatrixLayout{
261+ .dataType = dataTypeConvert (info.dataType ),
262+ .rows = static_cast <uint64_t >(info.seqLen ),
263+ .cols = static_cast <uint64_t >(info.headDim ),
264+ .majorStride = static_cast <int64_t >(info.headDim ),
265+ .order = ROW_MAJOR,
266+ .batchCount = static_cast <int32_t >(info.batch * info.nHead ),
267+ .batchStride = static_cast <int64_t >(info.seqLen * info.headDim ),
268+ }),
269+ k_ (MatrixLayout{
270+ .dataType = dataTypeConvert (info.dataType ),
271+ .rows = static_cast <uint64_t >(info.headDim ),
272+ .cols = static_cast <uint64_t >(attLen),
273+ .majorStride = static_cast <int64_t >(info.headDim ),
274+ .order = COL_MAJOR,
275+ .batchCount = static_cast <int32_t >(info.batch * info.nHead ),
276+ .batchStride = static_cast <int64_t >(info.cacheLen * info.headDim ),
277+ }),
278+ v_ (MatrixLayout{
279+ .dataType = dataTypeConvert (info.dataType ),
280+ .rows = static_cast <uint64_t >(attLen),
281+ .cols = static_cast <uint64_t >(info.headDim ),
282+ .majorStride = static_cast <int64_t >(info.headDim ),
283+ .order = ROW_MAJOR,
284+ .batchCount = static_cast <int32_t >(info.batch * info.nHead ),
285+ .batchStride = static_cast <int64_t >(info.cacheLen * info.headDim ),
286+ }),
287+ att_ (MatrixLayout{
288+ .dataType = dataTypeConvert (info.dataType ),
289+ .rows = static_cast <uint64_t >(info.seqLen ),
290+ .cols = static_cast <uint64_t >(attLen),
291+ .majorStride = static_cast <int64_t >(info.cacheLen ),
292+ .order = ROW_MAJOR,
293+ .batchCount = static_cast <int32_t >(info.batch * info.nHead ),
294+ .batchStride = static_cast <int64_t >(info.cacheLen * info.seqLen ),
295+ });
296+ {
297+ auto [algo, workspaceSize] = tune (
298+ handle, d->mul ,
299+ q_, k_, att_,
300+ DYNAMIC_WORKSPACE_SIZE);
301+ half alpha = rsqrtf (info.headDim ), beta = 0 ;
302+ cublasLtMatmul (
303+ handle, d->mul .get (),
304+ &alpha,
305+ q, q_.get (),
306+ kCache , k_.get (),
307+ &beta,
308+ att, att_.get (),
309+ att, att_.get (),
310+ &algo,
311+ workspace, workspaceSize,
312+ stream);
313+ }
314+ softmax<<<dim3 (info.batch * info.nHead, info.seqLen),
315+ std::min (1024u , attLen),
316+ attLen * sizeof(float ),
317+ stream>>>(
318+ att, AttentionCausualMask(), attLen, info.cacheLen);
319+ {
320+ auto [algo, workspaceSize] = tune (
321+ handle, d->mul ,
322+ att_, v_, q_,
323+ DYNAMIC_WORKSPACE_SIZE);
324+ half alpha = 1 , beta = 0 ;
325+ cublasLtMatmul (
326+ handle, d->mul .get (),
327+ &alpha,
328+ att, att_.get (),
329+ vCache, v_.get (),
330+ &beta,
331+ o, q_.get (),
332+ o, q_.get (),
333+ &algo,
334+ workspace, workspaceSize,
335+ stream);
336+ }
337+ };
338+
339+ return {std::move (routine), workspaceSize};
340+ }
341+ TODO (" " );
342+ }
343+
196344 TODO (" " );
197345 }
198346
0 commit comments