pytorch权重初始化(2)
权重初始化 def weights_normal_init (model, dev= 0.01 ) : if isinstance(model, list): for m in model: weights_normal_init(m, dev) else : for m in model.modules(): if isinstance(m, nn.Conv2d): #print torch.sum(m.weight) m.weight.data.normal_( 0.0 , dev) if m.bias is not None : m.bias.data.fill_( 0.0 ) elif isinstance(m, nn.Linear): m.weight.data.normal_( 0.0 , dev) 网络结构 class Conv2d (nn.Module) : def __init__ (self, in_channels, out_channels, kernel_size, stride= 1 , relu=True, same_padding=False, bn=False) : super(Conv2d, self).__init__() padding = int((kernel_size - 1 ) / 2 ) if same_padding else 0