Saving and Loading Your Model to Resume Training in PyTorch

Posted in :

使用 colab 的正確姿勢之一,就是要無縫的接續訓練,在 pytorch resume training 的方法如下。

Saving a Checkpoint

import torch
import shutil
def save_ckp(state, is_best, checkpoint_dir, best_model_dir):
     f_path = checkpoint_dir / 'checkpoint.pt'
     torch.save(state, f_path)
     if is_best:
         best_fpath = best_model_dir / 'best_model.pt'
         shutil.copyfile(f_path, best_fpath)

checkpoint = {
     'epoch': epoch + 1,
     'state_dict': model.state_dict(),
     'optimizer': optimizer.state_dict()
 }
save_ckp(checkpoint, is_best, checkpoint_dir, model_dir)

附註:有些人會使用 .pth 為副檔名。


Loading a checkpoint

def load_ckp(checkpoint_fpath, model, optimizer):
     checkpoint = torch.load(checkpoint_fpath)
     model.load_state_dict(checkpoint['state_dict'])
     optimizer.load_state_dict(checkpoint['optimizer'])
     return model, optimizer, checkpoint['epoch']

model = MyModel(*args, **kwargs)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
ckp_path = "path/to/checkpoint/checkpoint.pt"
model, optimizer, start_epoch = load_ckp(ckp_path, model, optimizer)

資料來源:
https://medium.com/analytics-vidhya/saving-and-loading-your-model-to-resume-training-in-pytorch-cb687352fa61

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *