Skip to content

Commit 16c8ba2

Browse files
fix uniform range for dit stage (#1215)
* fix uniform range for dit stage * fix for ddp logging block
1 parent de9f833 commit 16c8ba2

File tree

4 files changed

+20
-19
lines changed

4 files changed

+20
-19
lines changed

docs/zh/examples/fundiff.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
| 预训练模型 | 指标 |
3636
|:--| :--|
37-
| [fundiff_turbulence_mass_transfer_dit_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/fundiff/fundiff_turbulence_mass_transfer_dit_pretrained.pdparams) | Mean relative p error: 0.0651<br>Max relative p error: 0.1329<br>Min relative p error: 0.0345<br>Std relative p error: 0.0300<br>Mean relative sdf error: 0.0684<br>Max relative sdf error: 0.1443<br>Min relative sdf error: 0.0335<br>Std relative sdf error: 0.0377 |
37+
| [fundiff_turbulence_mass_transfer_dit_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/fundiff/fundiff_turbulence_mass_transfer_dit_pretrained.pdparams) | Mean relative p error: 0.066<br>Max relative p error: 0.159<br>Min relative p error: 0.029<br>Std relative p error: 0.027<br>Mean relative sdf error: 0.085<br>Max relative sdf error: 0.307<br>Min relative sdf error: 0.022<br>Std relative sdf error: 0.0499 |
3838

3939
## 1. 背景简介
4040

examples/fundiff/conf/diffusion.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ FAE:
6161
num_mlp_layers: 2
6262
out_dim: 1
6363
layer_norm_eps: 1e-05
64-
pretrained_model_path: null
64+
pretrained_model_path: ???
6565

6666
DIT:
6767
# input_keys: [u, v, p, sdf]

examples/fundiff/main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ def forward(self, batch: Dict[str, paddle.Tensor]):
6666
z_sdf = self.encoder(sdf)
6767
z_1 = paddle.concat([z_p, z_sdf], axis=-1)
6868
z_0 = paddle.randn(z_1.shape) # (b, 200, 512)
69-
t = paddle.uniform([z_1.shape[0], *[1 for _ in range(z_1.ndim - 1)]])
69+
t = paddle.uniform(
70+
[z_1.shape[0], *[1 for _ in range(z_1.ndim - 1)]],
71+
min=0.0,
72+
max=1.0,
73+
)
7074
z_t = t * (z_1 - z_0) + z_0
7175
v_t = z_1 - z_0
7276
else:

ppsci/utils/logger.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -213,22 +213,19 @@ def scalar(
213213
wandb_writer (Optional[wandb.run]): Run object of WandB to record metrics. Defaults to None.
214214
tbd_writer (Optional[tbd.SummaryWriter]): Run object of WandB to record metrics. Defaults to None.
215215
"""
216-
if vdl_writer is not None:
217-
with misc.RankZeroOnly() as is_master:
218-
if is_master:
219-
for name, value in metric_dict.items():
220-
vdl_writer.add_scalar(name, value, step)
221-
222-
if wandb_writer is not None:
223-
with misc.RankZeroOnly() as is_master:
224-
if is_master:
225-
wandb_writer.log({"step": step, **metric_dict})
226-
227-
if tbd_writer is not None:
228-
with misc.RankZeroOnly() as is_master:
229-
if is_master:
230-
for name, value in metric_dict.items():
231-
tbd_writer.add_scalar(name, value, global_step=step)
216+
with misc.RankZeroOnly() as is_master:
217+
if vdl_writer is not None and is_master:
218+
for name, value in metric_dict.items():
219+
vdl_writer.add_scalar(name, value, step)
220+
221+
with misc.RankZeroOnly() as is_master:
222+
if wandb_writer is not None and is_master:
223+
wandb_writer.log({"step": step, **metric_dict})
224+
225+
with misc.RankZeroOnly() as is_master:
226+
if tbd_writer is not None and is_master:
227+
for name, value in metric_dict.items():
228+
tbd_writer.add_scalar(name, value, global_step=step)
232229

233230

234231
def advertise():

0 commit comments

Comments
 (0)