Strange convergence in simple Neural Network

后端 未结 1 1407
予麋鹿
予麋鹿 2021-01-25 07:31

I\'ve been struggling for some time with building a simplistic NN in Java. I\'ve been working on and off on this project for a few months and I wanna finish it. My main issue is

相关标签:
1条回答
  • 2021-01-25 07:56

    I've seen serval problems with your code like your weight updates are incorrect for example. I'd also strongly recommend you to organize your code cleaner by introducing methods.

    Backpropagation is usually hard to implement efficiently but the formal definitions are easily translated into any language. I'd not recommend you to look at code for studying neural nets. Look at the math and try to understand that. This makes you way more flexible about implementing one from scratch.

    I can give you some hints by describing the forward and backward pass in pseudo code.

    As a matter of notation, I use i for the input, j for the hidden and k for the output layer. The bias of the input layer is then bias_i. The weights are w_mn for the weights connecting one node to another. The activation is a(x) and it's derivative a'(x).

    Forward pass:

    for each n of j
           dot = 0
           for each m of i
                  dot += m*w_mn
           n = a(dot + bias_i)
    

    The identical applies for the output layer k and the hidden layer j. Hence, just replace j by k and i by j for the this step.

    Backward pass:

    Calculate delta for output nodes:

    for each n of k
           d_n = a'(n)(n - target)
    

    Here, target is the expected output, n the output of the current output node. d_n is the delta of this node. An important note here is, that the derivatives of the logistic and the tanh function contain the output of the original function and this values don't have to be reevaluated. The logistic function is f(x) = 1/(1+e^(-x)) and it's derivative f'(x) = f(x)(1-f(x)). Since the value at each output node n was previously evaluated with f(x), one can simply apply n(1-n) as the derivative. In the case above this would calculate the delta as follows:

    d_n = n(1-n)(n - target)
    

    In the same fashion, calculate the deltas for the hidden nodes.

    for each n of j
          d_n = 0
          for each m of k
                 d_n += d_m*w_jk
          d_n = a'(n)*d_n
    

    Next step is to perform the weight update using the gradients. This is done by an algorithm called gradient descent. Without going into much detail, this can be accomplished as follows:

    for each n of j
          for each m of k
                w_nm -= learning_rate*n*d_m
    

    Same applies for the layer above. Just replace j by i and k by j.

    To update the biases, just sum up the deltas of the connected nodes, multiply this by the learning rate and subtract this product from the specific bias.

    0 讨论(0)
提交回复
热议问题