Note: All code for a self-contained example to reproduce my problem can be found below.
I have a tf.keras.models.Model
instance and need to train it with a
Replacing the low-level TF loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.stop_gradient(labels), logits=model_output))
by its Keras equivalent
loss = tf.reduce_mean(tf.keras.backend.categorical_crossentropy(target=labels, output=model_output, from_logits=True))
does the trick. Now the low-level TensorFlow training loop behaves just like model.fit()
.
However, I don't know why this is. If anyone knows why tf.keras.backend.categorical_crossentropy()
behaves well while tf.nn.softmax_cross_entropy_with_logits_v2()
doesn't work at all, please post an answer.
Another important note:
In order to train a tf.keras
model with a low-level TF training loop and a tf.data.Dataset
object, one generally shouldn't call the model on the iterator output. That is, one shouldn't do this:
model_output = model(features)
Instead, one should create a model in which the input layer is set to build on the iterator output instead of creating a placeholder, like so:
input_tensor = tf.keras.layers.Input(tensor=features)
This doesn't matter in this example, but it becomes relevant if any layers in the model have internal updates that need to be run during the training (e.g. BatchNormalization).
You apply a softmax activation on your last layer
x = tf.keras.layers.Dense(num_classes, activation='softmax', kernel_initializer='he_normal')(x)
and you apply again a softmax when using
tf.nn.softmax_cross_entropy_with_logits_v2
as it expects unscaled logits. From the documentation:
WARNING: This op expects unscaled logits, since it performs a softmax on logits internally for efficiency. Do not call this op with the output of softmax, as it will produce incorrect results.
Thus, remove the softmax activation of your last layer and it should work.
x = tf.keras.layers.Dense(num_classes, activation=None, kernel_initializer='he_normal')(x)
[...]
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.stop_gradient(labels), logits=model_output))