How can I access layers in a pytorch module by index?

后端 未结 2 1582
轻奢々
轻奢々 2021-01-24 05:08

I am trying to write a pytorch module with multiple layers. Since I need the intermediate outputs I cannot put them all in a Sequantial as usual. On the other hand, since there

2条回答
  •  深忆病人
    2021-01-24 05:37

    If you put your layers in a python list, pytorch does not register them correctly. You have to do so using ModuleList (https://pytorch.org/docs/master/generated/torch.nn.ModuleList.html).

    ModuleList can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all Module methods.

    Your code should be something like:

    
    import torch
    import torch.nn as nn
    import torch.optim as optim
    
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
    
            self.layer_list = nn.ModuleList()  # << the only changed line! <<
    
            self.layer_list.append(nn.Linear(2,3))
            self.layer_list.append(nn.Linear(3,4))
            self.layer_list.append(nn.Linear(4,5))
    
        def forward(self, x):
            res_list = [x]
            for i in range(len(self.layer_list)):
                res_list.append(self.layer_list[i](res_list[-1]))
            return res_list
    

    By using ModuleList you make sure all layers are registered in the computational graph.

    There is also a ModuleDict that you can use if you want to index your layers by name. You can check pytorch's containers here: https://pytorch.org/docs/master/nn.html#containers

提交回复
热议问题