How can I access layers in a pytorch module by index?

后端 未结 2 1581
轻奢々
轻奢々 2021-01-24 05:08

I am trying to write a pytorch module with multiple layers. Since I need the intermediate outputs I cannot put them all in a Sequantial as usual. On the other hand, since there

相关标签:
2条回答
  • 2021-01-24 05:37

    If you put your layers in a python list, pytorch does not register them correctly. You have to do so using ModuleList (https://pytorch.org/docs/master/generated/torch.nn.ModuleList.html).

    ModuleList can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all Module methods.

    Your code should be something like:

    
    import torch
    import torch.nn as nn
    import torch.optim as optim
    
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
    
            self.layer_list = nn.ModuleList()  # << the only changed line! <<
    
            self.layer_list.append(nn.Linear(2,3))
            self.layer_list.append(nn.Linear(3,4))
            self.layer_list.append(nn.Linear(4,5))
    
        def forward(self, x):
            res_list = [x]
            for i in range(len(self.layer_list)):
                res_list.append(self.layer_list[i](res_list[-1]))
            return res_list
    

    By using ModuleList you make sure all layers are registered in the computational graph.

    There is also a ModuleDict that you can use if you want to index your layers by name. You can check pytorch's containers here: https://pytorch.org/docs/master/nn.html#containers

    0 讨论(0)
  • 2021-01-24 05:40

    It is possible to list all layers on neural network by use

    list_layers = model.named_children()
    

    In the first case, you can use:

    parameters = list(Model1.parameters())+ list(Model2.parameters())
    optimizer = optim.Adam(parameters, lr=1e-3)
    

    In the second case, you didn't create the object, so basically you can try this:

    model = VAE()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    

    By the way, you can start from modifying the VAE example provided by Pytorch.

    Perhaps you miss the initial function or initialize the model in a wrong way. See the init function here.

    0 讨论(0)
提交回复
热议问题