Pytorch maxpooling over channels dimension

前端 未结 1 893
被撕碎了的回忆
被撕碎了的回忆 2021-01-15 19:51

I was trying to build a cnn to with Pytorch, and had difficulty in maxpooling. I have taken the cs231n held by Stanford. As I recalled, maxpooling can be used as a dimension

1条回答
  •  执念已碎
    2021-01-15 20:36

    Would something like this work?

    from torch.nn import MaxPool1D
    import torch.nn.functional as F
    
    class ChannelPool(MaxPool1D):
        def forward(self, input):
            n, c, w, h = input.size()
            input = input.view(n,c,w*h).permute(0,2,1)
            pooled =  F.max_pool1d(input, self.kernel_size, self.stride,
                            self.padding, self.dilation, self.ceil_mode,
                            self.return_indices)
            _, _, c = pooled.size()
            pooled = pooled.permute(0,2,1)
            return pooled.view(n,c,w,h)
    

    0 讨论(0)
提交回复
热议问题