Skip to content

Commit 18f9d21

Browse files
committed
Add missing context placement for QM9 train-time sampling
1 parent 22d187a commit 18f9d21

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

.gitignore

+2-2
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ outputs/
173173
epoch_*/
174174
.cache/
175175

176-
data/EDM/GEOM
177-
data/EDM/QM9
176+
data/EDM/GEOM*
177+
data/EDM/QM9*
178178

179179
# NFS
180180
.nfs*

src/models/qm9_mol_gen_ddpm.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,11 @@ def sample_and_analyze(
783783
assert int(num_nodes.max()) <= max_num_nodes
784784

785785
# context-conditioning
786-
context = None
786+
if self.condition_on_context:
787+
if context is None:
788+
context = self.props_distr.sample_batch(num_nodes)
789+
else:
790+
context = None
787791

788792
xh, batch_index, _ = self.ddpm.mol_gen_sample(
789793
num_samples=num_samples_batch,

0 commit comments

Comments
 (0)