I am trying to comprehend inner workings of the gradient accumulation in PyTorch
. My question is somewhat related to these two:
Why do we need to call zero_
You are not actually accumulating gradients. Just leaving off optimizer.zero_grad()
has no effect if you have a single .backward()
call, as the gradients are already zero to begin with (technically None
but they will be
automatically initialised to zero).
The only difference between your two versions, is how you calculate the final loss. The for loop of the second example does the same calculations as PyTorch does in the first example, but you do them individually, and PyTorch cannot optimise (parallelise and vectorise) your for loop, which makes an especially staggering difference on GPUs, granted that the tensors aren't tiny.
Before getting to gradient accumulation, let's start with your question:
Finally to my question: what exactly happens 'under the hood'?
Every operation on tensors is tracked in a computational graph if and only if one of the operands is already part of a computational graph. When you set requires_grad=True
of a tensor, it creates a computational graph with a single vertex, the tensor itself, which will remain a leaf in the graph. Any operation with that tensor will create a new vertex, which is the result of the operation, hence there is an edge from the operands to it, tracking the operation that was performed.
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(4.0)
c = a + b # => tensor(6., grad_fn=)
a.requires_grad # => True
a.is_leaf # => True
b.requires_grad # => False
b.is_leaf # => True
c.requires_grad # => True
c.is_leaf # => False
Every intermediate tensor automatically requires gradients and has a grad_fn
, which is the function to calculate the partial derivatives with respect to its inputs. Thanks to the chain rule, we can traverse the whole graph in reverse order to calculate the derivatives with respect to every single leaf, which are the parameters we want to optimise. That's the idea of backpropagation, also known as reverse mode differentiation. For more details I recommend reading Calculus on Computational Graphs: Backpropagation.
PyTorch uses that exact idea, when you call loss.backward()
it traverses the graph in reverse order, starting from loss
, and calculates the derivatives for each vertex. Whenever a leaf is reached, the calculated derivative for that tensor is stored in its .grad
attribute.
In your first example, that would lead to:
MeanBackward -> PowBackward -> SubBackward -> MulBackward`
The second example is almost identical, except that you calculate the mean manually, and instead of having a single path for the loss, you have multiple paths for each element of the loss calculation. To clarify, the single path also calculates the derivatives of each element, but internally, which again opens up the possibilities for some optimisations.
# Example 1
loss = (y - y_hat) ** 2
# => tensor([16., 4.], grad_fn=)
# Example 2
loss = []
for k in range(len(y)):
y_hat = model2(x[k])
loss.append((y[k] - y_hat) ** 2)
loss
# => [tensor([16.], grad_fn=), tensor([4.], grad_fn=)]
In either case a single graph is created that is backpropagated exactly once, that's the reason it's not considered gradient accumulation.
Gradient accumulation refers to the situation, where multiple backwards passes are performed before updating the parameters. The goal is to have the same model parameters for multiple inputs (batches) and then update the model's parameters based on all these batches, instead of performing an update after every single batch.
Let's revisit your example. x
has size [2], that's the size of our entire dataset. For some reason, we need to calculate the gradients based on the whole dataset. That is naturally the case when using a batch size of 2, since we would have the whole dataset at once. But what happens if we can only have batches of size 1? We could run them individually and update the model after each batch as usual, but then we don't calculate the gradients over the whole dataset.
What we need to do, is run each sample individually with the same model parameters and calculate the gradients without updating the model. Now you might be thinking, isn't that what you did in the second version? Almost, but not quite, and there is a crucial problem in your version, namely that you are using the same amount of memory as in the first version, because you have the same calculations and therefore the same number of values in the computational graph.
How do we free memory? We need to get rid of the tensors of the previous batch and also the computational graph, because that uses a lot of memory to keep track of everything that's necessary for the backpropagation. The computational graph is automatically destroyed when .backward()
is called (unless retain_graph=True
is specified).
def calculate_loss(x: torch.Tensor) -> torch.Tensor:
y = 2 * x
y_hat = model(x)
loss = (y - y_hat) ** 2
return loss.mean()
# With mulitple batches of size 1
batches = [torch.tensor([4.0]), torch.tensor([2.0])]
optimizer.zero_grad()
for i, batch in enumerate(batches):
# The loss needs to be scaled, because the mean should be taken across the whole
# dataset, which requires the loss to be divided by the number of batches.
loss = calculate_loss(batch) / len(batches)
loss.backward()
print(f"Batch size 1 (batch {i}) - grad: {model.weight.grad}")
print(f"Batch size 1 (batch {i}) - weight: {model.weight}")
# Updating the model only after all batches
optimizer.step()
print(f"Batch size 1 (final) - grad: {model.weight.grad}")
print(f"Batch size 1 (final) - weight: {model.weight}")
Output (I removed the Parameter containing messages for readability):
Batch size 1 (batch 0) - grad: tensor([-16.])
Batch size 1 (batch 0) - weight: tensor([1.], requires_grad=True)
Batch size 1 (batch 1) - grad: tensor([-20.])
Batch size 1 (batch 1) - weight: tensor([1.], requires_grad=True)
Batch size 1 (final) - grad: tensor([-20.])
Batch size 1 (final) - weight: tensor([1.2000], requires_grad=True)
As you can see, the model kept the same parameter for all batches, while the gradients were accumulate, and there is a single update at the end. Note that the loss needs to be scaled per batch, in order to have the same significance over the whole dataset as if you used a single batch.
While in this example, the whole dataset is used before performing the update, you can easily change that to update the parameters after a certain number of batches, but you have to remember to zero out the gradients after an optimiser step was taken. The general recipe would be:
accumulation_steps = 10
for i, batch in enumerate(batches):
# Scale the loss to the mean of the accumulated batch size
loss = calculate_loss(batch) / accumulation_steps
loss.backward()
if (i - 1) % accumulation_steps == 0:
optimizer.step()
# Reset gradients, for the next accumulated batches
optimizer.zero_grad()
You can find that recipe and more techniques for working with large batch sizes in HuggingFace - Training Neural Nets on Larger Batches: Practical Tips for 1-GPU, Multi-GPU & Distributed setups.