保存:
PATH = 'path/to/dir/my_model.pth'
torch.save(trained_model.state_dict(), PATH)
之后以下面方式加载:
untrained_model = YourModelClass(*args, **kwargs) # 要先实例化你的模型,并和要加载的模型结构相同
untrained_model.load_state_dict(torch.load(PATH))
保存:
torch.save(trained_model, PATH)
加载:
untrained_model = torch.load(PATH)
- 模型已经训练完成,你需要保存以备以后使用,如推断等
torch.save(model.state_dict(), PATH) # 保存模型 untrained_model = YourModelClass(*args, **kwargs) # 要先实例化你的模型,并和要加载的模型结构相同 untrained_model.load_state_dict(torch.load(PATH)) # 加载到实例化了的模型 untrained_model.eval() # 加载之后默认为train mode,推断时先转换到eval mode.
- 训练未完成,你需要保存以从断点处恢复训练进度。除了要保存模型,你还要保存optimizer,
epoch, score等状态。
state = { # 保存恢复训练所需要的所有状态 'epoch': epoch, 'model_state_dit': your_model.state_dict(), 'optimizer': optimizer.state_dict(), # ... } torch.save(state, PATH) # 加载训练的所有状态 # 加载训练进度 state = torch.load(PATH) # 恢复进度 epoch = state['epoch'] your_model.load_state_dict(state['model_state_dit']) optimizer = state['optimizer']