权重初始化
def weights_normal_init(model, dev=0.01):
if isinstance(model, list):
for m in model:
weights_normal_init(m, dev)
else:
for m in model.modules():
if isinstance(m, nn.Conv2d):
#print torch.sum(m.weight)
m.weight.data.normal_(0.0, dev)
if m.bias is not None:
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, dev)
网络结构
class Conv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, relu=True, same_padding=False, bn=False):
super(Conv2d, self).__init__()
padding = int((kernel_size - 1) / 2) if same_padding else 0
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class MCNN(nn.Module):
def __init__(self, bn=False):
super(MCNN, self).__init__()
self.branch1 = nn.Sequential(Conv2d(1, 16, 9, same_padding=True, bn=bn),
nn.MaxPool2d(2),
Conv2d(16, 32, 7, same_padding=True, bn=bn),
nn.MaxPool2d(2),
Conv2d(32, 16, 7, same_padding=True, bn=bn),
Conv2d(16, 8, 7, same_padding=True, bn=bn))
self.branch2 = nn.Sequential(Conv2d(1, 20, 7, same_padding=True, bn=bn),
nn.MaxPool2d(2),
Conv2d(20, 40, 5, same_padding=True, bn=bn),
nn.MaxPool2d(2),
Conv2d(40, 20, 5, same_padding=True, bn=bn),
Conv2d(20, 10, 5, same_padding=True, bn=bn))
self.branch3 = nn.Sequential(Conv2d(1, 24, 5, same_padding=True, bn=bn),
nn.MaxPool2d(2),
Conv2d(24, 48, 3, same_padding=True, bn=bn),
nn.MaxPool2d(2),
Conv2d(48, 24, 3, same_padding=True, bn=bn),
Conv2d(24, 12, 3, same_padding=True, bn=bn))
self.fuse = nn.Sequential(Conv2d(30, 1, 1, same_padding=True, bn=bn))
# self.upsample = nn.UpsamplingBilinear2d(scale_factor=4)
def forward(self, im_data):
x1 = self.branch1(im_data)
x2 = self.branch2(im_data)
x3 = self.branch3(im_data)
x = torch.cat((x1,x2,x3),1)
x = self.fuse(x)
# x = self.upsample(x)
return x
使用
model = MCNN()
weights_normal_init(model, dev=0.01)
来源:CSDN
作者:CS_lcylmh
链接:https://blog.csdn.net/qq_19598705/article/details/80935786