-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathmodel.lua
28 lines (27 loc) · 840 Bytes
/
model.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
if opt.begin_epoch == 1 then
if opt.model == 'c3d' then
dofile('models/c3d.lua')
elseif opt.model == 'resnet' then
dofile('models/resnet.lua')
end
model = create_model()
if not opt.no_cuda then
model = model:cuda()
cudnn.convert(model, cudnn)
end
else
local model_file_path = paths.concat(
opt.result_path, 'model_' .. (opt.begin_epoch - 1) .. '.t7')
assert(paths.filep(model_file_path),
'pretrained model at epoch ' .. (opt.begin_epoch - 1) .. ' does not exist')
print('pretrained model at epoch ' .. (opt.begin_epoch - 1) .. ' is loaded')
model = torch.load(model_file_path)
if not opt.no_cuda then
model = utils.make_data_parallel(model, opt.gpu_id, opt.n_gpus)
end
end
print(model)
criterion = nn.ClassNLLCriterion()
if not opt.no_cuda then
criterion = criterion:cuda()
end