How to initialize weights in PyTorch?

后端 未结 9 1898
暗喜
暗喜 2020-11-28 01:10

How to initialize the weights and biases (for example, with He or Xavier initialization) in a network in PyTorch?

相关标签:
9条回答
  • 2020-11-28 01:40

    To initialize layers you typically don't need to do anything.

    PyTorch will do it for you. If you think about, this has lot of sense. Why should we initialize layers, when PyTorch can do that following the latest trends.

    Check for instance the Linear layer.

    In the __init__ method it will call Kaiming He init function.

        def reset_parameters(self):
            init.kaiming_uniform_(self.weight, a=math.sqrt(3))
            if self.bias is not None:
                fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
                bound = 1 / math.sqrt(fan_in)
                init.uniform_(self.bias, -bound, bound)
    

    The similar is for other layers types. For conv2d for instance check here.

    To note : The gain of proper initialization is the faster training speed. If your problem deserves special initialization you can do it afterwords.

    0 讨论(0)
  • 2020-11-28 01:44
        import torch.nn as nn        
    
        # a simple network
        rand_net = nn.Sequential(nn.Linear(in_features, h_size),
                                 nn.BatchNorm1d(h_size),
                                 nn.ReLU(),
                                 nn.Linear(h_size, h_size),
                                 nn.BatchNorm1d(h_size),
                                 nn.ReLU(),
                                 nn.Linear(h_size, 1),
                                 nn.ReLU())
    
        # initialization function, first checks the module type,
        # then applies the desired changes to the weights
        def init_normal(m):
            if type(m) == nn.Linear:
                nn.init.uniform_(m.weight)
    
        # use the modules apply function to recursively apply the initialization
        rand_net.apply(init_normal)
    
    0 讨论(0)
  • 2020-11-28 01:45

    If you see a deprecation warning (@Fábio Perez)...

    def init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)
    
    net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
    net.apply(init_weights)
    
    0 讨论(0)
提交回复
热议问题