From 4d424473d53757a5a03efe26428fd5ceb8fd3ad2 Mon Sep 17 00:00:00 2001 From: LeoXing Date: Wed, 27 Jul 2022 12:03:59 +0800 Subject: [PATCH] support option to save image at test stage --- mmgen/core/hooks/visualization_hook.py | 5 +++-- .../test_core/test_hooks/test_visualizer_hook.py | 15 ++++++++------- tools/test.py | 12 ++++++++++++ 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/mmgen/core/hooks/visualization_hook.py b/mmgen/core/hooks/visualization_hook.py index 4872aff9f..25dda75fa 100644 --- a/mmgen/core/hooks/visualization_hook.py +++ b/mmgen/core/hooks/visualization_hook.py @@ -114,7 +114,7 @@ def __init__(self, n_row: Optional[int] = 8, message_hub_vis_kwargs: Optional[Tuple[str, dict, List[str], List[Dict]]] = None, - save_at_test: bool = True, + save_at_test: bool = False, test_vis_keys: Optional[Union[str, List[str]]] = None, show: bool = False, wait_time: float = 0): @@ -205,10 +205,11 @@ def after_test_iter(self, runner: Runner, batch_idx, data_batch, outputs): for key in target_keys: name = key.replace('.', '_') + name = f'test_{name}' self._visualizer.add_datasample( name=name, gen_samples=[sample], - batch_idx=curr_idx, + step=curr_idx, target_keys=key, n_row=1, color_order=output_color_order, diff --git a/tests/test_core/test_hooks/test_visualizer_hook.py b/tests/test_core/test_hooks/test_visualizer_hook.py index 4018ca35a..7e61d8564 100644 --- a/tests/test_core/test_hooks/test_visualizer_hook.py +++ b/tests/test_core/test_hooks/test_visualizer_hook.py @@ -397,7 +397,8 @@ def test_after_test_iter(self): interval=10, n_samples=2, test_vis_keys=['ema', 'orig', 'new_model.x_t', 'gt_img'], - vis_kwargs_list=dict(type='GAN')) + vis_kwargs_list=dict(type='GAN'), + save_at_test=True) mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() hook._visualizer = mock_visualuzer @@ -429,30 +430,30 @@ def test_after_test_iter(self): _, called_kwargs = args gen_samples = called_kwargs['gen_samples'] name = called_kwargs['name'] - batch_idx = called_kwargs['batch_idx'] + step = called_kwargs['step'] target_keys = called_kwargs['target_keys'] self.assertEqual(len(gen_samples), 1) idx_in_outputs = idx // 4 - self.assertEqual(batch_idx, idx_in_outputs + 42 * len(outputs)) + self.assertEqual(step, idx_in_outputs + 42 * len(outputs)) self.assertEqual(outputs[idx_in_outputs], gen_samples[0]) # check ema if idx % 4 == 0: self.assertEqual(target_keys, 'ema') - self.assertEqual(name, 'ema') + self.assertEqual(name, 'test_ema') # check orig elif idx % 4 == 1: self.assertEqual(target_keys, 'orig') - self.assertEqual(name, 'orig') + self.assertEqual(name, 'test_orig') # check x_t elif idx % 4 == 2: self.assertEqual(target_keys, 'new_model.x_t') - self.assertEqual(name, 'new_model_x_t') + self.assertEqual(name, 'test_new_model_x_t') # check gt else: self.assertEqual(target_keys, 'gt_img') - self.assertEqual(name, 'gt_img') + self.assertEqual(name, 'test_gt_img') # test get target key automatically hook.test_vis_keys_list = None diff --git a/tools/test.py b/tools/test.py index a6b9ed2ef..b01861c60 100644 --- a/tools/test.py +++ b/tools/test.py @@ -18,6 +18,10 @@ def parse_args(): parser.add_argument( '--work-dir', help='the directory to save the file containing evaluation metrics') + parser.add_argument( + '--save-image', + action='store_true', + help='whether save generated image in test process.') parser.add_argument( '--cfg-options', nargs='+', @@ -64,6 +68,14 @@ def main(): cfg.load_from = args.checkpoint + if args.save_vis: + custom_hooks = cfg.custom_hooks + for hook in custom_hooks: + hook_type = hook['type'] + if hook_type == 'GenVisualizationHook': + hook['save_at_test'] = True + break + # build the runner from config runner = Runner.from_cfg(cfg)