diff --git a/operatorspy/tests/random_sample.py b/operatorspy/tests/random_sample.py index ea680c57..344b7e21 100644 --- a/operatorspy/tests/random_sample.py +++ b/operatorspy/tests/random_sample.py @@ -31,21 +31,9 @@ class RandomSampleDescriptor(Structure): def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): indices = torch.zeros([topk], dtype = torch.int64) - dataNp = data.clone().detach() - sorted_indices = torch.arange(voc) - - for i in range(topk): - for j in range(i + 1, voc): - if(dataNp[i] < dataNp[j]): - tmp = dataNp[i].clone().detach() - dataNp[i] = dataNp[j].clone().detach() - dataNp[j] = tmp - - tmpInd = sorted_indices[i].clone().detach() - sorted_indices[i] = sorted_indices[j].clone().detach() - sorted_indices[j] = tmpInd + dataNp = data.clone() - #sorted_indices = torch.argsort(dataNp, descending=True) + sorted_indices = torch.argsort(dataNp, descending=True) indices = sorted_indices[:topk] dataNp = dataNp[sorted_indices] @@ -53,25 +41,22 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): globalM = dataNp[0] dataNp = (dataNp - globalM) / temperature dataNp = torch.softmax(dataNp.float(), dim = 0) - sum_s = 0 + + for i in range(1, topk): + dataNp[i] = dataNp[i] + dataNp[i - 1] + for end in range(topk): - sum_s += dataNp[end] - if(sum_s >= topp): + if(dataNp[end] >= topp): break if(end < topk - 1): end += 1 else: end = topk - sum_s = 0 - for i in range(end): - sum_s += dataNp[i] - random_val *= sum_s + random_val *= dataNp[end - 1] - sum_s = 0 for i in range(end): - sum_s += dataNp[i] - if(random_val < sum_s): + if(random_val < dataNp[i]): return indices[i] def random_sample_0(data): @@ -129,7 +114,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ ) if torch_device == "npu": torch.npu.synchronize() - + assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]] check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor)) print("Test passed!") @@ -168,7 +153,13 @@ def test_ascend(lib, test_cases): test(lib, handle, "npu", voc, random_val, topp, topk, temperature) destroy_handle(lib, handle) - +def test_teco(lib, test_cases): + import torch_sdaa + device = DeviceEnum.DEVICE_TECO + handle = create_handle(lib, device) + for (voc, random_val, topp, topk, temperature) in test_cases: + test(lib, handle, "sdaa", voc, random_val, topp, topk, temperature) + destroy_handle(lib, handle) if __name__ == "__main__": test_cases = [ @@ -224,6 +215,9 @@ def test_ascend(lib, test_cases): test_bang(lib, test_cases) if args.ascend: test_ascend(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend): + if args.teco: + test_teco(lib, test_cases) + + if not (args.cpu or args.cuda or args.bang or args.ascend or args.teco): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/src/ops/random_sample/operator.cc b/src/ops/random_sample/operator.cc index ff241e77..0bb830a6 100644 --- a/src/ops/random_sample/operator.cc +++ b/src/ops/random_sample/operator.cc @@ -14,6 +14,9 @@ #ifdef ENABLE_ASCEND_NPU #include "ascend/random_sample.h" #endif +#ifdef ENABLE_TECO_SDAA +#include "teco/random_sample_teco.h" +#endif __C infiniopStatus_t infiniopCreateRandomSampleDescriptor(infiniopHandle_t handle, infiniopRandomSampleDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, infiniopTensorDescriptor_t probs) { switch (handle->device) { @@ -35,8 +38,14 @@ __C infiniopStatus_t infiniopCreateRandomSampleDescriptor(infiniopHandle_t handl #ifdef ENABLE_ASCEND_NPU case DevAscendNpu: { return ascendCreateRandomSampleDescriptor((AscendHandle_t) handle, - (RandomSampleAscendDescriptor_t *) desc_ptr, result, probs); + (RandomSampleAscendDescriptor_t *) desc_ptr, result, probs); } +#endif +#ifdef ENABLE_TECO_SDAA + case DevTecoSDAA: + return tecoCreateRandomSampleDescriptor((TecoHandle_t) handle, + (RandomSampleTecoDescriptor_t *) desc_ptr, result, probs); + ; #endif } return STATUS_BAD_DEVICE; @@ -64,6 +73,10 @@ __C infiniopStatus_t infiniopGetRandomSampleWorkspaceSize(infiniopRandomSampleDe case DevAscendNpu: { return ascendGetRandomSampleWorkspaceSize((RandomSampleAscendDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_TECO_SDAA + case DevTecoSDAA: + return tecoGetRandomSampleWorkspaceSize((RandomSampleTecoDescriptor_t) desc, size); #endif } return STATUS_BAD_DEVICE; @@ -97,6 +110,10 @@ __C infiniopStatus_t infiniopRandomSample(infiniopRandomSampleDescriptor_t desc, case DevAscendNpu: { return ascendRandomSample((RandomSampleAscendDescriptor_t) desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream); } +#endif +#ifdef ENABLE_TECO_SDAA + case DevTecoSDAA: + return tecoRandomSample((RandomSampleTecoDescriptor_t) desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream); #endif } return STATUS_BAD_DEVICE; @@ -121,6 +138,10 @@ __C infiniopStatus_t infiniopDestroyRandomSampleDescriptor(infiniopRandomSampleD case DevAscendNpu: { return ascendDestroyRandomSampleDescriptor((RandomSampleAscendDescriptor_t) desc); } +#endif +#ifdef ENABLE_TECO_SDAA + case DevTecoSDAA: + return tecoDestroyRandomSampleDescriptor((RandomSampleTecoDescriptor_t) desc); #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/random_sample/teco/random_sample_teco.h b/src/ops/random_sample/teco/random_sample_teco.h new file mode 100644 index 00000000..f8432295 --- /dev/null +++ b/src/ops/random_sample/teco/random_sample_teco.h @@ -0,0 +1,42 @@ +#ifndef __SDAA_RANDOM_SAMPLE_H__ +#define __SDAA_RANDOM_SAMPLE_H__ + +#include "../../../devices/teco/teco_handle.h" +#include "../../utils.h" +#include "operators.h" +#include + +struct RandomSampleTecoDescriptor { + Device device; + int device_id; + tecodnnHandle_t handle; + sdaaStream_t stream; + DT dtype; + int voc; + DT rDtype; + int rLength; +}; + +typedef struct RandomSampleTecoDescriptor *RandomSampleTecoDescriptor_t; + +infiniopStatus_t tecoCreateRandomSampleDescriptor(TecoHandle_t handle, + RandomSampleTecoDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs); + +infiniopStatus_t tecoGetRandomSampleWorkspaceSize(RandomSampleTecoDescriptor_t desc, uint64_t *size); + +infiniopStatus_t tecoRandomSample(RandomSampleTecoDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *result, + void const *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream); + +infiniopStatus_t tecoDestroyRandomSampleDescriptor(RandomSampleTecoDescriptor_t desc); + + +#endif diff --git a/src/ops/random_sample/teco/random_sample_teco.scpp b/src/ops/random_sample/teco/random_sample_teco.scpp new file mode 100644 index 00000000..ddda8cd5 --- /dev/null +++ b/src/ops/random_sample/teco/random_sample_teco.scpp @@ -0,0 +1,273 @@ +#include "random_sample_teco.h" + +infiniopStatus_t tecoCreateRandomSampleDescriptor(TecoHandle_t handle, + RandomSampleTecoDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs) { + if (probs->ndim != 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!dtype_eq(probs->dt, F16)) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (!dtype_eq(result->dt, U64)) + return STATUS_BAD_TENSOR_DTYPE; + + int voc = probs->shape[0]; + int rLength = result->shape[0]; + + if (result->ndim != 1 && rLength != 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + + tecodnnHandle_t tecodnn_handle; + tecodnnCreate(&tecodnn_handle); + + *desc_ptr = new RandomSampleTecoDescriptor{ + handle->device, + handle->device_id, + tecodnn_handle, + handle->stream, + probs->dt, + voc, + result->dt, + rLength}; + tecodnnSetStream((*desc_ptr)->handle, (*desc_ptr)->stream); + return STATUS_SUCCESS; +} + +infiniopStatus_t tecoGetRandomSampleWorkspaceSize(RandomSampleTecoDescriptor_t desc, uint64_t *size) { + *size = desc->voc * (sizeof(uint64_t) + 3 * sizeof(desc->dtype)); + return STATUS_SUCCESS; +} + +infiniopStatus_t tecoDestroyRandomSampleDescriptor(RandomSampleTecoDescriptor_t desc) { + //sdaaStreamDestroy(desc->stream); + delete desc; + return STATUS_SUCCESS; +} + +template +void topkKernel(const void *probs, void *index, void *value, int topk, int voc, tecodnnHandle_t handle, sdaaStream_t stream) { + tecodnnSetStream(handle, stream); + tecodnnStatus_t status; + + tecodnnTensorDescriptor_t input_desc_teco, value_desc_teco, index_desc_teco; + tecodnnCreateTensorDescriptor(&input_desc_teco); + tecodnnCreateTensorDescriptor(&value_desc_teco); + tecodnnCreateTensorDescriptor(&index_desc_teco); + + int32_t probsDims[2] = {1, voc}, probsStrides[2] = {voc, 1}; + int32_t resultDims[2] = {1, topk}, resultStrides[2] = {topk, 1}; + if constexpr (sizeof(T) == 2) { + tecodnnSetTensorNdDescriptor(input_desc_teco, TECODNN_DATA_HALF, 2, probsDims, probsStrides); + tecodnnSetTensorNdDescriptor(value_desc_teco, TECODNN_DATA_HALF, 2, resultDims, resultStrides); + } else if constexpr (sizeof(T) == 4) { + tecodnnSetTensorNdDescriptor(input_desc_teco, TECODNN_DATA_FLOAT, 2, probsDims, probsStrides); + tecodnnSetTensorNdDescriptor(value_desc_teco, TECODNN_DATA_FLOAT, 2, resultDims, resultStrides); + } + + tecodnnSetTensorNdDescriptor(index_desc_teco, TECODNN_DATA_INT64, 2, resultDims, resultStrides); + + size_t workSpaceSizeInBytes; + int axis = 1; + bool largest = true; + bool sorted = true; + tecodnnGetTopkExWorkspaceSize(handle, axis, topk, largest, sorted, input_desc_teco, value_desc_teco, + index_desc_teco, &workSpaceSizeInBytes); + void *compute_workspace; + sdaaMalloc((void **) &compute_workspace, workSpaceSizeInBytes); + + status = tecodnnTopkEx(handle, axis, topk, largest, sorted, input_desc_teco, probs, value_desc_teco, value, + index_desc_teco, index, compute_workspace, workSpaceSizeInBytes); + sdaaStreamSynchronize(stream); + sdaaFree(compute_workspace); + tecodnnDestroyTensorDescriptor(input_desc_teco); + tecodnnDestroyTensorDescriptor(value_desc_teco); + tecodnnDestroyTensorDescriptor(index_desc_teco); + if (status != TECODNN_STATUS_SUCCESS) { + printf("topk %s\n", tecodnnGetErrorString(status)); + } +} +template +void softmaxKernel(const void *probs, void *destination, int voc, tecodnnHandle_t handle, sdaaStream_t stream) { + tecodnnSetStream(handle, stream); + tecodnnStatus_t status; + + tecodnnTensorDescriptor_t x_desc_teco, y_desc_teco; + tecodnnCreateTensorDescriptor(&x_desc_teco); + tecodnnCreateTensorDescriptor(&y_desc_teco); + + tecodnnSoftmaxAlgorithm_t algo = TECODNN_SOFTMAX_ACCURATE; + tecodnnSoftmaxMode_t mode = TECODNN_SOFTMAX_MODE_INSTANCE; + float alpha = 1.0f, beta = 0.0f; + if constexpr (sizeof(T) == 2) { + tecodnnSetTensor4dDescriptor(x_desc_teco, TECODNN_TENSOR_NHWC, TECODNN_DATA_HALF, 1, 1, 1, voc); + tecodnnSetTensor4dDescriptor(y_desc_teco, TECODNN_TENSOR_NHWC, TECODNN_DATA_HALF, 1, 1, 1, voc); + } else if constexpr (sizeof(T) == 4) { + tecodnnSetTensor4dDescriptor(x_desc_teco, TECODNN_TENSOR_NHWC, TECODNN_DATA_FLOAT, 1, 1, 1, voc); + tecodnnSetTensor4dDescriptor(y_desc_teco, TECODNN_TENSOR_NHWC, TECODNN_DATA_FLOAT, 1, 1, 1, voc); + } + + status = tecodnnSoftmaxForward(handle, algo, mode, &alpha, x_desc_teco, probs, &beta, y_desc_teco, destination); + sdaaStreamSynchronize(stream); + tecodnnDestroyTensorDescriptor(x_desc_teco); + tecodnnDestroyTensorDescriptor(y_desc_teco); + if (status != TECODNN_STATUS_SUCCESS) { + printf("softmax %s\n", tecodnnGetErrorString(status)); + } +} + +template +__global__ void memKernel(T *destination, T *value, int64_t *index, int topk){ + int remain = topk % threadDim; + int step_easy = (topk - remain) / threadDim; + int step_hard = step_easy + 1; + int step = (threadIdx < remain ? step_hard : step_easy); + int ind_start = (threadIdx < remain ? threadIdx * step_hard : remain * step_hard + (threadIdx - remain) * step_easy); + + for(int i = ind_start; i < ind_start + step; i++){ + value[i] = destination[index[i]]; + } +} + +template +void cumSumKernel(void *value, void *scan_value, int topk_, tecodnnHandle_t handle, sdaaStream_t stream) { + tecodnnSetStream(handle, stream); + tecodnnStatus_t status; + + tecodnnTensorDescriptor_t a_desc_teco, c_desc_teco; + tecodnnCreateTensorDescriptor(&a_desc_teco); + tecodnnCreateTensorDescriptor(&c_desc_teco); + + if constexpr (sizeof(T) == 2) { + tecodnnSetTensor4dDescriptor(a_desc_teco, TECODNN_TENSOR_NCHW, TECODNN_DATA_HALF, 1, 1, 1, topk_); + tecodnnSetTensor4dDescriptor(c_desc_teco, TECODNN_TENSOR_NCHW, TECODNN_DATA_HALF, 1, 1, 1, topk_); + } else if constexpr (sizeof(T) == 4) { + tecodnnSetTensor4dDescriptor(a_desc_teco, TECODNN_TENSOR_NCHW, TECODNN_DATA_FLOAT, 1, 1, 1, topk_); + tecodnnSetTensor4dDescriptor(c_desc_teco, TECODNN_TENSOR_NCHW, TECODNN_DATA_FLOAT, 1, 1, 1, topk_); + } + + + status = tecodnnCumSum(handle, 3, a_desc_teco, value, c_desc_teco, scan_value); + sdaaStreamSynchronize(stream); + tecodnnDestroyTensorDescriptor(a_desc_teco); + tecodnnDestroyTensorDescriptor(c_desc_teco); + if (status != TECODNN_STATUS_SUCCESS) { + printf("scan %s\n", tecodnnGetErrorString(status)); + } +} +template +__global__ void sample(T *scan_value, int64_t *index, uint64_t *result, float random_val, float topp, int topk) { + if(threadIdx == 0){ + int end = 0; + for (end = 0; end < topk; end++) { + + if (static_cast(scan_value[end]) >= topp) { + break; + } + } + + if (end < topk - 1) { + end += 1; + } else { + end = topk; + } + + random_val *= static_cast(scan_value[end - 1]); + + for (int i = 0; i < end; i++) { + if (random_val < static_cast(scan_value[i])) { + result[0] = static_cast(index[i]); + break; + } + } + } + +} + +__global__ void randomSampleKernel(uint64_t *result, int64_t *index){ + if(threadIdx == 0){ + result[0] = index[0]; + } +} +infiniopStatus_t tecoRandomSample(RandomSampleTecoDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *result, + const void *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream) { + + if (dtype_eq(desc->dtype, F16)) { + int topk_ = ((topk + 31) / 32) * 32;//cumsum要求64B对齐 + + char *origin = reinterpret_cast(workspace); + half *value = (half *) origin; + half *scan_value = value + topk_; + half *destination = scan_value + topk_; + + char *tmp_index = origin + (2 * topk_ + desc->voc) * sizeof(half); + int64_t *index = (int64_t *) tmp_index; + + + tecodnnMemset(desc->handle, value, 0, topk_); + + int voc = desc->voc; + tecodnnSetStream(desc->handle, desc->stream); + tecodnnStatus_t status; + + tecodnnTensorDescriptor_t input_desc_teco, value_desc_teco, index_desc_teco; + tecodnnCreateTensorDescriptor(&input_desc_teco); + tecodnnCreateTensorDescriptor(&value_desc_teco); + tecodnnCreateTensorDescriptor(&index_desc_teco); + + int32_t probsDims[2] = {1, voc}, probsStrides[2] = {voc, 1}; + int32_t resultDims[2] = {1, topk}, resultStrides[2] = {topk, 1}; + tecodnnSetTensorNdDescriptor(input_desc_teco, TECODNN_DATA_HALF, 2, probsDims, probsStrides); + tecodnnSetTensorNdDescriptor(value_desc_teco, TECODNN_DATA_HALF, 2, resultDims, resultStrides); + + tecodnnSetTensorNdDescriptor(index_desc_teco, TECODNN_DATA_INT64, 2, resultDims, resultStrides); + + size_t workSpaceSizeInBytes; + int axis = 1; + bool largest = true; + bool sorted = true; + tecodnnGetTopkExWorkspaceSize(desc->handle, axis, topk, largest, sorted, input_desc_teco, value_desc_teco, + index_desc_teco, &workSpaceSizeInBytes); + void *compute_workspace; + sdaaMalloc((void **) &compute_workspace, workSpaceSizeInBytes); + + status = tecodnnTopkEx(desc->handle, axis, topk, largest, sorted, input_desc_teco, probs, value_desc_teco, value, + index_desc_teco, index, compute_workspace, workSpaceSizeInBytes); + sdaaStreamSynchronize(desc->stream); + sdaaFree(compute_workspace); + tecodnnDestroyTensorDescriptor(input_desc_teco); + tecodnnDestroyTensorDescriptor(value_desc_teco); + tecodnnDestroyTensorDescriptor(index_desc_teco); + if (status != TECODNN_STATUS_SUCCESS) { + printf("topk %s\n", tecodnnGetErrorString(status)); + } + + //topkKernel(probs, (void *)index, (void *)value, topk, desc->voc, desc->handle, desc->stream); + + if (topp > 0 && topk > 1){ + softmaxKernel(probs, (void *) destination, desc->voc, desc->handle, desc->stream); + memKernel<<<1, desc->stream>>>(destination, value, index, topk); + sdaaDeviceSynchronize(); + cumSumKernel((void *) value, (void *) scan_value, topk_, desc->handle, desc->stream); + sample<<<1, desc->stream>>>(scan_value, index, (uint64_t *) result, random_val, topp, topk); + sdaaDeviceSynchronize(); + } + else { + randomSampleKernel<<<1, desc->stream>>>((uint64_t *)result, index); + sdaaDeviceSynchronize(); + } + + + return STATUS_SUCCESS; + } + return STATUS_BAD_TENSOR_DTYPE; +}