-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathvalidate.py
50 lines (39 loc) · 1.73 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from copy import deepcopy
from pathlib import Path
import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf, open_dict
import torch
import wandb
@hydra.main(version_base=None, config_path="config", config_name="validate") # only hydra config and overrides
def main(overrides):
# set to prevent warning
torch.set_float32_matmul_precision("high")
# get checkpoint
api = wandb.Api()
project_path = f"{overrides.wandb.entity}/{overrides.wandb.project}"
checkpoint_path = f"{project_path}/model-{overrides.runid}:{overrides.checkpoint}"
checkpoint = Path(api.artifact(checkpoint_path).download()) / "model.ckpt"
# get training config and merge with overrides
run = api.run(f"{project_path}/{overrides.runid}")
config = OmegaConf.create(deepcopy(run.config))
with open_dict(config):
config.merge_with(overrides)
# dataset + dataloader = lightning datamodule
datamodule = instantiate(config.datamodule)
# network + loss + optimizer = lightning module
# not strict state dict loading because not using compiled network params
network = instantiate(config.network)
loss_fns = instantiate(config.loss_fns)
litmodule = instantiate(config.litmodule, network, loss_fns, optimizer=None)
litmodule.load_state_dict(torch.load(checkpoint, weights_only=True, map_location="cpu")["state_dict"], strict=False)
litmodule.eval()
litmodule.freeze()
# callbacks
callbacks = instantiate(config.callbacks)
callbacks.pop("checkpoint")
# trainer and validate!
trainer = instantiate(config.trainer, logger=False, callbacks=[cb for cb in callbacks.values()])
trainer.validate(litmodule, datamodule=datamodule)
if __name__ == "__main__":
main()