I was trying to build a cnn to with Pytorch, and had difficulty in maxpooling. I have taken the cs231n held by Stanford. As I recalled, maxpooling can be used as a dimension
Would something like this work?
from torch.nn import MaxPool1D
import torch.nn.functional as F
class ChannelPool(MaxPool1D):
def forward(self, input):
n, c, w, h = input.size()
input = input.view(n,c,w*h).permute(0,2,1)
pooled = F.max_pool1d(input, self.kernel_size, self.stride,
self.padding, self.dilation, self.ceil_mode,
self.return_indices)
_, _, c = pooled.size()
pooled = pooled.permute(0,2,1)
return pooled.view(n,c,w,h)