pytorch how to set .requires_grad False

前端 未结 5 1814
长发绾君心
长发绾君心 2021-02-03 23:44

I want to set some of my model frozen. Following the official docs:

with torch.no_grad():
    linear = nn.Linear(1, 1         


        
5条回答
  •  北海茫月
    2021-02-04 00:22

    To complete @Salih_Karagoz's answer, you also have the torch.set_grad_enabled() context (further documentation here), which can be used to easily switch between train/eval modes:

    linear = nn.Linear(1,1)
    
    is_train = False
    
    for param in linear.parameters():
        param.requires_grad = is_train
    with torch.set_grad_enabled(is_train):
        linear.eval()
        print(linear.weight.requires_grad)
    

提交回复
热议问题