Why does TensorFlow example fail when increasing batch size?

后端 未结 4 607
旧时难觅i
旧时难觅i 2020-12-03 05:25

I was looking at the Tensorflow MNIST example for beginners and found that in this part:

for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(1         


        
相关标签:
4条回答
  • 2020-12-03 05:59

    @dga nicely explained you the reason of such behavior (the cross_entropy becomes too huge) and thus the algorithm will not be able to converge. There are a couple of ways to fix this. He already suggested to decrease the learning rate.

    Gradient descent is the most basic algorithm. Almost all other optimizers will be working properly:

    train_step = tf.train.AdagradOptimizer(0.01).minimize(cross_entropy)
    train_step = tf.train.AdamOptimizer().minimize(cross_entropy)
    train_step = tf.train.FtrlOptimizer(0.01).minimize(cross_entropy)
    train_step = tf.train.RMSPropOptimizer(0.01, 0.1).minimize(cross_entropy)
    

    Another approach is to use tf.nn.softmax_cross_entropy_with_logits which handles numeric instabilities.

    0 讨论(0)
  • 2020-12-03 06:07

    You're using the very basic linear model in the beginners example?

    Here's a trick to debug it - watch the cross-entropy as you increase the batch size (the first line is from the example, the second I just added):

    cross_entropy = -tf.reduce_sum(y_*tf.log(y))
    cross_entropy = tf.Print(cross_entropy, [cross_entropy], "CrossE")
    

    At a batch size of 204, you'll see:

    I tensorflow/core/kernels/logging_ops.cc:64] CrossE[92.37558]
    I tensorflow/core/kernels/logging_ops.cc:64] CrossE[90.107414]
    

    But at 205, you'll see a sequence like this, from the start:

    I tensorflow/core/kernels/logging_ops.cc:64] CrossE[472.02966]
    I tensorflow/core/kernels/logging_ops.cc:64] CrossE[475.11697]
    I tensorflow/core/kernels/logging_ops.cc:64] CrossE[1418.6655]
    I tensorflow/core/kernels/logging_ops.cc:64] CrossE[1546.3833]
    I tensorflow/core/kernels/logging_ops.cc:64] CrossE[1684.2932]
    I tensorflow/core/kernels/logging_ops.cc:64] CrossE[1420.02]
    I tensorflow/core/kernels/logging_ops.cc:64] CrossE[1796.0872]
    I tensorflow/core/kernels/logging_ops.cc:64] CrossE[nan]
    

    Ack - NaN's showing up. Basically, the large batch size is creating such a huge gradient that your model is spiraling out of control -- the updates it's applying are too large, and overshooting the direction it should go by a huge margin.

    In practice, there are a few ways to fix this. You could reduce the learning rate from .01 to, say, .005, which results in a final accuracy of 0.92.

    train_step = tf.train.GradientDescentOptimizer(0.005).minimize(cross_entropy)
    

    Or you could use a more sophisticated optimization algorithm (Adam, Momentum, etc.) that tries to do more to figure out the direction of the gradient. Or you could use a more complex model that has more free parameters across which to disperse that big gradient.

    0 讨论(0)
  • 2020-12-03 06:08

    @dga gave a great answer, but I wanted to expand a little.

    When I wrote the beginners tutorial, I implemented the cost function like so:

    cross_entropy = -tf.reduce_sum(y_*tf.log(y))

    I wrote it that way because that looks most similar to the mathematical definition of cross-entropy. But it might actually be better to do something like this:

    cross_entropy = -tf.reduce_mean(y_*tf.log(y))

    Why might it be nicer to use a mean instead of a sum? Well, if we sum, then doubling the batch size doubles the cost, and also doubles the magnitude of the gradient. Unless we adjust our learning rate (or use an algorithm that adjusts it for us, like @dga suggested) our training will explode! But if we use a mean, then our learning rate becomes kind of independent of our batch size, which is nice.

    I'd encourage you to check out Adam (tf.train.AdamOptimizer()). It's often more tolerant to fiddling with things than SGD.

    0 讨论(0)
  • 2020-12-03 06:08

    Nan occurs when 0*log(0) occurs:

    replace:

    cross_entropy = -tf.reduce_sum(y_*tf.log(y))
    

    with:

    cross_entropy = -tf.reduce_sum(y_*tf.log(y + 1e-10))
    
    0 讨论(0)
提交回复
热议问题