@@ -33,13 +33,13 @@ void test(po::variables_map &vm, torch::Device &device, GAN_Encoder &enc, GAN_Ge
33
33
// (0) Initialization and Declaration
34
34
float ave_anomaly_score, ave_res_loss, ave_dis_loss;
35
35
double seconds, ave_time;
36
- std::string path, result_dir, fname;
36
+ std::string path, result_dir, output_dir, heatmap_dir, fname;
37
37
std::string dataroot;
38
38
std::ofstream ofs, ofs_score;
39
39
std::chrono::system_clock::time_point start, end;
40
40
std::tuple<torch::Tensor, std::vector<std::string>> data;
41
41
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> anomaly_score_with_alpha;
42
- torch::Tensor image, z, output;
42
+ torch::Tensor image, z, output, heatmap ;
43
43
torch::Tensor anomaly_score, res_loss, dis_loss;
44
44
datasets::ImageFolderWithPaths dataset;
45
45
DataLoader::ImageFolderWithPaths dataloader;
@@ -67,6 +67,8 @@ void test(po::variables_map &vm, torch::Device &device, GAN_Encoder &enc, GAN_Ge
67
67
gen->eval ();
68
68
dis->eval ();
69
69
result_dir = vm[" test_result_dir" ].as <std::string>(); fs::create_directories (result_dir);
70
+ output_dir = result_dir + " /output" ; fs::create_directories (output_dir);
71
+ heatmap_dir = result_dir + " /heatmap" ; fs::create_directories (heatmap_dir);
70
72
ofs.open (result_dir + " /loss.txt" , std::ios::out);
71
73
ofs_score.open (result_dir + " /anomaly_score.txt" , std::ios::out);
72
74
while (dataloader (data)){
@@ -96,9 +98,13 @@ void test(po::variables_map &vm, torch::Device &device, GAN_Encoder &enc, GAN_Ge
96
98
ofs << ' <' << std::get<1 >(data).at (0 ) << " > anomaly_score:" << anomaly_score.item <float >() << " res:" << res_loss.item <float >() << " dis:" << dis_loss.item <float >() << std::endl;
97
99
ofs_score << anomaly_score.item <float >() << std::endl;
98
100
99
- fname = result_dir + ' /' + std::get<1 >(data).at (0 );
101
+ fname = output_dir + ' /' + std::get<1 >(data).at (0 );
100
102
visualizer::save_image (output.detach (), fname, /* range=*/ output_range, /* cols=*/ 1 , /* padding=*/ 0 );
101
103
104
+ fname = heatmap_dir + ' /' + std::get<1 >(data).at (0 );
105
+ heatmap = visualizer::create_heatmap (torch::abs (image - output).mean (/* dim=*/ 1 , /* keepdim=*/ true ), /* range=*/ {0 , (output_range.second - output_range.first ) * vm[" heatmap_max" ].as <float >()});
106
+ visualizer::save_image (heatmap.detach (), fname, /* range=*/ {0.0 , 1.0 }, /* cols=*/ 1 , /* padding=*/ 0 );
107
+
102
108
}
103
109
104
110
// (5) Calculate Average
0 commit comments