Skip to content

Commit 0ecd479

Browse files
committed
Add conditional sweeping
1 parent 9c3f0f8 commit 0ecd479

7 files changed

+108
-18
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ scripts/nautilus/persistent_storage.yaml
168168

169169
bio-diffusion/
170170
logs/
171+
outputs/
171172

172173
data/EDM/GEOM
173174
data/EDM/QM9

configs/mol_gen_eval_conditional_qm9.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ batch_size: 100
2424
debug_break: false
2525
sweep_property_values: false
2626
num_sweeps: 10
27+
experiment_name: ${.property}-conditioning
2728
output_dir: ""

src/models/__init__.py

+67
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from functools import partial
1212
from matplotlib.lines import Line2D
13+
from omegaconf import DictConfig
1314
from torch.utils.data import DataLoader
1415
from torch.distributions.categorical import Categorical
1516
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
@@ -18,6 +19,8 @@
1819
from torchtyping import TensorType, patch_typeguard
1920
from typeguard import typechecked
2021

22+
from src.models.components import save_xyz_file, visualize_mol_chain
23+
2124
patch_typeguard() # use before @typechecked
2225

2326
HALT_FILE_EXTENSION = "done"
@@ -194,6 +197,70 @@ def log_grad_flow_full(
194197
wandb_run.log({"Gradient flow": plt})
195198

196199

200+
@typechecked
201+
def sample_sweep_conditionally(
202+
model: nn.Module,
203+
props_distr: object,
204+
num_nodes: int = 19,
205+
num_frames: int = 100
206+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
207+
num_nodes_ = torch.tensor([num_nodes] * num_frames, device=model.device)
208+
209+
context = []
210+
for key in props_distr.distributions:
211+
min_val, max_val = props_distr.distributions[key][num_nodes]['params']
212+
mean, mad = props_distr.normalizer[key]['mean'], props_distr.normalizer[key]['mad']
213+
min_val = ((min_val - mean) / (mad)).cpu().numpy()
214+
max_val = ((max_val - mean) / (mad)).cpu().numpy()
215+
context_row = torch.tensor(np.linspace(min_val, max_val, num_frames)).unsqueeze(1)
216+
context.append(context_row)
217+
context = torch.cat(context, dim=-1).float().to(model.device)
218+
219+
x, one_hot, charges, batch_index = model.sample(
220+
num_samples=num_frames,
221+
num_nodes=num_nodes_,
222+
context=context,
223+
fix_noise=True
224+
)
225+
return x, one_hot, charges, batch_index
226+
227+
228+
@typechecked
229+
def save_and_sample_conditionally(
230+
cfg: DictConfig,
231+
model: nn.Module,
232+
props_distr: object,
233+
dataset_info: Dict[str, Any],
234+
epoch: int = 0,
235+
id_from: int = 0
236+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
237+
x, one_hot, charges, batch_index = sample_sweep_conditionally(
238+
model=model,
239+
props_distr=props_distr
240+
)
241+
242+
save_xyz_file(
243+
path=f"outputs/{cfg.experiment_name}/analysis/run{epoch}/",
244+
positions=x,
245+
one_hot=one_hot,
246+
charges=charges,
247+
dataset_info=dataset_info,
248+
id_from=id_from,
249+
name="conditional",
250+
batch_index=batch_index
251+
)
252+
253+
visualize_mol_chain(
254+
path=f"outputs/{cfg.experiment_name}/analysis/run{epoch}/",
255+
dataset_info=dataset_info,
256+
wandb_run=None,
257+
spheres_3d=True,
258+
mode="conditional"
259+
)
260+
261+
return x, one_hot, charges
262+
263+
197264
class NumNodesDistribution(nn.Module):
198265
"""
199266
Adapted from: https://github.com/ehoogeboom/e3_diffusion_for_molecules

src/models/components/variational_diffusion.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -829,8 +829,10 @@ def sample_normal(
829829
) -> TensorType["batch_num_nodes", "num_x_dims_plus_num_node_scalar_features"]:
830830
"""Sample from a Normal distribution."""
831831
if fix_noise:
832-
raise NotImplementedError("The `fix_noise` option is currently not supported.")
833-
eps = self.sample_combined_position_feature_noise(batch_index, node_mask, generate_x_only=generate_x_only)
832+
batch_index_ = torch.zeros_like(batch_index) # broadcast same noise across batch
833+
eps = self.sample_combined_position_feature_noise(batch_index_, node_mask, generate_x_only=generate_x_only)
834+
else:
835+
eps = self.sample_combined_position_feature_noise(batch_index, node_mask, generate_x_only=generate_x_only)
834836
return mu + sigma[batch_index] * eps
835837

836838
@typechecked
@@ -1317,7 +1319,12 @@ def mol_gen_sample(
13171319
context = context * node_mask.float().unsqueeze(-1)
13181320

13191321
# sample from the noise distribution (i.e., p(z_T))
1320-
z = self.sample_combined_position_feature_noise(batch_index, node_mask, generate_x_only=generate_x_only)
1322+
if fix_noise:
1323+
batch_index_ = torch.zeros_like(batch_index) # broadcast same noise across batch
1324+
z = self.sample_combined_position_feature_noise(batch_index_, node_mask, generate_x_only=generate_x_only)
1325+
else:
1326+
z = self.sample_combined_position_feature_noise(batch_index, node_mask, generate_x_only=generate_x_only)
1327+
13211328
self.assert_mean_zero_with_mask(z[:, :self.num_x_dims], node_mask)
13221329

13231330
# iteratively sample p(z_s | z_t) for `t = 1, ..., T`, with `s = t - 1`.

src/models/geom_mol_gen_ddpm.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -589,8 +589,9 @@ def sample(
589589
num_nodes: Optional[TensorType["batch_size"]] = None,
590590
node_mask: Optional[TensorType["batch_num_nodes"]] = None,
591591
context: Optional[TensorType["batch_size", "num_context_features"]] = None,
592+
fix_noise: bool = False,
592593
num_timesteps: Optional[int] = None
593-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
594+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
594595
# node count-conditioning
595596
if num_nodes is None:
596597
num_nodes = self.ddpm.num_nodes_distribution.sample(num_samples)
@@ -605,7 +606,7 @@ def sample(
605606
context = None
606607

607608
# sampling
608-
xh, _, _ = self.ddpm.mol_gen_sample(
609+
xh, batch_index, _ = self.ddpm.mol_gen_sample(
609610
num_samples=num_samples,
610611
num_nodes=num_nodes,
611612
node_mask=node_mask,
@@ -618,7 +619,7 @@ def sample(
618619
one_hot = xh[:, self.num_x_dims:-1] if self.include_charges else xh[:, self.num_x_dims:]
619620
charges = xh[:, -1:] if self.include_charges else torch.zeros(0, device=self.device)
620621

621-
return x, one_hot, charges
622+
return x, one_hot, charges, batch_index
622623

623624
@torch.no_grad()
624625
@typechecked

src/models/qm9_mol_gen_ddpm.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -590,8 +590,9 @@ def sample(
590590
num_nodes: Optional[TensorType["batch_size"]] = None,
591591
node_mask: Optional[TensorType["batch_num_nodes"]] = None,
592592
context: Optional[TensorType["batch_size", "num_context_features"]] = None,
593+
fix_noise: bool = False,
593594
num_timesteps: Optional[int] = None
594-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
595+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
595596
# node count-conditioning
596597
if num_nodes is None:
597598
num_nodes = self.ddpm.num_nodes_distribution.sample(num_samples)
@@ -610,7 +611,7 @@ def sample(
610611
context = None
611612

612613
# sampling
613-
xh, _, _ = self.ddpm.mol_gen_sample(
614+
xh, batch_index, _ = self.ddpm.mol_gen_sample(
614615
num_samples=num_samples,
615616
num_nodes=num_nodes,
616617
node_mask=node_mask,
@@ -623,7 +624,7 @@ def sample(
623624
one_hot = xh[:, self.num_x_dims:-1] if self.include_charges else xh[:, self.num_x_dims:]
624625
charges = xh[:, -1:] if self.include_charges else torch.zeros(0, device=self.device)
625626

626-
return x, one_hot, charges
627+
return x, one_hot, charges, batch_index
627628

628629
@torch.no_grad()
629630
@typechecked

src/mol_gen_eval_conditional_qm9.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from src.datamodules.components.edm import get_bond_length_arrays
1919
from src.datamodules.components.edm.datasets_config import QM9_WITH_H, QM9_WITHOUT_H
20-
from src.models import NumNodesDistribution, PropertiesDistribution, compute_mean_mad
20+
from src.models import NumNodesDistribution, PropertiesDistribution, compute_mean_mad, save_and_sample_conditionally
2121
from src.utils.pylogger import get_pylogger
2222

2323
from src import LR_SCHEDULER_MANUAL_INTERPOLATION_HELPER_CONFIG_ITEMS, LR_SCHEDULER_MANUAL_INTERPOLATION_PRIMARY_CONFIG_ITEMS, get_classifier, test_with_property_classifier, utils
@@ -96,7 +96,7 @@ def __iter__(self):
9696
def sample(self) -> Dict[str, Any]:
9797
num_nodes = self.nodes_distr.sample(self.num_samples).to(self.device)
9898
context = self.props_distr.sample_batch(num_nodes).to(self.device)
99-
x, one_hot, _ = self.model.sample(
99+
x, one_hot, _, _ = self.model.sample(
100100
num_samples=self.num_samples,
101101
num_nodes=num_nodes,
102102
context=context
@@ -165,14 +165,12 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
165165

166166
assert (
167167
os.path.exists(cfg.generator_model_filepath) and
168-
os.path.exists(cfg.classifier_model_dir) and
168+
(os.path.exists(cfg.classifier_model_dir) or cfg.sweep_property_values) and
169169
cfg.property in cfg.generator_model_filepath and
170-
cfg.property in cfg.classifier_model_dir
170+
(cfg.property in cfg.classifier_model_dir or cfg.sweep_property_values)
171171
)
172172

173-
log.info("Loading classifier model!")
174173
device = f"cuda:{cfg.trainer.devices[0]}" if torch.cuda.is_available() else "cpu"
175-
classifier = get_classifier(cfg.classifier_model_dir).to(device)
176174

177175
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
178176
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule)
@@ -221,8 +219,6 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
221219
bonds[0], bonds[1], bonds[2]
222220
)
223221

224-
log.info("Creating dataloader with generator!")
225-
226222
splits = ["train", "valid", "test"]
227223
dataloaders = [
228224
datamodule.train_dataloader(),
@@ -250,8 +246,21 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
250246
nodes_distr = NumNodesDistribution(histogram)
251247

252248
if cfg.sweep_property_values:
253-
raise NotImplementedError()
249+
log.info(f"Sampling conditionally via a sweep!")
250+
251+
for i in range(cfg.num_sweeps):
252+
log.info(f"Sampling sweep {i + 1}/{cfg.num_sweeps}!")
253+
save_and_sample_conditionally(
254+
cfg=cfg,
255+
model=model,
256+
props_distr=props_distr,
257+
dataset_info=dataset_info,
258+
epoch=i,
259+
id_from=0
260+
)
254261
else:
262+
log.info("Creating dataloader with generator!")
263+
255264
conditional_diffusion_dataloader = ConditionalDiffusionDataLoader(
256265
model=model,
257266
nodes_distr=nodes_distr,
@@ -261,6 +270,9 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
261270
dataset_info=dataset_info,
262271
iterations=cfg.iterations
263272
)
273+
274+
log.info("Loading classifier model!")
275+
classifier = get_classifier(cfg.classifier_model_dir).to(device)
264276

265277
log.info("Evaluating classifier on generator's samples!")
266278
loss = test_with_property_classifier(

0 commit comments

Comments
 (0)