PyTorch, nn.Sequential(), access weights of a specific module in nn.Sequential()

后端 未结 3 1158
执笔经年
执笔经年 2021-01-12 09:51

this should be a quick one. When I use a pre-defined module in PyTorch, I can typically access its weights fairly easily. However, how do I access them if I wrapped the modu

相关标签:
3条回答
  • 2021-01-12 10:08

    From the PyTorch forum, this is the recommended way:

    model_2.layer[0].weight
    
    0 讨论(0)
  • 2021-01-12 10:17

    You can access modules by name using _modules:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
    
            self.conv1 = nn.Conv2d(3, 3, 3)
    
        def forward(self, input):
            return self.conv1(input)
    
    model = Net()
    print(model._modules['conv1'])
    
    0 讨论(0)
  • 2021-01-12 10:29

    An easy way to access the weights is to use the state_dict() of your model.

    This should work in your case:

    for k, v in model_2.state_dict().iteritems():
        print("Layer {}".format(k))
        print(v)
    

    Another option is to get the modules() iterator. If you know beforehand the type of your layers this should also work:

    for layer in model_2.modules():
       if isinstance(layer, nn.Linear):
            print(layer.weight)
    
    0 讨论(0)
提交回复
热议问题