@@ -91,7 +91,7 @@ def _update(engine, batch):
91
91
loss .backward ()
92
92
optimizer .step ()
93
93
94
- return loss .item (), ( pred , target )
94
+ return loss .item ()
95
95
96
96
trainer = Engine (_update )
97
97
@@ -100,13 +100,11 @@ def _update(engine, batch):
100
100
timer = Timer (average = True )
101
101
102
102
# attach running average metrics
103
- RunningAverage (output_transform = lambda x : x [0 ]).attach (trainer , 'loss' )
104
- cm = ConfusionMatrix (num_classes , output_transform = lambda x : x [1 ])
105
- mIoU (cm , ignore_index = 0 ).attach (trainer , 'mIoU' )
103
+ RunningAverage (output_transform = lambda x : x ).attach (trainer , 'loss' )
106
104
107
105
# attach progress bar
108
106
pbar = ProgressBar (persist = True )
109
- pbar .attach (trainer , metric_names = ['loss' , 'mIoU' ])
107
+ pbar .attach (trainer , metric_names = ['loss' ])
110
108
111
109
@trainer .on (Events .EPOCH_COMPLETED )
112
110
def save_checkpoint (engine ):
@@ -125,8 +123,8 @@ def _inference(engine, batch):
125
123
return pred , target
126
124
127
125
evaluator = Engine (_inference )
128
- cm2 = ConfusionMatrix (num_classes )
129
- mIoU (cm2 , ignore_index = 0 ).attach (evaluator , 'mIoU' )
126
+ cm = ConfusionMatrix (num_classes )
127
+ mIoU (cm , ignore_index = 0 ).attach (evaluator , 'mIoU' )
130
128
Loss (criterion ).attach (evaluator , 'loss' )
131
129
132
130
def _global_step_transform (engine , event_name ):
@@ -135,11 +133,11 @@ def _global_step_transform(engine, event_name):
135
133
tb_logger = TensorboardLogger (args .log_dir )
136
134
tb_logger .attach (trainer ,
137
135
log_handler = OutputHandler (tag = 'training' ,
138
- metric_names = ['loss' , 'mIoU' ]),
136
+ metric_names = ['loss' ]),
139
137
event_name = Events .ITERATION_COMPLETED )
140
138
141
139
tb_logger .attach (evaluator ,
142
- log_handler = OutputHandler (tag = 'validation_eval ' ,
140
+ log_handler = OutputHandler (tag = 'validation ' ,
143
141
metric_names = ['loss' , 'mIoU' ],
144
142
global_step_transform = _global_step_transform ),
145
143
event_name = Events .EPOCH_COMPLETED )
0 commit comments