@@ -9,12 +9,13 @@ using namespace refactor;
99using namespace kernel ;
1010using namespace hardware ;
1111
12+ template <decltype (DataType::internal) T>
1213void testBinaryCuda (SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape dimC) {
1314 // Create Tensor and build kernels
14- using T_ = primitive<DataType::I8 >::type;
15- auto aTensor = Tensor::share (DataType::I8 , dimA, LayoutType::NCHW);
16- auto bTensor = Tensor::share (DataType::I8 , dimB, LayoutType::NCHW);
17- auto cTensor = Tensor::share (DataType::I8 , dimC, LayoutType::NCHW);
15+ using T_ = primitive<T >::type;
16+ auto aTensor = Tensor::share (T , dimA, LayoutType::NCHW);
17+ auto bTensor = Tensor::share (T , dimB, LayoutType::NCHW);
18+ auto cTensor = Tensor::share (T , dimC, LayoutType::NCHW);
1819
1920 auto cpuKernel = BinaryCpu::build (binaryOPT, *aTensor, *bTensor),
2021 cudaKernel = BinaryCuda::build (binaryOPT, *aTensor, *bTensor);
@@ -24,8 +25,8 @@ void testBinaryCuda(SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape di
2425 auto cudaRoutine = cudaKernel->lower (res).routine ;
2526
2627 // Init inputs and outputs
27- std::vector<T_> a (aTensor->elementsSize (), 3 . 0f );
28- std::vector<T_> b (bTensor->elementsSize (), 2 . 0f );
28+ std::vector<T_> a (aTensor->elementsSize (), 3 );
29+ std::vector<T_> b (bTensor->elementsSize (), 2 );
2930 std::vector<T_> c (cTensor->elementsSize ());
3031 auto &dev = *device::init (Device::Type::Nvidia, 0 , " " );
3132 auto aGPU = dev.malloc (aTensor->bytesSize ()),
@@ -53,35 +54,56 @@ void testBinaryCuda(SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape di
5354}
5455
5556TEST (kernel, BinaryCudaAdd) {
56- testBinaryCuda (SimpleBinaryType::Add,
57- Shape{2 , 5 , 10 , 20 , 3 , 4 },
58- Shape{2 , 5 , 10 , 20 , 3 , 4 },
59- Shape{2 , 5 , 10 , 20 , 3 , 4 });
57+ testBinaryCuda<DataType::I8> (SimpleBinaryType::Add,
58+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
59+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
60+ Shape{2 , 5 , 10 , 20 , 3 , 4 });
6061}
6162
6263TEST (kernel, BinaryCudaMul) {
63- testBinaryCuda (SimpleBinaryType::Mul,
64- Shape{2 , 5 , 10 , 20 , 3 , 4 },
65- Shape{2 , 5 , 10 , 20 , 3 , 4 },
66- Shape{2 , 5 , 10 , 20 , 3 , 4 });
64+ testBinaryCuda<DataType::I8> (SimpleBinaryType::Mul,
65+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
66+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
67+ Shape{2 , 5 , 10 , 20 , 3 , 4 });
6768}
6869
6970TEST (kernel, BinaryCudaSub) {
70- testBinaryCuda (SimpleBinaryType::Sub,
71- Shape{2 , 5 , 10 , 20 , 3 , 4 },
72- Shape{2 , 5 , 10 , 20 , 3 , 4 },
73- Shape{2 , 5 , 10 , 20 , 3 , 4 });
71+ testBinaryCuda<DataType::I8> (SimpleBinaryType::Sub,
72+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
73+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
74+ Shape{2 , 5 , 10 , 20 , 3 , 4 });
7475}
7576
7677TEST (kernel, BinaryCudaDiv) {
77- testBinaryCuda (SimpleBinaryType::Div,
78- Shape{2 , 5 , 10 , 20 , 3 , 4 },
79- Shape{2 , 5 , 10 , 20 , 3 , 4 },
80- Shape{2 , 5 , 10 , 20 , 3 , 4 });
78+ testBinaryCuda<DataType::I8>(SimpleBinaryType::Div,
79+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
80+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
81+ Shape{2 , 5 , 10 , 20 , 3 , 4 });
82+ }
83+
84+ TEST (kernel, BinaryCudaMod) {
85+ testBinaryCuda<DataType::I8>(SimpleBinaryType::Mod,
86+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
87+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
88+ Shape{2 , 5 , 10 , 20 , 3 , 4 });
89+ }
90+
91+ TEST (kernel, BinaryCudaFmodI8) {
92+ testBinaryCuda<DataType::I8>(SimpleBinaryType::Fmod,
93+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
94+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
95+ Shape{2 , 5 , 10 , 20 , 3 , 4 });
96+ }
97+
98+ TEST (kernel, BinaryCudaFmodF32) {
99+ testBinaryCuda<DataType::F32>(SimpleBinaryType::Fmod,
100+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
101+ Shape{2 , 5 , 10 , 20 , 3 , 4 },
102+ Shape{2 , 5 , 10 , 20 , 3 , 4 });
81103}
82104
83105TEST (kernel, BinaryCudaBroadcast) {
84- testBinaryCuda (SimpleBinaryType::Add, Shape{1 , 2 , 3 , 4 , 5 , 6 }, Shape{}, Shape{1 , 2 , 3 , 4 , 5 , 6 });
106+ testBinaryCuda<DataType::I8> (SimpleBinaryType::Add, Shape{1 , 2 , 3 , 4 , 5 , 6 }, Shape{}, Shape{1 , 2 , 3 , 4 , 5 , 6 });
85107}
86108
87109#endif
0 commit comments