Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mmgen/core/hooks/visualization_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self,
n_row: Optional[int] = 8,
message_hub_vis_kwargs: Optional[Tuple[str, dict, List[str],
List[Dict]]] = None,
save_at_test: bool = True,
save_at_test: bool = False,
test_vis_keys: Optional[Union[str, List[str]]] = None,
show: bool = False,
wait_time: float = 0):
Expand Down Expand Up @@ -205,10 +205,11 @@ def after_test_iter(self, runner: Runner, batch_idx, data_batch, outputs):

for key in target_keys:
name = key.replace('.', '_')
name = f'test_{name}'
self._visualizer.add_datasample(
name=name,
gen_samples=[sample],
batch_idx=curr_idx,
step=curr_idx,
target_keys=key,
n_row=1,
color_order=output_color_order,
Expand Down
15 changes: 8 additions & 7 deletions tests/test_core/test_hooks/test_visualizer_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,8 @@ def test_after_test_iter(self):
interval=10,
n_samples=2,
test_vis_keys=['ema', 'orig', 'new_model.x_t', 'gt_img'],
vis_kwargs_list=dict(type='GAN'))
vis_kwargs_list=dict(type='GAN'),
save_at_test=True)
mock_visualuzer = MagicMock()
mock_visualuzer.add_datasample = MagicMock()
hook._visualizer = mock_visualuzer
Expand Down Expand Up @@ -429,30 +430,30 @@ def test_after_test_iter(self):
_, called_kwargs = args
gen_samples = called_kwargs['gen_samples']
name = called_kwargs['name']
batch_idx = called_kwargs['batch_idx']
step = called_kwargs['step']
target_keys = called_kwargs['target_keys']

self.assertEqual(len(gen_samples), 1)
idx_in_outputs = idx // 4
self.assertEqual(batch_idx, idx_in_outputs + 42 * len(outputs))
self.assertEqual(step, idx_in_outputs + 42 * len(outputs))
self.assertEqual(outputs[idx_in_outputs], gen_samples[0])

# check ema
if idx % 4 == 0:
self.assertEqual(target_keys, 'ema')
self.assertEqual(name, 'ema')
self.assertEqual(name, 'test_ema')
# check orig
elif idx % 4 == 1:
self.assertEqual(target_keys, 'orig')
self.assertEqual(name, 'orig')
self.assertEqual(name, 'test_orig')
# check x_t
elif idx % 4 == 2:
self.assertEqual(target_keys, 'new_model.x_t')
self.assertEqual(name, 'new_model_x_t')
self.assertEqual(name, 'test_new_model_x_t')
# check gt
else:
self.assertEqual(target_keys, 'gt_img')
self.assertEqual(name, 'gt_img')
self.assertEqual(name, 'test_gt_img')

# test get target key automatically
hook.test_vis_keys_list = None
Expand Down
12 changes: 12 additions & 0 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def parse_args():
parser.add_argument(
'--work-dir',
help='the directory to save the file containing evaluation metrics')
parser.add_argument(
'--save-image',
action='store_true',
help='whether save generated image in test process.')
parser.add_argument(
'--cfg-options',
nargs='+',
Expand Down Expand Up @@ -64,6 +68,14 @@ def main():

cfg.load_from = args.checkpoint

if args.save_vis:
custom_hooks = cfg.custom_hooks
for hook in custom_hooks:
hook_type = hook['type']
if hook_type == 'GenVisualizationHook':
hook['save_at_test'] = True
break

# build the runner from config
runner = Runner.from_cfg(cfg)

Expand Down