How can I load and use a PyTorch (.pth.tar) model

后端 未结 1 1553
予麋鹿
予麋鹿 2021-01-20 23:21

I am not very familiar with Torch, and I primarily use Tensorflow. I, however, need to use a retrained inception model that was retrained in Torch. Due to the large amount of co

1条回答
  •  爱一瞬间的悲伤
    2021-01-21 00:04

    you basically need to do the same as in tensorflow. That is, when you store a network, only the parameters (i.e. the trainable objects in your network) will be stored, but not the "glue", that is all the logic you need to use a trained model. So if you have a .pth.tar file, you can load it, thereby overriding the parameter values of a model already defined.

    That means that the general procedure of saving/loading a model is as follows:

    • write your network definition (i.e. your nn.Module object)
    • train or otherwise change the network's parameters in a way you want
    • save the parameters using torch.save
    • when you want to use that network, use the same definition of an nn.Module object to first instantiate a pytorch network
    • then override the values of the network's parameters using torch.load

    Here's a discussion with some references on how to do this: pytorch forums

    And here's a super short mwe:

    # to store
    torch.save({
        'state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
    }, 'filename.pth.tar')
    
    # to load
    checkpoint = torch.load('filename.pth.tar')
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    

    0 讨论(0)
提交回复
热议问题