forked from YuelangX/AvatarMAV
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
44 lines (35 loc) · 1.94 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from config.config import config_headnerf
import argparse
import torch
import os
from lib.utils.util_seed import seed_everything
from lib.dataset.NeRFDataset import NeRFDataset
from lib.module.HeadModule import HeadModule
from lib.module.NeuralCameraModule import NeuralCameraModule
from lib.recorder.Recorder import TrainRecorder
from lib.trainer.Trainer import Trainer
if __name__ == '__main__':
seed_everything(2645647)
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/train.yaml')
arg = parser.parse_args()
cfg = config_headnerf()
cfg.load(arg.config)
cfg = cfg.get_cfg()
dataset = NeRFDataset(cfg.dataset)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True, pin_memory=True)
dataset_eval = NeRFDataset(cfg.dataset)
device = torch.device('cuda:%d' % cfg.gpu_id)
headnerf = HeadModule(cfg.headmodule).to(device)
if os.path.exists(cfg.load_headnerf_checkpoint):
headnerf.load_state_dict(torch.load(cfg.load_headnerf_checkpoint, map_location=lambda storage, loc: storage))
neu_camera = NeuralCameraModule(headnerf, cfg.neuralcamera)
optimizer = torch.optim.Adam([{'params' : headnerf.feature_volume, 'lr' : cfg.lr_feat_vol},
{'params' : headnerf.density_linear.parameters(), 'lr' : cfg.lr_feat_net},
{'params' : headnerf.color_linear.parameters(), 'lr' : cfg.lr_feat_net},
{'params' : headnerf.deform_bs_volume, 'lr' : cfg.lr_deform_vol,},
{'params' : headnerf.deform_mean_volume, 'lr' : cfg.lr_deform_vol,},
{'params' : headnerf.deform_linear.parameters(), 'lr' : cfg.lr_deform_net}])
recorder = TrainRecorder(cfg.recorder, dataset_eval)
trainer = Trainer(dataloader, headnerf, neu_camera, optimizer, recorder, cfg.gpu_id)
trainer.train(0, 100)