diff --git a/include/ops/softmax/softmax.h b/include/ops/softmax/softmax.h new file mode 100644 index 00000000..adc4a887 --- /dev/null +++ b/include/ops/softmax/softmax.h @@ -0,0 +1,27 @@ +#ifndef SOFTMAX_H +#define SOFTMAX_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct SoftmaxDescriptor { + Device device; +} SoftmaxDescriptor; + +typedef SoftmaxDescriptor *infiniopSoftmaxDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateSoftmaxDescriptor(infiniopHandle_t handle, + infiniopSoftmaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t input_desc, int axis, infiniopTensorDescriptor_t output_desc); + +__C infiniopStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t desc, uint64_t *size); +__C __export infiniopStatus_t infiniopSoftmax(infiniopSoftmaxDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *input, + void *output, + void *stream); + +__C __export infiniopStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc); + + +#endif diff --git a/operatorspy/tests/softmax.py b/operatorspy/tests/softmax.py new file mode 100644 index 00000000..aa3be9ea --- /dev/null +++ b/operatorspy/tests/softmax.py @@ -0,0 +1,144 @@ +from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p +import ctypes +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + create_workspace, + check_error, + rearrange_tensor, +) + +from operatorspy.tests.test_utils import get_args +import torch + + +class SoftmaxDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopSoftmaxDescriptor_t = POINTER(SoftmaxDescriptor) + + +def softmax(x, axis): + return torch.softmax(x, axis = axis).to(x.dtype) + + +def test(lib, handle, torch_device, x_shape, axis, x_dtype=torch.float16): + print( + f"Testing Softmax on {torch_device} with x_shape:{x_shape} , axis:{axis} ,dtype:{x_dtype}" + ) + x = torch.rand(x_shape, dtype=x_dtype).to(torch_device) + y = torch.rand(x_shape, dtype=x_dtype).to(torch_device) + ans = softmax(x, axis) + x_tensor = to_tensor(x, lib) + y_tensor = to_tensor(y, lib) + descriptor = infiniopSoftmaxDescriptor_t() + check_error( + lib.infiniopCreateSoftmaxDescriptor( + handle, ctypes.byref(descriptor), x_tensor.descriptor, axis, y_tensor.descriptor + ) + ) + workspace_size = c_uint64(0) + check_error( + lib.infiniopGetSoftmaxWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = create_workspace(workspace_size.value, torch_device) + check_error( + lib.infiniopSoftmax( + descriptor, + workspace.data_ptr() if workspace is not None else None, + workspace_size.value, + x_tensor.data, + y_tensor.data, + None, + ) + ) + err = y.reshape(-1,1) - ans.reshape(-1,1) + print(max(abs(err))) + assert torch.allclose(y, ans, atol=0, rtol=1e-2) + check_error(lib.infiniopDestroySoftmaxDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for x_shape, axis, x_dtype in test_cases: + test(lib, handle, "cpu", x_shape, axis, x_dtype) + destroy_handle(lib, handle) + + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for x_shape, axis, x_dtype in test_cases: + test(lib, handle, "cuda", x_shape, axis, x_dtype) + destroy_handle(lib, handle) + + +def test_bang(lib, test_cases): + import torch_mlu + + device = DeviceEnum.DEVICE_BANG + handle = create_handle(lib, device) + for x_shape, axis, x_dtype in test_cases: + test(lib, handle, "mlu", x_shape, axis, x_dtype) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # x_shape, axis + # 寒武纪芯片的国产CPU可能不支持f16 + ((32, 20, 512), 0, torch.float16), + ((32, 20, 512), 1, torch.float16), + ((32, 20, 512), 2, torch.float16), + + ((32, 20, 512), 0, torch.float32), + ((32, 20, 512), 1, torch.float32), + ((32, 20, 512), 2, torch.float32), + + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateSoftmaxDescriptor.restype = c_int32 + lib.infiniopCreateSoftmaxDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopSoftmaxDescriptor_t), + infiniopTensorDescriptor_t, + ] + + lib.infiniopSoftmax.restype = c_int32 + lib.infiniopSoftmax.argtypes = [ + infiniopSoftmaxDescriptor_t, + c_void_p, + c_uint64, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroySoftmaxDescriptor.restype = c_int32 + lib.infiniopDestroySoftmaxDescriptor.argtypes = [ + infiniopSoftmaxDescriptor_t, + ] + + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + if args.bang: + test_bang(lib, test_cases) + + if not (args.cpu or args.cuda or args.bang): + test_cpu(lib, test_cases) + print("Test passed!") diff --git a/src/ops/softmax/bang/softmax_bang.cc b/src/ops/softmax/bang/softmax_bang.cc new file mode 100644 index 00000000..343c08f2 --- /dev/null +++ b/src/ops/softmax/bang/softmax_bang.cc @@ -0,0 +1,61 @@ +#include "softmax_bang.h" +#include "../../utils.h" + +infiniopStatus_t bangCreateSoftmaxDescriptor(BangHandle_t handle, + SoftmaxBangDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t input_desc, int axis, infiniopTensorDescriptor_t output_desc) { + + if (input_desc->ndim != output_desc->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!dtype_eq(input_desc->dt, F16) && !dtype_eq(input_desc->dt, F32)) { + return STATUS_BAD_TENSOR_DTYPE; + } + + int ndim = input_desc->ndim; + + for (int i = 0; i < ndim; i++) { + if (input_desc->shape[i] != output_desc->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + + int stride = 1; + int dimsize = static_cast(input_desc->shape[axis]); + int othersize = 1; + int frontsize = 1; + + for (int s = ndim - 1; s >= 0; s--) { + if (s > axis) { + stride *= static_cast(input_desc->shape[s]); + } + if (s < axis) { + frontsize *= static_cast(input_desc->shape[s]); + } + if (s != axis) { + othersize *= static_cast(input_desc->shape[s]); + } + } + *desc_ptr = new SoftmaxBangDescriptor{ + handle->device, + handle->device_id, + input_desc->dt, + ndim, + axis, + dimsize, + stride, + othersize, + frontsize}; + + return STATUS_SUCCESS; +} +infiniopStatus_t bangGetSoftmaxWorkspaceSize(SoftmaxBangDescriptor_t desc, unsigned long int *size) { + *size = 32 * desc->othersize * sizeof(desc->dtype);//taskDim * othersize * sizeof(T),taskDim不超过32 + return STATUS_SUCCESS; +} + +infiniopStatus_t bangDestroySoftmaxDescriptor(SoftmaxBangDescriptor_t desc) { + + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/softmax/bang/softmax_bang.h b/src/ops/softmax/bang/softmax_bang.h new file mode 100644 index 00000000..76310928 --- /dev/null +++ b/src/ops/softmax/bang/softmax_bang.h @@ -0,0 +1,36 @@ +#ifndef __BANG_SOFTMAX_H__ +#define __BANG_SOFTMAX_H__ + +#include "../../../devices/bang/bang_handle.h" +#include "../../utils.h" +#include "operators.h" + +struct SoftmaxBangDescriptor { + Device device; + int device_id; + DT dtype; + int ndim; + int axis; + int dimsize; + int stride; + int othersize; + int frontsize; +}; + +typedef struct SoftmaxBangDescriptor *SoftmaxBangDescriptor_t; + +infiniopStatus_t bangCreateSoftmaxDescriptor(BangHandle_t handle, + SoftmaxBangDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t input_desc, int axis, infiniopTensorDescriptor_t output_desc); + +infiniopStatus_t bangGetSoftmaxWorkspaceSize(SoftmaxBangDescriptor_t desc, unsigned long int *size); +infiniopStatus_t bangSoftmax(SoftmaxBangDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *input, + void *output, + void *stream); + +infiniopStatus_t bangDestroySoftmaxDescriptor(SoftmaxBangDescriptor_t desc); + + +#endif \ No newline at end of file diff --git a/src/ops/softmax/bang/softmax_bang.mlu b/src/ops/softmax/bang/softmax_bang.mlu new file mode 100644 index 00000000..1af3417b --- /dev/null +++ b/src/ops/softmax/bang/softmax_bang.mlu @@ -0,0 +1,996 @@ +#include "../../../devices/bang/common_bang.h" +#include "bang.h" +#include "softmax_bang.h" +#include "cnrt.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void softmaxKernelAxis_e(T *destination, T const *source, int othersize, int dimsize, int dimS) {// axis = -1 + + const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 8; + const int wSize = 128 / sizeof(T); + + const int maxNum = SRC_MAX_SIZE/sizeof(T); + __nram__ T srcMax[2]; + if(dimsize >= maxNum){ + T *src = (T *)nram_buffer; + T *destSum = src + 3 * maxNum; + T *destSumFinal = destSum + maxNum; + T destOldMax; + T destNewMax; + + int remain = dimsize % maxNum; + int repeat = (dimsize - remain)/maxNum; + + int otherRemain = othersize % taskDim; + int stepEasy = (othersize - otherRemain) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < otherRemain ? stepHard : stepEasy); + int startHard = taskId * stepHard; + int startEasy = otherRemain * stepHard + (taskId - otherRemain) * stepEasy; + int indStart = (taskId < otherRemain ? startHard : startEasy); + source = source + indStart * dimsize; + destination = destination + indStart * dimsize; + + for(int s = 0; s < step; s++){ + + destOldMax = -INFINITY; + destNewMax = -INFINITY; + __bang_write_zero(destSum, maxNum); + for(int i = 0; i < repeat + 1; i++){ + if(i < repeat){ + __memcpy_async(src + i % 2 * maxNum, source + s * dimsize + i * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(i > 0){ + __bang_argmax(srcMax, src + (i - 1) % 2 * maxNum, maxNum); + if(destNewMax < srcMax[0]){ + destNewMax = srcMax[0]; + } + __bang_sub_scalar(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, destNewMax, maxNum); + __bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum); + if(i > 1){ + __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum); + } + __bang_add(destSum, destSum, src + (i - 1) % 2 * maxNum, maxNum); + destOldMax = destNewMax; + } + __sync_all_ipu(); + } + //------------ + if(remain){ + __bang_write_value(src, maxNum, -INFINITY); + __memcpy(src, source + s * dimsize + repeat * maxNum, remain * sizeof(T), GDRAM2NRAM); + + __bang_argmax(srcMax, src, maxNum); + if(destNewMax < srcMax[0]){ + destNewMax = srcMax[0]; + } + + __bang_sub_scalar(src, src, destNewMax, maxNum); + __bang_active_exp_less_0(src, src, maxNum); + if(repeat > 0){ + __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum); + } + __bang_add(destSum, destSum, src, maxNum); + destOldMax = destNewMax; + } + //-------------- + //-------------------------------- + + int segNum = maxNum / wSize; + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize); + + //----------- + T globalSumInv = 1.0/destSumFinal[0]; + for(int i = 0; i < repeat + 2; i++){ + if(i < repeat){ + __memcpy_async(src + i % 3 * maxNum, source + s * dimsize + i * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(i > 0 && i < repeat + 1){ + __bang_sub_scalar(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, destNewMax, maxNum); + __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum); + __bang_mul_scalar(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, globalSumInv, maxNum); + } + if(i > 1){ + __memcpy_async(destination + s * dimsize + (i - 2) * maxNum, src + (i - 2) % 3 * maxNum, maxNum * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu(); + + } + if(remain){ + __bang_write_value(src, maxNum, destNewMax); + __memcpy(src, source + s * dimsize + repeat * maxNum, remain * sizeof(T), GDRAM2NRAM); + __bang_sub_scalar(src, src, destNewMax, maxNum); + __bang_active_exp_less_0(src, src, maxNum); + __bang_mul_scalar(src, src, globalSumInv, maxNum); + __memcpy(destination + s * dimsize + repeat * maxNum, src, remain * sizeof(T), NRAM2GDRAM); + } + } + + } + else{ + int multiple = maxNum / dimsize;//一个src可以处理multiple个otherIdx + int size = taskDim * multiple;//所有core可以处理size个otherIdx + int remain = othersize % size;// remain < taskDim * multiple + int repeat = (othersize - remain) / size; + + int remainT = remain % taskDim; + int stepEasy = (remain - remainT) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remainT ? stepHard : stepEasy); + int startHard = taskId * stepHard * dimsize;//前面remainT个taskId分配到stepHard个dimsize + int startEasy = remainT * stepHard * dimsize + (taskId - remainT) * stepEasy * dimsize; + int indStart = (taskId < remainT ? startHard : startEasy); + + //-----------------------------------------allocate memory + T* src = (T *)nram_buffer;//src[maxNum] + T* tmp = src + 3 * maxNum;//tmp[dimS] + T* destSum = tmp + dimS;//destSum[dimS],dimS >= max(dimsize, wSize), dimS = pow(2,K) ,pow(2,K - 1) < dimsize + T* destSumFinal = destSum + wSize; + //----------------------------------------- + //printf("taskId:%d, repeat:%d, step:%d, repeatDim:%d, indstart:%d, %d\n", taskId, repeat, step, repeatDim, indStart, indStart * dimsize); + int tid; + __bang_write_value(tmp, dimS, -INFINITY); + __bang_write_zero(destSum, dimS); + if(repeat >= 2){ + int s = 0; + tid = s * size * dimsize + taskId * multiple * dimsize; + __memcpy(src + s % 3 * maxNum, source + tid, multiple * dimsize * sizeof(T), GDRAM2NRAM); + s = 1; + tid = s * size * dimsize + taskId * multiple * dimsize; + __memcpy_async(src + s % 3 * maxNum, source + tid, multiple * dimsize * sizeof(T), GDRAM2NRAM); + + // compute ------------------------ + for(int j = 0; j < multiple; j++){ + + __memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM); + __bang_argmax(srcMax, tmp, dimS); + __bang_sub_scalar(tmp, tmp, srcMax[0], dimS); + __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM); + } + __bang_active_exp_less_0(src + (s - 1) %3 * maxNum, src + (s - 1) %3 * maxNum, maxNum); + for(int j = 0; j < multiple; j++){ + + __memcpy(destSum, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM); + __memcpy(tmp, destSum, dimsize * sizeof(T), NRAM2NRAM); + int segNum = dimS / wSize;//Starting numerical summation + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize); + T globalSumInv = 1.0/destSumFinal[0]; + __bang_mul_scalar(tmp, tmp, globalSumInv, dimS); + + __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM); + } + // compute ------------------------ + + for(int s = 2; s < repeat; s++){ + tid = (s - 2) * size * dimsize + taskId * multiple * dimsize; + __memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * dimsize * sizeof(T), NRAM2GDRAM); + + tid = s * size * dimsize + taskId * multiple * dimsize; + __memcpy_async(src + s % 3 * maxNum, source + tid, multiple * dimsize * sizeof(T), GDRAM2NRAM); + + // compute ------------------------ + + __bang_argmax(srcMax, src + (s - 1) %3 * maxNum, maxNum);//这一段特殊处理取全局max + __bang_sub_scalar(src + (s - 1) %3 * maxNum, src + (s - 1) %3 * maxNum, srcMax[0], maxNum); + __bang_active_exp_less_0(src + (s - 1) %3 * maxNum, src + (s - 1) %3 * maxNum, maxNum); + + for(int j = 0; j < multiple; j++){ + __memcpy(destSum, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM); + __memcpy(tmp, destSum, dimsize * sizeof(T), NRAM2NRAM); + int segNum = dimS / wSize;//Starting numerical summation + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize); + T globalSumInv = 1.0/destSumFinal[0]; + __bang_mul_scalar(tmp, tmp, globalSumInv, dimS); + + __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM); + } + // compute ------------------------ + } + s = repeat; + tid = (s - 2) * size * dimsize + taskId * multiple * dimsize; + __memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * dimsize * sizeof(T), NRAM2GDRAM); + // compute ------------------------ + for(int j = 0; j < multiple; j++){ + + __memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM); + __bang_argmax(srcMax, tmp, dimS); + __bang_sub_scalar(tmp, tmp, srcMax[0], dimS); + __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM); + } + __bang_active_exp_less_0(src + (s - 1) %3 * maxNum, src + (s - 1) %3 * maxNum, maxNum); + for(int j = 0; j < multiple; j++){ + + __memcpy(destSum, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM); + __memcpy(tmp, destSum, dimsize * sizeof(T), NRAM2NRAM); + int segNum = dimS / wSize;//Starting numerical summation + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize); + T globalSumInv = 1.0/destSumFinal[0]; + __bang_mul_scalar(tmp, tmp, globalSumInv, dimS); + + __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM); + } + // compute ------------------------ + s = repeat + 1; + tid = (s - 2) * size * dimsize + taskId * multiple * dimsize; + __memcpy(destination + tid, src + (s - 2) % 3 * maxNum, multiple * dimsize * sizeof(T), NRAM2GDRAM); + } + else{ + for(int s = 0; s < repeat + 2; s++){ + if(s < repeat){ + tid = s * size * dimsize + taskId * multiple * dimsize; + __memcpy_async(src + s % 3 * maxNum, source + tid, multiple * dimsize * sizeof(T), GDRAM2NRAM); + } + if(s > 0 && s < repeat + 1){ + // compute ------------------------ + + for(int j = 0; j < multiple; j++){ + __memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM); + __bang_argmax(srcMax, tmp, dimS); + __bang_sub_scalar(tmp, tmp, srcMax[0], dimS); + __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM); + } + __bang_active_exp_less_0(src + (s - 1) %3 * maxNum, src + (s - 1) %3 * maxNum, maxNum); + + for(int j = 0; j < multiple; j++){ + __memcpy(destSum, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM); + __memcpy(tmp, destSum, dimsize * sizeof(T), NRAM2NRAM); + int segNum = dimS / wSize;//Starting numerical summation + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize); + T globalSumInv = 1.0/destSumFinal[0]; + __bang_mul_scalar(tmp, tmp, globalSumInv, dimS); + + __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM); + } + // compute ------------------------ + } + if(s > 1){ + tid = (s - 2) * size * dimsize + taskId * multiple * dimsize; + __memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * dimsize * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu();//如果maxNum比较小,此时访存时间>计算时间,无法延迟 + } + } + if(step){ + tid = repeat * size * dimsize + indStart; + __memcpy(src, source + tid, step * dimsize * sizeof(T), GDRAM2NRAM); + for(int s = 0; s < step; s++){//Step targets parts of othersize that cannot be divided by multiple * dimsize + __bang_write_zero(destSum, dimS); + + __bang_write_value(tmp, dimS, -INFINITY); + __memcpy(tmp, src + s * dimsize, dimsize * sizeof(T), NRAM2NRAM); + + __bang_argmax(srcMax, tmp, dimS); + + __bang_sub_scalar(tmp, tmp, srcMax[0], dimS); + + __bang_active_exp_less_0(tmp, tmp, dimS); + __memcpy(destSum, tmp, dimsize * sizeof(T), NRAM2NRAM); + + int segNum = dimS / wSize; + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize); + + T globalSumInv = 1.0/destSumFinal[0]; + __bang_mul_scalar(tmp, tmp, globalSumInv, dimS); + __memcpy(src + s * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM); + } + __memcpy(destination + tid, src, step * dimsize * sizeof(T), NRAM2GDRAM); + } + + } +} +template +__mlu_global__ void softmaxKernelAxis_s(T *destination, T const *source, T *tmpGdram, int othersize, int dimsize, int stride) {// axis = 0 + + const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 8; + + + const int maxNum = SRC_MAX_SIZE/sizeof(T); + if(othersize > taskDim * maxNum){ + //-----------------------------------------allocate memory + T* src = (T *)nram_buffer;// src[3 * maxNum] + T* tmpSum = src + 3 * maxNum;//tmpSum[maxNum] + T* tmpNewMax = src + 4 * maxNum;//tmpNewMax[maxNum] + T* tmpOldMax = src + 5 * maxNum;//tmpOldMax[maxNum] + //----------------------------------------- + int remain = othersize % taskDim; + int stepEasy = (othersize - remain)/taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remain ? stepHard : stepEasy);//The first part of taskId handles an additional element + int indStart = (taskId < remain ? taskId * stepHard : remain * stepHard + (taskId - remain) * stepEasy); + int remainNram = step%maxNum; + int repeat = (step - remainNram)/maxNum; + + for(int j = 0; j < repeat; j++){ + __bang_write_value(tmpNewMax, maxNum, -INFINITY); + __bang_write_zero(tmpSum, maxNum); + for(int i = 0; i < dimsize + 1; i++){ + if(i < dimsize){ + __memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(i > 0){ + __bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);//Continuously updating the maximum value + __bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M) + if(i > 1){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(T), NRAM2NRAM);//oldM = newM + } + __sync_all_ipu(); + } + __bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum + + for(int i = 0; i < dimsize + 2; i++){ + if(i < dimsize){ + __memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(i > 0 && i < dimsize + 1){ + __bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M) + __bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum); + } + if(i > 1){ + __memcpy_async(destination + (i - 2) * stride + indStart + j * maxNum, src + (i - 2) % 3 * maxNum, maxNum * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu(); + } + } + if(remainNram){ + __bang_write_value(tmpNewMax, maxNum, -INFINITY); + __bang_write_zero(tmpSum, maxNum); + __bang_write_zero(src, 3 * maxNum); + + for(int i = 0; i < dimsize + 1; i++){ + if(i < dimsize){ + __memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(T), GDRAM2NRAM); + } + if(i > 0){ + __bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum); + __bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M) + if(i > 1){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(T), NRAM2NRAM);//oldM = newM + } + __sync_all_ipu(); + } + + __bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum + //Start exponential transformation and write back to GDRAM + + for(int i = 0; i < dimsize + 2; i++){ + if(i < dimsize){ + __memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(T), GDRAM2NRAM); + } + if(i > 0 && i < dimsize + 1){ + __bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M) + __bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum); + } + if(i > 1){ + __memcpy_async(destination + (i - 2) * stride + indStart + repeat * maxNum, src + (i - 2) % 3 * maxNum, remainNram * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu(); + } + } + } + else if (othersize > maxNum && othersize <= taskDim * maxNum){ + T* src = (T *)nram_buffer;// src[3 * maxNum] + T* tmpSum = src + 3 * maxNum;//tmpSum[maxNum] + T* tmpNewMax = src + 4 * maxNum;//tmpNewMax[maxNum] + T* tmpOldMax = src + 5 * maxNum;//tmpOldMax[maxNum] + //----------------------------------------- + int remain = othersize % taskDim; + int stepEasy = (othersize - remain)/taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remain ? stepHard : stepEasy);//The first part of taskId handles an additional element + int indStart = (taskId < remain ? taskId * stepHard : remain * stepHard + (taskId - remain) * stepEasy); + + __bang_write_value(tmpNewMax, maxNum, -INFINITY); + __bang_write_zero(tmpSum, maxNum); + __bang_write_zero(src, 3 * maxNum); + + for(int i = 0; i < dimsize + 1; i++){ + if(i < dimsize){ + __memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart, step * sizeof(T), GDRAM2NRAM); + } + if(i > 0){ + __bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum); + __bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M) + if(i > 1){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(T), NRAM2NRAM);//oldM = newM + } + __sync_all_ipu(); + } + + __bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum + //Start exponential transformation and write back to GDRAM + + for(int i = 0; i < dimsize + 2; i++){ + if(i < dimsize){ + __memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart, step * sizeof(T), GDRAM2NRAM); + } + if(i > 0 && i < dimsize + 1){ + __bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M) + __bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum); + } + if(i > 1){ + __memcpy_async(destination + (i - 2) * stride + indStart, src + (i - 2) % 3 * maxNum, step * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu(); + } + } + else{ + + int multiple = maxNum / othersize; + int size = taskDim * multiple; + int remain = dimsize % size; + int repeat = (dimsize - remain) / size; + + int remainT = remain % taskDim; + int stepEasy = (remain - remainT) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remainT ? stepHard : stepEasy); + int indStart = (taskId < remainT ? taskId * stepHard : remainT * stepHard + (taskId - remainT) * stepEasy); + + T* src = (T *)nram_buffer;// src[3 * maxNum] + T* tmpSum = src + 3 * maxNum;//tmpSum[othersize] + T* tmpNewMax = tmpSum + othersize;//tmpNewMax[othersize] + T* tmpOldMax = tmpNewMax + othersize;//tmpOldMax[othersize] + T* tmpGlobal = tmpOldMax + othersize; + __bang_write_value(tmpNewMax, othersize, -INFINITY); + + __bang_write_zero(tmpSum, othersize); + __bang_write_zero(src, 3 * maxNum); + + for(int i = 0; i < repeat + 1; i++){ + if (i < repeat){ + __memcpy_async(src + (i % 2) * maxNum, source + (i * size + taskId * multiple) * stride, multiple * othersize * sizeof(T), GDRAM2NRAM);//stride=othersize + } + if(i > 0){ + for(int m = 0; m < multiple; m++){ + __bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum + m * othersize, othersize); + } + for(int m = 0; m < multiple; m++){ + __bang_sub(src + (i - 1) % 2 * maxNum + m * othersize, src + (i - 1) % 2 * maxNum + m * othersize, tmpNewMax, othersize);//x - M + } + __bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, multiple * othersize);//exp(x - M) + if(i > 1){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, othersize); //sum = sum * exp(oldM - newM) + } + for(int m = 0; m < multiple; m++){ + __bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum + m * othersize, othersize); + } + __memcpy(tmpOldMax, tmpNewMax, othersize * sizeof(T), NRAM2NRAM); + } + __sync_all_ipu(); + } + + if(step) { + __memcpy(src, source + repeat * size * stride + indStart * stride, step * othersize * sizeof(T), GDRAM2NRAM);//stride=othersize + + for(int m = 0; m < step; m++){ + __bang_maxequal(tmpNewMax, tmpNewMax, src + m * othersize, othersize); + } + for(int m = 0; m < step; m++){ + __bang_sub(src + m * othersize, src + m * othersize, tmpNewMax, othersize);//x - M + } + __bang_active_exp_less_0(src, src, step * othersize);//exp(x - M) + if(repeat > 0){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, othersize); //sum = sum * exp(oldM - newM) + } + for(int m = 0; m < step; m++){ + __bang_add(tmpSum, tmpSum, src + m * othersize, othersize); + } + __memcpy(tmpOldMax, tmpNewMax, othersize * sizeof(T), NRAM2NRAM); + } + //---------------- + if(repeat > 0 || dimsize >= taskDim){ + __memcpy(tmpGdram + taskId * othersize, tmpNewMax, othersize * sizeof(T), NRAM2GDRAM); + __sync_all(); + __bang_write_value(tmpNewMax, othersize, -INFINITY); + for(int id = 0; id < taskDim; id++){ + __memcpy(tmpGlobal, tmpGdram + id * othersize, othersize * sizeof(T), GDRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, tmpGlobal, othersize); + } + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize); + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize); + __bang_mul(tmpSum, tmpSum, tmpOldMax, othersize); + __memcpy(tmpGdram + taskId * othersize, tmpSum, othersize * sizeof(T), NRAM2GDRAM); + __sync_all(); + __bang_write_zero(tmpSum, othersize); + for(int id = 0; id < taskDim; id++){ + __memcpy(tmpGlobal, tmpGdram + id * othersize, othersize * sizeof(T), GDRAM2NRAM); + __bang_add(tmpSum, tmpSum, tmpGlobal, othersize); + } + __bang_active_recip_greater_1(tmpSum, tmpSum, othersize); + } + else{ + __memcpy(tmpGdram + taskId * othersize, tmpNewMax, othersize * sizeof(T), NRAM2GDRAM); + __sync_all(); + __bang_write_value(tmpNewMax, othersize, -INFINITY); + for(int id = 0; id < dimsize; id++){ + __memcpy(tmpGlobal, tmpGdram + id * othersize, othersize * sizeof(T), GDRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, tmpGlobal, othersize); + } + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize); + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize); + __bang_mul(tmpSum, tmpSum, tmpOldMax, othersize); + __memcpy(tmpGdram + taskId * othersize, tmpSum, othersize * sizeof(T), NRAM2GDRAM); + __sync_all(); + __bang_write_zero(tmpSum, othersize); + for(int id = 0; id < dimsize; id++){ + __memcpy(tmpGlobal, tmpGdram + id * othersize, othersize * sizeof(T), GDRAM2NRAM); + __bang_add(tmpSum, tmpSum, tmpGlobal, othersize); + } + __bang_active_recip_greater_1(tmpSum, tmpSum, othersize); + } + + //------------------- + for(int i = 0; i < repeat + 2; i++){ + if(i < repeat){ + __memcpy_async(src + (i % 3) * maxNum, source + (i * size + taskId * multiple) * stride, multiple * othersize * sizeof(T), GDRAM2NRAM);//stride=othersize + } + if(i > 0){ + for(int m = 0; m < multiple; m++){ + __bang_sub(src + (i - 1) % 3 * maxNum + m * othersize, src + (i - 1) % 3 * maxNum + m * othersize, tmpNewMax, othersize); + } + __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, multiple * othersize); + for(int m = 0; m < multiple; m++){ + __bang_mul(src + (i - 1) % 3 * maxNum + m * othersize, src + (i - 1) % 3 * maxNum + m * othersize, tmpSum, othersize); + } + } + if (i > 1){ + __memcpy_async(destination + ((i - 2) * size + taskId * multiple) * stride, src + (i - 2) % 3 * maxNum, multiple * othersize * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu(); + } + if(step) { + __memcpy(src, source + repeat * size * stride + indStart * stride, step * othersize * sizeof(T), GDRAM2NRAM);//stride=othersize + for(int m = 0; m < step; m++){ + __bang_sub(src + m * othersize, src + m * othersize, tmpNewMax, othersize); + } + __bang_active_exp_less_0(src, src, step * othersize); + for(int m = 0; m < step; m++){ + __bang_mul(src + m * othersize, src + m * othersize, tmpSum, othersize); + } + __memcpy(destination + repeat * size * stride + indStart * stride, src, step * othersize * sizeof(T), NRAM2GDRAM); + } + } +} +template +__mlu_global__ void softmaxKernelAxis_m(T *destination, T const *source, int frontsize, int dimsize, int stride, int strideS) { + // 0= maxNum){ + //-----------------------------------------allocate memory + T *src = (T *)nram_buffer; + T *tmpSum = src + 3 * maxNum; + T *tmpNewMax = tmpSum + maxNum; + T *tmpOldMax = tmpNewMax + maxNum; + //----------------------------------------- + int remain = stride % maxNum; + int repeat = (stride - remain) / maxNum; + + for(int ind = taskId; ind < frontsize; ind += taskDim){ + int frontIdx = ind * dimsize * stride; + for(int j = 0; j < repeat; j++){ + __bang_write_value(tmpNewMax, maxNum, -INFINITY); + __bang_write_zero(tmpSum, maxNum); + //__bang_write_zero(src, maxNum); + for(int i = 0; i < dimsize; i++){ + __memcpy(src, source + frontIdx + i * stride + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//Continuously updating the maximum value + __bang_sub(src, src, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) + if(i > 0){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(T), NRAM2NRAM);//oldM = newM + } + __bang_active_reciphp(tmpSum, tmpSum, maxNum);//计算1/sum + //Start exponential transformation and write back to GDRAM + __bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized + __memcpy(destination + (dimsize - 1) * stride + frontIdx + j * maxNum, src, maxNum * sizeof(T), NRAM2GDRAM); + for(int i = 0; i < dimsize - 1; i++){ + __memcpy(src, source + frontIdx + i * stride + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + __bang_sub(src, src, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) + __bang_mul(src, src, tmpSum, maxNum); + __memcpy(destination + frontIdx + i * stride + j * maxNum, src, maxNum * sizeof(T), NRAM2GDRAM); + } + } + if(remain){ + + __bang_write_value(tmpNewMax, maxNum, -INFINITY); + __bang_write_zero(tmpSum, maxNum); + __bang_write_value(src, maxNum, -INFINITY); + for(int i = 0; i < dimsize; i++){ + __memcpy(src, source + frontIdx + i * stride + repeat * maxNum, remain * sizeof(T), GDRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum); + __bang_sub(src, src, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) + if(i > 0){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(T), NRAM2NRAM);//oldM = newM + } + //------------------- + __bang_active_reciphp(tmpSum, tmpSum, maxNum);//计算1/sum + //Start exponential transformation and write back to GDRAM + __bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized + __memcpy(destination + (dimsize - 1) * stride + frontIdx + repeat * maxNum, src, remain * sizeof(T), NRAM2GDRAM); + for(int i = 0; i < dimsize - 1; i++){ + __memcpy(src, source + i * stride + frontIdx + repeat * maxNum, remain * sizeof(T), GDRAM2NRAM); + __bang_sub(src, src, tmpNewMax, maxNum);//x - M + __bang_active_exp_less_0(src, src, maxNum);//exp(x - M) + __bang_mul(src, src, tmpSum, maxNum); + __memcpy(destination + i * stride + frontIdx + repeat * maxNum, src, remain * sizeof(T), NRAM2GDRAM); + } + //--------------------- + } + } + } + else if(stride < maxNum && dimsize * stride >= maxNum){ + + //-----------------------------------------allocate memory + T* src = (T *)nram_buffer; + T* tmp = src + 3 * maxNum; + T* tmpOldMax = tmp + strideS; + T* tmpNewMax = tmpOldMax + strideS; + T* tmpSum = tmpNewMax + strideS; + //----------------------------------------- + int multiple = maxNum / stride; + int size = multiple * stride;//The maximum amount of data that can be stored in an SRC + int remain = dimsize % multiple;//If it cannot be divisible, this part of the data needs special processing + int repeat = (dimsize - remain) / multiple;//The total number of loops required to load the entire dimsize + + int taskRemain = frontsize % taskDim; + int stepEasy = (frontsize - taskRemain) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < taskRemain ? stepHard : stepEasy);//The number of frontsize processed per taskId + int indStart = (taskId < taskRemain ? taskId * stepHard : taskRemain * stepHard + (taskId - taskRemain) * stepEasy); + source = source + indStart * dimsize * stride; + destination = destination + indStart * dimsize * stride; + //printf("maxNum:%d, dimsize * stride:%d, multiple:%d, size:%d, repeat:%d,remain:%d\n",maxNum, dimsize * stride, multiple, size, repeat,remain); + for(int ind = 0; ind < step; ind++){ + int frontIdx = ind * dimsize * stride; + + __bang_write_value(tmpNewMax, strideS, -INFINITY);//Must be initialized to negative infinity + __bang_write_value(tmp, strideS, -INFINITY);//Must be initialized to negative infinity + __bang_write_zero(tmpSum, strideS);//Must be initialized to zero + + for(int j = 0; j < repeat + 1; j++){ + if(j < repeat){ + __memcpy_async(src + j % 2 * maxNum, source + frontIdx + j * multiple * stride, size * sizeof(T), GDRAM2NRAM); + } + if(j > 0){ + for(int m = 0; m < multiple; m++){ + __memcpy(tmp, src + (j - 1) % 2 * maxNum + m * stride, stride * sizeof(T), NRAM2NRAM); + + __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//Although the stream S stream section after tmpNewMax is 0, there is no need to write back to GDRAM, which does not affect the result + + __bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0 + __bang_active_exp_less_0(tmp, tmp, strideS); + if(j != 1 || m != 0){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M) + + __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(T), NRAM2NRAM);//oldM = newM + } + } + __sync_all_ipu(); + } + + if(remain){ + __memcpy(src, source + frontIdx + repeat * multiple * stride, remain * stride * sizeof(T), GDRAM2NRAM); + for(int m = 0; m < remain; m++){ + __memcpy(tmp, src + m * stride, stride * sizeof(T), NRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS); + __bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0 + __bang_active_exp_less_0(tmp, tmp, strideS); + if(repeat != 0 || m != 0){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(T), NRAM2NRAM);//oldM = newM + } + } + + //At this point, tmpNewMax stores the maximum value of the data corresponding to a fixed frontIdx and bedsize, while tmpSum stores the corresponding value sum + + __bang_active_reciphp(tmpSum, tmpSum, strideS); + + if(remain){ + for(int m = 0; m < remain; m++){ + __memcpy(tmp, src + m * stride, stride * sizeof(T), NRAM2NRAM); + __bang_sub(tmp, tmp, tmpNewMax, strideS); + __bang_active_exp_less_0(tmp, tmp, strideS); + __bang_mul(tmp, tmp, tmpSum, strideS); + __memcpy(destination + frontIdx + repeat * multiple * stride + m * stride, tmp, stride * sizeof(T), NRAM2GDRAM); + } + + } + for(int j = 0 ; j < repeat + 2; j++){ + if(j < repeat){ + __memcpy_async(src + j % 3 * maxNum, source + frontIdx + j * multiple * stride, size * sizeof(T), GDRAM2NRAM); + } + if(j > 0 && j < repeat + 1){ + for(int m = 0; m < multiple; m++){ + __memcpy(tmp, src + (j - 1) % 3 * maxNum + m * stride, stride * sizeof(T), NRAM2NRAM); + + __bang_sub(tmp, tmp, tmpNewMax, strideS); + __bang_active_exp_less_0(tmp, tmp, strideS); + __bang_mul(tmp, tmp, tmpSum, strideS); + __memcpy(src + (j - 1) % 3 * maxNum + m * stride, tmp, stride * sizeof(T), NRAM2NRAM); + } + } + if(j > 1){ + __memcpy_async(destination + frontIdx + (j - 2) * multiple * stride, src + (j - 2) % 3 * maxNum, size * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu(); + } + } + } + else if(dimsize * stride < maxNum){ + //-----------------------------------------allocate memory + T* src = (T *)nram_buffer; + T* tmp = src + 3 * maxNum; + T* tmpOldMax = tmp + strideS; + T* tmpNewMax = tmpOldMax + strideS; + T* tmpSum = tmpNewMax + strideS; + //----------------------------------------- + int behindsize = dimsize * stride; + int multiple = maxNum / behindsize;//Represents the amount that a maxNum can share in frontsize + + int remainF = frontsize % (taskDim * multiple); + int remainT = remainF % taskDim; + int stepEasy = (remainF - remainT) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remainT ? stepHard : stepEasy); + int taskRepeat = (frontsize - remainF) / (taskDim * multiple); + //At this point, corresponding to frontsize, the amount of data processed by each taskId is taskRepeat * multiple+step + int startHard = taskId * (taskRepeat * multiple + stepHard); + int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy); + int indStart = (taskId < remainT ? startHard: startEasy); + source = source + indStart * behindsize;//indStart * behindsize Indicates the offset corresponding to different taskIds + destination = destination + indStart * behindsize; + int tid; + for(int s = 0; s < taskRepeat + 2; s++){ + if(s < taskRepeat){ + tid = s * multiple * behindsize; + __memcpy_async(src + s % 3 * maxNum, source + tid, multiple * behindsize * sizeof(T), GDRAM2NRAM); + } + if(s > 0 && s < taskRepeat + 1){ + for(int m = 0; m < multiple; m++){ + __bang_write_zero(tmpSum, strideS); + __bang_write_value(tmp, strideS, -INFINITY); + __bang_write_value(tmpNewMax, strideS, -INFINITY); + for(int i = 0; i < dimsize; i++){ + __memcpy(tmp, src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, stride * sizeof(T), NRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS); + __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M + __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) + if(i > 0){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(T), NRAM2NRAM);//oldM = newM + } + __bang_active_reciphp(tmpSum, tmpSum, strideS); + __bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized + + __memcpy(src + (s - 1) % 3 * maxNum + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(T), NRAM2NRAM); + for(int i = 0; i < dimsize - 1; i++){ + __memcpy(tmp, src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, stride * sizeof(T), NRAM2NRAM); + __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M + __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) + __bang_mul(tmp, tmp, tmpSum, strideS); + + __memcpy(src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, tmp, stride * sizeof(T), NRAM2NRAM); + } + } + } + if(s > 1){ + tid = (s - 2) * multiple * behindsize; + __memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * behindsize * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu(); + } + //__bang_printf("taskId:%d, multiple:%d, taskRepeat:%d, step:%d, indStart:%d\n",taskId, multiple, taskRepeat, step, indStart * behindsize); + if(step){ + tid = taskRepeat * multiple * behindsize; + __memcpy(src, source + tid, step * behindsize * sizeof(T), GDRAM2NRAM); + for(int m = 0; m < step; m++){ + __bang_write_zero(tmpSum, strideS); + __bang_write_value(tmp, strideS, -INFINITY); + __bang_write_value(tmpNewMax, strideS, -INFINITY); + for(int i = 0; i < dimsize; i++){ + __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(T), NRAM2NRAM); + __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS); + __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M + __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) + if(i > 0){ + __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM + __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM) + __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS); //sum = sum * exp(oldM - newM) + } + __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M) + __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(T), NRAM2NRAM);//oldM = newM + } + //__bang_printf("max:%.2f,%.2f, sum:%.2f,sum:%.2f\n", tmpNewMax[0], tmpNewMax[1], tmpSum[0], tmpSum[0]); + __bang_active_reciphp(tmpSum, tmpSum, strideS); + __bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized + //__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(T), NRAM2GDRAM); + __memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(T), NRAM2NRAM); + for(int i = 0; i < dimsize - 1; i++){ + __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(T), NRAM2NRAM); + __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M + __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M) + __bang_mul(tmp, tmp, tmpSum, strideS); + //__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(T), NRAM2GDRAM); + __memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(T), NRAM2NRAM); + } + } + __memcpy(destination + tid, src, step * behindsize * sizeof(T), NRAM2GDRAM); + } + } + +} + +template +void softmaxUnion1(cnrtQueue_t queue, void *workspace, + uint64_t workspace_size, void const *input, void *output, int othersize, int dimsize, int frontsize, int stride, int axis, int ndim) { + auto mlu_destination = reinterpret_cast(output); + auto mlu_src = reinterpret_cast(input); + const int wSize = 128 / sizeof(T); + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + k_type = CNRT_FUNC_TYPE_UNION1; + + + if(axis == ndim - 1){ + int dimS; + float mi = log2(dimsize); + if (floor(mi) == mi) + { + dimS = dimsize; + } + else + { + dimS = static_cast(pow(2, floor(mi) + 1)); + } + if (dimS < wSize) + { + dimS = wSize; + } + softmaxKernelAxis_e<<>>(mlu_destination, mlu_src, othersize, dimsize, dimS); + } + else if(axis == 0){ + T *tmpGdram = reinterpret_cast(workspace); + softmaxKernelAxis_s<<>>(mlu_destination, mlu_src, tmpGdram, othersize, dimsize, stride); + + } + else{ + float mi = log2(stride); + int strideS; + if(floor(mi) == mi){ + strideS = stride; + } + else{ + strideS = static_cast(pow(2,floor(mi) + 1)); + } + softmaxKernelAxis_m<<>>(mlu_destination, mlu_src, frontsize, dimsize, stride, strideS); + } + + cnrtQueueSync(queue); +} + + + +void softmax_bang(SoftmaxBangDescriptor_t desc, void *workspace, + uint64_t workspace_size, void const *input, void *output, void *stream) { + auto queue = reinterpret_cast(stream); + + int ndim = desc->ndim; + int axis = desc->axis; + int stride = desc->stride; + int dimsize = desc->dimsize; + int othersize = desc->othersize; + int frontsize = desc->frontsize; + + + if (dtype_eq(desc->dtype, F16)) { + softmaxUnion1(queue, workspace, workspace_size, input, output, othersize, dimsize, frontsize, stride, axis, ndim); + } + else if (dtype_eq(desc->dtype, F32)) { + softmaxUnion1(queue, workspace, workspace_size, input, output, othersize, dimsize, frontsize, stride, axis, ndim); + } + +} + +infiniopStatus_t bangSoftmax(SoftmaxBangDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *input, + void *output, + void *stream) { + if (cnrtSetDevice(desc->device_id) != cnrtSuccess) { + return STATUS_BAD_DEVICE; + } + if (dtype_eq(desc->dtype, F16) || dtype_eq(desc->dtype, F32)) { + softmax_bang(desc, workspace, workspace_size, input, output, stream); + return STATUS_SUCCESS; + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/softmax/bang/softmax_cnnl.cc b/src/ops/softmax/bang/softmax_cnnl.cc new file mode 100644 index 00000000..647110b1 --- /dev/null +++ b/src/ops/softmax/bang/softmax_cnnl.cc @@ -0,0 +1,154 @@ +#include "softmax_cnnl.h" +#include "../../../devices/bang/bang_handle.h" +#include "../../../devices/bang/common_bang.h" +#include "../../utils.h" +#include "cnrt.h" +infiniopStatus_t cnnlCreateSoftmaxDescriptor(BangHandle_t handle, + SoftmaxCnnlDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t input_desc, + int axis, + infiniopTensorDescriptor_t output_desc) { + if (input_desc->ndim != output_desc->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!dtype_eq(input_desc->dt, F16) && !dtype_eq(input_desc->dt, F32)) { + return STATUS_BAD_TENSOR_DTYPE; + } + + int ndim = input_desc->ndim; + + for (int i = 0; i < ndim; i++) { + if (input_desc->shape[i] != output_desc->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + + int *shape = new int[ndim]; + for (int i = 0; i < ndim; i++) { + shape[i] = static_cast(input_desc->shape[i]); + } + cnnlSoftmaxMode_t mode; + std::vector inDim = {1, 1, 1}; + std::vector outDim = inDim; + + if (ndim >= 3) { + if (axis == 0) { + mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION; + inDim[0] = shape[0]; + inDim[1] = shape[1]; + for (int i = 2; i < ndim; ++i) { + inDim[2] *= shape[i]; + } + outDim = inDim; + } else if (axis == ndim - 1) { + mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION; + inDim[0] = shape[0]; + for (int i = 1; i < axis; ++i) { + inDim[1] *= shape[i]; + } + inDim[2] = shape[axis]; + outDim = inDim; + } else { + mode = CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION; + for (int i = 0; i < axis; ++i) { + inDim[0] *= shape[i]; + } + inDim[1] = shape[axis]; + for (int i = axis + 1; i < ndim; ++i) { + inDim[2] *= shape[i]; + } + outDim = inDim; + } + } else if (ndim == 2) { + if (axis == 0) { + mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION; + inDim[0] = shape[0]; + inDim[1] = shape[1]; + + outDim = inDim; + } else { + mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION; + inDim[1] = shape[0]; + inDim[2] = shape[1]; + + outDim = inDim; + } + } else { + mode = CNNL_SOFTMAX_MODE_HIGH_DIMENSION; + inDim[0] = shape[0]; + + outDim = inDim; + } + cnnlTensorDescriptor_t aDesc, cDesc; + cnnlCreateTensorDescriptor(&aDesc); + cnnlCreateTensorDescriptor(&cDesc); + if (dtype_eq(input_desc->dt, F16)) { + cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_HALF, + inDim.size(), inDim.data()); + cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_HALF, + outDim.size(), outDim.data()); + } else if (dtype_eq(input_desc->dt, F32)) { + cnnlSetTensorDescriptor( + aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, + inDim.size(), inDim.data()); + cnnlSetTensorDescriptor( + cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, + outDim.size(), outDim.data()); + } + + float alpha = 1.0; + float beta = 0.0; + *desc_ptr = new SoftmaxCnnlDescriptor{ + handle->device, + handle->device_id, + input_desc->dt, + handle->cnnl_handles, + mode, + aDesc, + cDesc, + alpha, + beta}; + return STATUS_SUCCESS; +} +infiniopStatus_t cnnlGetSoftmaxWorkspaceSize(SoftmaxCnnlDescriptor_t desc, unsigned long int *size) { + *size = 0; + return STATUS_SUCCESS; +} + +infiniopStatus_t cnnlDestroySoftmaxDescriptor(SoftmaxCnnlDescriptor_t desc) { + desc->cnnl_handles = nullptr; + cnnlDestroyTensorDescriptor(desc->aDesc); + cnnlDestroyTensorDescriptor(desc->cDesc); + delete desc; + return STATUS_SUCCESS; +} + +void softmax_cnnl(SoftmaxCnnlDescriptor_t desc, void const *input, void *output, void *stream) { + float alpha = desc->alpha; + float beta = desc->beta; + cnnlSoftmaxMode_t mode = desc->mode; + cnnlTensorDescriptor_t aDesc = desc->aDesc; + cnnlTensorDescriptor_t cDesc = desc->cDesc; + + use_cnnl(desc->cnnl_handles, desc->device_id, (cnrtQueue_t) stream, + [&](cnnlHandle_t handle) { + cnnlSoftmaxForward_v2(handle, CNNL_SOFTMAX_ACCURATE, + mode, CNNL_COMPUTATION_ULTRAHIGH_PRECISION, + &alpha, aDesc, input, &beta, cDesc, output); + }); +} +infiniopStatus_t cnnlSoftmax(SoftmaxCnnlDescriptor_t desc, void *workspace, + uint64_t workspace_size, void const *input, void *output, void *stream) { + if (cnrtSetDevice(desc->device_id) != cnrtSuccess) { + return STATUS_BAD_DEVICE; + } + + if (dtype_eq(desc->dtype, F16) || dtype_eq(desc->dtype, F32)) { + softmax_cnnl(desc, input, output, stream); + + return STATUS_SUCCESS; + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/softmax/bang/softmax_cnnl.h b/src/ops/softmax/bang/softmax_cnnl.h new file mode 100644 index 00000000..d2faaefc --- /dev/null +++ b/src/ops/softmax/bang/softmax_cnnl.h @@ -0,0 +1,34 @@ +#ifndef __CNNL_SOFTMAX_H__ +#define __CNNL_SOFTMAX_H__ +#include "../../../devices/bang/bang_handle.h" +#include "cnnl.h" +#include "cnnl_extra.h" +#include "operators.h" + +struct SoftmaxCnnlDescriptor { + Device device; + int device_id; + DT dtype; + std::shared_ptr> cnnl_handles; + cnnlSoftmaxMode_t mode; + cnnlTensorDescriptor_t aDesc; + cnnlTensorDescriptor_t cDesc; + float alpha; + float beta; +}; +typedef struct SoftmaxCnnlDescriptor *SoftmaxCnnlDescriptor_t; + +infiniopStatus_t cnnlCreateSoftmaxDescriptor(BangHandle_t handle, + SoftmaxCnnlDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t input_desc, + int axis, + infiniopTensorDescriptor_t output_desc); + +infiniopStatus_t cnnlGetSoftmaxWorkspaceSize(SoftmaxCnnlDescriptor_t desc, unsigned long int *size); +infiniopStatus_t cnnlSoftmax(SoftmaxCnnlDescriptor_t desc, void *workspace, + uint64_t workspace_size, void const *input, void *output, void *stream); + +infiniopStatus_t cnnlDestroySoftmaxDescriptor(SoftmaxCnnlDescriptor_t desc); + + +#endif// __CNNL_SOFTMAX_H__ diff --git a/src/ops/softmax/cpu/softmax_cpu.cc b/src/ops/softmax/cpu/softmax_cpu.cc new file mode 100644 index 00000000..1142bd49 --- /dev/null +++ b/src/ops/softmax/cpu/softmax_cpu.cc @@ -0,0 +1,121 @@ +#include "softmax_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" +#include + +infiniopStatus_t cpuCreateSoftmaxDescriptor(infiniopHandle_t handle, + SoftmaxCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t input_desc, int axis, infiniopTensorDescriptor_t output_desc) { + if (input_desc->ndim != output_desc->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + + if (!dtype_eq(input_desc->dt, F16) && !dtype_eq(input_desc->dt, F32)) { + return STATUS_BAD_TENSOR_DTYPE; + } + + int ndim = input_desc->ndim; + + for (int i = 0; i < ndim; i++) { + if (input_desc->shape[i] != output_desc->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + int *shape = new int[ndim]; + + for (int i = 0; i < ndim; i++) { + shape[i] = static_cast(input_desc->shape[i]); + } + *desc_ptr = new SoftmaxCpuDescriptor{ + handle->device, + input_desc->dt, + ndim, + axis, + shape}; + + return STATUS_SUCCESS; +} + + +infiniopStatus_t cpuDestroySoftmaxDescriptor(SoftmaxCpuDescriptor_t desc) { + delete[] desc->shape; + delete desc; + return STATUS_SUCCESS; +} +infiniopStatus_t cpuGetSoftmaxWorkspaceSize(SoftmaxCpuDescriptor_t desc, unsigned long int *size) { + *size = 0; + return STATUS_SUCCESS; +} +void softmax_cpu(SoftmaxCpuDescriptor_t desc, + void const *input, void *output) { + int ndim = desc->ndim; + int axis = desc->axis; + auto shape = desc->shape; + int dimsize = shape[axis]; + int othersize = 1; + int stride = 1; + + for (int s = ndim - 1; s >= 0; s--) { + + if (s > axis) { + stride *= shape[s]; + } + if (s != axis) { + othersize *= shape[s]; + } + } + if (dtype_eq(desc->dtype, F16)) { + auto source = reinterpret_cast(input); + auto destination = reinterpret_cast(output); + //假设[I, J, K, S], axis = 1, othersize = IKS + for (int ind = 0; ind < othersize; ind++) { //ind = i(KS) + k(S) + s + int tid = ind % stride + (ind - ind % stride) * dimsize;//now, tid = i(JKS) + k(S) + s; + float localM = -__FLT_MAX__; + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + localM = fmax(localM, f16_to_f32(source[index])); + } + float localS = 0.0f; + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + localS += std::exp(f16_to_f32(source[index]) - localM); + } + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + destination[index] = f32_to_f16(std::exp(f16_to_f32(source[index]) - localM) / localS); + } + } + } else if (dtype_eq(desc->dtype, F32)) { + auto source = reinterpret_cast(input); + auto destination = reinterpret_cast(output); + //假设[I, J, K, S], axis = 1, othersize = IKS + for (int ind = 0; ind < othersize; ind++) { //ind = i(KS) + k(S) + s + int tid = ind % stride + (ind - ind % stride) * dimsize;//now, tid = i(JKS) + k(S) + s; + float localM = -__FLT_MAX__; + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + localM = fmax(localM, source[index]); + } + float localS = 0.0f; + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + localS += std::exp(source[index] - localM); + } + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + destination[index] = std::exp(source[index] - localM) / localS; + } + } + } +} +infiniopStatus_t cpuSoftmax(SoftmaxCpuDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *input, void *output, + void *stream) { + if (dtype_eq(desc->dtype, F16) || dtype_eq(desc->dtype, F32)) { + softmax_cpu(desc, input, output); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/softmax/cpu/softmax_cpu.h b/src/ops/softmax/cpu/softmax_cpu.h new file mode 100644 index 00000000..64b3faa7 --- /dev/null +++ b/src/ops/softmax/cpu/softmax_cpu.h @@ -0,0 +1,33 @@ +#ifndef __CPU_SOFTMAX_H__ +#define __CPU_SOFTMAX_H__ + +#include "../../../devices/cpu/cpu_handle.h" +#include "../../utils.h" +#include "operators.h" + +struct SoftmaxCpuDescriptor { + Device device; + DT dtype; + int ndim; + int axis; + int *shape; +}; + +typedef struct SoftmaxCpuDescriptor *SoftmaxCpuDescriptor_t; + +infiniopStatus_t cpuCreateSoftmaxDescriptor(infiniopHandle_t handle, + SoftmaxCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t input_desc, int axis, infiniopTensorDescriptor_t output_desc); + +infiniopStatus_t cpuGetSoftmaxWorkspaceSize(SoftmaxCpuDescriptor_t desc, unsigned long int *size); + +infiniopStatus_t cpuSoftmax(SoftmaxCpuDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *input, + void *output, + void *stream); + +infiniopStatus_t cpuDestroySoftmaxDescriptor(SoftmaxCpuDescriptor_t desc); + + +#endif diff --git a/src/ops/softmax/cuda/softmax.cu b/src/ops/softmax/cuda/softmax.cu new file mode 100644 index 00000000..0bd8810c --- /dev/null +++ b/src/ops/softmax/cuda/softmax.cu @@ -0,0 +1,335 @@ +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" +#include "softmax.cuh" +#include +#include +struct __align__(8) DataMaxSum {// update the global max and sum, store the + // output at max_tmp and sum_tmp + float max_tmp; // store max + float sum_tmp; // store sum +}; +__device__ __forceinline__ DataMaxSum reduce_dms_op(DataMaxSum a, + DataMaxSum b) { + bool a_bigger = (a.max_tmp > b.max_tmp); + DataMaxSum bigger = a_bigger ? a : b; + DataMaxSum smaller = a_bigger ? b : a; + bigger.sum_tmp = bigger.sum_tmp + + smaller.sum_tmp * __expf(smaller.max_tmp - bigger.max_tmp); + + return bigger; +} +template +__launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel( + T const *input, T *output, int dimsize, + int stride) {// if set axis = 1, inputShape=[I,J,K,S] + // tid = i(JKS) + j(KS) + k(S) + s + + // blockDim.x = othersize = size/dimsize = IKS + // blockIdx.x = i(KS) + k(S) + s,blockIdx.x%stride = k(S) + s + + int tid = + blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) * + dimsize;// now, tid = i(JKS) + k(S) + s; + + DataMaxSum dms_partial; + dms_partial.max_tmp = -__FLT_MAX__; + dms_partial.sum_tmp = 0.0f; + DataMaxSum dms_input; + int remain = dimsize % BLOCK_DIM; + int step = (dimsize - remain) / BLOCK_DIM + 1;// step <= numPerThread + + if (threadIdx.x < remain) { + for (int ind = 0; ind < step; ind++) { + dms_input.max_tmp = + static_cast(input[tid + (threadIdx.x * step + ind) * stride]); + + dms_input.sum_tmp = 1.0f; + dms_partial = + reduce_dms_op(dms_partial, + dms_input);// reduce the data to one block + } + } else { + for (int ind = 0; ind < step - 1; ind++) { + dms_input.max_tmp = + static_cast(input[tid + (remain * step + + (threadIdx.x - remain) * (step - 1) + ind) * + stride]); + + dms_input.sum_tmp = 1.0f; + dms_partial = + reduce_dms_op(dms_partial, + dms_input);// reduce the data to one block + } + } + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ DataMaxSum dms_total; + DataMaxSum dms_block = + BlockReduce(temp_storage).Reduce(dms_partial, reduce_dms_op); + if (threadIdx.x == + 0) {// must set threadIdx.x = 0 write the output to memory + dms_total = dms_block; + } + __syncthreads(); + //----------------- + if (threadIdx.x < remain) { + for (int ind = 0; ind < step; ind++) { + + output[tid + (threadIdx.x * step + ind) * stride] = static_cast( + __expf(static_cast( + input[tid + (threadIdx.x * step + ind) * stride]) - + dms_total.max_tmp) * + __fdividef(1.0F, dms_total.sum_tmp)); + } + } else { + for (int ind = 0; ind < step - 1; ind++) { + + output[tid + + (remain * step + (threadIdx.x - remain) * (step - 1) + ind) * + stride] = static_cast(__expf(static_cast(input[tid + + (remain * step + + (threadIdx.x - remain) * (step - 1) + ind) * + stride]) - + dms_total.max_tmp) * + __fdividef(1.0F, dms_total.sum_tmp)); + } + } +} + +template +__global__ void +_blockSoftmaxKernel(T const *input, T *output, + int dimsize, + int stride) {// if set axis = 1, inputShape=[I,J,K,S] + // tid = i(JKS) + j(KS) + k(S) + s + + // blockDim.x = othersize = size/dimsize = IKS + // blockIdx.x = i(KS) + k(S) + s,blockIdx.x%stride = k(S) + s + + int tid = + blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) * + dimsize;// now, tid = i(JKS) + k(S) + s; + int remain = dimsize % BLOCK_DIM; + int step = (dimsize - remain) / BLOCK_DIM + 1;// step <= numPerThread + float dataPerThread[numPerThread]; + + DataMaxSum dms_partial; + dms_partial.max_tmp = -__FLT_MAX__; + dms_partial.sum_tmp = 0.0f; + DataMaxSum dms_input; + if (threadIdx.x < remain) { + for (int ind = 0; ind < step; ind++) { + dataPerThread[ind] = + static_cast(input[tid + (threadIdx.x * step + ind) * stride]); + dms_input.max_tmp = dataPerThread[ind]; + dms_input.sum_tmp = 1.0f; + dms_partial = + reduce_dms_op(dms_partial, + dms_input);// reduce the data to one block + } + } else { + for (int ind = 0; ind < step - 1; ind++) { + dataPerThread[ind] = + static_cast(input[tid + (remain * step + + (threadIdx.x - remain) * (step - 1) + ind) * + stride]); + dms_input.max_tmp = dataPerThread[ind]; + dms_input.sum_tmp = 1.0f; + dms_partial = + reduce_dms_op(dms_partial, + dms_input);// reduce the data to one block + } + } + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ DataMaxSum dms_total; + DataMaxSum dms_block = + BlockReduce(temp_storage).Reduce(dms_partial, reduce_dms_op); + if (threadIdx.x == + 0) {// must set threadIdx.x = 0 write the output to memory + dms_total = dms_block; + } + __syncthreads(); + //----------------- + if (threadIdx.x < remain) { + for (int ind = 0; ind < step; ind++) { + output[tid + (threadIdx.x * step + ind) * stride] = static_cast( + __expf(dataPerThread[ind] - dms_total.max_tmp) * + __fdividef(1.0F, dms_total.sum_tmp)); + } + } else { + for (int ind = 0; ind < step - 1; ind++) { + output[tid + + (remain * step + (threadIdx.x - remain) * (step - 1) + ind) * + stride] = static_cast(__expf(dataPerThread[ind] - dms_total.max_tmp) * + __fdividef(1.0F, dms_total.sum_tmp)); + } + } +} + +template +struct SumOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return a + b; + } +}; + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return max(a, b); + } +}; +template class ReductionOp, typename T, + int thread_group_width> +__inline__ __device__ T WarpAllReduce(T val) { + for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { + val = ReductionOp()(val, __shfl_xor_sync(0xffffffff, val, mask)); + } + return val; +} + +template +__global__ void _warpSoftmaxKernel(T const *input, T *output, + int othersize, int dimsize, int stride) { + int otherIdx = blockIdx.x * blockDim.y + threadIdx.y; + + int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize; + float dataPerThreadx[numPerThreadx]; + if (otherIdx < othersize) { + + __shared__ float max_total[BLOCK_DIM_y]; + __shared__ float sum_total[BLOCK_DIM_y]; + float max_data = -__FLT_MAX__; + + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { + dataPerThreadx[ph] = + static_cast(input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride]); + max_data = max(max_data, dataPerThreadx[ph]); + } + + max_data = WarpAllReduce(max_data); + + if (threadIdx.x == 0) + max_total[threadIdx.y] = max_data; + + //-------------------------------------------- + float sum_data = 0.0f; + + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { + dataPerThreadx[ph] = + __expf(dataPerThreadx[ph] - max_total[threadIdx.y]); + sum_data += dataPerThreadx[ph]; + } + + sum_data = WarpAllReduce(sum_data); + + if (threadIdx.x == 0) + sum_total[threadIdx.y] = sum_data; + + //-------------------------------------------- + + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++) { + output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] = static_cast( + dataPerThreadx[ph] * __fdividef(1.0F, sum_total[threadIdx.y])); + } + } +} +//----------------- + +//------------------ +template +void softmax_nv_gpu(SoftmaxCudaDescriptor_t desc, void const *input, void *output, void *stream) { + + int dimsize = desc->dimsize; + int stride = desc->stride; + + int num_blocks = desc->othersize; + if (dimsize > 1024 * 128) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>((T *) input, (T *) output, dimsize, stride); + } else if (dimsize > 1024 * 64) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>((T *) input, (T *) output, dimsize, stride); + } else if (dimsize > 1024 * 32) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>((T *) input, (T *) output, dimsize, stride); + } else if (dimsize > 1024 * 16) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>((T *) input, (T *) output, dimsize, stride); + } else if (dimsize > 1024 * 4) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>((T *) input, (T *) output, dimsize, stride); + } else if (dimsize > 1024) { + + int BLOCK_DIM = 1024; + _blockSoftmaxKernel + <<>>((T *) input, (T *) output, dimsize, stride); + } else if (dimsize > 31) { + int BLOCK_DIM_x = 32; + int BLOCK_DIM_y = 32; + int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + _warpSoftmaxKernel + <<>>((T *) input, (T *) output, num_blocks, dimsize, stride); + } else if (dimsize > 15) { + int BLOCK_DIM_x = 16; + int BLOCK_DIM_y = 64; + int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + _warpSoftmaxKernel + <<>>((T *) input, (T *) output, num_blocks, dimsize, stride); + } else if (dimsize > 7) { + int BLOCK_DIM_x = 8; + int BLOCK_DIM_y = 128; + int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + _warpSoftmaxKernel + <<>>((T *) input, (T *) output, num_blocks, dimsize, stride); + } else { + int BLOCK_DIM_x = 4; + int BLOCK_DIM_y = 256; + int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + _warpSoftmaxKernel + <<>>((T *) input, (T *) output, num_blocks, dimsize, stride); + } +} +infiniopStatus_t cudaSoftmax(SoftmaxCudaDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *input, + void *output, + void *stream) { + if (cudaSetDevice(desc->device_id) != cudaSuccess) { + return STATUS_BAD_DEVICE; + } + if (dtype_eq(desc->dtype, F16)) { + softmax_nv_gpu(desc, input, output, stream); + return STATUS_SUCCESS; + } + if (dtype_eq(desc->dtype, F32)) { + softmax_nv_gpu(desc, input, output, stream); + return STATUS_SUCCESS; + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/softmax/cuda/softmax.cuh b/src/ops/softmax/cuda/softmax.cuh new file mode 100644 index 00000000..55192532 --- /dev/null +++ b/src/ops/softmax/cuda/softmax.cuh @@ -0,0 +1,33 @@ +#ifndef __CUDA_SOFTMAX_H__ +#define __CUDA_SOFTMAX_H__ + +#include "../../../devices/cuda/cuda_handle.h" +#include "../../utils.h" +#include "operators.h" + +struct SoftmaxCudaDescriptor { + Device device; + int device_id; + DT dtype; + int dimsize; + int stride; + int othersize; +}; + +typedef struct SoftmaxCudaDescriptor *SoftmaxCudaDescriptor_t; + +infiniopStatus_t cudaCreateSoftmaxDescriptor(CudaHandle_t handle, + SoftmaxCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t input_desc, int axis, infiniopTensorDescriptor_t output_desc); + +infiniopStatus_t cudaGetSoftmaxWorkspaceSize(SoftmaxCudaDescriptor_t desc, unsigned long int *size); +infiniopStatus_t cudaSoftmax(SoftmaxCudaDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *input, + void *output, + void *stream); + +infiniopStatus_t cudaDestroySoftmaxDescriptor(SoftmaxCudaDescriptor_t desc); + + +#endif diff --git a/src/ops/softmax/cuda/softmax_cuda.cc b/src/ops/softmax/cuda/softmax_cuda.cc new file mode 100644 index 00000000..f00d283b --- /dev/null +++ b/src/ops/softmax/cuda/softmax_cuda.cc @@ -0,0 +1,54 @@ +#include "../../utils.h" +#include "softmax.cuh" + +infiniopStatus_t cudaCreateSoftmaxDescriptor(CudaHandle_t handle, + SoftmaxCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t input_desc, int axis, infiniopTensorDescriptor_t output_desc) { + + if (input_desc->ndim != output_desc->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!dtype_eq(input_desc->dt, F16) && !dtype_eq(input_desc->dt, F32)) { + return STATUS_BAD_TENSOR_DTYPE; + } + + int ndim = input_desc->ndim; + + for (int i = 0; i < ndim; i++) { + if (input_desc->shape[i] != output_desc->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + int dimsize = static_cast(input_desc->shape[axis]); + int stride = 1; + int size = 1; + for (int i = ndim - 1; i >= 0; i -= 1) { + size *= static_cast(input_desc->shape[i]); + } + for (int i = ndim - 1; i >= 0; i -= 1) { + if (i == axis) { + break; + } + stride *= static_cast(input_desc->shape[i]); + } + int othersize = size / dimsize; + *desc_ptr = new SoftmaxCudaDescriptor{ + handle->device, + handle->device_id, + input_desc->dt, + dimsize, + stride, + othersize}; + + return STATUS_SUCCESS; +} +infiniopStatus_t cudaGetSoftmaxWorkspaceSize(SoftmaxCudaDescriptor_t desc, unsigned long int *size) { + *size = 0; + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaDestroySoftmaxDescriptor(SoftmaxCudaDescriptor_t desc) { + + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/softmax/operator.cc b/src/ops/softmax/operator.cc new file mode 100644 index 00000000..8e6760e2 --- /dev/null +++ b/src/ops/softmax/operator.cc @@ -0,0 +1,109 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/softmax/softmax.h" + +#ifdef ENABLE_CPU +#include "cpu/softmax_cpu.h" +#endif +#ifdef ENABLE_NV_GPU +#include "../../devices/cuda/common_cuda.h" +#include "../../devices/cuda/cuda_handle.h" +#include "cuda/softmax.cuh" +#endif +#ifdef ENABLE_CAMBRICON_MLU +#include "../../devices/bang/bang_handle.h" +#include "bang/softmax_bang.h" +#include "bang/softmax_cnnl.h" +#endif + + +__C infiniopStatus_t infiniopCreateSoftmaxDescriptor( + infiniopHandle_t handle, + infiniopSoftmaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t input_desc, int axis, infiniopTensorDescriptor_t output_desc) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateSoftmaxDescriptor(handle, (SoftmaxCpuDescriptor_t *) desc_ptr, input_desc, axis, output_desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaCreateSoftmaxDescriptor((CudaHandle_t) handle, (SoftmaxCudaDescriptor_t *) desc_ptr, input_desc, axis, output_desc); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + case DevCambriconMlu: { + //return bangCreateSoftmaxDescriptor((BangHandle_t) handle, (SoftmaxBangDescriptor_t *) desc_ptr, input_desc, axis, output_desc); + return cnnlCreateSoftmaxDescriptor((BangHandle_t) handle, (SoftmaxCnnlDescriptor_t *) desc_ptr, input_desc, axis, output_desc); + } +#endif + } + return STATUS_BAD_DEVICE; +} +__C infiniopStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t desc, uint64_t *size) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuGetSoftmaxWorkspaceSize((SoftmaxCpuDescriptor_t) desc, size); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaGetSoftmaxWorkspaceSize((SoftmaxCudaDescriptor_t) desc, size); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + case DevCambriconMlu: { + //return bangGetSoftmaxWorkspaceSize((SoftmaxBangDescriptor_t) desc, size); + return cnnlGetSoftmaxWorkspaceSize((SoftmaxCnnlDescriptor_t) desc, size); + } +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopSoftmax(infiniopSoftmaxDescriptor_t desc, void *workspace, + uint64_t workspace_size, void const *input, void *output, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuSoftmax((SoftmaxCpuDescriptor_t) desc, workspace, workspace_size, input, output, stream); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaSoftmax((SoftmaxCudaDescriptor_t) desc, workspace, workspace_size, input, output, stream); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + case DevCambriconMlu: { + //return bangSoftmax((SoftmaxBangDescriptor_t) desc, workspace, workspace_size, input, output, stream); + return cnnlSoftmax((SoftmaxCnnlDescriptor_t) desc, workspace, workspace_size, input, output, stream); + } +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroySoftmaxDescriptor((SoftmaxCpuDescriptor_t) desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaDestroySoftmaxDescriptor((SoftmaxCudaDescriptor_t) desc); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + case DevCambriconMlu: { + //return bangDestroySoftmaxDescriptor((SoftmaxBangDescriptor_t) desc); + return cnnlDestroySoftmaxDescriptor((SoftmaxCnnlDescriptor_t) desc); + } +#endif + } + return STATUS_BAD_DEVICE; +}