Skip to content

Commit 2cf9487

Browse files
authored
feat(benchmark): make model and input shape customizable (#1282)
1 parent 50dc355 commit 2cf9487

File tree

1 file changed

+43
-3
lines changed

1 file changed

+43
-3
lines changed

benchmark/tm_benchmark.cc

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,30 @@ int benchmark_threads = 1;
3737
int benchmark_model = -1;
3838
int benchmark_cluster = 0;
3939
int benchmark_mask = 0xFFFF;
40+
int benchmark_data_type = 0;
4041
std::string benchmark_device = "";
4142
context_t s_context;
4243

4344

45+
int get_tenser_element_size(int data_type)
46+
{
47+
switch (data_type)
48+
{
49+
case TENGINE_DT_FP32:
50+
case TENGINE_DT_INT32:
51+
return 4;
52+
case TENGINE_DT_FP16:
53+
case TENGINE_DT_INT16:
54+
return 2;
55+
case TENGINE_DT_INT8:
56+
case TENGINE_DT_UINT8:
57+
return 1;
58+
default:
59+
return 0;
60+
}
61+
}
62+
63+
4464
int benchmark_graph(options_t* opt, const char* name, const char* file, int height, int width, int channel, int batch)
4565
{
4666
// create graph, load tengine model xxx.tmfile
@@ -55,9 +75,9 @@ int benchmark_graph(options_t* opt, const char* name, const char* file, int heig
5575
int input_size = height * width * channel;
5676
int shape[] = { batch, channel, height, width }; // nchw
5777

58-
std::vector<float> inout_buffer(input_size);
78+
std::vector<unsigned char> input_buffer(batch * input_size * get_tenser_element_size(benchmark_data_type));
5979

60-
memset(inout_buffer.data(), 1, inout_buffer.size() * sizeof(float));
80+
memset(input_buffer.data(), 1, input_buffer.size());
6181

6282
tensor_t input_tensor = get_graph_input_tensor(graph, 0, 0);
6383
if (input_tensor == nullptr)
@@ -79,7 +99,7 @@ int benchmark_graph(options_t* opt, const char* name, const char* file, int heig
7999
}
80100

81101
// prepare process input data, set the data mem to input tensor
82-
if (set_tensor_buffer(input_tensor, inout_buffer.data(), (int)(inout_buffer.size() * sizeof(float))) < 0)
102+
if (set_tensor_buffer(input_tensor, input_buffer.data(), (int)input_buffer.size()) < 0)
83103
{
84104
fprintf(stderr, "Tengine Benchmark: Set input tensor buffer failed\n");
85105
return -1;
@@ -147,6 +167,9 @@ int main(int argc, char* argv[])
147167
cmd.add<int>("model", 's', "benchmark which model, \"-1\" means all models", false, -1);
148168
cmd.add<int>("cpu_mask", 'a', "benchmark on masked cpu core(s)", false, -1);
149169
cmd.add<std::string>("device", 'd', "device name (should be upper-case)", false);
170+
cmd.add<std::string>("model_file", 'm', "path to a model file", false);
171+
cmd.add<std::string>("input_shape", 'i', "shape of input (n,c,h,w)", false);
172+
cmd.add<int>("input_dtype", 'f', "data type of input", false);
150173

151174
cmd.parse_check(argc, argv);
152175

@@ -156,6 +179,9 @@ int main(int argc, char* argv[])
156179
benchmark_cluster = cmd.get<int>("cpu_cluster");
157180
benchmark_mask = cmd.get<int>("cpu_mask");
158181
benchmark_device = cmd.get<std::string>("device");
182+
std::string benchmark_model_file = cmd.get<std::string>("model_file");
183+
std::string input_shape = cmd.get<std::string>("input_shape");
184+
benchmark_data_type = cmd.get<int>("input_dtype");
159185
if (benchmark_device.empty())
160186
{
161187
benchmark_device = "CPU";
@@ -263,6 +289,20 @@ int main(int argc, char* argv[])
263289
benchmark_graph(&opt, "mobilefacenets", "./models/mobilefacenets_benchmark.tmfile", 112, 112, 3, 1);
264290
break;
265291
default:
292+
if (!benchmark_model_file.empty()) {
293+
int n = 1, c = 3, h = 224, w = 224;
294+
if (!input_shape.empty()) {
295+
char ch;
296+
int count = sscanf(input_shape.c_str(), "%u%c%u%c%u%c%u", &n, &ch, &c, &ch, &h, &ch, &w);
297+
if (count == 3) {
298+
w = h, h = c, c = n, n = 1;
299+
} else if (count == 2) {
300+
w = c, h = n, c = 3, n = 1;
301+
}
302+
}
303+
benchmark_graph(&opt, benchmark_model_file.c_str(), benchmark_model_file.c_str(), w, h, c, n);
304+
break;
305+
}
266306
benchmark_graph(&opt, "squeezenet_v1.1", "./models/squeezenet_v1.1_benchmark.tmfile", 227, 227, 3, 1);
267307
benchmark_graph(&opt, "mobilenetv1", "./models/mobilenet_benchmark.tmfile", 224, 224, 3, 1);
268308
benchmark_graph(&opt, "mobilenetv2", "./models/mobilenet_v2_benchmark.tmfile", 224, 224, 3, 1);

0 commit comments

Comments
 (0)