17
17
18
18
from src .datamodules .components .edm import get_bond_length_arrays
19
19
from src .datamodules .components .edm .datasets_config import QM9_WITH_H , QM9_WITHOUT_H
20
- from src .models import NumNodesDistribution , PropertiesDistribution , compute_mean_mad
20
+ from src .models import NumNodesDistribution , PropertiesDistribution , compute_mean_mad , save_and_sample_conditionally
21
21
from src .utils .pylogger import get_pylogger
22
22
23
23
from src import LR_SCHEDULER_MANUAL_INTERPOLATION_HELPER_CONFIG_ITEMS , LR_SCHEDULER_MANUAL_INTERPOLATION_PRIMARY_CONFIG_ITEMS , get_classifier , test_with_property_classifier , utils
@@ -96,7 +96,7 @@ def __iter__(self):
96
96
def sample (self ) -> Dict [str , Any ]:
97
97
num_nodes = self .nodes_distr .sample (self .num_samples ).to (self .device )
98
98
context = self .props_distr .sample_batch (num_nodes ).to (self .device )
99
- x , one_hot , _ = self .model .sample (
99
+ x , one_hot , _ , _ = self .model .sample (
100
100
num_samples = self .num_samples ,
101
101
num_nodes = num_nodes ,
102
102
context = context
@@ -165,14 +165,12 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
165
165
166
166
assert (
167
167
os .path .exists (cfg .generator_model_filepath ) and
168
- os .path .exists (cfg .classifier_model_dir ) and
168
+ ( os .path .exists (cfg .classifier_model_dir ) or cfg . sweep_property_values ) and
169
169
cfg .property in cfg .generator_model_filepath and
170
- cfg .property in cfg .classifier_model_dir
170
+ ( cfg .property in cfg .classifier_model_dir or cfg . sweep_property_values )
171
171
)
172
172
173
- log .info ("Loading classifier model!" )
174
173
device = f"cuda:{ cfg .trainer .devices [0 ]} " if torch .cuda .is_available () else "cpu"
175
- classifier = get_classifier (cfg .classifier_model_dir ).to (device )
176
174
177
175
log .info (f"Instantiating datamodule <{ cfg .datamodule ._target_ } >" )
178
176
datamodule : LightningDataModule = hydra .utils .instantiate (cfg .datamodule )
@@ -221,8 +219,6 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
221
219
bonds [0 ], bonds [1 ], bonds [2 ]
222
220
)
223
221
224
- log .info ("Creating dataloader with generator!" )
225
-
226
222
splits = ["train" , "valid" , "test" ]
227
223
dataloaders = [
228
224
datamodule .train_dataloader (),
@@ -250,8 +246,21 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
250
246
nodes_distr = NumNodesDistribution (histogram )
251
247
252
248
if cfg .sweep_property_values :
253
- raise NotImplementedError ()
249
+ log .info (f"Sampling conditionally via a sweep!" )
250
+
251
+ for i in range (cfg .num_sweeps ):
252
+ log .info (f"Sampling sweep { i + 1 } /{ cfg .num_sweeps } !" )
253
+ save_and_sample_conditionally (
254
+ cfg = cfg ,
255
+ model = model ,
256
+ props_distr = props_distr ,
257
+ dataset_info = dataset_info ,
258
+ epoch = i ,
259
+ id_from = 0
260
+ )
254
261
else :
262
+ log .info ("Creating dataloader with generator!" )
263
+
255
264
conditional_diffusion_dataloader = ConditionalDiffusionDataLoader (
256
265
model = model ,
257
266
nodes_distr = nodes_distr ,
@@ -261,6 +270,9 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
261
270
dataset_info = dataset_info ,
262
271
iterations = cfg .iterations
263
272
)
273
+
274
+ log .info ("Loading classifier model!" )
275
+ classifier = get_classifier (cfg .classifier_model_dir ).to (device )
264
276
265
277
log .info ("Evaluating classifier on generator's samples!" )
266
278
loss = test_with_property_classifier (
0 commit comments