使用 colab 的正確姿勢之一,就是要無縫的接續訓練,在 pytorch resume training 的方法如下。
Saving a Checkpoint
import torch
import shutil
def save_ckp(state, is_best, checkpoint_dir, best_model_dir):
f_path = checkpoint_dir / 'checkpoint.pt'
torch.save(state, f_path)
if is_best:
best_fpath = best_model_dir / 'best_model.pt'
shutil.copyfile(f_path, best_fpath)
checkpoint = {
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()
}
save_ckp(checkpoint, is_best, checkpoint_dir, model_dir)
附註:有些人會使用 .pth 為副檔名。
Loading a checkpoint
def load_ckp(checkpoint_fpath, model, optimizer):
checkpoint = torch.load(checkpoint_fpath)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
return model, optimizer, checkpoint['epoch']
model = MyModel(*args, **kwargs)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
ckp_path = "path/to/checkpoint/checkpoint.pt"
model, optimizer, start_epoch = load_ckp(ckp_path, model, optimizer)