PyTorch保存和加載多GPU模型和單GPU模型

初次使用PyTorch這種深度學習框架由單GPU過渡到多GPU時,需要注意如下問題。

1、當我們使用單GPU訓練模型時,通過下面的方法保存和加載模型參數:

<code>(1)torch.save(the_model.state_dict(), path) #第一種方法

(2)the_model.load_state_dict(torch.load(path))/<code>

2、當我們使用多GPU訓練模型時,在保存模型參數時需做如下修改:

<code>torch.save(the_model.module.state_dict(), path) #第二種方法/<code>

3、如果我們已經按照第一種方法保存多GPU模型參數,加載模型參數時需做如下修改:

<code>kwargs={'map_location':lambda storage, loc: storage.cuda(GPU_ID)}

def load_GPUS(the_model, path, kwargs):
state_dict = torch.load(path,**kwargs)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:]
new_state_dict[name] = v
the_model.load_state_dict(new_state_dict)
return the_model

the_model = load_GPUs(the_model, path, kwargs)/<code>


分享到:


相關文章: