本文目录
1. 简介
该论文利用多尺度特征来识别姿态,如下图所示,每个子网络称为hourglass Network,是一个沙漏型的结构,多个这种结构堆叠起来,称作stacked hourglass。堆叠的方式,方便每个模块在整个图像上重新估计姿态和特征。如下图所示,输入图像通过全卷积网络fcn后,得到特征,而后通过多个堆叠的hourglass,得到最终的热图。
Hourglass如下图所示。其中每个方块均为下下图的残差模块。
Hourglass采用了中间监督(Intermediate Supervision)。每个hourglass均会有热图(蓝色)。训练阶段,将这些热图和真实热图计算损失MSE,并求和,得到损失;推断阶段,使用的是最后一个hourglass的热图。
2. stacked hourglass
堆叠hourglass结构如下图所示(nChannels=256,nStack=2,nModules=2,numReductions=4, nJoints=17):
代码如下:
1 class StackedHourGlass(nn.Module):
2 """docstring for StackedHourGlass"""
3 def __init__(self, nChannels, nStack, nModules, numReductions, nJoints):
4 super(StackedHourGlass, self).__init__()
5 self.nChannels = nChannels
6 self.nStack = nStack
7 self.nModules = nModules
8 self.numReductions = numReductions
9 self.nJoints = nJoints
10
11 self.start = M.BnReluConv(3, 64, kernelSize = 7, stride = 2, padding = 3) # BN+ReLU+conv
12
13 self.res1 = M.Residual(64, 128) # 输入和输出不等,输入通过1*1conv结果和3*(BN+ReLU+conv)求和
14 self.mp = nn.MaxPool2d(2, 2)
15 self.res2 = M.Residual(128, 128) # 输入和输出相等,为x+3*(BN+ReLU+conv)
16 self.res3 = M.Residual(128, self.nChannels) # 输入和输出相等,为x+3*(BN+ReLU+conv);否则输入通过1*1conv结果和3*(BN+ReLU+conv)求和。
17
18 _hourglass, _Residual, _lin1, _chantojoints, _lin2, _jointstochan = [],[],[],[],[],[]
19
20 for _ in range(self.nStack): # 堆叠个数
21 _hourglass.append(Hourglass(self.nChannels, self.numReductions, self.nModules))
22 _ResidualModules = []
23 for _ in range(self.nModules):
24 _ResidualModules.append(M.Residual(self.nChannels, self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv)
25 _ResidualModules = nn.Sequential(*_ResidualModules)
26 _Residual.append(_ResidualModules) # self.nModules 个 3*(BN+ReLU+conv)
27 _lin1.append(M.BnReluConv(self.nChannels, self.nChannels)) # BN+ReLU+conv
28 _chantojoints.append(nn.Conv2d(self.nChannels, self.nJoints,1)) # 1*1 conv,维度变换
29 _lin2.append(nn.Conv2d(self.nChannels, self.nChannels,1)) # 1*1 conv,维度不变
30 _jointstochan.append(nn.Conv2d(self.nJoints,self.nChannels,1)) # 1*1 conv,维度变换
31
32 self.hourglass = nn.ModuleList(_hourglass)
33 self.Residual = nn.ModuleList(_Residual)
34 self.lin1 = nn.ModuleList(_lin1)
35 self.chantojoints = nn.ModuleList(_chantojoints)
36 self.lin2 = nn.ModuleList(_lin2)
37 self.jointstochan = nn.ModuleList(_jointstochan)
38
39 def forward(self, x):
40 x = self.start(x)
41 x = self.res1(x)
42 x = self.mp(x)
43 x = self.res2(x)
44 x = self.res3(x)
45 out = []
46
47 for i in range(self.nStack):
48 x1 = self.hourglass[i](x)
49 x1 = self.Residual[i](x1)
50 x1 = self.lin1[i](x1)
51 out.append(self.chantojoints[i](x1))
52 x1 = self.lin2[i](x1)
53 x = x + x1 + self.jointstochan[i](out[i]) # 特征求和
54
55 return (out)
3. hourglass
hourglass在numReductions>1时,递归调用自己,结构如下:
代码如下:
1 class Hourglass(nn.Module):
2 """docstring for Hourglass"""
3 def __init__(self, nChannels = 256, numReductions = 4, nModules = 2, poolKernel = (2,2), poolStride = (2,2), upSampleKernel = 2):
4 super(Hourglass, self).__init__()
5 self.numReductions = numReductions
6 self.nModules = nModules
7 self.nChannels = nChannels
8 self.poolKernel = poolKernel
9 self.poolStride = poolStride
10 self.upSampleKernel = upSampleKernel
11
12 """For the skip connection, a residual module (or sequence of residuaql modules) """
13 _skip = []
14 for _ in range(self.nModules):
15 _skip.append(M.Residual(self.nChannels, self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv)
16 self.skip = nn.Sequential(*_skip)
17
18 """First pooling to go to smaller dimension then pass input through
19 Residual Module or sequence of Modules then and subsequent cases:
20 either pass through Hourglass of numReductions-1 or pass through M.Residual Module or sequence of Modules """
21 self.mp = nn.MaxPool2d(self.poolKernel, self.poolStride)
22
23 _afterpool = []
24 for _ in range(self.nModules):
25 _afterpool.append(M.Residual(self.nChannels, self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv)
26 self.afterpool = nn.Sequential(*_afterpool)
27
28 if (numReductions > 1):
29 self.hg = Hourglass(self.nChannels, self.numReductions-1, self.nModules, self.poolKernel, self.poolStride) # 嵌套调用本身
30 else:
31 _num1res = []
32 for _ in range(self.nModules):
33 _num1res.append(M.Residual(self.nChannels,self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv)
34 self.num1res = nn.Sequential(*_num1res) # doesnt seem that important ?
35
36 """ Now another M.Residual Module or sequence of M.Residual Modules """
37 _lowres = []
38 for _ in range(self.nModules):
39 _lowres.append(M.Residual(self.nChannels,self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv)
40 self.lowres = nn.Sequential(*_lowres)
41
42 """ Upsampling Layer (Can we change this??????) As per Newell's paper upsamping recommended """
43 self.up = myUpsample()#nn.Upsample(scale_factor = self.upSampleKernel) # 将高和宽扩充为原来2倍,实现上采样
44
45
46 def forward(self, x):
47 out1 = x
48 out1 = self.skip(out1) # 输入和输出相等,为x+3*(BN+ReLU+conv)
49 out2 = x
50 out2 = self.mp(out2) # 降维
51 out2 = self.afterpool(out2) # 输入和输出相等,为x+3*(BN+ReLU+conv)
52 if self.numReductions>1:
53 out2 = self.hg(out2) # 嵌套调用本身
54 else:
55 out2 = self.num1res(out2) # 输入和输出相等,为x+3*(BN+ReLU+conv)
56 out2 = self.lowres(out2) # 输入和输出相等,为x+3*(BN+ReLU+conv)
57 out2 = self.up(out2) # 升维
58
59 return out2 + out1 # 求和
4. 上采样myUpsample
上采样代码如下:
1 class myUpsample(nn.Module):
2 def __init__(self):
3 super(myUpsample, self).__init__()
4 pass
5 def forward(self, x): # 将高和宽扩充为原来2倍,实现上采样
6 return x[:, :, :, None, :, None].expand(-1, -1, -1, 2, -1, 2).reshape(x.size(0), x.size(1), x.size(2)*2, x.size(3)*2)
其中x为(N)(C)(H)(W)的矩阵,x[:, :, :, None, :, None]为(N)(C)(H)(1)(W)(1)的矩阵,expand之后变成(N)(C)(H)(2)(W)(2)的矩阵,最终reshape之后变成(N)(C)(2H) (2W)的矩阵,实现了将1个像素水平和垂直方向各扩充2倍,变成4个像素(4个像素值相同),完成了上采样。
5. 残差模块
残差模块结构如下:
代码如下:
1 class Residual(nn.Module):
2 """docstring for Residual""" # 输入和输出相等,为x+3*(BN+ReLU+conv);否则输入通过1*1conv结果和3*(BN+ReLU+conv)求和
3 def __init__(self, inChannels, outChannels):
4 super(Residual, self).__init__()
5 self.inChannels = inChannels
6 self.outChannels = outChannels
7 self.cb = ConvBlock(inChannels, outChannels) # 3 * (BN+ReLU+conv) 其中第一组降维,第二组不变,第三组升维
8 self.skip = SkipLayer(inChannels, outChannels) # 输入和输出通道相等,则输出=输入,否则为1*1 conv
9
10 def forward(self, x):
11 out = 0
12 out = out + self.cb(x)
13 out = out + self.skip(x)
14 return out
其中skiplayer代码如下:
1 class SkipLayer(nn.Module):
2 """docstring for SkipLayer""" # 输入和输出通道相等,则输出=输入,否则为1*1 conv
3 def __init__(self, inChannels, outChannels):
4 super(SkipLayer, self).__init__()
5 self.inChannels = inChannels
6 self.outChannels = outChannels
7 if (self.inChannels == self.outChannels):
8 self.conv = None
9 else:
10 self.conv = nn.Conv2d(self.inChannels, self.outChannels, 1)
11
12 def forward(self, x):
13 if self.conv is not None:
14 x = self.conv(x)
15 return x
6. conv
1 class BnReluConv(nn.Module):
2 """docstring for BnReluConv""" # BN+ReLU+conv
3 def __init__(self, inChannels, outChannels, kernelSize = 1, stride = 1, padding = 0):
4 super(BnReluConv, self).__init__()
5 self.inChannels = inChannels
6 self.outChannels = outChannels
7 self.kernelSize = kernelSize
8 self.stride = stride
9 self.padding = padding
10
11 self.bn = nn.BatchNorm2d(self.inChannels)
12 self.conv = nn.Conv2d(self.inChannels, self.outChannels, self.kernelSize, self.stride, self.padding)
13 self.relu = nn.ReLU()
14
15 def forward(self, x):
16 x = self.bn(x)
17 x = self.relu(x)
18 x = self.conv(x)
19 return x
7. ConvBlock
1 class ConvBlock(nn.Module):
2 """docstring for ConvBlock""" # 3 * (BN+ReLU+conv) 其中第一组降维,第二组不变,第三组升维
3 def __init__(self, inChannels, outChannels):
4 super(ConvBlock, self).__init__()
5 self.inChannels = inChannels
6 self.outChannels = outChannels
7 self.outChannelsby2 = outChannels//2
8
9 self.cbr1 = BnReluConv(self.inChannels, self.outChannelsby2, 1, 1, 0) # BN+ReLU+conv
10 self.cbr2 = BnReluConv(self.outChannelsby2, self.outChannelsby2, 3, 1, 1) # BN+ReLU+conv
11 self.cbr3 = BnReluConv(self.outChannelsby2, self.outChannels, 1, 1, 0) # BN+ReLU+conv
12
13 def forward(self, x):
14 x = self.cbr1(x)
15 x = self.cbr2(x)
16 x = self.cbr3(x)
17 return x
来源:https://blog.csdn.net/hejin_some/article/details/100980757