Derivative of neural network with respect to input

前端 未结 3 464
忘掉有多难
忘掉有多难 2021-01-20 18:04

I trained a neural network to do a regression on the sine function and would like to compute the first and second derivative with respect to the input. I tried using the tf.

相关标签:
3条回答
  • 2021-01-20 18:12

    I don't think you can calculate second order derivatives using tf.gradients. Take a look at tf.hessians (what you really want is the diagonal of the Hessian matrix), e.g. [1].

    An alternative is to use tf.GradientTape: [2].

    [1] https://github.com/gknilsen/pyhessian

    [2] https://www.tensorflow.org/api_docs/python/tf/GradientTape

    0 讨论(0)
  • 2021-01-20 18:26

    One possible explanation for what you observed, could be that your function is not derivable two times. It looks as if there are jumps in the 1st derivative around the extrema. If so, the 2nd derivative of the function doesn't really exist and the plot you get higly depends on how the library handles such places.

    Consider the following picture of a non-smooth function, that jumps from 0.5 to -0.5 for all x in {1, 2, ....}. It's slope is 1 in all places except when x is an integer. If you'd try to plot it's derivative, you would probably see a straight line at y=1, which can be easily misinterpreted because if someone just looks at this plot, they could think the function is completely linear and starts from -infinity to +infinity.

    If your results are produced by a neural net which uses RELU, you can try to do the same with the sigmoid activation function. I suppose you won't see that many spikes with this function.

    0 讨论(0)
  • 2021-01-20 18:27

    What you learned was the sinus function and not its derivative : during the training process, you are controlling the error with your cost function that takes into account only the values, but it does not control the slope at all : you could have learned a very noisy function but matching the data points exactly.

    If you are just using the data point in your cost function, you have no guarantee about the derivative you've learned. However, with some advanced training technics, you could also learn such a derivative : https://arxiv.org/abs/1706.04859

    So as a summary, it is not a code issue but only a theoritical issue

    0 讨论(0)
提交回复
热议问题