Differentiable round function in Tensorflow?

前端 未结 6 1062
面向向阳花
面向向阳花 2021-02-04 20:27

So the output of my network is a list of propabilities, which I then round using tf.round() to be either 0 or 1, this is crucial for this project. I then found out that tf.roun

6条回答
  •  别那么骄傲
    2021-02-04 21:02

    Rounding is a fundamentally nondifferentiable function, so you're out of luck there. The normal procedure for this kind of situation is to find a way to either use the probabilities, say by using them to calculate an expected value, or by taking the maximum probability that is output and choose that one as the network's prediction. If you aren't using the output for calculating your loss function though, you can go ahead and just apply it to the result and it doesn't matter if it's differentiable. Now, if you want an informative loss function for the purpose of training the network, maybe you should consider whether keeping the output in the format of probabilities might actually be to your advantage (it will likely make your training process smoother)- that way you can just convert the probabilities to actual estimates outside of the network, after training.

提交回复
热议问题