@@ -18,7 +18,8 @@ namespace refactor::kernel {
1818 auto K::build (Op op, Tensor const &a) noexcept -> KernelBox {
1919 static const std::unordered_set<Op>
2020 supportedOp{Op::Abs, Op::Relu, Op::Sqrt,
21- Op::Sigmoid, Op::Tanh, Op::Neg};
21+ Op::Sigmoid, Op::Tanh, Op::Neg,
22+ Op::Erf};
2223#ifndef USE_CUDA
2324 return nullptr ;
2425#endif
@@ -140,6 +141,19 @@ extern "C" __global__ void kernel(
140141 {__ (Op::Neg, DT::BF16), " -x" },
141142 {__ (Op::Neg, DT::F32 ), " -x" },
142143 {__ (Op::Neg, DT::F64 ), " -x" },
144+
145+ {__ (Op::Erf, DT::F32 ), " erff(x)" },
146+ {__ (Op::Erf, DT::F64 ), " erf(x)" },
147+ {__ (Op::Erf, DT::U8 ), " erff(static_cast<float>(x))" },
148+ {__ (Op::Erf, DT::I8 ), " erff(static_cast<float>(x))" },
149+ {__ (Op::Erf, DT::U16 ), " erff(static_cast<float>(x))" },
150+ {__ (Op::Erf, DT::I16 ), " erff(static_cast<float>(x))" },
151+ {__ (Op::Erf, DT::U32 ), " erf(static_cast<double>(x))" },
152+ {__ (Op::Erf, DT::I32 ), " erf(static_cast<double>(x))" },
153+ {__ (Op::Erf, DT::U64 ), " erf(static_cast<double>(x))" },
154+ {__ (Op::Erf, DT::I64 ), " erf(static_cast<double>(x))" },
155+ {__ (Op::Erf, DT::FP16), " __float2half(erff(__half2float(x)))" },
156+ {__ (Op::Erf, DT::BF16), " __float2bfloat16(erff(__bfloat162float(x)))" },
143157 };
144158 // clang-format on
145159
0 commit comments