How to construct a network with two inputs in PyTorch

前端 未结 1 599
[愿得一人]
[愿得一人] 2021-02-14 03:18

Suppose I want to have the general neural network architecture:

Input1 --> CNNLayer 
                    \\
                     ---> FCLayer ---> Outpu         


        
相关标签:
1条回答
  • 2021-02-14 04:06

    By "combine them" I assume you mean to concatenate the two inputs.
    Assuming you concat along the second dimension:

    import torch
    from torch import nn
    
    class TwoInputsNet(nn.Module):
      def __init__(self):
        super(TwoInputsNet, self).__init__()
        self.conv = nn.Conv2d( ... )  # set up your layer here
        self.fc1 = nn.Linear( ... )  # set up first FC layer
        self.fc2 = nn.Linear( ... )  # set up the other FC layer
    
      def forward(self, input1, input2):
        c = self.conv(input1)
        f = self.fc1(input2)
        # now we can reshape `c` and `f` to 2D and concat them
        combined = torch.cat((c.view(c.size(0), -1),
                              f.view(f.size(0), -1)), dim=1)
        out = self.fc2(combined)
        return out
    

    Note that when you define the number of inputs to self.fc2 you need to take into account both out_channels of self.conv as well as the output spatial dimensions of c.

    0 讨论(0)
提交回复
热议问题