问题
I ran into some memory issues (GPU) when running a large RNN network, but I want to keep my batch size reasonable so I wanted to try out gradient accumulation. In a network where you predict the output in one go, that seems self-evident but in an RNN you do multiple forward passes for each input step. Because of that, I fear that my implementation does not work as intended. I started from user albanD's excellent examples here , but I think they should be modified when using an RNN. The reason I think that is because you accumulate much more gradients because you do multiple forwards per sequence.
My current implementation looks like this, at the same time allowing for AMP in PyTorch 1.6 which seems important - everything needs to be called in the right place. Note that this is just an abstract version, which might seem like a lot of code but it is mostly comments.
def train(epochs):
"""Main training loop. Loops for `epoch` number of epochs. Calls `process`."""
for epoch in range(1, epochs + 1):
train_loss = process("train")
valid_loss = process("valid")
# ... check whether we improved over earlier epochs
if lr_scheduler:
lr_scheduler.step(valid_loss)
def process(do):
"""Do a single epoch run through the dataloader of the training or validation set.
Also takes care of optimizing the model after every `gradient_accumulation_steps` steps.
Calls `step` for each batch where it gets the loss from."""
if do == "train":
model.train()
torch.set_grad_enabled(True)
else:
model.eval()
torch.set_grad_enabled(False)
loss = 0.
for batch_idx, batch in enumerate(dataloaders[do]):
step_loss, avg_step_loss = step(batch)
loss += avg_step_loss
if do == "train":
if amp:
scaler.scale(step_loss).backward()
if (batch_idx + 1) % gradient_accumulation_steps == 0:
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# clip in-place
clip_grad_norm_(model.parameters(), 2.0)
scaler.step(optimizer)
scaler.update()
model.zero_grad()
else:
step_loss.backward()
if (batch_idx + 1) % gradient_accumulation_steps == 0:
clip_grad_norm_(model.parameters(), 2.0)
optimizer.step()
model.zero_grad()
# return average loss
return loss / len(dataloaders[do])
def step():
"""Processes one step (one batch) by forwarding multiple times to get a final prediction for a given sequence."""
# do stuff... init hidden state and first input etc.
loss = torch.tensor([0.]).to(device)
for i in range(target_len):
with torch.cuda.amp.autocast(enabled=amp):
# overwrite previous decoder_hidden
output, decoder_hidden = model(decoder_input, decoder_hidden)
# compute loss between predicted classes (bs x classes) and correct classes for _this word_
item_loss = criterion(output, target_tensor[i])
# We calculate the gradients for the average step so that when
# we do take an optimizer.step, it takes into account the mean step_loss
# across batches. So basically (A+B+C)/3 = A/3 + B/3 + C/3
loss += (item_loss / gradient_accumulation_steps)
topv, topi = output.topk(1)
decoder_input = topi.detach()
return loss, loss.item() / target_len
The above does not seem to work as I had hoped, i.e. it still runs into out-of-memory issues very quickly. I think the reason is that step
already accumulates so much information, but I am not sure.
回答1:
For simplicity, I will only take care of amp
enabled gradient accumulation, without amp the idea is the same. And your step presented runs under amp
so let's stick to that.
step
In PyTorch documentation about amp you have an example of gradient accumulation. You should do it inside step
. Each time you run loss.backward()
gradient is accumulated inside tensor leafs which can be optimized by optimizer
. Hence, your step
should look like this (see comments):
def step():
"""Processes one step (one batch) by forwarding multiple times to get a final prediction for a given sequence."""
# You should not accumulate loss on `GPU`, RAM and CPU is better for that
# Use GPU only for calculations, not for gathering metrics etc.
loss = 0
for i in range(target_len):
with torch.cuda.amp.autocast(enabled=amp):
# where decoder_input is from?
# I assume there is one in real code
output, decoder_hidden = model(decoder_input, decoder_hidden)
# Here you divide by accumulation steps
item_loss = criterion(output, target_tensor[i]) / (
gradient_accumulation_steps * target_len
)
scaler.scale(item_loss).backward()
loss += item_loss.detach().item()
# Not sure what was topv for here
_, topi = output.topk(1)
decoder_input = topi.detach()
# No need to return loss now as we did backward above
return loss / target_len
As you detach
decoder_input
anyway (so it is like totally new hidden input without history and parameters will be optimized based on that, not based on all runs) there is no need for backward
in process. Also, you probably don't need decoder_hidden
, if it isn't passed to the network, torch.tensor
filled with zeros is passed implicitly.
Also we should divide by gradient_accumulation_steps * target_len
as that's how many backward
s we will run before single optimization step.
As some of your variables are ill-defined I assume you just made a scheme of what's going on.
Also, if you want the history to be kept you shouldn't detach
decoder_input
, in this case it would look like this:
def step():
"""Processes one step (one batch) by forwarding multiple times to get a final prediction for a given sequence."""
loss = 0
for i in range(target_len):
with torch.cuda.amp.autocast(enabled=amp):
output, decoder_hidden = model(decoder_input, decoder_hidden)
item_loss = criterion(output, target_tensor[i]) / (
gradient_accumulation_steps * target_len
)
_, topi = output.topk(1)
decoder_input = topi
loss += item_loss
scaler.scale(loss).backward()
return loss.detach().cpu() / target_len
This effectively goes through RNN multiple times and will probably raise OOM, not sure what you are after here. If that's the case there's not much you can do AFAIK as the RNN computations are simply too long to fit into the GPU.
process
Only relevant part of this code is presented, so it would be:
loss = 0.0
for batch_idx, batch in enumerate(dataloaders[do]):
# Here everything is detached from graph so we're safe
avg_step_loss = step(batch)
loss += avg_step_loss
if do == "train":
if (batch_idx + 1) % gradient_accumulation_steps == 0:
# You can use unscale as in the example in PyTorch's docs
# just like you did
scaler.unscale_(optimizer)
# clip in-place
clip_grad_norm_(model.parameters(), 2.0)
scaler.step(optimizer)
scaler.update()
# IMO in this case optimizer.zero_grad is more readable
# but it's a nitpicking
optimizer.zero_grad()
# return average loss
return loss / len(dataloaders[do])
Question-like
[...] in an RNN you do multiple forward passes for each input step. Because of that, I fear that my implementation does not work as intended.
It does not matter. For each forward you should usually do one backward (seems to be the case here, see steps for possible options). After that we (usually) don't need loss connected to graph as we already performed backpropagation
, got our gradients and are ready to optimize parameters.
That loss needs to have history, as it goes back to the process loop where backward will be called on it.
No need to call backward
in process as presented.
来源:https://stackoverflow.com/questions/63934070/gradient-accumulation-in-an-rnn