问题
How to detect source of vanishing gradients in pytorch?
By vanishing gradients, I mean then the training loss doesn't go down below some value, even on limited sets of data.
I am trying to train some network, and I have the above problem, in which I can't even get the network to over fit, but can't understand the source of the problem.
I've spent a long time googling this, and only found ways to prevent over fitting, but nothing about under fitting, or specifically, vanishing gradients.
What I did find:
Pytorch forum discussion about "bad gradients". It only refers to exploding gradients, and nan gradients, and leads to here and here which is more of the same.
I know that "making the network larger or more complex" is a general suggested way of causing over fitting (which is desired right now).
I also know that very deep networks can have their gradients vanish.
It is not clear to me that a larger network would solve the problem because it could create its own problem, as I just stated, and again I would not know how to debug this, while still seeing roughly the same behavior.
Changing the architecture to some res-net could help, but also could not, because the problem was not pinpointed to be caused by network depth.
Dead Relu can cause underfitting, and indeed moving to LeakyRelu helps, but still not enough.
How would one debug sources of under fitting in Pytorch, specifically, caused by vanishing gradients?
Instead of shooting blindly, trying things, I would like to be able to properly visualize the gradients in my network to know what I am actually trying to solve instead of guessing.
Surely, I am not the first one to have this requirement, and tools and methodologies were created for this purpose.
I would like to read about them, but don't know what to look for.
The specific net I have right now is irrelevant, as this is a general question about methodology.
回答1:
You can use tensorboard with Pytorch to visualize the training gradients. Add the gradients to a tensorboard histogram during training.
For example...
Let:
model
be your pytorch modelmodel_input
be an example input to your modelrun_name
be a string identifier for your training session
from torch.utils.tensorboard import SummaryWriter
summary_writer = SummaryWriter(comment=run_name)
summary_writer.add_graph(model, model_input, verbose=True)
# Training loop
for step_index in ...:
# Calculate loss etc
for name, param in model.named_parameters():
summary_writer.add_histogram(f'{name}.grad', param.grad, step_index)
References:
- https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html
- https://discuss.pytorch.org/t/is-there-a-way-to-visualize-the-gradient-path-of-the-back-propagation-of-the-entire-network/44322/4
- https://debuggercafe.com/track-your-pytorch-deep-learning-project-with-tensorboard/
来源:https://stackoverflow.com/questions/66137298/how-to-detect-source-of-under-fitting-and-vanishing-gradients-in-pytorch