There is a cycle in PyTorch:
- Forward when we get output or
y_hat
from the input,
- Calculating loss where
loss = loss_fn(y_hat, y)
loss.backward
when we calculate the gradients
optimizer.step
when we update parameters
Or in code:
for mb in range(10): # 10 mini batches
y_pred = model(x)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
If we would not clear the gradients after the optimizer.step
, which is the appropriate step or just before the next backward()
gradients would accumulate.
Here is an example showing accumulation:
import torch
w = torch.rand(5)
w.requires_grad_()
print(w)
s = w.sum()
s.backward()
print(w.grad) # tensor([1., 1., 1., 1., 1.])
s.backward()
print(w.grad) # tensor([2., 2., 2., 2., 2.])
s.backward()
print(w.grad) # tensor([3., 3., 3., 3., 3.])
s.backward()
print(w.grad) # tensor([4., 4., 4., 4., 4.])
loss.backward() does not have any way specifying this.
torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)
From all the options you can specify there is no way to zero the gradients manually. Like this in previous mini example:
w.grad.zero_()
There was some discussion on doing zero_grad()
every time with backward()
(obviously previous gradients) and to keep grads with preserve_grads=True, but this never came to life.