@@ -83,23 +83,27 @@ def training_step(self, batch, batch_idx) -> Tensor | Mapping[str, Any] | None:
83
83
x , y = data
84
84
85
85
Y_hat = self (x , y )
86
- loss , energy , thermal = L (x , Y_hat , y , self ._extra_out ['thermal_y' ],
87
- self ._extra_out ['thermal_x' ])
86
+ loss = L (x , Y_hat , y , self ._extra_out ['thermal_y' ],
87
+ self ._extra_out ['thermal_x' ])
88
+ # Channelwise evaluation.
89
+ energy = L .energy (y , Y_hat ).mean (0 )
90
+ thermal = L .evaluate (y [:, 6 :], self ._extra_out ['thermal_y' ]).mean (0 )
88
91
89
92
self ._extra_out ['x' ] = x
90
93
self ._extra_out ['y' ] = y
91
94
self ._extra_out ['Y_hat' ] = Y_hat
92
95
93
96
batch_loss = loss .mean ()
97
+ # Batchwise evaluation.
94
98
sample_loss = loss .detach ().mean (- 1 )
95
- ssim = L .evaluate (y , Y_hat )
96
99
97
100
self .log ("loss/train" , batch_loss ,
98
101
logger = True , prog_bar = True , on_epoch = True ,
99
- on_step = False , batch_size = sample_loss .size (0 ))
102
+ on_step = True , batch_size = sample_loss .size (0 ))
100
103
101
- self .log_dict ({"thermal_loss/train" : thermal ,
102
- # Per date.
104
+ self .log_dict ({** {f"thermal_{ i } /train" : v for i , v in
105
+ enumerate (thermal )},
106
+ # Per date batch.
103
107
** {f"{ k } /train" : v for k , v in zip (dates , sample_loss ,
104
108
strict = True )},
105
109
# Per tile.
@@ -142,16 +146,20 @@ def validation_step(self, batch, batch_idx) -> Tensor | Mapping[str, Any] | None
142
146
x , y = data
143
147
144
148
Y_hat = self (x , y )
145
- loss , energy , thermal = L (x , Y_hat , y , self ._extra_out ['thermal_y' ],
146
- self ._extra_out ['thermal_x' ])
149
+ loss = L (x , Y_hat , y , self ._extra_out ['thermal_y' ],
150
+ self ._extra_out ['thermal_x' ])
151
+ # Channelwise evaluation.
152
+ energy = L .energy (y , Y_hat ).mean (0 )
153
+ thermal = L .evaluate (y [:, 6 :], self ._extra_out ['thermal_y' ]).mean (0 )
147
154
148
155
batch_loss = loss .mean ()
149
- sample_loss = loss .mean (- 1 )
156
+ sample_loss = loss .detach (). mean (- 1 )
150
157
151
158
self .log ("loss/val" , batch_loss , prog_bar = True , logger = True ,
152
159
on_epoch = True , on_step = True , batch_size = sample_loss .size (0 ))
153
160
154
- self .log_dict ({"thermal_loss/val" : thermal ,
161
+ self .log_dict ({** {f"thermal_{ i } /val" : v
162
+ for i , v in enumerate (thermal )},
155
163
# Per date.
156
164
** {f"{ k } /val" : v for k , v in zip (dates , sample_loss ,
157
165
strict = True )},
@@ -172,23 +180,30 @@ def test_step(self, batch, batch_idx) -> Tensor | Mapping[str, Any] | None:
172
180
dates , tiles = metadata
173
181
x , y = data
174
182
Y_hat = self (x , y )
175
- loss , energy , thermal = L (x , Y_hat , y , self ._extra_out ['thermal_y' ],
176
- self ._extra_out ['thermal_x' ])
183
+
184
+ loss = L (x , Y_hat , y , self ._extra_out ['thermal_y' ],
185
+ self ._extra_out ['thermal_x' ])
186
+
187
+ # Channelwise evaluation.
188
+ energy = L .energy (y , Y_hat ).mean (0 )
189
+ thermal = L .evaluate (y [:, 6 :], self ._extra_out ['thermal_y' ]).mean (0 )
190
+
177
191
batch_loss = loss .mean ()
178
- sample_loss = loss .mean (- 1 )
192
+ sample_loss = loss .detach (). mean (- 1 )
179
193
180
194
self .log ("loss/test" , batch_loss , prog_bar = True , logger = True ,
181
- on_step = False , on_epoch = True , batch_size = sample_loss .size (0 ))
195
+ on_step = True , on_epoch = True , batch_size = sample_loss .size (0 ))
182
196
183
- self .log_dict ({"thermal_loss/test" : thermal ,
197
+ self .log_dict ({** {f"thermal_{ i } /test" : v
198
+ for i , v in enumerate (thermal )},
184
199
# Per date.
185
200
** {f"{ k } /test" : v for k , v in zip (dates , sample_loss ,
186
201
strict = True )},
187
202
# Per tile.
188
203
** {f"{ k } /test" : v for k , v in zip (tiles , sample_loss ,
189
204
strict = True )}
190
205
},
191
- on_step = False ,
206
+ on_step = True ,
192
207
on_epoch = True ,
193
208
prog_bar = False ,
194
209
logger = True ,
0 commit comments