I am trying to write a pytorch module with multiple layers. Since I need the intermediate outputs I cannot put them all in a Sequantial as usual. On the other hand, since there
If you put your layers in a python list, pytorch does not register them correctly. You have to do so using ModuleList
(https://pytorch.org/docs/master/generated/torch.nn.ModuleList.html).
ModuleList can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all Module methods.
Your code should be something like:
import torch
import torch.nn as nn
import torch.optim as optim
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layer_list = nn.ModuleList() # << the only changed line! <<
self.layer_list.append(nn.Linear(2,3))
self.layer_list.append(nn.Linear(3,4))
self.layer_list.append(nn.Linear(4,5))
def forward(self, x):
res_list = [x]
for i in range(len(self.layer_list)):
res_list.append(self.layer_list[i](res_list[-1]))
return res_list
By using ModuleList
you make sure all layers are registered in the computational graph.
There is also a ModuleDict
that you can use if you want to index your layers by name. You can check pytorch's containers here: https://pytorch.org/docs/master/nn.html#containers