Why is this TensorFlow implementation vastly less successful than Matlab's NN?

后端 未结 2 944
一整个雨季
一整个雨季 2021-01-30 11:07

As a toy example I\'m trying to fit a function f(x) = 1/x from 100 no-noise data points. The matlab default implementation is phenomenally successful with mean squa

2条回答
  •  囚心锁ツ
    2021-01-30 11:40

    btw, here's a slightly cleaned up version of the above that cleans up some of the shape issues and unnecessary bouncing between tf and np. It achieves 3e-08 after 40k steps, or about 1.5e-5 after 4000:

    import tensorflow as tf
    import numpy as np
    
    def weight_variable(shape):
      initial = tf.truncated_normal(shape, stddev=0.1)
      return tf.Variable(initial)
    
    def bias_variable(shape):
      initial = tf.constant(0.1, shape=shape)
      return tf.Variable(initial)
    
    xTrain = np.linspace(0.2, 0.8, 101).reshape([1, -1])
    yTrain = (1/xTrain)
    
    x = tf.placeholder(tf.float32, [1,None])
    hiddenDim = 10
    
    b = bias_variable([hiddenDim,1])
    W = weight_variable([hiddenDim, 1])
    
    b2 = bias_variable([1])
    W2 = weight_variable([1, hiddenDim])
    
    hidden = tf.nn.sigmoid(tf.matmul(W, x) + b)
    y = tf.matmul(W2, hidden) + b2
    
    # Minimize the squared errors.                                                                
    loss = tf.reduce_mean(tf.square(y - yTrain))
    step = tf.Variable(0, trainable=False)
    rate = tf.train.exponential_decay(0.15, step, 1, 0.9999)
    optimizer = tf.train.AdamOptimizer(rate)
    train = optimizer.minimize(loss, global_step=step)
    init = tf.initialize_all_variables()
    
    # Launch the graph                                                                            
    sess = tf.Session()
    sess.run(init)
    
    for step in xrange(0, 40001):
        train.run({x: xTrain}, sess)
        if step % 500 == 0:
            print loss.eval({x: xTrain}, sess)
    

    All that said, it's probably not too surprising that LMA is doing better than a more general DNN-style optimizer for fitting a 2D curve. Adam and the rest are targeting very high dimensionality problems, and LMA starts to get glacially slow for very large networks (see 12-15).

提交回复
热议问题