Pytorch_Part7_模型使用
共同贡献 PyTorch常见错误与坑汇总文档 : 《PyTorch常见报错/坑汇总》 一、模型保存与加载 1. 序列化与反序列化 net = LeNet2(classes=2019) # 法1: 保存整个Module,不仅保存参数,也保存结构 torch.save(net, path) net_load = torch.load(path_model) # 网络名称、结构、模型参数、优化器参数均保留 # 法2: 保存模型参数(推荐,占用资源少) state_dict = net.state_dict() torch.save(state_dict , path) net_new = LeNet2(classes=2019) net_new.load_state_dict(state_dict_load) 2. 断点续训练 保存: checkpoint = { "model_state_dict": net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch } path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch) torch.save(checkpoint, path_checkpoint) 恢复: # ======