pytorch how to set .requires_grad False

前端 未结 5 1817
长发绾君心
长发绾君心 2021-02-03 23:44

I want to set some of my model frozen. Following the official docs:

with torch.no_grad():
    linear = nn.Linear(1, 1         


        
5条回答
  •  离开以前
    2021-02-04 00:08

    Nice. The trick is to check that when you define a Linear layar, by default the parameters will have requires_grad=True, because we would like to learn, right?

    l = nn.Linear(1, 1)
    p = l.parameters()
    for _ in p:
        print (_)
    
    # Parameter containing:
    # tensor([[-0.3258]], requires_grad=True)
    # Parameter containing:
    # tensor([0.6040], requires_grad=True)    
    

    The other construct,

    with torch.no_grad():
    

    Means you cannot learn in here.

    So your code, just shows you are capable of learning, even though you are in torch.no_grad() where learning is forbidden.

    with torch.no_grad():
        linear = nn.Linear(1, 1)
        linear.eval()
        print(linear.weight.requires_grad) #true
    

    If you really plan to turn off requires_grad for the weight parameter, you can do it also with:

    linear.weight.requires_grad_(False)
    

    or

    linear.weight.requires_grad = False
    

    So your code may become like this:

    with torch.no_grad():
        linear = nn.Linear(1, 1)
        linear.weight.requires_grad_(False)
        linear.eval()
        print(linear.weight.requires_grad)
    

    If you plan to switch to requires_grad for all params in a module:

    l = nn.Linear(1, 1)
    for _ in l.parameters():
        _.requires_grad_(False)
        print(_)
    

提交回复
热议问题