PyTorch : How to properly create a list of nn.Linear()

前端 未结 1 1519
佛祖请我去吃肉
佛祖请我去吃肉 2021-01-17 10:35

I have created a class that has nn.Module as subclass.

In my class, I have to create N number of linear transformation, where N is given as class parameters.

1条回答
  •  北海茫月
    2021-01-17 11:00

    You can use nn.ModuleList to wrap your list of linear layers as explained here

    self.list_1 = nn.ModuleList(self.list_1)
    

    0 讨论(0)
提交回复
热议问题