Skip to content

Commit 4f148cd

Browse files
authored
Merge pull request #71 from InfiniTensor/dev-erf
feat: add Erf cpu/cuda kernel
2 parents 0f9fde5 + cc8a86d commit 4f148cd

File tree

5 files changed

+58
-1
lines changed

5 files changed

+58
-1
lines changed

src/04kernel/src/kernels/simple_unary/cpu_kernel.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace refactor::kernel {
1818
Op::Sigmoid,
1919
Op::Tanh,
2020
Op::Neg,
21+
Op::Erf,
2122
};
2223
return supportedOp.contains(op) && a.dataType.isCpuNumberic()
2324
? std::make_unique<K>(op, a.dataType, a.elementsSize())
@@ -155,6 +156,21 @@ namespace refactor::kernel {
155156
default:
156157
UNREACHABLE();
157158
}
159+
case Op::Erf:
160+
switch (dataType) {
161+
CASE(std::erf, F32);
162+
CASE(std::erf, F64);
163+
CASE(std::erf, I8);
164+
CASE(std::erf, I16);
165+
CASE(std::erf, I32);
166+
CASE(std::erf, I64);
167+
CASE(std::erf, U8);
168+
CASE(std::erf, U16);
169+
CASE(std::erf, U32);
170+
CASE(std::erf, U64);
171+
default:
172+
UNREACHABLE();
173+
}
158174
default:
159175
UNREACHABLE();
160176
}

src/04kernel/src/kernels/simple_unary/cuda_kernel.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ namespace refactor::kernel {
1818
auto K::build(Op op, Tensor const &a) noexcept -> KernelBox {
1919
static const std::unordered_set<Op>
2020
supportedOp{Op::Abs, Op::Relu, Op::Sqrt,
21-
Op::Sigmoid, Op::Tanh, Op::Neg};
21+
Op::Sigmoid, Op::Tanh, Op::Neg,
22+
Op::Erf};
2223
#ifndef USE_CUDA
2324
return nullptr;
2425
#endif
@@ -140,6 +141,19 @@ extern "C" __global__ void kernel(
140141
{__(Op::Neg, DT::BF16), "-x"},
141142
{__(Op::Neg, DT::F32 ), "-x"},
142143
{__(Op::Neg, DT::F64 ), "-x"},
144+
145+
{__(Op::Erf, DT::F32 ), "erff(x)"},
146+
{__(Op::Erf, DT::F64 ), "erf(x)"},
147+
{__(Op::Erf, DT::U8 ), "erff(static_cast<float>(x))"},
148+
{__(Op::Erf, DT::I8 ), "erff(static_cast<float>(x))"},
149+
{__(Op::Erf, DT::U16 ), "erff(static_cast<float>(x))"},
150+
{__(Op::Erf, DT::I16 ), "erff(static_cast<float>(x))"},
151+
{__(Op::Erf, DT::U32 ), "erf(static_cast<double>(x))"},
152+
{__(Op::Erf, DT::I32 ), "erf(static_cast<double>(x))"},
153+
{__(Op::Erf, DT::U64 ), "erf(static_cast<double>(x))"},
154+
{__(Op::Erf, DT::I64 ), "erf(static_cast<double>(x))"},
155+
{__(Op::Erf, DT::FP16), "__float2half(erff(__half2float(x)))"},
156+
{__(Op::Erf, DT::BF16), "__float2bfloat16(erff(__bfloat162float(x)))"},
143157
};
144158
// clang-format on
145159

src/04kernel/test/kernels/simple_unary/test_cpu.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@ TEST(kernel, SimpleUnaryCpu) {
3131
testOp(SimpleUnaryType::Abs, std::abs);
3232
testOp(SimpleUnaryType::Sqrt, std::sqrt);
3333
testOp(SimpleUnaryType::Tanh, std::tanh);
34+
testOp(SimpleUnaryType::Erf, std::erf);
3435
}

src/04kernel/test/kernels/simple_unary/test_cuda.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ TEST(kernel, SimpleUnaryCuda) {
5151
testOp(SimpleUnaryType::Sqrt);
5252
testOp(SimpleUnaryType::Sigmoid);
5353
testOp(SimpleUnaryType::Tanh);
54+
testOp(SimpleUnaryType::Erf);
5455
}
5556

5657
#endif
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include "../src/operators/simple_unary.hh"
2+
#include "onnx/operators.h"
3+
#include <gtest/gtest.h>
4+
5+
using namespace refactor;
6+
using namespace onnx;
7+
8+
TEST(infer, SimpleUnary) {
9+
onnx::register_();
10+
11+
{
12+
// Erf Test
13+
auto edges = Edges{
14+
{Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""},
15+
};
16+
count_t inputs[]{0};
17+
auto infered = SimpleUnary(SimpleUnaryType::Erf).infer(TensorRefs(edges, inputs), {true});
18+
ASSERT_TRUE(infered.isOk());
19+
auto outputs = std::move(infered.unwrap());
20+
ASSERT_EQ(outputs.size(), 1);
21+
auto y = std::move(outputs[0]);
22+
ASSERT_EQ(y->dataType, DataType::F32);
23+
ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)}));
24+
}
25+
}

0 commit comments

Comments
 (0)