How does tensorflow handle non differentiable nodes during gradient calculation?

后端 未结 1 1638
粉色の甜心
粉色の甜心 2020-12-31 19:05

I understood the concept of automatic differentiation, but couldn\'t find any explanation how tensorflow calculates the error gradient for non differentiable functions as fo

相关标签:
1条回答
  • 2020-12-31 19:32

    In the case of tf.where, you have a function with three inputs, condition C, value on true T and value on false F, and one output Out. The gradient receives one value and has to return three values. Currently, no gradient is computed for the condition (that would hardly make sense), so you just need to do the gradients for T and F. Assuming the input and the outputs are vectors, imagine C[0] is True. Then Out[0] comes from T[0], and its gradient should propagate back. On the other hand, F[0] would have been discarded, so its gradient should be made zero. If Out[1] were False, then the gradient for F[1] should propagate but not for T[1]. So, in short, for T you should propagate the given gradient where C is True and make it zero where it is False, and the opposite for F. If you look at the implementation of the gradient of tf.where (Select operation), it does exactly that:

    @ops.RegisterGradient("Select")
    def _SelectGrad(op, grad):
      c = op.inputs[0]
      x = op.inputs[1]
      zeros = array_ops.zeros_like(x)
      return (None, array_ops.where(c, grad, zeros), array_ops.where(
          c, zeros, grad))
    

    Note the input values themselves are not used in the computation, that will be done by the gradients of the operation producing those inputs. For tf.cond, the code is a bit more complicated, because the same operation (Merge) is used in different contexts, and also tf.cond also uses Switch operations inside. However the idea is the same. Essentially, Switch operations are used for each input, so the input that was activated (the first if the condition was True and the second otherwise) gets the received gradient and the other input gets a "switched off" gradient (like None), and does not propagate back further.

    0 讨论(0)
提交回复
热议问题