diff --git a/include/ops/concat/concat.h b/include/ops/concat/concat.h new file mode 100644 index 00000000..20ca6339 --- /dev/null +++ b/include/ops/concat/concat.h @@ -0,0 +1,27 @@ +#ifndef CONCAT_H +#define CONCAT_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ConcatDescriptor { + Device device; +} ConcatDescriptor; + +typedef ConcatDescriptor *infiniopConcatDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateConcatDescriptor(infiniopHandle_t handle, + infiniopConcatDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *x, + uint64_t num_inputs, + int64_t axis); + +__C __export infiniopStatus_t infiniopConcat(infiniopConcatDescriptor_t desc, + void *y, + void const **x, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyConcatDescriptor(infiniopConcatDescriptor_t desc); + +#endif diff --git a/operatorspy/liboperators.py b/operatorspy/liboperators.py index 868cc88d..fb58d6a7 100644 --- a/operatorspy/liboperators.py +++ b/operatorspy/liboperators.py @@ -10,7 +10,6 @@ LIB_OPERATORS_DIR = os.path.join(os.environ.get("INFINI_ROOT"), "lib") - class TensorDescriptor(Structure): _fields_ = [ ("dt", DataLayout), @@ -19,10 +18,8 @@ class TensorDescriptor(Structure): ("pattern", POINTER(c_int64)), ] - infiniopTensorDescriptor_t = ctypes.POINTER(TensorDescriptor) - class CTensor: def __init__(self, desc, data): self.descriptor = desc diff --git a/operatorspy/tests/concat.py b/operatorspy/tests/concat.py new file mode 100644 index 00000000..96f34088 --- /dev/null +++ b/operatorspy/tests/concat.py @@ -0,0 +1,212 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_int64 +import ctypes +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +from enum import Enum, auto +import torch + + +class Inplace(Enum): + OUT_OF_PLACE = auto() + +class ConcatDescriptor(Structure): + _fields_ = [("device", c_int32),] + + +infiniopConcatDescriptor_t = POINTER(ConcatDescriptor) + + +def concat_py(*tensors, dim=0): + return torch.cat(tensors, dim=dim) + + +def test( + lib, + handle, + torch_device, + c_shape, + axis, + input_shapes, + tensor_dtype=torch.float32, + inplace=Inplace.OUT_OF_PLACE, +): + """ + 测试 concat 算子 + """ + print( + f"Testing Concat on {torch_device} with output_shape:{c_shape}, input_shapes:{input_shapes}, axis:{axis}, dtype:{tensor_dtype}, inplace: {inplace.name}" + ) + + inputs = [torch.rand(shape, dtype=tensor_dtype).to(torch_device) for shape in input_shapes] + + if inplace == Inplace.OUT_OF_PLACE: + c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device) + else: + c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device) + + ans = concat_py(*inputs, dim=axis) + + input_tensors = [to_tensor(t, lib) for t in inputs] + c_tensor = to_tensor(c, lib) if inplace == Inplace.OUT_OF_PLACE else to_tensor(c, lib) + + descriptor = infiniopConcatDescriptor_t() + + num_inputs = len(input_tensors) + input_desc_array_type = infiniopTensorDescriptor_t * num_inputs + input_desc_array = input_desc_array_type(*[t.descriptor for t in input_tensors]) + + check_error( + lib.infiniopCreateConcatDescriptor( + handle, + ctypes.byref(descriptor), + c_tensor.descriptor, + input_desc_array, + c_uint64(num_inputs), + c_int64(axis), + ) + ) + + input_data_ptrs = (c_void_p * num_inputs)(*[t.data for t in input_tensors]) + check_error( + lib.infiniopConcat( + descriptor, + c_tensor.data, + ctypes.cast(input_data_ptrs, POINTER(c_void_p)), + None + ) + ) + + assert torch.allclose(c, ans, atol=0, rtol=0), "Concat result does not match PyTorch's result." + + check_error(lib.infiniopDestroyConcatDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for c_shape, axis, input_shapes, inplace in test_cases: + test(lib, handle, "cpu", c_shape, axis, input_shapes, tensor_dtype = torch.float16, inplace = inplace) + test(lib, handle, "cpu", c_shape, axis, input_shapes, tensor_dtype = torch.float32, inplace = inplace) + destroy_handle(lib, handle) + + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for c_shape, axis, input_shapes, inplace in test_cases: + test(lib, handle, "cuda", c_shape, axis, input_shapes, tensor_dtype = torch.float16, inplace = inplace) + test(lib, handle, "cuda", c_shape, axis, input_shapes, tensor_dtype = torch.float32, inplace = inplace) + destroy_handle(lib, handle) + +def test_bang(lib, test_cases): + import torch_mlu + + device = DeviceEnum.DEVICE_BANG + handle = create_handle(lib, device) + for c_shape, axis, input_shapes, inplace in test_cases: + test(lib, handle, "mlu", c_shape, axis, input_shapes, inplace=inplace) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + + test_cases = [ + #output_tensor, axis, inputs_tensors, inplace + + ((6,), 0, [(2,), (4,)], Inplace.OUT_OF_PLACE), + + ((6, 3), 0, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE), + ((3, 6), 1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE), + ((3, 7), 1, [(3, 2), (3, 4), (3, 1)], Inplace.OUT_OF_PLACE), + ((3, 3, 10), 2, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE), + ((4, 3, 6), 0, [(3, 3, 6), (1, 3, 6)], Inplace.OUT_OF_PLACE), + ((2, 6, 3), 1, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), + ((2, 3, 6), 2, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), + ((4, 3, 5, 6), 0, [(1, 3, 5, 6), (3, 3, 5, 6)], Inplace.OUT_OF_PLACE), + ((2, 5, 5, 6), 1, [(2, 3, 5, 6), (2, 2, 5, 6)], Inplace.OUT_OF_PLACE), + ((2, 3, 5, 6), 2, [(2, 3, 2, 6), (2, 3, 3, 6)], Inplace.OUT_OF_PLACE), + ((2, 3, 5, 6), 3, [(2, 3, 5, 3), (2, 3, 5, 3)], Inplace.OUT_OF_PLACE), + ((2, 3, 5, 15), 3, [(2, 3, 5, 3), (2, 3, 5, 3), (2, 3, 5, 9)], Inplace.OUT_OF_PLACE), + ((4, 2, 3, 4, 5), 0, [(1, 2, 3, 4, 5), (3, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), + ((2, 4, 3, 2, 5), 1, [(2, 2, 3, 2, 5), (2, 2, 3, 2, 5)], Inplace.OUT_OF_PLACE), + ((1, 2, 4, 4, 5), 2, [(1, 2, 2, 4, 5), (1, 2, 2, 4, 5)], Inplace.OUT_OF_PLACE), + ((1, 2, 3, 8, 5), 3, [(1, 2, 3, 4, 5), (1, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), + ((1, 2, 3, 4, 5), 4, [(1, 2, 3, 4, 3), (1, 2, 3, 4, 2)], Inplace.OUT_OF_PLACE), + ((4, 14, 3, 4, 5), 1, [(4, 3, 3, 4, 5), (4, 5, 3, 4, 5), (4, 6, 3, 4, 5)], Inplace.OUT_OF_PLACE), + + ((6,), -1, [(2,), (4,)], Inplace.OUT_OF_PLACE), + ((6, 3), -2, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE), + ((3, 6), -1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE), + ((3, 7), -1, [(3, 2), (3, 4), (3, 1)], Inplace.OUT_OF_PLACE), + ((3, 3, 10), -1, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE), + ((4, 3, 6), -3, [(3, 3, 6), (1, 3, 6)], Inplace.OUT_OF_PLACE), + ((2, 6, 3), -2, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), + ((2, 3, 6), -1, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE), + ((4, 3, 5, 6), -4, [(1, 3, 5, 6), (3, 3, 5, 6)], Inplace.OUT_OF_PLACE), + ((2, 5, 5, 6), -3, [(2, 3, 5, 6), (2, 2, 5, 6)], Inplace.OUT_OF_PLACE), + ((2, 3, 5, 6), -2, [(2, 3, 2, 6), (2, 3, 3, 6)], Inplace.OUT_OF_PLACE), + ((2, 3, 5, 6), -1, [(2, 3, 5, 3), (2, 3, 5, 3)], Inplace.OUT_OF_PLACE), + ((2, 3, 5, 15), -1, [(2, 3, 5, 3), (2, 3, 5, 3), (2, 3, 5, 9)], Inplace.OUT_OF_PLACE), + ((4, 2, 3, 4, 5), -5, [(1, 2, 3, 4, 5), (3, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), + ((2, 4, 3, 2, 5), -4, [(2, 2, 3, 2, 5), (2, 2, 3, 2, 5)], Inplace.OUT_OF_PLACE), + ((1, 2, 4, 4, 5), -3, [(1, 2, 2, 4, 5), (1, 2, 2, 4, 5)], Inplace.OUT_OF_PLACE), + ((1, 2, 3, 8, 5), -2, [(1, 2, 3, 4, 5), (1, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE), + ((1, 2, 3, 4, 5), -1, [(1, 2, 3, 4, 3), (1, 2, 3, 4, 2)], Inplace.OUT_OF_PLACE), + ((4, 14, 3, 4, 5), -4, [(4, 3, 3, 4, 5), (4, 5, 3, 4, 5), (4, 6, 3, 4, 5)], Inplace.OUT_OF_PLACE), + + ] + + args = get_args() + lib = open_lib() + + lib.infiniopCreateConcatDescriptor.restype = c_int32 + lib.infiniopCreateConcatDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopConcatDescriptor_t), + infiniopTensorDescriptor_t, + POINTER(infiniopTensorDescriptor_t), + c_uint64, # nums_input + c_int64, # axis + ] + + lib.infiniopConcat.restype = c_int32 + lib.infiniopConcat.argtypes = [ + infiniopConcatDescriptor_t, + c_void_p, + POINTER(c_void_p), + c_void_p, + ] + + lib.infiniopDestroyConcatDescriptor.restype = c_int32 + lib.infiniopDestroyConcatDescriptor.argtypes = [ + infiniopConcatDescriptor_t, + ] + + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + if args.bang: + test_bang(lib, test_cases) + if not (args.cpu or args.cuda or args.bang): + test_cpu(lib, test_cases) + + print("\033[92mConcat Test passed!\033[0m") + + + + diff --git a/src/ops/concat/cpu/concat_cpu.cc b/src/ops/concat/cpu/concat_cpu.cc new file mode 100644 index 00000000..6c9bd419 --- /dev/null +++ b/src/ops/concat/cpu/concat_cpu.cc @@ -0,0 +1,139 @@ +#include "concat_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" + +infiniopStatus_t cpuCreateConcatDescriptor( + infiniopHandle_t handle, + ConcatCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *x, + uint64_t num_inputs, + int64_t axis) { + if (y == nullptr || x == nullptr || desc_ptr == nullptr || num_inputs == 0) { + return STATUS_BAD_PARAM; + } + + if (!is_contiguous(y)) { + return STATUS_BAD_TENSOR_STRIDES; + } + + int64_t ndim = y->ndim; + if (axis >= ndim || axis < -ndim) { + return STATUS_BAD_PARAM; + } + + if(axis < 0){ + axis = axis + ndim; + } + + uint64_t total_size = 0; + std::vector> input_shapes(num_inputs); + + std::vector output_shape(y->shape, y->shape + ndim); + + for (size_t i = 0; i < num_inputs; ++i) { + + if (!is_contiguous(x[i])) { + return STATUS_BAD_TENSOR_STRIDES; + } + + if (x[i]->dt != y->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + + if (x[i]->ndim != ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + + for (size_t j = 0; j < ndim; ++j) { + if (j != axis && x[i]->shape[j] != y->shape[j]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + + input_shapes[i] = std::vector(x[i]->shape, x[i]->shape + ndim); + total_size += x[i]->shape[axis]; + } + + if (total_size != y->shape[axis]) { + return STATUS_BAD_TENSOR_SHAPE; + } + + *desc_ptr = new ConcatCpuDescriptor{ + DevCpu, + y->dt, + axis, + num_inputs, + input_shapes, + output_shape, + }; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuDestroyConcatDescriptor(ConcatCpuDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} + +template +infiniopStatus_t concatCompute(const ConcatCpuDescriptor_t& desc, + T* y, + void const** x) { + int64_t axis = desc->axis; + uint64_t num_inputs = desc->num_inputs; + const std::vector>& input_shapes = desc->input_shapes; + const std::vector& output_shape = desc->output_shape; + + size_t blockOffsetInner = 1; + for (size_t i = output_shape.size() - 1; i > axis; --i) { + blockOffsetInner *= output_shape[i]; + } + size_t blockOffset = output_shape[axis] * blockOffsetInner; + + for (size_t i = 0; i < num_inputs; ++i) { + const std::vector& input_shape = input_shapes[i]; + + size_t dimOffset = 0; + for (size_t j = 0; j < i; ++j) { + dimOffset += input_shapes[j][axis]; + } + + size_t localBlockOffset = 1; + for (size_t j = input_shape.size() - 1; j >= axis && j != static_cast(-1); --j) { + localBlockOffset *= input_shape[j]; + } + + size_t innerOffset = blockOffsetInner * dimOffset; + size_t inSize = 1; + for (auto dim : input_shape) { + inSize *= dim; + } + + T* input_data = static_cast(const_cast(x[i])); + + #pragma omp parallel for + for (size_t iOffset = 0; iOffset < inSize; ++iOffset) { + + size_t oOffset = iOffset % localBlockOffset + innerOffset + + iOffset / localBlockOffset * blockOffset; + + y[oOffset] = input_data[iOffset]; + } + } + + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuConcat(ConcatCpuDescriptor_t desc, + void *y, + void const **x, + void *stream) { + if (desc->dtype == F16) { + return concatCompute(desc, reinterpret_cast(y), x); + } + if (desc->dtype == F32) { + return concatCompute(desc, reinterpret_cast(y), x); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/concat/cpu/concat_cpu.h b/src/ops/concat/cpu/concat_cpu.h new file mode 100644 index 00000000..a8d4d71d --- /dev/null +++ b/src/ops/concat/cpu/concat_cpu.h @@ -0,0 +1,32 @@ +#ifndef __CPU_CONCAT_H__ +#define __CPU_CONCAT_H__ +#include "operators.h" +#include +#include + +struct ConcatCpuDescriptor { + Device device; + DT dtype; + int64_t axis; + uint64_t num_inputs; + std::vector> input_shapes; + std::vector output_shape; +}; + +typedef struct ConcatCpuDescriptor *ConcatCpuDescriptor_t; + +infiniopStatus_t cpuCreateConcatDescriptor(infiniopHandle_t handle, + ConcatCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *x, + uint64_t num_inputs, + int64_t axis); + +infiniopStatus_t cpuConcat(ConcatCpuDescriptor_t desc, + void *y, + void const **x, + void *stream); + +infiniopStatus_t cpuDestroyConcatDescriptor(ConcatCpuDescriptor_t desc); + +#endif diff --git a/src/ops/concat/cuda/concat.cc b/src/ops/concat/cuda/concat.cc new file mode 100644 index 00000000..d99d167b --- /dev/null +++ b/src/ops/concat/cuda/concat.cc @@ -0,0 +1,73 @@ +#include "concat.cuh" +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" + +infiniopStatus_t cudaCreateConcatDescriptor(CudaHandle_t handle, + ConcatCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *x, + uint64_t num_inputs, + int64_t axis){ + if (y == nullptr || x == nullptr || desc_ptr == nullptr || num_inputs == 0) { + return STATUS_BAD_PARAM; + } + + if (!is_contiguous(y)) { + return STATUS_BAD_TENSOR_STRIDES; + } + + int64_t ndim = y->ndim; + if (axis >= ndim || axis < -ndim) { + return STATUS_BAD_PARAM; + } + + if(axis < 0){ + axis = axis + ndim; + } + uint64_t total_size = 0; + + std::vector> input_shapes(num_inputs); + std::vector output_shape(y->shape, y->shape + ndim); + + for (size_t i = 0; i < num_inputs; ++i) { + + if (!is_contiguous(x[i])) { + return STATUS_BAD_TENSOR_STRIDES; + } + + if (x[i]->dt != y->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (x[i]->ndim != ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + for (size_t j = 0; j < ndim; ++j) { + if (j != axis && x[i]->shape[j] != y->shape[j]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + + input_shapes[i] = std::vector(x[i]->shape, x[i]->shape + ndim); + total_size += x[i]->shape[axis]; + } + + if (total_size != y->shape[axis]) { + return STATUS_BAD_TENSOR_SHAPE; + } + + *desc_ptr = new ConcatCudaDescriptor{ + DevNvGpu, + y->dt, + axis, + num_inputs, + input_shapes, + output_shape, + }; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaDestroyConcatDescriptor(ConcatCudaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} \ No newline at end of file diff --git a/src/ops/concat/cuda/concat.cu b/src/ops/concat/cuda/concat.cu new file mode 100644 index 00000000..2c3d8ad6 --- /dev/null +++ b/src/ops/concat/cuda/concat.cu @@ -0,0 +1,86 @@ +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" +#include "concat.cuh" + +// Kernel function to perform concatenation on GPU +template +__global__ void concatKernel(const T* x, T* y, + size_t inSize, + size_t localBlockOffset, + size_t innerOffset, + size_t blockOffset) { + size_t iOffset = blockIdx.x * blockDim.x + threadIdx.x; + if (iOffset < inSize) { + size_t oOffset = (iOffset % localBlockOffset) + innerOffset + + (iOffset / localBlockOffset) * blockOffset; + y[oOffset] = x[iOffset]; + } +} + +template +infiniopStatus_t concatCompute(ConcatCudaDescriptor_t& desc, + T* y, + void const** x, + cudaStream_t stream) { + int64_t axis = desc->axis; + uint64_t num_inputs = desc->num_inputs; + const std::vector>& input_shapes = desc->input_shapes; + const std::vector& output_shape = desc->output_shape; + + size_t blockOffsetInner = 1; + for (size_t i = output_shape.size() - 1; i > axis; --i) { + blockOffsetInner *= output_shape[i]; + } + size_t blockOffset = output_shape[axis] * blockOffsetInner; + +#pragma unroll + for (size_t i = 0; i < num_inputs; ++i) { + const std::vector& input_shape = input_shapes[i]; + + size_t dimOffset = 0; + for (size_t j = 0; j < i; ++j) { + dimOffset += input_shapes[j][axis]; + } + + size_t localBlockOffset = 1; + for (size_t j = input_shape.size() - 1; j >= axis && j != static_cast(-1); --j) { + localBlockOffset *= input_shape[j]; + } + + size_t innerOffset = blockOffsetInner * dimOffset; + size_t inSize = 1; + for (auto dim : input_shape) { + inSize *= dim; + } + + T* input_data = static_cast(const_cast(x[i])); + + // Launch CUDA kernel + int threads = 256; + int blocks = (inSize + threads - 1) / threads; + concatKernel<<>>(input_data, y, inSize, localBlockOffset, innerOffset, blockOffset); + + // Check for CUDA errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + return STATUS_EXECUTION_FAILED; + } + } + + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaConcat(ConcatCudaDescriptor_t desc, + void* y, + void const** x, + void* stream) { + cudaStream_t cudaStream = reinterpret_cast(stream); + + if (desc->dtype == F16) { + return concatCompute(desc, reinterpret_cast(y), x, cudaStream); + } + if (desc->dtype == F32) { + return concatCompute(desc, reinterpret_cast(y), x, cudaStream); + } + return STATUS_BAD_TENSOR_DTYPE; +} \ No newline at end of file diff --git a/src/ops/concat/cuda/concat.cuh b/src/ops/concat/cuda/concat.cuh new file mode 100644 index 00000000..9eeaf06f --- /dev/null +++ b/src/ops/concat/cuda/concat.cuh @@ -0,0 +1,36 @@ +#ifndef __CUDA_CONCAT_H__ +#define __CUDA_CONCAT_H__ + +#include "../../../devices/cuda/common_cuda.h" +#include "../../../devices/cuda/cuda_handle.h" +#include "operators.h" +#include +#include +#include + +struct ConcatCudaDescriptor { + Device device; + DT dtype; + int64_t axis; + uint64_t num_inputs; + std::vector> input_shapes; + std::vector output_shape; +}; + +typedef struct ConcatCudaDescriptor *ConcatCudaDescriptor_t; + +infiniopStatus_t cudaCreateConcatDescriptor(CudaHandle_t handle, + ConcatCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *x, + uint64_t nums_input, + int64_t axis); + +infiniopStatus_t cudaConcat(ConcatCudaDescriptor_t desc, + void *y, + void const **x, + void *stream); + +infiniopStatus_t cudaDestroyConcatDescriptor(ConcatCudaDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/src/ops/concat/operator.cc b/src/ops/concat/operator.cc new file mode 100644 index 00000000..5f3cdae1 --- /dev/null +++ b/src/ops/concat/operator.cc @@ -0,0 +1,64 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/concat/concat.h" + +#ifdef ENABLE_CPU +#include "cpu/concat_cpu.h" +#endif +#ifdef ENABLE_NV_GPU +#include "../../devices/cuda/cuda_handle.h" +#include "cuda/concat.cuh" +#endif + +__C infiniopStatus_t infiniopCreateConcatDescriptor( + infiniopHandle_t handle, + infiniopConcatDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *x, + uint64_t num_inputs, + int64_t axis) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateConcatDescriptor(handle, (ConcatCpuDescriptor_t *) desc_ptr, y, x, num_inputs, axis); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaCreateConcatDescriptor((CudaHandle_t) handle, (ConcatCudaDescriptor_t *) desc_ptr, y, x, num_inputs, axis); + } +#endif + } + return STATUS_BAD_DEVICE; +} + + +__C infiniopStatus_t infiniopConcat(infiniopConcatDescriptor_t desc, void *y, void const **x, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuConcat((ConcatCpuDescriptor_t) desc, y, x, stream); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaConcat((ConcatCudaDescriptor_t) desc, y, x, stream); + } +#endif + } + return STATUS_BAD_DEVICE; +} + + +__C infiniopStatus_t infiniopDestroyConcatDescriptor(infiniopConcatDescriptor_t desc) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyConcatDescriptor((ConcatCpuDescriptor_t) desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaDestroyConcatDescriptor((ConcatCudaDescriptor_t) desc); + } +#endif + } + return STATUS_BAD_DEVICE; +}