问题
My question is concerning the syntax of pytorch register_hook
.
x = torch.tensor([1.], requires_grad=True)
y = x**2
z = 2*y
x.register_hook(print)
y.register_hook(print)
z.backward()
outputs:
tensor([2.])
tensor([4.])
this snippet simply prints the gradient of z
w.r.t x
and y
, respectively.
Now my (most likely trivial) question is how to return the intermediate gradients (rather than only printing)?
UPDATE:
It appears that calling retain_grad()
solves the issue for leaf nodes. ex. y.retain_grad()
.
However, retain_grad
does not seem to solve it for non-leaf nodes. Any suggestions?
回答1:
I think you can use those hooks to store the gradients in a global variable:
grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y
x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))
z.backward()
But you most likely also need to remember the corresponding tensor these gradients were computed for. In that case, we slightly extend above using a dict
instead of list
:
grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y
def store(grad,parent):
print(grad,parent)
grads[parent] = grad.clone()
x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))
z.sum().backward()
Now you can, for example, access tensor y
's grad simply using grads[y]
来源:https://stackoverflow.com/questions/55305262/how-to-return-intermideate-gradients-for-non-leaf-nodes-in-pytorch