Skip to content

Commit 28046e0

Browse files
committedOct 24, 2024
feat: logging workflow improvements
1 parent faf1e0a commit 28046e0

File tree

1 file changed

+31
-16
lines changed

1 file changed

+31
-16
lines changed
 

‎modules/lightning.py

+31-16
Original file line numberDiff line numberDiff line change
@@ -83,23 +83,27 @@ def training_step(self, batch, batch_idx) -> Tensor | Mapping[str, Any] | None:
8383
x, y = data
8484

8585
Y_hat = self(x, y)
86-
loss, energy, thermal = L(x, Y_hat, y, self._extra_out['thermal_y'],
87-
self._extra_out['thermal_x'])
86+
loss = L(x, Y_hat, y, self._extra_out['thermal_y'],
87+
self._extra_out['thermal_x'])
88+
# Channelwise evaluation.
89+
energy = L.energy(y, Y_hat).mean(0)
90+
thermal = L.evaluate(y[:, 6:], self._extra_out['thermal_y']).mean(0)
8891

8992
self._extra_out['x'] = x
9093
self._extra_out['y'] = y
9194
self._extra_out['Y_hat'] = Y_hat
9295

9396
batch_loss = loss.mean()
97+
# Batchwise evaluation.
9498
sample_loss = loss.detach().mean(-1)
95-
ssim = L.evaluate(y, Y_hat)
9699

97100
self.log("loss/train", batch_loss,
98101
logger=True, prog_bar=True, on_epoch=True,
99-
on_step=False, batch_size=sample_loss.size(0))
102+
on_step=True, batch_size=sample_loss.size(0))
100103

101-
self.log_dict({"thermal_loss/train": thermal,
102-
# Per date.
104+
self.log_dict({**{f"thermal_{i}/train": v for i, v in
105+
enumerate(thermal)},
106+
# Per date batch.
103107
**{f"{k}/train": v for k, v in zip(dates, sample_loss,
104108
strict=True)},
105109
# Per tile.
@@ -142,16 +146,20 @@ def validation_step(self, batch, batch_idx) -> Tensor | Mapping[str, Any] | None
142146
x, y = data
143147

144148
Y_hat = self(x, y)
145-
loss, energy, thermal = L(x, Y_hat, y, self._extra_out['thermal_y'],
146-
self._extra_out['thermal_x'])
149+
loss = L(x, Y_hat, y, self._extra_out['thermal_y'],
150+
self._extra_out['thermal_x'])
151+
# Channelwise evaluation.
152+
energy = L.energy(y, Y_hat).mean(0)
153+
thermal = L.evaluate(y[:, 6:], self._extra_out['thermal_y']).mean(0)
147154

148155
batch_loss = loss.mean()
149-
sample_loss = loss.mean(-1)
156+
sample_loss = loss.detach().mean(-1)
150157

151158
self.log("loss/val", batch_loss, prog_bar=True, logger=True,
152159
on_epoch=True, on_step=True, batch_size=sample_loss.size(0))
153160

154-
self.log_dict({"thermal_loss/val": thermal,
161+
self.log_dict({**{f"thermal_{i}/val": v
162+
for i, v in enumerate(thermal)},
155163
# Per date.
156164
**{f"{k}/val": v for k, v in zip(dates, sample_loss,
157165
strict=True)},
@@ -172,23 +180,30 @@ def test_step(self, batch, batch_idx) -> Tensor | Mapping[str, Any] | None:
172180
dates, tiles = metadata
173181
x, y = data
174182
Y_hat = self(x, y)
175-
loss, energy, thermal = L(x, Y_hat, y, self._extra_out['thermal_y'],
176-
self._extra_out['thermal_x'])
183+
184+
loss = L(x, Y_hat, y, self._extra_out['thermal_y'],
185+
self._extra_out['thermal_x'])
186+
187+
# Channelwise evaluation.
188+
energy = L.energy(y, Y_hat).mean(0)
189+
thermal = L.evaluate(y[:, 6:], self._extra_out['thermal_y']).mean(0)
190+
177191
batch_loss = loss.mean()
178-
sample_loss = loss.mean(-1)
192+
sample_loss = loss.detach().mean(-1)
179193

180194
self.log("loss/test", batch_loss, prog_bar=True, logger=True,
181-
on_step=False, on_epoch=True, batch_size=sample_loss.size(0))
195+
on_step=True, on_epoch=True, batch_size=sample_loss.size(0))
182196

183-
self.log_dict({"thermal_loss/test": thermal,
197+
self.log_dict({**{f"thermal_{i}/test": v
198+
for i, v in enumerate(thermal)},
184199
# Per date.
185200
**{f"{k}/test": v for k, v in zip(dates, sample_loss,
186201
strict=True)},
187202
# Per tile.
188203
**{f"{k}/test": v for k, v in zip(tiles, sample_loss,
189204
strict=True)}
190205
},
191-
on_step=False,
206+
on_step=True,
192207
on_epoch=True,
193208
prog_bar=False,
194209
logger=True,

0 commit comments

Comments
 (0)
Please sign in to comment.