diff --git a/3rd-party/cccl b/3rd-party/cccl index b7d4228ab..d27b58963 160000 --- a/3rd-party/cccl +++ b/3rd-party/cccl @@ -1 +1 @@ -Subproject commit b7d4228ab7268ed928984cd61096079bd671d25d +Subproject commit d27b58963128f17a6c2f3f867301d54e9f4b48cd diff --git a/CMakeLists.txt b/CMakeLists.txt index 49ddcda61..2d76dd81b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,12 +12,20 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_POSITION_INDEPENDENT_CODE ON) +# Download with: +# +# mkdir -p cmake +# wget -O cmake/CPM.cmake https://github.com/cpm-cmake/CPM.cmake/releases/latest/download/get_cpm.cmake +include(cmake/CPM.cmake) + if(USE_CUDA) + CPMAddPackage(NAME CCCL SOURCE_DIR ${CMAKE_SOURCE_DIR}/3rd-party/cccl) + add_compile_definitions(USE_CUDA) enable_language(CUDA) set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER}) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - set(CMAKE_CUDA_ARCHITECTURES 80) + set(CMAKE_CUDA_ARCHITECTURES native) endif() if(NOT DEFINED CMAKE_CUDA_STANDARD) set(CMAKE_CUDA_STANDARD 17) @@ -45,7 +53,7 @@ endif() if (USE_BANG) add_compile_definitions(USE_BANG) include_directories(src/kernels/mlu/include) - + # Neuware Evironment if ((NOT DEFINED NEUWARE_HOME) AND (NOT DEFINED ENV{NEUWARE_HOME})) message(FATAL_ERROR "NEUWARE_HOME is not defined from cmake or env") @@ -55,14 +63,14 @@ if (USE_BANG) set(NEUWARE_HOME $ENV{NEUWARE_HOME} CACHE STRING "NEUWARE_HOME directory for Cambricon Neuware development") endif() message(STATUS "NEUWARE_HOME: ${NEUWARE_HOME}") - + # cnrt cndrv cnnl include_directories("${NEUWARE_HOME}/include") find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall") - + if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH})) execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE) set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH") diff --git a/cmake/CPM.cmake b/cmake/CPM.cmake new file mode 100644 index 000000000..cc25ec280 --- /dev/null +++ b/cmake/CPM.cmake @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: MIT +# +# SPDX-FileCopyrightText: Copyright (c) 2019-2023 Lars Melchior and contributors + +set(CPM_DOWNLOAD_VERSION 0.38.7) +set(CPM_HASH_SUM "83e5eb71b2bbb8b1f2ad38f1950287a057624e385c238f6087f94cdfc44af9c5") + +if(CPM_SOURCE_CACHE) + set(CPM_DOWNLOAD_LOCATION "${CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") +elseif(DEFINED ENV{CPM_SOURCE_CACHE}) + set(CPM_DOWNLOAD_LOCATION "$ENV{CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") +else() + set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake") +endif() + +# Expand relative path. This is important if the provided path contains a tilde (~) +get_filename_component(CPM_DOWNLOAD_LOCATION ${CPM_DOWNLOAD_LOCATION} ABSOLUTE) + +file(DOWNLOAD + https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake + ${CPM_DOWNLOAD_LOCATION} EXPECTED_HASH SHA256=${CPM_HASH_SUM} +) + +include(${CPM_DOWNLOAD_LOCATION}) diff --git a/src/02hardware/include/hardware/devices/nvidia.h b/src/02hardware/include/hardware/devices/nvidia.h index d19dd3152..18a4269d9 100644 --- a/src/02hardware/include/hardware/devices/nvidia.h +++ b/src/02hardware/include/hardware/devices/nvidia.h @@ -3,6 +3,12 @@ #include "../device.h" +#define CUDA_ASSERT(STATUS) \ + if (auto status = (STATUS); status != cudaSuccess) { \ + RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ + cudaGetErrorString(status), (int) status)); \ + } + namespace refactor::hardware { class Nvidia final : public Device { diff --git a/src/02hardware/src/devices/nvidia/device.cc b/src/02hardware/src/devices/nvidia/device.cc index fd10cb704..20f63c0fc 100644 --- a/src/02hardware/src/devices/nvidia/device.cc +++ b/src/02hardware/src/devices/nvidia/device.cc @@ -4,12 +4,6 @@ #ifdef USE_CUDA #include "memory.hh" #include - -#define CUDA_ASSERT(STATUS) \ - if (auto status = (STATUS); status != cudaSuccess) { \ - RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ - cudaGetErrorString(status), (int) status)); \ - } #endif namespace refactor::hardware { diff --git a/src/02hardware/src/devices/nvidia/memory.cc b/src/02hardware/src/devices/nvidia/memory.cc index 42310196c..1c3be21e6 100644 --- a/src/02hardware/src/devices/nvidia/memory.cc +++ b/src/02hardware/src/devices/nvidia/memory.cc @@ -1,15 +1,9 @@ #ifdef USE_CUDA #include "memory.hh" -#include "common.h" +#include "hardware/devices/nvidia.h" #include -#define CUDA_ASSERT(STATUS) \ - if (auto status = (STATUS); status != cudaSuccess) { \ - RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ - cudaGetErrorString(status), (int) status)); \ - } - namespace refactor::hardware { using M = NvidiaMemory; diff --git a/src/04kernel/CMakeLists.txt b/src/04kernel/CMakeLists.txt index 77b655c0e..efdeb0dac 100644 --- a/src/04kernel/CMakeLists.txt +++ b/src/04kernel/CMakeLists.txt @@ -26,7 +26,8 @@ if(USE_CUDA) # nvrtc for cuda kernel compile # cublas for matmul # cudnn for conv and others - target_link_libraries(kernel PUBLIC cuda nvrtc cublas cublasLt cudnn kernel_cuda) + target_link_libraries(kernel PUBLIC cuda kernel_cuda) + target_link_libraries(kernel PRIVATE nvrtc cublas cublasLt cudnn) target_include_directories(kernel PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) find_package(NCCL REQUIRED) diff --git a/src/04kernel/cuda/CMakeLists.txt b/src/04kernel/cuda/CMakeLists.txt index 4c976e33d..07223090b 100644 --- a/src/04kernel/cuda/CMakeLists.txt +++ b/src/04kernel/cuda/CMakeLists.txt @@ -4,7 +4,7 @@ project(kernel_cuda) file(GLOB_RECURSE KERNEL_CUDA_SUB_SRC src/*.cu) add_library(kernel_cuda STATIC ${KERNEL_CUDA_SUB_SRC}) -target_link_libraries(kernel_cuda PUBLIC common) +target_link_libraries(kernel_cuda PUBLIC common CCCL::CCCL) target_include_directories(kernel_cuda PUBLIC include) file(GLOB_RECURSE KERNEL_CUDA_TEST test/*.cu) diff --git a/src/04kernel/include/kernel/attributes/attention_info.h b/src/04kernel/include/kernel/attributes/attention_info.h new file mode 100644 index 000000000..386776816 --- /dev/null +++ b/src/04kernel/include/kernel/attributes/attention_info.h @@ -0,0 +1,20 @@ +#ifndef KERNEL_ATTENTION_INFO_H +#define KERNEL_ATTENTION_INFO_H + +#include "../tensor.h" + +namespace refactor::kernel { + + struct AttentionInfo { + DataType dataType; + dim_t batch, nHead, nKVHead, seqLen, headDim, cacheLen; + bool concatCache, resetCache; + + dim_t attLen(dim_t pastSeqLen) const noexcept; + size_t attSize(dim_t pastSeqLen) const noexcept; + size_t maxAttSize() const noexcept; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_ATTENTION_INFO_H diff --git a/src/04kernel/include/kernel/collectors/attention.h b/src/04kernel/include/kernel/collectors/attention.h index 527bc63fe..abf33957d 100644 --- a/src/04kernel/include/kernel/collectors/attention.h +++ b/src/04kernel/include/kernel/collectors/attention.h @@ -6,9 +6,8 @@ namespace refactor::kernel { struct AttentionCollector final : public InfoCollector { - dim_t maxSeqLen; - AttentionCollector(decltype(_target), decltype(maxSeqLen)) noexcept; + AttentionCollector(decltype(_target)) noexcept; std::vector filter(TensorRefs inputs, TensorRefs outputs) const final; diff --git a/src/04kernel/src/attributes/attention_info.cc b/src/04kernel/src/attributes/attention_info.cc new file mode 100644 index 000000000..a867fd3fb --- /dev/null +++ b/src/04kernel/src/attributes/attention_info.cc @@ -0,0 +1,17 @@ +#include "kernel/attributes/attention_info.h" + +namespace refactor::kernel { + + dim_t AttentionInfo::attLen(dim_t pastSeqLen) const noexcept { + return pastSeqLen + seqLen; + } + + size_t AttentionInfo::attSize(dim_t pastSeqLen) const noexcept { + return batch * nHead * seqLen * attLen(pastSeqLen) * dataType.size(); + } + + size_t AttentionInfo::maxAttSize() const noexcept { + return batch * nHead * seqLen * (cacheLen ? cacheLen : seqLen) * dataType.size(); + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/collectors/attention.cc b/src/04kernel/src/collectors/attention.cc index 3933097fa..a778c1280 100644 --- a/src/04kernel/src/collectors/attention.cc +++ b/src/04kernel/src/collectors/attention.cc @@ -1,38 +1,57 @@ #include "kernel/collectors/attention.h" +#include "kernel/attributes/attention_info.h" // #include "../kernels/attention/cpu_kernel.hh" #include "../kernels/attention/cuda_kernel.hh" namespace refactor::kernel { AttentionCollector::AttentionCollector( - decltype(_target) target, - decltype(maxSeqLen) maxSeqLen_) noexcept - : InfoCollector(target), - maxSeqLen(maxSeqLen_) {} + decltype(_target) target) noexcept + : InfoCollector(target) {} std::vector AttentionCollector::filter(TensorRefs inputs, TensorRefs outputs) const { auto const &query = inputs[0].get(); auto const &key = inputs[1].get(); - auto pastSeqLen = inputs.size() == 3 ? 0 : *inputs[2].get().data->get(); - auto cacheLen = outputs.size() == 1 ? 0 : outputs[1].get().shape[2]; - std::vector ans; + AttentionInfo info{ + .dataType = query.dataType, + .batch = query.shape[0], + .nHead = query.shape[1], + .nKVHead = key.shape[1], + .seqLen = query.shape[2], + .headDim = query.shape[3], + .cacheLen = 0, + .concatCache = false, + .resetCache = false, + }; + switch (outputs.size()) { + case 1: + // no kv cache + ASSERT(inputs.size() == 3, ""); + break; + case 3: + switch (inputs.size()) { + case 6: + info.resetCache = true; + case 4: + info.concatCache = true; + case 3: + info.cacheLen = outputs[1].get().shape[2]; + break; + default: + UNREACHABLE(); + } + break; + default: + UNREACHABLE(); + } + + std ::vector ans; switch (_target) { case decltype(_target)::Cpu: break; case decltype(_target)::Nvidia: { - decltype(AttentionCuda::info) info{ - .dataType = query.dataType, - .batch = query.shape[0], - .nHead = query.shape[1], - .nKVHead = key.shape[1], - .pastSeqLen = static_cast(pastSeqLen), - .seqLen = query.shape[2], - .cacheLen = cacheLen, - .headDim = query.shape[3], - .resetCache = false, - }; if (auto ptr = AttentionCuda::build(info); ptr) { ans.emplace_back(std::move(ptr)); } diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu new file mode 100644 index 000000000..aca6e8b76 --- /dev/null +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -0,0 +1,399 @@ +#include "../../utilities/cuda/cublas_context.hh" +#include "cuda_kernel.hh" +#include "hardware/functions.h" +#include "kernel/cuda/functions.cuh" +#include + +namespace refactor::kernel { + using K = AttentionCuda; + using namespace cublas; + + // 因果系统的注意力遮罩。 + // tokenId: 第几个词 + // seqLen: 此次处理的词数 + // posId: 在 kv cache 中的位置 + // attLen = pastSeqLen + seqLen + struct AttentionCausualMask { + __forceinline__ __device__ bool + operator()(int tokenId, int seqLen, + int posId, int attLen) { + // tokenId ↓ |<---attLen---->| + // 0 | * * ... * | + // 1 | * * ... * * | + // 2 | * * ... * * * | + // seqLen: 3 |---------------| + return attLen + tokenId >= posId + seqLen; + } + }; + + // gridDim.x = batch * nHead + // gridDim.y = seqLen + // blockDim.x = 1024 + // sizeof(shared) = attLen * sizeof(float) + template + static __global__ void softmax( + T *__restrict__ att, + Mask mask, + uint32_t attLen, + uint32_t bufLen) { + // 找到这个线程块对应的 attention 区域 + att += (blockIdx.x * gridDim.y + blockIdx.y) * bufLen; + // 将输入装入共享内存并 cast + mask + extern __shared__ float shared[];// size = attLen = pastSeqLen + seqLen + for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { + shared[i] = mask(blockIdx.y, gridDim.y, i, attLen) ? float(att[i]) : -__FLT_MAX__; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tempStorage; + __shared__ float sharedMax, sharedSum; + + float localMax = -1e20; + for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { + localMax = cub::Max()(localMax, shared[i]); + } + localMax = BlockReduce(tempStorage).Reduce(localMax, cub::Max(), attLen); + if (threadIdx.x == 0) { sharedMax = localMax; } + __syncthreads(); + + float localSum = 1e-20; + for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { + localSum += shared[i] = expf(shared[i] - sharedMax); + } + localSum = BlockReduce(tempStorage).Reduce(localSum, cub::Sum(), attLen); + if (threadIdx.x == 0) { sharedSum = localSum; } + __syncthreads(); + + auto reciprocal = fdividef(1, sharedSum); + for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { + att[i] = shared[i] * reciprocal; + } + } + + static __global__ void concatCache( + void *__restrict__ cache, + void const *__restrict__ value, + dim_t pageStrideI, + dim_t pageStrideO, + dim_t pastOffset, + dim_t n_items) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n_items) { + auto dst = tid / pageStrideI * pageStrideO + pastOffset + (tid % pageStrideI); + reinterpret_cast(cache)[dst] = reinterpret_cast(value)[tid]; + } + } + constexpr uint64_t DYNAMIC_WORKSPACE_SIZE = 40 << 20;// 试出来 40MiB 是够用的 + + RoutineWorkspace K::lower(Resources &res) const { + auto handle = res.fetchOrStore()->handle; + + constexpr auto ROW_MAJOR = CUBLASLT_ORDER_ROW; + constexpr auto COL_MAJOR = CUBLASLT_ORDER_COL; + + // if (!info.cacheLen) { + // if (info.nHead == info.nKVHead) { + // // RAII for closure + // struct Descriptors { + // MatMulDescriptor mul; + // MatrixDescriptor q, k, v, att; + // cublasLtMatmulAlgo_t algoQK, algoAV; + // size_t workspaceSizeQK, workspaceSizeAV; + + // Descriptors(CublasLtContext const &context, + // AttentionInfo info) + // : mul(computeTypeConvert(info.dataType), + // dataTypeConvert(info.dataType)), + // q(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.seqLen), + // .cols = static_cast(info.headDim), + // .majorStride = static_cast(info.headDim), + // .order = ROW_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.seqLen * info.headDim), + // }), + // k(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.headDim), + // .cols = static_cast(info.seqLen), + // .majorStride = static_cast(info.headDim), + // .order = COL_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.seqLen * info.headDim), + // }), + // v(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.seqLen), + // .cols = static_cast(info.headDim), + // .majorStride = static_cast(info.headDim), + // .order = ROW_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.seqLen * info.headDim), + // }), + // att(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.seqLen), + // .cols = static_cast(info.seqLen), + // .majorStride = static_cast(info.seqLen), + // .order = ROW_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.seqLen * info.seqLen), + // }) { + // auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att, DYNAMIC_WORKSPACE_SIZE); + // auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q, DYNAMIC_WORKSPACE_SIZE); + // algoQK = algoQK_; + // algoAV = algoAV_; + // workspaceSizeQK = workspaceSizeQK_; + // workspaceSizeAV = workspaceSizeAV_; + // } + // }; + + // auto const &context = *res.fetchOrStore(); + // auto d = std::make_shared(context, info); + // auto workspaceSize = info.attSize(0); + // workspaceSize = hardware::alignBytes(workspaceSize, 256); + // workspaceSize += d->workspaceSizeQK; + // workspaceSize += d->workspaceSizeAV; + // workspaceSize = hardware::alignBytes(workspaceSize, 256); + + // auto routine = [d = std::move(d), info = this->info]// + // (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { + // auto handle = res.fetchOrStore()->handle; + // auto q = inputs[0]; + // auto k = inputs[1]; + // auto v = inputs[2]; + // auto o = outputs[0]; + // auto att = reinterpret_cast(workspace); + // auto workspaceQK = reinterpret_cast(workspace) + hardware::alignBytes(info.attSize(0), 256); + // auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256); + // auto stream = cudaStreamLegacy; + // { + // half alpha = rsqrtf(info.headDim), beta = 0; + // cublasLtMatmul( + // handle, d->mul.get(), + // &alpha, + // q, d->q.get(), + // k, d->k.get(), + // &beta, + // att, d->att.get(), + // att, d->att.get(), + // &d->algoQK, + // workspaceQK, d->workspaceSizeQK, + // stream); + // } + // auto attLen = info.attLen(0); + // auto bufLen = attLen; + // softmax<<>>( + // att, AttentionCausualMask(), attLen, bufLen); + // { + // half alpha = 1, beta = 0; + // cublasLtMatmul( + // handle, d->mul.get(), + // &alpha, + // att, d->att.get(), + // v, d->v.get(), + // &beta, + // o, d->q.get(), + // o, d->q.get(), + // &d->algoAV, + // workspaceAV, d->workspaceSizeAV, + // stream); + // } + // }; + + // return {std::move(routine), workspaceSize}; + // } + // TODO(""); + // } + + if (info.concatCache && !info.resetCache) { + if (info.nHead == info.nKVHead) { + + // RAII for closure + struct Descriptors { + MatMulDescriptor mul; + + Descriptors(AttentionInfo info) + : mul(computeTypeConvert(info.dataType), + dataTypeConvert(info.dataType)) {} + }; + + auto const &context = *res.fetchOrStore(); + auto d = std::make_shared(info); + auto attentionSize = info.maxAttSize(); + auto workspaceSize = DYNAMIC_WORKSPACE_SIZE + attentionSize; + + auto routine = [d = std::move(d), info = this->info]// + (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { + auto handle = res.fetchOrStore()->handle; + auto q = inputs[0]; + auto k = inputs[1]; + auto v = inputs[2]; + int64_t past; + cudaMemcpy(&past, inputs[3], sizeof(int64_t), cudaMemcpyDeviceToHost); + auto attLen = info.attLen(past); + auto o = reinterpret_cast(outputs[0]); + auto kCache = reinterpret_cast(outputs[1]); + auto vCache = reinterpret_cast(outputs[2]); + auto att = reinterpret_cast(reinterpret_cast(workspace) + DYNAMIC_WORKSPACE_SIZE); + auto stream = cudaStreamLegacy; + { + auto itemsPerLine = info.headDim * sizeof(half) / sizeof(float4); + auto threads = info.batch * info.nHead * info.seqLen * itemsPerLine; + auto blocks = (threads + 1023) / 1024; + concatCache<<>>( + kCache, k, + info.seqLen * itemsPerLine, + info.cacheLen * itemsPerLine, + past * itemsPerLine, + threads); + concatCache<<>>( + vCache, v, + info.seqLen * itemsPerLine, + info.cacheLen * itemsPerLine, + past * itemsPerLine, + threads); + } + // MatrixDescriptor + // q_(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.seqLen), + // .cols = static_cast(info.headDim), + // .majorStride = static_cast(info.headDim), + // .order = ROW_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.seqLen * info.headDim), + // }), + // k_(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.headDim), + // .cols = static_cast(attLen), + // .majorStride = static_cast(info.headDim), + // .order = COL_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.cacheLen * info.headDim), + // }), + // v_(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(attLen), + // .cols = static_cast(info.headDim), + // .majorStride = static_cast(info.headDim), + // .order = ROW_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.cacheLen * info.headDim), + // }), + // att_(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.seqLen), + // .cols = static_cast(attLen), + // .majorStride = static_cast(info.cacheLen), + // .order = ROW_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.cacheLen * info.seqLen), + // }); + { + // auto [algo, workspaceSize] = tune( + // handle, d->mul, + // q_, k_, att_, + // DYNAMIC_WORKSPACE_SIZE); + half alpha = rsqrtf(info.headDim), beta = 0; + // cublasLtMatmul( + // handle, d->mul.get(), + // &alpha, + // q, q_.get(), + // kCache, k_.get(), + // &beta, + // att, att_.get(), + // att, att_.get(), + // &algo, + // workspace, workspaceSize, + // stream); + cublasGemmStridedBatchedEx( + handle, // handle + CUBLAS_OP_T, // trans a + CUBLAS_OP_N, // trans b + attLen, // m + info.seqLen, // n + info.headDim, // k + &alpha, // alpha + kCache, // a + CUDA_R_16F, // a type + info.headDim, // lda + info.cacheLen * info.headDim,// a stride + q, // b + CUDA_R_16F, // b type + info.headDim, // ldb + info.seqLen * info.headDim, // b stride + &beta, // beta + att, // c + CUDA_R_16F, // c type + info.cacheLen, // ldc + info.cacheLen * info.seqLen, // c stride + info.batch * info.nHead, // batch count + CUDA_R_32F, // compute type + CUBLAS_GEMM_DEFAULT // algo + ); + } + softmax<<>>( + att, AttentionCausualMask(), attLen, info.cacheLen); + { + // auto [algo, workspaceSize] = tune( + // handle, d->mul, + // att_, v_, q_, + // DYNAMIC_WORKSPACE_SIZE); + half alpha = 1, beta = 0; + // cublasLtMatmul( + // handle, d->mul.get(), + // &alpha, + // att, att_.get(), + // vCache, v_.get(), + // &beta, + // o, q_.get(), + // o, q_.get(), + // &algo, + // workspace, workspaceSize, + // stream); + cublasGemmStridedBatchedEx( + handle, // handle + CUBLAS_OP_N, // trans a + CUBLAS_OP_N, // trans b + attLen, // m + info.seqLen, // n + info.headDim, // k + &alpha, // alpha + vCache, // a + CUDA_R_16F, // a type + info.headDim, // lda + info.cacheLen * info.headDim,// a stride + att, // b + CUDA_R_16F, // b type + info.cacheLen, // ldb + info.cacheLen * info.seqLen, // b stride + &beta, // beta + o, // c + CUDA_R_16F, // c type + info.headDim, // ldc + info.seqLen * info.headDim, // c stride + info.batch * info.nHead, // batch count + CUDA_R_32F, // compute type + CUBLAS_GEMM_DEFAULT // algo + ); + } + }; + + return {std::move(routine), workspaceSize}; + } + TODO(""); + } + + TODO(""); + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.hh b/src/04kernel/src/kernels/attention/cuda_kernel.hh index 5ea19ae88..20cf9712d 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.hh +++ b/src/04kernel/src/kernels/attention/cuda_kernel.hh @@ -1,17 +1,13 @@ #ifndef KERNEL_ATTENTION_CUDA_KERNEL_HH #define KERNEL_ATTENTION_CUDA_KERNEL_HH +#include "kernel/attributes/attention_info.h" #include "kernel/kernel.h" -#include "kernel/tensor.h" namespace refactor::kernel { struct AttentionCuda final : public Kernel { - struct { - DataType dataType; - dim_t batch, nHead, nKVHead, pastSeqLen, seqLen, cacheLen, headDim; - bool resetCache; - } info; + AttentionInfo info; AttentionCuda(decltype(info)) noexcept; diff --git a/src/04kernel/src/utilities/cuda/cublaslt_context.cu b/src/04kernel/src/utilities/cuda/cublaslt_context.cu deleted file mode 100644 index 2fc8fb182..000000000 --- a/src/04kernel/src/utilities/cuda/cublaslt_context.cu +++ /dev/null @@ -1,33 +0,0 @@ -#include "common.h" -#include "cublaslt_context.hh" - -namespace refactor::kernel::cublas { - - CublasLtContext::CublasLtContext() : runtime::Resource() { - if (cublasLtCreate(&handle) != CUBLAS_STATUS_SUCCESS) { - RUNTIME_ERROR("Failed to create cublasLt handle"); - } - } - CublasLtContext::~CublasLtContext() { - if (cublasLtDestroy(handle) != CUBLAS_STATUS_SUCCESS) { - fmt::println("Failed to destroy cublasLt handle"); - abort(); - } - } - - auto CublasLtContext::typeId() noexcept -> size_t { - static uint8_t ID = 1; - return reinterpret_cast(&ID); - } - auto CublasLtContext::build() noexcept -> runtime::ResourceBox { - return std::make_unique(); - } - - auto CublasLtContext::resourceTypeId() const noexcept -> size_t { - return typeId(); - } - auto CublasLtContext::description() const noexcept -> std::string_view { - return "CublasLtContext"; - } - -}// namespace refactor::kernel::cublas diff --git a/src/04kernel/src/utilities/cuda/cublaslt_context.hh b/src/04kernel/src/utilities/cuda/cublaslt_context.hh deleted file mode 100644 index 84e1d2d90..000000000 --- a/src/04kernel/src/utilities/cuda/cublaslt_context.hh +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef KERNEL_CUBLASLT_CONTEXT_HH -#define KERNEL_CUBLASLT_CONTEXT_HH - -#include "runtime/resource.h" -#include - -#define CUBLAS_ASSERT(STATUS) \ - if (auto status = (STATUS); status != CUBLAS_STATUS_SUCCESS) { \ - fmt::println("cublas failed on \"" #STATUS "\" with {}", \ - (int) status); \ - abort(); \ - } - -namespace refactor::kernel::cublas { - - struct CublasLtContext final : public runtime::Resource { - cublasLtHandle_t handle; - - CublasLtContext(); - ~CublasLtContext(); - CublasLtContext(CublasLtContext const &) noexcept = delete; - CublasLtContext(CublasLtContext &&) noexcept = delete; - - static size_t typeId() noexcept; - static runtime::ResourceBox build() noexcept; - - size_t resourceTypeId() const noexcept final; - std::string_view description() const noexcept final; - }; - -}// namespace refactor::kernel::cublas - -#endif// KERNEL_CUBLASLT_CONTEXT_HH diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu new file mode 100644 index 000000000..ab797e8f7 --- /dev/null +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu @@ -0,0 +1,160 @@ +#include "cublaslt_utils.cuh" +#include "hardware/devices/nvidia.h" + +namespace refactor::kernel::cublas { + + CublasLtContext::CublasLtContext() : runtime::Resource() { + if (cublasLtCreate(&handle) != CUBLAS_STATUS_SUCCESS) { + RUNTIME_ERROR("Failed to create cublasLt handle"); + } + } + CublasLtContext::~CublasLtContext() { + if (cublasLtDestroy(handle) != CUBLAS_STATUS_SUCCESS) { + fmt::println("Failed to destroy cublasLt handle"); + abort(); + } + } + + auto CublasLtContext::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto CublasLtContext::build() noexcept -> runtime::ResourceBox { + return std::make_unique(); + } + + auto CublasLtContext::resourceTypeId() const noexcept -> size_t { + return typeId(); + } + auto CublasLtContext::description() const noexcept -> std::string_view { + return "CublasLtContext"; + } + + cudaDataType dataTypeConvert(DataType dt) { + switch (dt) { + case DataType::F32: + return CUDA_R_32F; + case DataType::FP16: + return CUDA_R_16F; + case DataType::BF16: + return CUDA_R_16BF; + default: + TODO(""); + } + } + cublasComputeType_t computeTypeConvert(DataType dt) { + switch (dt) { + case DataType::F32: + case DataType::BF16: + return CUBLAS_COMPUTE_32F; + case DataType::FP16: + return CUBLAS_COMPUTE_16F; + default: + TODO(""); + } + } + + MatMulDescriptor::MatMulDescriptor(cublasComputeType_t compute, cudaDataType data) + : _internal(nullptr) { + CUBLASLT_ASSERT(cublasLtMatmulDescCreate(&_internal, compute, data)); + } + MatMulDescriptor::~MatMulDescriptor() { + CUBLASLT_ASSERT(cublasLtMatmulDescDestroy(_internal)); + } + cublasLtMatmulDesc_t MatMulDescriptor::get() const noexcept { + return _internal; + } + + MatrixDescriptor::MatrixDescriptor(MatrixLayout layout) + : _internal(nullptr) { + CUBLASLT_ASSERT(cublasLtMatrixLayoutCreate( + &_internal, + layout.dataType, + layout.rows, + layout.cols, + layout.majorStride)); + CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute( + _internal, + CUBLASLT_MATRIX_LAYOUT_ORDER, + &layout.order, + sizeof(layout.order))); + CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute( + _internal, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &layout.batchCount, + sizeof(layout.batchCount))); + CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute( + _internal, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &layout.batchStride, + sizeof(layout.batchStride))); + } + MatrixDescriptor::~MatrixDescriptor() { + CUBLASLT_ASSERT(cublasLtMatrixLayoutDestroy(_internal)); + } + cublasLtMatrixLayout_t MatrixDescriptor::get() const noexcept { + return _internal; + } + + std::pair + tune(cublasLtHandle_t handle, + MatMulDescriptor const &matmul, + MatrixDescriptor const &a, + MatrixDescriptor const &b, + MatrixDescriptor const &c, + uint64_t maxWorkspace) { + + int device; + CUDA_ASSERT(cudaGetDevice(&device)); + cudaDeviceProp prop; + CUDA_ASSERT(cudaGetDeviceProperties(&prop, device)); + + uint32_t alignment = prop.textureAlignment; + + cublasLtMatmulPreference_t preference; + CUBLASLT_ASSERT(cublasLtMatmulPreferenceCreate(&preference)); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &maxWorkspace, + sizeof(maxWorkspace))); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, + &alignment, + sizeof(alignment))); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, + &alignment, + sizeof(alignment))); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, + &alignment, + sizeof(alignment))); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, + &alignment, + sizeof(alignment))); + + cublasLtMatmulHeuristicResult_t result; + int ansN; + CUBLASLT_ASSERT(cublasLtMatmulAlgoGetHeuristic( + handle, + matmul.get(), + a.get(), + b.get(), + c.get(), + c.get(), + preference, + 1, + &result, + &ansN)); + ASSERT(ansN == 1, ""); + + return {result.algo, result.workspaceSize}; + } + +}// namespace refactor::kernel::cublas diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh new file mode 100644 index 000000000..33de075a9 --- /dev/null +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh @@ -0,0 +1,76 @@ +#ifndef KERNEL_CUBLASLT_UTILS_CUH +#define KERNEL_CUBLASLT_UTILS_CUH + +#include "common.h" +#include "runtime/resource.h" +#include + +#define CUBLASLT_ASSERT(STATUS) \ + if (auto status = (STATUS); status != CUBLAS_STATUS_SUCCESS) { \ + fmt::println("cublasLt failed on \"" #STATUS "\" with {}", \ + (int) status); \ + abort(); \ + } + +namespace refactor::kernel::cublas { + + struct CublasLtContext final : public runtime::Resource { + cublasLtHandle_t handle; + + CublasLtContext(); + ~CublasLtContext(); + CublasLtContext(CublasLtContext const &) noexcept = delete; + CublasLtContext(CublasLtContext &&) noexcept = delete; + + static size_t typeId() noexcept; + static runtime::ResourceBox build() noexcept; + + size_t resourceTypeId() const noexcept final; + std::string_view description() const noexcept final; + }; + + cudaDataType dataTypeConvert(DataType); + cublasComputeType_t computeTypeConvert(DataType); + + class MatMulDescriptor { + cublasLtMatmulDesc_t _internal; + + public: + MatMulDescriptor(cublasComputeType_t, cudaDataType); + ~MatMulDescriptor(); + MatMulDescriptor(MatMulDescriptor const &) noexcept = delete; + MatMulDescriptor(MatMulDescriptor &&) noexcept = delete; + cublasLtMatmulDesc_t get() const noexcept; + }; + + struct MatrixLayout { + cudaDataType dataType; + uint64_t rows, cols; + int64_t majorStride; + cublasLtOrder_t order; + int32_t batchCount; + int64_t batchStride; + }; + + class MatrixDescriptor { + cublasLtMatrixLayout_t _internal; + + public: + MatrixDescriptor(MatrixLayout layout); + ~MatrixDescriptor(); + MatrixDescriptor(MatrixDescriptor const &) noexcept = delete; + MatrixDescriptor(MatrixDescriptor &&) noexcept = delete; + cublasLtMatrixLayout_t get() const noexcept; + }; + + std::pair + tune(cublasLtHandle_t, + MatMulDescriptor const &, + MatrixDescriptor const &, + MatrixDescriptor const &, + MatrixDescriptor const &, + uint64_t); + +}// namespace refactor::kernel::cublas + +#endif// KERNEL_CUBLASLT_UTILS_CUH diff --git a/src/04kernel/test/kernels/attention/test_cuda.cpp b/src/04kernel/test/kernels/attention/test_cuda.cpp new file mode 100644 index 000000000..794ae1748 --- /dev/null +++ b/src/04kernel/test/kernels/attention/test_cuda.cpp @@ -0,0 +1,59 @@ +#ifdef USE_CUDA + +#include "../../../src/kernels/attention/cuda_kernel.hh" +#include "hardware/device_manager.h" +#include "kernel/cuda/functions.cuh" +#include +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, AttentionCudaNoKvCache) { + // build routine + AttentionInfo info{ + .dataType = DataType::F32, + .batch = 1, + .nHead = 4, + .nKVHead = 4, + .seqLen = 31, + .headDim = 256, + .cacheLen = 0, + .concatCache = false, + .resetCache = false, + }; + auto q = Tensor::share(DataType::F32, Shape{info.batch, info.nHead, info.seqLen, info.headDim}), + k = Tensor::share(DataType::F32, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}), + v = Tensor::share(DataType::F32, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}), + o = q; + auto kernel = AttentionCuda::build(info); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto [routine, workspaceSize] = kernel->lower(res); + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto qGpu = dev.malloc(q->bytesSize()), + kGpu = dev.malloc(k->bytesSize()), + vGpu = dev.malloc(v->bytesSize()), + oGpu = dev.malloc(o->bytesSize()), + workspace = dev.malloc(workspaceSize); + // put input data + std::vector + q_(q->elementsSize(), 1), + k_(k->elementsSize(), 1), + v_(v->elementsSize(), 1), + o_(o->elementsSize()); + qGpu->copyFromHost(q_.data()); + kGpu->copyFromHost(k_.data()); + vGpu->copyFromHost(v_.data()); + // inference + { + void const *inputs[]{*qGpu, *kGpu, *vGpu}; + void *outputs[]{*oGpu}; + routine(res, *workspace, inputs, outputs); + } + cuda::sync(); +} + +#endif diff --git a/src/04kernel/test/kernels/softmax/test_cuda.cpp b/src/04kernel/test/kernels/softmax/test_cuda.cpp index 4290e852e..ce3cc1ad4 100644 --- a/src/04kernel/test/kernels/softmax/test_cuda.cpp +++ b/src/04kernel/test/kernels/softmax/test_cuda.cpp @@ -4,18 +4,19 @@ #include "../../../src/kernels/softmax/cuda_kernel.hh" #include "hardware/device_manager.h" #include +#include using namespace refactor; using namespace kernel; using namespace hardware; -TEST(kernel, SoftmaxCuda) { +static void test(Shape shape, int axis) { // build routine - auto xTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4}); - auto outTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4}); - dim_t axis = 1; - auto kCpu = SoftmaxCpu::build(SoftmaxInfo(*xTensor, axis)); - auto kCuda = SoftmaxCuda::build(SoftmaxInfo(*xTensor, axis)); + auto xTensor = Tensor::share(DataType::F32, shape); + auto outTensor = Tensor::share(DataType::F32, shape); + SoftmaxInfo info(*xTensor, axis); + auto kCpu = SoftmaxCpu::build(info); + auto kCuda = SoftmaxCuda::build(info); ASSERT_TRUE(kCpu && kCuda); auto res = runtime::Resources(); auto rCpu = kCpu->lower(res).routine; @@ -28,6 +29,7 @@ TEST(kernel, SoftmaxCuda) { std::vector data(xTensor->elementsSize(), 0), cpuOut(outTensor->elementsSize()); + std::iota(data.begin(), data.end(), 0); gpuIn->copyFromHost(data.data(), xTensor->bytesSize()); // inference { @@ -49,4 +51,9 @@ TEST(kernel, SoftmaxCuda) { } } +TEST(kernel, SoftmaxCuda) { + test({2, 3, 2, 5, 4}, 1); + test({2, 2048, 2, 5, 4}, 1); +} + #endif diff --git a/src/05computation/include/computation/operators/attention.h b/src/05computation/include/computation/operators/attention.h index d5f37997f..753df9461 100644 --- a/src/05computation/include/computation/operators/attention.h +++ b/src/05computation/include/computation/operators/attention.h @@ -6,14 +6,14 @@ namespace refactor::computation { struct Attention final : public Operator { - dim_t maxSeqLen; - constexpr Attention(decltype(maxSeqLen) maxSeqLen_) noexcept - : Operator(), maxSeqLen(maxSeqLen_) {} + constexpr Attention() noexcept = default; static size_t typeId() noexcept; size_t opTypeId() const noexcept final; std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const final; + std::string serialize() const noexcept final; }; }// namespace refactor::computation diff --git a/src/05computation/src/operators/attention.cc b/src/05computation/src/operators/attention.cc index 4624482af..b57886391 100644 --- a/src/05computation/src/operators/attention.cc +++ b/src/05computation/src/operators/attention.cc @@ -1,4 +1,5 @@ #include "computation/operators/attention.h" +#include "kernel/collectors/attention.h" namespace refactor::computation { using Op = Attention; @@ -9,5 +10,12 @@ namespace refactor::computation { } auto Op::opTypeId() const noexcept -> size_t { return typeId(); } auto Op::name() const noexcept -> std::string_view { return "Attention"; } + auto Op::candidateKernels(Target target) const -> kernel::CollectorBox { + using Collector_ = kernel::AttentionCollector; + return std::make_unique(target); + } + auto Op::serialize() const noexcept -> std::string { + return "Attention()"; + } }// namespace refactor::computation diff --git a/src/08-01llm/src/operators/attention.cc b/src/08-01llm/src/operators/attention.cc index 15479c6a1..cc2e9ce48 100644 --- a/src/08-01llm/src/operators/attention.cc +++ b/src/08-01llm/src/operators/attention.cc @@ -9,7 +9,7 @@ namespace refactor::llm { : Operator(), maxSeqLen(maxSeqLen_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto maxSeqLen = attributes.getOrInsert("max_seq_len", {0}).float_(); + auto maxSeqLen = attributes.getOrInsert("max_seq_len", {0}).int_(); return OpBox(std::make_unique(maxSeqLen)); } auto Op::typeId() -> size_t { @@ -80,10 +80,10 @@ namespace refactor::llm { if (pastSeqLen.dataType != DataType::I64 || pastSeqLen.shape != Shape{DimExpr(1)}) { return Err(InferError(ERROR_MSG("Past seqlen error"))); } - auto pastSeqLenVal = pastSeqLen.data->get()[0]; if (maxSeqLen <= 0) { + auto pastSeqLenVal = pastSeqLen.data->get()[0]; return outputs(pastSeqLenVal + seqlen); - } else if (maxSeqLen >= pastSeqLenVal + seqlen) { + } else if (maxSeqLen >= 1 + seqlen) { return outputs(maxSeqLen); } else { return Err(InferError(ERROR_MSG("max_seq_len must not less than seqlen"))); @@ -94,7 +94,6 @@ namespace refactor::llm { if (pastSeqLen.dataType != DataType::I64 || pastSeqLen.shape != Shape{DimExpr(1)}) { return Err(InferError(ERROR_MSG("Past seqlen error"))); } - auto pastSeqLenVal = pastSeqLen.data->get()[0]; auto const &kCahce = inputs[4], &vCache = inputs[5]; @@ -107,15 +106,14 @@ namespace refactor::llm { kCahce.shape[3] != kvShape[3] || kCahce.shape[0] != kvShape[0] || kCahce.shape[2] != kvShape[2] || - kCahce.shape[3] != kvShape[3] || - pastSeqLenVal < kCacheSeqLen || - pastSeqLenVal < vCacheSeqLen) { + kCahce.shape[3] != kvShape[3]) { return Err(InferError(ERROR_MSG("KV cache error"))); } if (maxSeqLen <= 0) { + auto pastSeqLenVal = pastSeqLen.data->get()[0]; return outputs(pastSeqLenVal + seqlen); - } else if (maxSeqLen >= pastSeqLenVal + seqlen) { + } else if (maxSeqLen >= 1 + seqlen) { return outputs(maxSeqLen); } else { return Err(InferError(ERROR_MSG("max_seq_len must not less than seqlen"))); @@ -129,7 +127,7 @@ namespace refactor::llm { auto Op::lower(TensorRefs) const -> computation::OpBox { using Op_ = computation::Attention; - return std::make_unique(maxSeqLen); + return std::make_unique(); } }// namespace refactor::llm diff --git a/src/08-01llm/test/test_attention.cpp b/src/08-01llm/test/test_attention.cpp new file mode 100644 index 000000000..fbe7d2e7a --- /dev/null +++ b/src/08-01llm/test/test_attention.cpp @@ -0,0 +1,44 @@ +#include "../src/operators/attention.hh" +#include "llm/operators.h" +#include + +using namespace refactor; +using namespace llm; + +TEST(infer, AttentionNoKvCache) { + llm::register_(); + auto batch = DimExpr("N"); + auto numHead = DimExpr(16); + auto seqLen = DimExpr(31); + auto headDim = DimExpr(64); + { + auto edges = Edges{ + {Tensor::share(DataType::FP16, Shape{batch, numHead, seqLen, headDim}, {}), ""}, + {Tensor::share(DataType::FP16, Shape{batch, numHead, seqLen, headDim}, {}), ""}, + {Tensor::share(DataType::FP16, Shape{batch, numHead, seqLen, headDim}, {}), ""}, + }; + count_t inputs[]{0, 1, 2}; + auto infered = Attention(0).infer(TensorRefs(edges, inputs), {true}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::FP16); + ASSERT_EQ(y->shape, edges[0].tensor->shape); + } + { + auto edges = Edges{ + {Tensor::share(DataType::FP16, Shape{batch, numHead, seqLen, headDim}, {}), ""}, + {Tensor::share(DataType::FP16, Shape{batch, DimExpr(4), seqLen, headDim}, {}), ""}, + {Tensor::share(DataType::FP16, Shape{batch, DimExpr(4), seqLen, headDim}, {}), ""}, + }; + count_t inputs[]{0, 1, 2}; + auto infered = Attention(0).infer(TensorRefs(edges, inputs), {true}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::FP16); + ASSERT_EQ(y->shape, edges[0].tensor->shape); + } +}