this should be a quick one. When I use a pre-defined module in PyTorch, I can typically access its weights fairly easily. However, how do I access them if I wrapped the modu
From the PyTorch forum, this is the recommended way:
model_2.layer[0].weight
You can access modules by name using _modules
:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 3, 3)
def forward(self, input):
return self.conv1(input)
model = Net()
print(model._modules['conv1'])
An easy way to access the weights is to use the state_dict()
of your model.
This should work in your case:
for k, v in model_2.state_dict().iteritems():
print("Layer {}".format(k))
print(v)
Another option is to get the modules()
iterator. If you know beforehand the type of your layers this should also work:
for layer in model_2.modules():
if isinstance(layer, nn.Linear):
print(layer.weight)