使用 colab 的正確姿勢之一,就是要無縫的接續訓練,在 pytorch resume training 的方法如下。
Saving a Checkpoint
import torch import shutil def save_ckp(state, is_best, checkpoint_dir, best_model_dir): f_path = checkpoint_dir / 'checkpoint.pt' torch.save(state, f_path) if is_best: best_fpath = best_model_dir / 'best_model.pt' shutil.copyfile(f_path, best_fpath) checkpoint = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() } save_ckp(checkpoint, is_best, checkpoint_dir, model_dir)
附註:有些人會使用 .pth 為副檔名。
Loading a checkpoint
def load_ckp(checkpoint_fpath, model, optimizer): checkpoint = torch.load(checkpoint_fpath) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) return model, optimizer, checkpoint['epoch'] model = MyModel(*args, **kwargs) optimizer = optim.Adam(model.parameters(), lr=learning_rate) ckp_path = "path/to/checkpoint/checkpoint.pt" model, optimizer, start_epoch = load_ckp(ckp_path, model, optimizer)