Differentiable round function in Tensorflow?

前端 未结 6 1046
面向向阳花
面向向阳花 2021-02-04 20:27

So the output of my network is a list of propabilities, which I then round using tf.round() to be either 0 or 1, this is crucial for this project. I then found out that tf.roun

6条回答
  •  深忆病人
    2021-02-04 21:24

    This works for me:

    x_rounded_NOT_differentiable = tf.round(x)
    x_rounded_differentiable = (x - (tf.stop_gradient(x) - x_rounded_NOT_differentiable))
    

提交回复
热议问题