Skip to content

Commit c817ec8

Browse files
committed
Remove mIoU during training
1 parent b5cf272 commit c817ec8

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

train_kitti.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _update(engine, batch):
9191
loss.backward()
9292
optimizer.step()
9393

94-
return loss.item(), (pred, target)
94+
return loss.item()
9595

9696
trainer = Engine(_update)
9797

@@ -100,13 +100,11 @@ def _update(engine, batch):
100100
timer = Timer(average=True)
101101

102102
# attach running average metrics
103-
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'loss')
104-
cm = ConfusionMatrix(num_classes, output_transform=lambda x: x[1])
105-
mIoU(cm, ignore_index=0).attach(trainer, 'mIoU')
103+
RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')
106104

107105
# attach progress bar
108106
pbar = ProgressBar(persist=True)
109-
pbar.attach(trainer, metric_names=['loss', 'mIoU'])
107+
pbar.attach(trainer, metric_names=['loss'])
110108

111109
@trainer.on(Events.EPOCH_COMPLETED)
112110
def save_checkpoint(engine):
@@ -125,8 +123,8 @@ def _inference(engine, batch):
125123
return pred, target
126124

127125
evaluator = Engine(_inference)
128-
cm2 = ConfusionMatrix(num_classes)
129-
mIoU(cm2, ignore_index=0).attach(evaluator, 'mIoU')
126+
cm = ConfusionMatrix(num_classes)
127+
mIoU(cm, ignore_index=0).attach(evaluator, 'mIoU')
130128
Loss(criterion).attach(evaluator, 'loss')
131129

132130
def _global_step_transform(engine, event_name):
@@ -135,11 +133,11 @@ def _global_step_transform(engine, event_name):
135133
tb_logger = TensorboardLogger(args.log_dir)
136134
tb_logger.attach(trainer,
137135
log_handler=OutputHandler(tag='training',
138-
metric_names=['loss', 'mIoU']),
136+
metric_names=['loss']),
139137
event_name=Events.ITERATION_COMPLETED)
140138

141139
tb_logger.attach(evaluator,
142-
log_handler=OutputHandler(tag='validation_eval',
140+
log_handler=OutputHandler(tag='validation',
143141
metric_names=['loss', 'mIoU'],
144142
global_step_transform=_global_step_transform),
145143
event_name=Events.EPOCH_COMPLETED)

0 commit comments

Comments
 (0)