@@ -37,10 +37,30 @@ int benchmark_threads = 1;
37
37
int benchmark_model = -1 ;
38
38
int benchmark_cluster = 0 ;
39
39
int benchmark_mask = 0xFFFF ;
40
+ int benchmark_data_type = 0 ;
40
41
std::string benchmark_device = " " ;
41
42
context_t s_context;
42
43
43
44
45
+ int get_tenser_element_size (int data_type)
46
+ {
47
+ switch (data_type)
48
+ {
49
+ case TENGINE_DT_FP32:
50
+ case TENGINE_DT_INT32:
51
+ return 4 ;
52
+ case TENGINE_DT_FP16:
53
+ case TENGINE_DT_INT16:
54
+ return 2 ;
55
+ case TENGINE_DT_INT8:
56
+ case TENGINE_DT_UINT8:
57
+ return 1 ;
58
+ default :
59
+ return 0 ;
60
+ }
61
+ }
62
+
63
+
44
64
int benchmark_graph (options_t * opt, const char * name, const char * file, int height, int width, int channel, int batch)
45
65
{
46
66
// create graph, load tengine model xxx.tmfile
@@ -55,9 +75,9 @@ int benchmark_graph(options_t* opt, const char* name, const char* file, int heig
55
75
int input_size = height * width * channel;
56
76
int shape[] = { batch, channel, height, width }; // nchw
57
77
58
- std::vector<float > inout_buffer ( input_size);
78
+ std::vector<unsigned char > input_buffer (batch * input_size * get_tenser_element_size (benchmark_data_type) );
59
79
60
- memset (inout_buffer .data (), 1 , inout_buffer .size () * sizeof ( float ));
80
+ memset (input_buffer .data (), 1 , input_buffer .size ());
61
81
62
82
tensor_t input_tensor = get_graph_input_tensor (graph, 0 , 0 );
63
83
if (input_tensor == nullptr )
@@ -79,7 +99,7 @@ int benchmark_graph(options_t* opt, const char* name, const char* file, int heig
79
99
}
80
100
81
101
// prepare process input data, set the data mem to input tensor
82
- if (set_tensor_buffer (input_tensor, inout_buffer .data (), (int )(inout_buffer .size () * sizeof ( float ) )) < 0 )
102
+ if (set_tensor_buffer (input_tensor, input_buffer .data (), (int )input_buffer .size ()) < 0 )
83
103
{
84
104
fprintf (stderr, " Tengine Benchmark: Set input tensor buffer failed\n " );
85
105
return -1 ;
@@ -147,6 +167,9 @@ int main(int argc, char* argv[])
147
167
cmd.add <int >(" model" , ' s' , " benchmark which model, \" -1\" means all models" , false , -1 );
148
168
cmd.add <int >(" cpu_mask" , ' a' , " benchmark on masked cpu core(s)" , false , -1 );
149
169
cmd.add <std::string>(" device" , ' d' , " device name (should be upper-case)" , false );
170
+ cmd.add <std::string>(" model_file" , ' m' , " path to a model file" , false );
171
+ cmd.add <std::string>(" input_shape" , ' i' , " shape of input (n,c,h,w)" , false );
172
+ cmd.add <int >(" input_dtype" , ' f' , " data type of input" , false );
150
173
151
174
cmd.parse_check (argc, argv);
152
175
@@ -156,6 +179,9 @@ int main(int argc, char* argv[])
156
179
benchmark_cluster = cmd.get <int >(" cpu_cluster" );
157
180
benchmark_mask = cmd.get <int >(" cpu_mask" );
158
181
benchmark_device = cmd.get <std::string>(" device" );
182
+ std::string benchmark_model_file = cmd.get <std::string>(" model_file" );
183
+ std::string input_shape = cmd.get <std::string>(" input_shape" );
184
+ benchmark_data_type = cmd.get <int >(" input_dtype" );
159
185
if (benchmark_device.empty ())
160
186
{
161
187
benchmark_device = " CPU" ;
@@ -263,6 +289,20 @@ int main(int argc, char* argv[])
263
289
benchmark_graph (&opt, " mobilefacenets" , " ./models/mobilefacenets_benchmark.tmfile" , 112 , 112 , 3 , 1 );
264
290
break ;
265
291
default :
292
+ if (!benchmark_model_file.empty ()) {
293
+ int n = 1 , c = 3 , h = 224 , w = 224 ;
294
+ if (!input_shape.empty ()) {
295
+ char ch;
296
+ int count = sscanf (input_shape.c_str (), " %u%c%u%c%u%c%u" , &n, &ch, &c, &ch, &h, &ch, &w);
297
+ if (count == 3 ) {
298
+ w = h, h = c, c = n, n = 1 ;
299
+ } else if (count == 2 ) {
300
+ w = c, h = n, c = 3 , n = 1 ;
301
+ }
302
+ }
303
+ benchmark_graph (&opt, benchmark_model_file.c_str (), benchmark_model_file.c_str (), w, h, c, n);
304
+ break ;
305
+ }
266
306
benchmark_graph (&opt, " squeezenet_v1.1" , " ./models/squeezenet_v1.1_benchmark.tmfile" , 227 , 227 , 3 , 1 );
267
307
benchmark_graph (&opt, " mobilenetv1" , " ./models/mobilenet_benchmark.tmfile" , 224 , 224 , 3 , 1 );
268
308
benchmark_graph (&opt, " mobilenetv2" , " ./models/mobilenet_v2_benchmark.tmfile" , 224 , 224 , 3 , 1 );
0 commit comments