问题
My question is related to this one
I am working to implement the method described in the article https://drive.google.com/file/d/1s-qs-ivo_fJD9BU_tM5RY8Hv-opK4Z-H/view . The final algorithm to use is here (it is on page 6):
- d are units vector
- xhi is a non-null number
- D is the loss function (sparse cross-entropy in my case)
The idea is to do an adversarial training, by modifying the data in the direction where the network is the most sensible to small changes and training the network with the modified data but with the same label as the original data.
The loss function used to train the model is here:
- l is a loss measure on the labelled data
- Rvadv is the value inside the gradient in the picture of algorithm 1
- the article chose alpha = 1
The idea is to incorporate the performances of the model for the labelled dataset in the loss
I am trying to implement this method in Keras with the MNIST dataset and a mini-batch of 100 data. When I tried to do the final gradient descent to update the weights, after some iterations I have Nan values that appear, and I don't know why. I posted the notebook on a collab session (I don't for how much time it will stand so I also post the code in a gist):
- collab session: https://colab.research.google.com/drive/1lowajNWD-xvrJDEcVklKOidVuyksFYU3?usp=sharing
- gist : https://gist.github.com/DridriLaBastos/e82ec90bd699641124170d07e5a8ae4c
回答1:
It's kind of stander problem of NaN
in training, I suggest you read this answer about issue NaN with Adam solver for the cause and solution in common case.
Basically I just did following two change and code running without NaN
in gradients:
Reduce the learning rate in optimizer at
model.compile
tooptimizer=tf.keras.optimizers.Adam(learning_rate=1e-3)
,Replace the
C = [loss(label,pred) for label, pred in zip(yBatchTrain,dumbModel(dataNoised,training=False))]
toC = loss(yBatchTrain,dumbModel(dataNoised,training=False))
If you still have this kind of error then the next few thing you could try is:
- Clip the loss or gradient
- Switch all tensor from
tf.float32
totf.float64
Next time when you facing this kind of error, you could using tf.debugging.check_numerics to find root cause of the NaN
来源:https://stackoverflow.com/questions/65643459/keras-nan-value-when-computing-the-loss