Skip to content

Commit a828bfd

Browse files
committed
feat(cpu): support concat negative axis
1 parent f00e101 commit a828bfd

File tree

6 files changed

+22
-41
lines changed

6 files changed

+22
-41
lines changed

include/ops/concat/concat.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,24 @@
44
#include "../../export.h"
55
#include "../../operators.h"
66

7-
// Concat描述符结构
87
typedef struct ConcatDescriptor {
9-
Device device; // 设备类型(例如 DevCpu、DevNvGpu)
10-
uint64_t axis; // 拼接轴(从0开始)
8+
Device device;
119
} ConcatDescriptor;
1210

1311
typedef ConcatDescriptor *infiniopConcatDescriptor_t;
1412

15-
// 创建Concat描述符
1613
__C __export infiniopStatus_t infiniopCreateConcatDescriptor(infiniopHandle_t handle,
1714
infiniopConcatDescriptor_t *desc_ptr,
1815
infiniopTensorDescriptor_t y,
1916
infiniopTensorDescriptor_t *x,
2017
uint64_t num_inputs,
21-
uint64_t axis);
18+
int64_t axis);
2219

23-
// 执行Concat操作
2420
__C __export infiniopStatus_t infiniopConcat(infiniopConcatDescriptor_t desc,
2521
void *y,
2622
void const **x,
2723
void *stream);
28-
29-
// 销毁Concat描述符
24+
3025
__C __export infiniopStatus_t infiniopDestroyConcatDescriptor(infiniopConcatDescriptor_t desc);
3126

3227
#endif

operatorspy/liboperators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Device = c_int
99
Optype = c_int
1010

11-
LIB_OPERATORS_DIR = os.path.join(os.environ.get("INFINI_ROOT"))
11+
LIB_OPERATORS_DIR = os.path.join(os.environ.get("INFINI_ROOT"), "lib")
1212

1313
class TensorDescriptor(Structure):
1414
_fields_ = [

operatorspy/tests/concat.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,6 @@ def test(
6464
c = torch.zeros(c_shape, dtype=tensor_dtype).to(torch_device)
6565

6666
ans = concat_py(*inputs, dim=axis)
67-
68-
print("ans:",ans)
69-
print("-" * 50)
7067

7168
input_tensors = [to_tensor(t, lib) for t in inputs]
7269
c_tensor = to_tensor(c, lib) if inplace == Inplace.OUT_OF_PLACE else to_tensor(c, lib)
@@ -97,11 +94,7 @@ def test(
9794
None
9895
)
9996
)
100-
101-
print("c2:",c)
102-
print("-" * 50)
10397

104-
# 验证结果
10598
assert torch.allclose(c, ans, atol=0, rtol=0), "Concat result does not match PyTorch's result."
10699

107100
check_error(lib.infiniopDestroyConcatDescriptor(descriptor))

src/ops/concat/cpu/concat_cpu.cc

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,25 @@ infiniopStatus_t cpuCreateConcatDescriptor(
88
infiniopTensorDescriptor_t y,
99
infiniopTensorDescriptor_t *x,
1010
uint64_t num_inputs,
11-
uint64_t axis) {
11+
int64_t axis) {
1212
if (y == nullptr || x == nullptr || desc_ptr == nullptr || num_inputs == 0) {
1313
return STATUS_BAD_PARAM;
1414
}
1515

16-
uint64_t ndim = y->ndim; // 输出张量维度
17-
if (axis >= ndim) {
18-
return STATUS_BAD_TENSOR_SHAPE;
16+
int64_t ndim = y->ndim;
17+
if (axis >= ndim || axis < -ndim) {
18+
return STATUS_BAD_PARAM;
19+
}
20+
21+
if(axis < 0){
22+
axis = axis + ndim;
1923
}
2024

21-
uint64_t total_size = 0; // 拼接轴的总大小
22-
std::vector<std::vector<uint64_t>> input_shapes(num_inputs); // 输入张量形状
25+
uint64_t total_size = 0;
26+
std::vector<std::vector<uint64_t>> input_shapes(num_inputs);
2327

2428
std::vector<uint64_t> output_shape(y->shape, y->shape + ndim);
2529

26-
// 验证输入张量的形状和步长
2730
for (size_t i = 0; i < num_inputs; ++i) {
2831

2932
if (x[i]->dt != y->dt) {
@@ -41,12 +44,9 @@ infiniopStatus_t cpuCreateConcatDescriptor(
4144
}
4245

4346
input_shapes[i] = std::vector<uint64_t>(x[i]->shape, x[i]->shape + ndim);
44-
45-
// 累加拼接轴的总大小
4647
total_size += x[i]->shape[axis];
4748
}
4849

49-
// 验证输出张量形状是否匹配
5050
if (total_size != y->shape[axis]) {
5151
return STATUS_BAD_TENSOR_SHAPE;
5252
}
@@ -72,8 +72,7 @@ template <typename T>
7272
infiniopStatus_t concatCompute(const ConcatCpuDescriptor_t& desc,
7373
T* y,
7474
void const** x) {
75-
// 获取描述符中的信息
76-
uint64_t axis = desc->axis;
75+
int64_t axis = desc->axis;
7776
uint64_t num_inputs = desc->num_inputs;
7877
const std::vector<std::vector<uint64_t>>& input_shapes = desc->input_shapes;
7978
const std::vector<uint64_t>& output_shape = desc->output_shape;
@@ -84,7 +83,6 @@ infiniopStatus_t concatCompute(const ConcatCpuDescriptor_t& desc,
8483
}
8584
size_t blockOffset = output_shape[axis] * blockOffsetInner;
8685

87-
// concat
8886
for (size_t i = 0; i < num_inputs; ++i) {
8987
const std::vector<uint64_t>& input_shape = input_shapes[i];
9088

@@ -104,7 +102,6 @@ infiniopStatus_t concatCompute(const ConcatCpuDescriptor_t& desc,
104102
inSize *= dim;
105103
}
106104

107-
// 获取输入和输出的数据指针
108105
T* input_data = static_cast<T*>(const_cast<void*>(x[i]));
109106

110107
#pragma omp parallel for
@@ -120,16 +117,15 @@ infiniopStatus_t concatCompute(const ConcatCpuDescriptor_t& desc,
120117
return STATUS_SUCCESS;
121118
}
122119

123-
// 主拼接函数
124120
infiniopStatus_t cpuConcat(ConcatCpuDescriptor_t desc,
125121
void *y,
126122
void const **x,
127123
void *stream) {
128-
// 根据数据类型调用相应的模板实例
124+
129125
switch (desc->dtype.size) {
130126
case sizeof(float): // FLOAT32
131127
return concatCompute<float>(desc, reinterpret_cast<float*>(y), x);
132-
// 可以根据需要添加更多数据类型
128+
// add other data.type
133129
default:
134130
return STATUS_SUCCESS;
135131
}

src/ops/concat/cpu/concat_cpu.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,23 @@
44
#include <vector>
55
#include <cstring>
66

7-
// 支持高维拼接的CPU-specific Concat描述符
87
struct ConcatCpuDescriptor {
98
Device device;
109
DT dtype;
11-
uint64_t axis;
10+
int64_t axis;
1211
uint64_t num_inputs;
13-
std::vector<std::vector<uint64_t>> input_shapes; // 输入张量的形状
14-
std::vector<uint64_t> output_shape; // 输出张量的形状
12+
std::vector<std::vector<uint64_t>> input_shapes;
13+
std::vector<uint64_t> output_shape;
1514
};
1615

17-
18-
1916
typedef struct ConcatCpuDescriptor *ConcatCpuDescriptor_t;
2017

2118
infiniopStatus_t cpuCreateConcatDescriptor(infiniopHandle_t handle,
2219
ConcatCpuDescriptor_t *desc_ptr,
2320
infiniopTensorDescriptor_t y,
2421
infiniopTensorDescriptor_t *x,
2522
uint64_t num_inputs,
26-
uint64_t axis);
23+
int64_t axis);
2724

2825
infiniopStatus_t cpuConcat(ConcatCpuDescriptor_t desc,
2926
void *y,

src/ops/concat/operator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ __C infiniopStatus_t infiniopCreateConcatDescriptor(
1616
infiniopTensorDescriptor_t y,
1717
infiniopTensorDescriptor_t *x,
1818
uint64_t num_inputs,
19-
uint64_t axis) {
19+
int64_t axis) {
2020
switch (handle->device) {
2121
#ifdef ENABLE_CPU
2222
case DevCpu:

0 commit comments

Comments
 (0)