Skip to content

add cpu concat #138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions include/ops/concat/concat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef CONCAT_H
#define CONCAT_H

#include "../../export.h"
#include "../../operators.h"

typedef struct ConcatDescriptor {
Device device;
} ConcatDescriptor;

typedef ConcatDescriptor *infiniopConcatDescriptor_t;

__C __export infiniopStatus_t infiniopCreateConcatDescriptor(infiniopHandle_t handle,
infiniopConcatDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t *x,
uint64_t num_inputs,
int64_t axis);

__C __export infiniopStatus_t infiniopConcat(infiniopConcatDescriptor_t desc,
void *y,
void const **x,
void *stream);

__C __export infiniopStatus_t infiniopDestroyConcatDescriptor(infiniopConcatDescriptor_t desc);

#endif
3 changes: 0 additions & 3 deletions operatorspy/liboperators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

LIB_OPERATORS_DIR = os.path.join(os.environ.get("INFINI_ROOT"), "lib")


class TensorDescriptor(Structure):
_fields_ = [
("dt", DataLayout),
Expand All @@ -19,10 +18,8 @@ class TensorDescriptor(Structure):
("pattern", POINTER(c_int64)),
]


infiniopTensorDescriptor_t = ctypes.POINTER(TensorDescriptor)


class CTensor:
def __init__(self, desc, data):
self.descriptor = desc
Expand Down
212 changes: 212 additions & 0 deletions operatorspy/tests/concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_int64
import ctypes
import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
check_error,
)

from operatorspy.tests.test_utils import get_args
from enum import Enum, auto
import torch


class Inplace(Enum):
OUT_OF_PLACE = auto()

class ConcatDescriptor(Structure):
_fields_ = [("device", c_int32),]


infiniopConcatDescriptor_t = POINTER(ConcatDescriptor)


def concat_py(*tensors, dim=0):
return torch.cat(tensors, dim=dim)


def test(
lib,
handle,
torch_device,
c_shape,
axis,
input_shapes,
tensor_dtype=torch.float32,
inplace=Inplace.OUT_OF_PLACE,
):
"""
测试 concat 算子
"""
print(
f"Testing Concat on {torch_device} with output_shape:{c_shape}, input_shapes:{input_shapes}, axis:{axis}, dtype:{tensor_dtype}, inplace: {inplace.name}"
)

inputs = [torch.rand(shape, dtype=tensor_dtype).to(torch_device) for shape in input_shapes]

if inplace == Inplace.OUT_OF_PLACE:
c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device)
else:
c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device)

ans = concat_py(*inputs, dim=axis)

input_tensors = [to_tensor(t, lib) for t in inputs]
c_tensor = to_tensor(c, lib) if inplace == Inplace.OUT_OF_PLACE else to_tensor(c, lib)

descriptor = infiniopConcatDescriptor_t()

num_inputs = len(input_tensors)
input_desc_array_type = infiniopTensorDescriptor_t * num_inputs
input_desc_array = input_desc_array_type(*[t.descriptor for t in input_tensors])

check_error(
lib.infiniopCreateConcatDescriptor(
handle,
ctypes.byref(descriptor),
c_tensor.descriptor,
input_desc_array,
c_uint64(num_inputs),
c_int64(axis),
)
)

input_data_ptrs = (c_void_p * num_inputs)(*[t.data for t in input_tensors])
check_error(
lib.infiniopConcat(
descriptor,
c_tensor.data,
ctypes.cast(input_data_ptrs, POINTER(c_void_p)),
None
)
)

assert torch.allclose(c, ans, atol=0, rtol=0), "Concat result does not match PyTorch's result."

check_error(lib.infiniopDestroyConcatDescriptor(descriptor))


def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for c_shape, axis, input_shapes, inplace in test_cases:
test(lib, handle, "cpu", c_shape, axis, input_shapes, tensor_dtype = torch.float16, inplace = inplace)
test(lib, handle, "cpu", c_shape, axis, input_shapes, tensor_dtype = torch.float32, inplace = inplace)
destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for c_shape, axis, input_shapes, inplace in test_cases:
test(lib, handle, "cuda", c_shape, axis, input_shapes, tensor_dtype = torch.float16, inplace = inplace)
test(lib, handle, "cuda", c_shape, axis, input_shapes, tensor_dtype = torch.float32, inplace = inplace)
destroy_handle(lib, handle)

def test_bang(lib, test_cases):
import torch_mlu

device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for c_shape, axis, input_shapes, inplace in test_cases:
test(lib, handle, "mlu", c_shape, axis, input_shapes, inplace=inplace)
destroy_handle(lib, handle)


if __name__ == "__main__":

test_cases = [
#output_tensor, axis, inputs_tensors, inplace

((6,), 0, [(2,), (4,)], Inplace.OUT_OF_PLACE),

((6, 3), 0, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE),
((3, 6), 1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE),
((3, 7), 1, [(3, 2), (3, 4), (3, 1)], Inplace.OUT_OF_PLACE),
((3, 3, 10), 2, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE),
((4, 3, 6), 0, [(3, 3, 6), (1, 3, 6)], Inplace.OUT_OF_PLACE),
((2, 6, 3), 1, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE),
((2, 3, 6), 2, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE),
((4, 3, 5, 6), 0, [(1, 3, 5, 6), (3, 3, 5, 6)], Inplace.OUT_OF_PLACE),
((2, 5, 5, 6), 1, [(2, 3, 5, 6), (2, 2, 5, 6)], Inplace.OUT_OF_PLACE),
((2, 3, 5, 6), 2, [(2, 3, 2, 6), (2, 3, 3, 6)], Inplace.OUT_OF_PLACE),
((2, 3, 5, 6), 3, [(2, 3, 5, 3), (2, 3, 5, 3)], Inplace.OUT_OF_PLACE),
((2, 3, 5, 15), 3, [(2, 3, 5, 3), (2, 3, 5, 3), (2, 3, 5, 9)], Inplace.OUT_OF_PLACE),
((4, 2, 3, 4, 5), 0, [(1, 2, 3, 4, 5), (3, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE),
((2, 4, 3, 2, 5), 1, [(2, 2, 3, 2, 5), (2, 2, 3, 2, 5)], Inplace.OUT_OF_PLACE),
((1, 2, 4, 4, 5), 2, [(1, 2, 2, 4, 5), (1, 2, 2, 4, 5)], Inplace.OUT_OF_PLACE),
((1, 2, 3, 8, 5), 3, [(1, 2, 3, 4, 5), (1, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE),
((1, 2, 3, 4, 5), 4, [(1, 2, 3, 4, 3), (1, 2, 3, 4, 2)], Inplace.OUT_OF_PLACE),
((4, 14, 3, 4, 5), 1, [(4, 3, 3, 4, 5), (4, 5, 3, 4, 5), (4, 6, 3, 4, 5)], Inplace.OUT_OF_PLACE),

((6,), -1, [(2,), (4,)], Inplace.OUT_OF_PLACE),
((6, 3), -2, [(2, 3), (4, 3)], Inplace.OUT_OF_PLACE),
((3, 6), -1, [(3, 2), (3, 4)], Inplace.OUT_OF_PLACE),
((3, 7), -1, [(3, 2), (3, 4), (3, 1)], Inplace.OUT_OF_PLACE),
((3, 3, 10), -1, [(3, 3, 4), (3, 3, 6)], Inplace.OUT_OF_PLACE),
((4, 3, 6), -3, [(3, 3, 6), (1, 3, 6)], Inplace.OUT_OF_PLACE),
((2, 6, 3), -2, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE),
((2, 3, 6), -1, [(2, 3, 3), (2, 3, 3)], Inplace.OUT_OF_PLACE),
((4, 3, 5, 6), -4, [(1, 3, 5, 6), (3, 3, 5, 6)], Inplace.OUT_OF_PLACE),
((2, 5, 5, 6), -3, [(2, 3, 5, 6), (2, 2, 5, 6)], Inplace.OUT_OF_PLACE),
((2, 3, 5, 6), -2, [(2, 3, 2, 6), (2, 3, 3, 6)], Inplace.OUT_OF_PLACE),
((2, 3, 5, 6), -1, [(2, 3, 5, 3), (2, 3, 5, 3)], Inplace.OUT_OF_PLACE),
((2, 3, 5, 15), -1, [(2, 3, 5, 3), (2, 3, 5, 3), (2, 3, 5, 9)], Inplace.OUT_OF_PLACE),
((4, 2, 3, 4, 5), -5, [(1, 2, 3, 4, 5), (3, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE),
((2, 4, 3, 2, 5), -4, [(2, 2, 3, 2, 5), (2, 2, 3, 2, 5)], Inplace.OUT_OF_PLACE),
((1, 2, 4, 4, 5), -3, [(1, 2, 2, 4, 5), (1, 2, 2, 4, 5)], Inplace.OUT_OF_PLACE),
((1, 2, 3, 8, 5), -2, [(1, 2, 3, 4, 5), (1, 2, 3, 4, 5)], Inplace.OUT_OF_PLACE),
((1, 2, 3, 4, 5), -1, [(1, 2, 3, 4, 3), (1, 2, 3, 4, 2)], Inplace.OUT_OF_PLACE),
((4, 14, 3, 4, 5), -4, [(4, 3, 3, 4, 5), (4, 5, 3, 4, 5), (4, 6, 3, 4, 5)], Inplace.OUT_OF_PLACE),

]

args = get_args()
lib = open_lib()

lib.infiniopCreateConcatDescriptor.restype = c_int32
lib.infiniopCreateConcatDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopConcatDescriptor_t),
infiniopTensorDescriptor_t,
POINTER(infiniopTensorDescriptor_t),
c_uint64, # nums_input
c_int64, # axis
]

lib.infiniopConcat.restype = c_int32
lib.infiniopConcat.argtypes = [
infiniopConcatDescriptor_t,
c_void_p,
POINTER(c_void_p),
c_void_p,
]

lib.infiniopDestroyConcatDescriptor.restype = c_int32
lib.infiniopDestroyConcatDescriptor.argtypes = [
infiniopConcatDescriptor_t,
]

if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)
if not (args.cpu or args.cuda or args.bang):
test_cpu(lib, test_cases)

print("\033[92mConcat Test passed!\033[0m")




139 changes: 139 additions & 0 deletions src/ops/concat/cpu/concat_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#include "concat_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../utils.h"

infiniopStatus_t cpuCreateConcatDescriptor(
infiniopHandle_t handle,
ConcatCpuDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t *x,
uint64_t num_inputs,
int64_t axis) {
if (y == nullptr || x == nullptr || desc_ptr == nullptr || num_inputs == 0) {
return STATUS_BAD_PARAM;
}

if (!is_contiguous(y)) {
return STATUS_BAD_TENSOR_STRIDES;
}

int64_t ndim = y->ndim;
if (axis >= ndim || axis < -ndim) {
return STATUS_BAD_PARAM;
}

if(axis < 0){
axis = axis + ndim;
}

uint64_t total_size = 0;
std::vector<std::vector<uint64_t>> input_shapes(num_inputs);

std::vector<uint64_t> output_shape(y->shape, y->shape + ndim);

for (size_t i = 0; i < num_inputs; ++i) {

if (!is_contiguous(x[i])) {
return STATUS_BAD_TENSOR_STRIDES;
}

if (x[i]->dt != y->dt) {
return STATUS_BAD_TENSOR_DTYPE;
}

if (x[i]->ndim != ndim) {
return STATUS_BAD_TENSOR_SHAPE;
}

for (size_t j = 0; j < ndim; ++j) {
if (j != axis && x[i]->shape[j] != y->shape[j]) {
return STATUS_BAD_TENSOR_SHAPE;
}
}

input_shapes[i] = std::vector<uint64_t>(x[i]->shape, x[i]->shape + ndim);
total_size += x[i]->shape[axis];
}

if (total_size != y->shape[axis]) {
return STATUS_BAD_TENSOR_SHAPE;
}

*desc_ptr = new ConcatCpuDescriptor{
DevCpu,
y->dt,
axis,
num_inputs,
input_shapes,
output_shape,
};

return STATUS_SUCCESS;
}

infiniopStatus_t cpuDestroyConcatDescriptor(ConcatCpuDescriptor_t desc) {
delete desc;
return STATUS_SUCCESS;
}

template <typename T>
infiniopStatus_t concatCompute(const ConcatCpuDescriptor_t& desc,
T* y,
void const** x) {
int64_t axis = desc->axis;
uint64_t num_inputs = desc->num_inputs;
const std::vector<std::vector<uint64_t>>& input_shapes = desc->input_shapes;
const std::vector<uint64_t>& output_shape = desc->output_shape;

size_t blockOffsetInner = 1;
for (size_t i = output_shape.size() - 1; i > axis; --i) {
blockOffsetInner *= output_shape[i];
}
size_t blockOffset = output_shape[axis] * blockOffsetInner;

for (size_t i = 0; i < num_inputs; ++i) {
const std::vector<uint64_t>& input_shape = input_shapes[i];

size_t dimOffset = 0;
for (size_t j = 0; j < i; ++j) {
dimOffset += input_shapes[j][axis];
}

size_t localBlockOffset = 1;
for (size_t j = input_shape.size() - 1; j >= axis && j != static_cast<size_t>(-1); --j) {
localBlockOffset *= input_shape[j];
}

size_t innerOffset = blockOffsetInner * dimOffset;
size_t inSize = 1;
for (auto dim : input_shape) {
inSize *= dim;
}

T* input_data = static_cast<T*>(const_cast<void*>(x[i]));

#pragma omp parallel for
for (size_t iOffset = 0; iOffset < inSize; ++iOffset) {

size_t oOffset = iOffset % localBlockOffset + innerOffset +
iOffset / localBlockOffset * blockOffset;

y[oOffset] = input_data[iOffset];
}
}

return STATUS_SUCCESS;
}

infiniopStatus_t cpuConcat(ConcatCpuDescriptor_t desc,
void *y,
void const **x,
void *stream) {
if (desc->dtype == F16) {
return concatCompute<uint16_t>(desc, reinterpret_cast<uint16_t*>(y), x);
}
if (desc->dtype == F32) {
return concatCompute<float>(desc, reinterpret_cast<float*>(y), x);
}
return STATUS_BAD_TENSOR_DTYPE;
}
Loading
Loading