Skip to content

Commit a41a09f

Browse files
authored
Merge pull request #72 from InfiniTensor/add_mod_kernel
添加 Mod cpu/cuda 算子
2 parents 4f148cd + 23ce522 commit a41a09f

File tree

9 files changed

+207
-27
lines changed

9 files changed

+207
-27
lines changed

src/04kernel/include/kernel/collectors/simple_binary.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ namespace refactor::kernel {
1414
And,
1515
Or,
1616
Xor,
17+
Mod,
18+
Fmod,
1719
};
1820

1921
std::string_view opName(SimpleBinaryType type);

src/04kernel/src/collectors/simple_binary.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ namespace refactor::kernel {
1919
CASE(And);
2020
CASE(Or);
2121
CASE(Xor);
22+
CASE(Mod);
23+
CASE(Fmod);
2224
default:
2325
UNREACHABLE();
2426
}

src/04kernel/src/kernels/simple_binary/cpu_kernel.cc

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "cpu_kernel.hh"
2+
#include <cmath>
23
#include <execution>
34

45
namespace refactor::kernel {
@@ -118,8 +119,38 @@ namespace refactor::kernel {
118119
UNREACHABLE();
119120
}
120121
}
121-
default:
122-
UNREACHABLE();
122+
case Op::Mod: {
123+
switch (dataType.internal) {
124+
CASE_DT(a % b, U8);
125+
CASE_DT(a % b, I8);
126+
CASE_DT(a % b, U16);
127+
CASE_DT(a % b, I16);
128+
CASE_DT(a % b, I32);
129+
CASE_DT(a % b, I64);
130+
CASE_DT(a % b, U32);
131+
CASE_DT(a % b, U64);
132+
default:
133+
UNREACHABLE();
134+
}
135+
}
136+
case Op::Fmod: {
137+
switch (dataType.internal) {
138+
CASE_DT(std::fmod(a, b), F32);
139+
CASE_DT(a % b, U8);
140+
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I8);
141+
CASE_DT(a % b, U16);
142+
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I16);
143+
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I32);
144+
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I64);
145+
CASE_DT(std::fmod(a, b), F64);
146+
CASE_DT(a % b, U32);
147+
CASE_DT(a % b, U64);
148+
default:
149+
UNREACHABLE();
150+
}
151+
default:
152+
UNREACHABLE();
153+
}
123154
}
124155
}
125156

src/04kernel/src/kernels/simple_binary/cuda_kernel.cc

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,46 @@ extern "C" __global__ void kernel(
135135
case DataType::F32:
136136
return "powf(a, b)";
137137
case DataType::FP16:
138-
return "__float2half(__powf(__half2float(a), __half2float(b)))";
138+
return "__float2half(powf(__half2float(a), __half2float(b)))";
139139
case DataType::BF16:
140140
return "__float2bfloat16(powf(__bfloat162float(a), __bfloat162float(b)))";
141141
default:
142142
return "pow(a, b)";
143143
}
144+
case SimpleBinaryType::Mod:
145+
switch (dt) {
146+
case DataType::U8:
147+
case DataType::I8:
148+
case DataType::U16:
149+
case DataType::I16:
150+
case DataType::I32:
151+
case DataType::I64:
152+
case DataType::U32:
153+
case DataType::U64:
154+
return "a % b";
155+
default:
156+
UNREACHABLE();
157+
}
158+
case SimpleBinaryType::Fmod:
159+
switch (dt) {
160+
case DataType::U8:
161+
case DataType::I8:
162+
case DataType::U16:
163+
case DataType::I16:
164+
case DataType::I32:
165+
case DataType::I64:
166+
case DataType::U32:
167+
case DataType::U64:
168+
return "a % b < 0 ? (a % b + b) : (a % b)";
169+
case DataType::F32:
170+
return "fmodf(a, b)";
171+
case DataType::FP16:
172+
return "__float2half(fmodf(__half2float(a), __half2float(b)))";
173+
case DataType::BF16:
174+
return "__float2bfloat16(fmodf(__bfloat162float(a), __bfloat162float(b)))";
175+
default:
176+
UNREACHABLE();
177+
}
144178
default:
145179
UNREACHABLE();
146180
}

src/04kernel/test/kernels/simple_binary/test_binary_cpu.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "../src/kernels/simple_binary/cpu_kernel.hh"
2+
#include <cmath>
23
#include <gtest/gtest.h>
34

45
using namespace refactor;
@@ -27,11 +28,60 @@ void testBinaryCPU(SimpleBinaryType binaryOPT, std::function<float(float, float)
2728
}
2829
}
2930

31+
void testModCPU(SimpleBinaryType binaryOPT, std::function<int(int, int)> operation) {
32+
// Create Tensor and build kernels
33+
auto aTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW);
34+
auto bTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW);
35+
auto cTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW);
36+
auto cpuKernel = BinaryCpu::build(binaryOPT, *aTensor, *bTensor);
37+
ASSERT_TRUE(cpuKernel);
38+
auto res = runtime::Resources();
39+
auto cpuRoutine = cpuKernel->lower(res).routine;
40+
// Init inputs and outputs
41+
std::vector<int> a(aTensor->elementsSize(), -3);
42+
std::vector<int> b(bTensor->elementsSize(), 2);
43+
std::vector<int> c(cTensor->elementsSize());
44+
// Compute
45+
void const *inputs[]{a.data(), b.data()};
46+
void *outputs[]{c.data()};
47+
cpuRoutine(res, nullptr, inputs, outputs);
48+
// Compare
49+
for (auto i : range0_(c.size())) {
50+
EXPECT_FLOAT_EQ(c[i], operation(a[i], b[i]));
51+
}
52+
}
53+
54+
void testFmodWithI32CPU(SimpleBinaryType binaryOPT, std::function<int(int, int)> operation) {
55+
// Create Tensor and build kernels
56+
auto aTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW);
57+
auto bTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW);
58+
auto cTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW);
59+
auto cpuKernel = BinaryCpu::build(binaryOPT, *aTensor, *bTensor);
60+
ASSERT_TRUE(cpuKernel);
61+
auto res = runtime::Resources();
62+
auto cpuRoutine = cpuKernel->lower(res).routine;
63+
// Init inputs and outputs
64+
std::vector<int> a(aTensor->elementsSize(), -3);
65+
std::vector<int> b(bTensor->elementsSize(), 2);
66+
std::vector<int> c(cTensor->elementsSize());
67+
// Compute
68+
void const *inputs[]{a.data(), b.data()};
69+
void *outputs[]{c.data()};
70+
cpuRoutine(res, nullptr, inputs, outputs);
71+
// Compare
72+
for (auto i : range0_(c.size())) {
73+
EXPECT_FLOAT_EQ(c[i], operation(a[i], b[i]));
74+
}
75+
}
76+
3077
TEST(kernel, BinaryCpu) {
3178
testBinaryCPU(SimpleBinaryType::Add, [](float a, float b) { return a + b; });
3279
testBinaryCPU(SimpleBinaryType::Sub, [](float a, float b) { return a - b; });
3380
testBinaryCPU(SimpleBinaryType::Mul, [](float a, float b) { return a * b; });
3481
testBinaryCPU(SimpleBinaryType::Div, [](float a, float b) { return a / b; });
82+
testModCPU(SimpleBinaryType::Mod, [](int a, int b) { return a % b; });
83+
testFmodWithI32CPU(SimpleBinaryType::Fmod, [](int a, int b) { return a % b < 0 ? (a % b + b) : (a % b); });
84+
testBinaryCPU(SimpleBinaryType::Fmod, [](float a, float b) { return std::fmod(a, b); });
3585
}
3686

3787
TEST(kernel, BinaryCpuBroadcast) {

src/04kernel/test/kernels/simple_binary/test_binary_cuda.cpp

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@ using namespace refactor;
99
using namespace kernel;
1010
using namespace hardware;
1111

12+
template<decltype(DataType::internal) T>
1213
void testBinaryCuda(SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape dimC) {
1314
// Create Tensor and build kernels
14-
using T_ = primitive<DataType::I8>::type;
15-
auto aTensor = Tensor::share(DataType::I8, dimA, LayoutType::NCHW);
16-
auto bTensor = Tensor::share(DataType::I8, dimB, LayoutType::NCHW);
17-
auto cTensor = Tensor::share(DataType::I8, dimC, LayoutType::NCHW);
15+
using T_ = primitive<T>::type;
16+
auto aTensor = Tensor::share(T, dimA, LayoutType::NCHW);
17+
auto bTensor = Tensor::share(T, dimB, LayoutType::NCHW);
18+
auto cTensor = Tensor::share(T, dimC, LayoutType::NCHW);
1819

1920
auto cpuKernel = BinaryCpu::build(binaryOPT, *aTensor, *bTensor),
2021
cudaKernel = BinaryCuda::build(binaryOPT, *aTensor, *bTensor);
@@ -24,8 +25,8 @@ void testBinaryCuda(SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape di
2425
auto cudaRoutine = cudaKernel->lower(res).routine;
2526

2627
// Init inputs and outputs
27-
std::vector<T_> a(aTensor->elementsSize(), 3.0f);
28-
std::vector<T_> b(bTensor->elementsSize(), 2.0f);
28+
std::vector<T_> a(aTensor->elementsSize(), 3);
29+
std::vector<T_> b(bTensor->elementsSize(), 2);
2930
std::vector<T_> c(cTensor->elementsSize());
3031
auto &dev = *device::init(Device::Type::Nvidia, 0, "");
3132
auto aGPU = dev.malloc(aTensor->bytesSize()),
@@ -53,35 +54,56 @@ void testBinaryCuda(SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape di
5354
}
5455

5556
TEST(kernel, BinaryCudaAdd) {
56-
testBinaryCuda(SimpleBinaryType::Add,
57-
Shape{2, 5, 10, 20, 3, 4},
58-
Shape{2, 5, 10, 20, 3, 4},
59-
Shape{2, 5, 10, 20, 3, 4});
57+
testBinaryCuda<DataType::I8>(SimpleBinaryType::Add,
58+
Shape{2, 5, 10, 20, 3, 4},
59+
Shape{2, 5, 10, 20, 3, 4},
60+
Shape{2, 5, 10, 20, 3, 4});
6061
}
6162

6263
TEST(kernel, BinaryCudaMul) {
63-
testBinaryCuda(SimpleBinaryType::Mul,
64-
Shape{2, 5, 10, 20, 3, 4},
65-
Shape{2, 5, 10, 20, 3, 4},
66-
Shape{2, 5, 10, 20, 3, 4});
64+
testBinaryCuda<DataType::I8>(SimpleBinaryType::Mul,
65+
Shape{2, 5, 10, 20, 3, 4},
66+
Shape{2, 5, 10, 20, 3, 4},
67+
Shape{2, 5, 10, 20, 3, 4});
6768
}
6869

6970
TEST(kernel, BinaryCudaSub) {
70-
testBinaryCuda(SimpleBinaryType::Sub,
71-
Shape{2, 5, 10, 20, 3, 4},
72-
Shape{2, 5, 10, 20, 3, 4},
73-
Shape{2, 5, 10, 20, 3, 4});
71+
testBinaryCuda<DataType::I8>(SimpleBinaryType::Sub,
72+
Shape{2, 5, 10, 20, 3, 4},
73+
Shape{2, 5, 10, 20, 3, 4},
74+
Shape{2, 5, 10, 20, 3, 4});
7475
}
7576

7677
TEST(kernel, BinaryCudaDiv) {
77-
testBinaryCuda(SimpleBinaryType::Div,
78-
Shape{2, 5, 10, 20, 3, 4},
79-
Shape{2, 5, 10, 20, 3, 4},
80-
Shape{2, 5, 10, 20, 3, 4});
78+
testBinaryCuda<DataType::I8>(SimpleBinaryType::Div,
79+
Shape{2, 5, 10, 20, 3, 4},
80+
Shape{2, 5, 10, 20, 3, 4},
81+
Shape{2, 5, 10, 20, 3, 4});
82+
}
83+
84+
TEST(kernel, BinaryCudaMod) {
85+
testBinaryCuda<DataType::I8>(SimpleBinaryType::Mod,
86+
Shape{2, 5, 10, 20, 3, 4},
87+
Shape{2, 5, 10, 20, 3, 4},
88+
Shape{2, 5, 10, 20, 3, 4});
89+
}
90+
91+
TEST(kernel, BinaryCudaFmodI8) {
92+
testBinaryCuda<DataType::I8>(SimpleBinaryType::Fmod,
93+
Shape{2, 5, 10, 20, 3, 4},
94+
Shape{2, 5, 10, 20, 3, 4},
95+
Shape{2, 5, 10, 20, 3, 4});
96+
}
97+
98+
TEST(kernel, BinaryCudaFmodF32) {
99+
testBinaryCuda<DataType::F32>(SimpleBinaryType::Fmod,
100+
Shape{2, 5, 10, 20, 3, 4},
101+
Shape{2, 5, 10, 20, 3, 4},
102+
Shape{2, 5, 10, 20, 3, 4});
81103
}
82104

83105
TEST(kernel, BinaryCudaBroadcast) {
84-
testBinaryCuda(SimpleBinaryType::Add, Shape{1, 2, 3, 4, 5, 6}, Shape{}, Shape{1, 2, 3, 4, 5, 6});
106+
testBinaryCuda<DataType::I8>(SimpleBinaryType::Add, Shape{1, 2, 3, 4, 5, 6}, Shape{}, Shape{1, 2, 3, 4, 5, 6});
85107
}
86108

87109
#endif

src/05computation/src/operators/simple_binary.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ namespace refactor::computation {
3939
static uint8_t ID = 8;
4040
return reinterpret_cast<size_t>(&ID);
4141
}
42+
case Ty::Mod: {
43+
static uint8_t ID = 9;
44+
return reinterpret_cast<size_t>(&ID);
45+
}
46+
case Ty::Fmod: {
47+
static uint8_t ID = 10;
48+
return reinterpret_cast<size_t>(&ID);
49+
}
4250
default:
4351
UNREACHABLE();
4452
}
@@ -64,6 +72,10 @@ namespace refactor::computation {
6472
return "Or";
6573
case Ty::Xor:
6674
return "Xor";
75+
case Ty::Mod:
76+
return "Mod";
77+
case Ty::Fmod:
78+
return "Fmod";
6779
default:
6880
UNREACHABLE();
6981
}

src/07onnx/src/operators/simple_binary.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace refactor::onnx {
1010
: Operator(), type(type_) {}
1111

1212
auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox {
13-
ASSERT(attributes.empty(), "Simple binary operator should not have attributes");
13+
auto fmod = defaultOr(attributes, "fmod", {0}).int_();
1414
// clang-format off
1515
auto type =
1616
opType == "onnx::Add" ? Ty::Add :
@@ -21,6 +21,7 @@ namespace refactor::onnx {
2121
opType == "onnx::And" ? Ty::And :
2222
opType == "onnx::Or" ? Ty::Or :
2323
opType == "onnx::Xor" ? Ty::Xor :
24+
opType == "onnx::Mod" ? (fmod == 0 ? Ty::Mod : Ty::Fmod) :
2425
UNREACHABLEX(Ty, "Unsupported binary operator: {}", opType);
2526
// clang-format on
2627
return OpBox(std::make_unique<Op>(type));
@@ -48,6 +49,26 @@ namespace refactor::onnx {
4849
static uint8_t ID = 5;
4950
return reinterpret_cast<size_t>(&ID);
5051
}
52+
case Ty::And: {
53+
static uint8_t ID = 6;
54+
return reinterpret_cast<size_t>(&ID);
55+
}
56+
case Ty::Or: {
57+
static uint8_t ID = 7;
58+
return reinterpret_cast<size_t>(&ID);
59+
}
60+
case Ty::Xor: {
61+
static uint8_t ID = 8;
62+
return reinterpret_cast<size_t>(&ID);
63+
}
64+
case Ty::Mod: {
65+
static uint8_t ID = 9;
66+
return reinterpret_cast<size_t>(&ID);
67+
}
68+
case Ty::Fmod: {
69+
static uint8_t ID = 10;
70+
return reinterpret_cast<size_t>(&ID);
71+
}
5172
default:
5273
UNREACHABLE();
5374
}
@@ -65,6 +86,8 @@ namespace refactor::onnx {
6586
case Ty::And: return "onnx::And";
6687
case Ty::Or : return "onnx::Or" ;
6788
case Ty::Xor: return "onnx::Xor";
89+
case Ty::Mod: return "onnx::Mod";
90+
case Ty::Fmod: return "onnx::Mod";
6891
default: UNREACHABLE();
6992
}
7093
// clang-format on
@@ -162,6 +185,8 @@ namespace refactor::onnx {
162185
case Ty::And : type_ = Ty_::And; break;
163186
case Ty::Or : type_ = Ty_::Or ; break;
164187
case Ty::Xor : type_ = Ty_::Xor; break;
188+
case Ty::Mod : type_ = Ty_::Mod; break;
189+
case Ty::Fmod : type_ = Ty_::Fmod; break;
165190
default: UNREACHABLE();
166191
}
167192
// clang-format on

src/07onnx/src/operators/simple_binary.hh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ namespace refactor::onnx {
1515
And,
1616
Or,
1717
Xor,
18+
Mod,
19+
Fmod,
1820
};
1921

2022
struct SimpleBinary final : public Operator {

0 commit comments

Comments
 (0)