What does the copy_initial_weights documentation mean in the higher library for Pytorch?

后端 未结 2 573
猫巷女王i
猫巷女王i 2021-01-17 09:42

I was trying to use the higher library for meta-learning and I was having issues understanding what the copy_initial_weights mean. The docs say:

2条回答
  •  感情败类
    2021-01-17 10:12

    I think it's more or less clear what this means now to me.

    First I'd like to make some notation clear, specially with respect to indices wrt inner time step and outer time step (also known as episodes):

    W^ = denotes the value a tensor has at time step inner_i, outer_i.
    

    At the beginning of training a neural net has params:

    W^<0,0>
    

    and are held inside it's module. For the sake of explanation the specific tensor (for the base model) will be denoted:

    W = the weight holding the weights for the model. This can be thought as the initialization of the model.
    

    and will be updated with with an in-place operation (this is important since W is the placeholder for all W^<0,outer_i> for all outer step values during "normal" meta-learning) by the outer optimizer. I want to emphasize that W is the tensor for the normal Pytorch neural net base model. By changing this in-place with an outer optimizer (like Adam) we are effectively training the initialization. The outer optimizer will use the gradients wrt this tensor to do the update through the whole unrolled inner loop process.

    When we say copy_initial_weights=False we mean that we will have a gradient path directly to W with whatever value it currently has. Usually the context manager is done before a inner loop after an outer step has been done so W will have W^<0,outer_i> for the current step. In particular the code that does this is this one for copy_initial_weight=False:

    params = [ p.clone() if device is None else p.clone().to(device) for p in module.parameters() ]
    

    this might look confusing if you're not familiar with clone but what it's doing is making a copy of the current weight of W. The unusual thing is that clone also remembers the gradient history from the tensor it came from (.clone() is as identity). It's main use it to add an extra layer of safety from the user doing dangerous in-place ops in it's differentiable optimizer. Assuming the user never did anything crazy with in-place ops one could in theory remove the .clone(). the reason this is confusing imho is because "copying in Pytorch" (clinging) does not automatically block gradient flows, which is what a "real" copy would do (i.e. create a 100% totally separate tensor). This is not what clone does and that is not what copy_initial_weights does.

    When copy_initial_weights=True what really happens is that the weights are cloned and detached. See the code it eventually runs (here and here):

    params = [_copy_tensor(p, safe_copy, device) for p in module.parameters()]
    

    which runs copy tensor (assuming they are doing a safe copy i.e. doing the extra clone):

     t = t.clone().detach().requires_grad_(t.requires_grad)
    

    Note that .detach() does not allocate new memory. It shares the memory with the original tensor, which is why the .clone() is needed to have this op be "safe" (usually wrt in-place ops).

    So when copy_initial_weights they are copying and detaching the current value of W. This is usually W^<0,outer_i> if it's doing usual meta-learning in the inner adaptation loop. So the intended semantics of copy_initial_weight is that and the initial_weight they simply mean W. The important thing to note is that the intermediate tensors for the net in the inner loop are not denoted in my notation but they are fmodel.parameters(t=inner_i). Also if things are usually meta-learning we have fmodel.parameters(t=0) = W and it gets update in-place by the outer optimizer.

    Note that because of the outer optimizer's in-place op and the freeing of the graphs we never take the derivate Grad_{W^<0,0>} with respect to the initial value of W. Which was something I initially thought we were doing.

提交回复
热议问题