Getting around tf.argmax which is not differentiable

匿名 (未验证) 提交于 2019-12-03 01:34:02

问题:

I've written a custom loss function for my neural network but it can't compute any gradients. I thinks it is because I need the index of the highest value and are therefore using argmax to get this index.

As argmax is not differentiable I to get around this but I don't know how it is possible.

Can anyone help?

回答1:

If you are cool with approximates,

import tensorflow as tf import numpy as np  sess = tf.Session() x = tf.placeholder(dtype=tf.float32, shape=(None,)) beta = tf.placeholder(dtype=tf.float32)  # Pseudo-math for the below # y = sum( i * exp(beta * x[i]) ) / sum( exp(beta * x[i]) ) y = tf.reduce_sum(tf.cumsum(tf.ones_like(x)) * tf.exp(beta * x) / tf.reduce_sum(tf.exp(beta * x))) - 1  print("I can compute the gradient", tf.gradients(y, x))  for run in range(10):     data = np.random.randn(10)     print(data.argmax(), sess.run(y, feed_dict={x:data/np.linalg.norm(data), beta:1e2})) 

This is using a trick that computing the mean in low temperature environments gives to the approximate maximum of the probability space. Low temperature in this case correlates with beta being very large.

In fact, as beta approaches infinity, my algorithm will converge to the maximum (assuming the maximum is unique). Unfortunately, beta can't get too large before you have numerical errors and get NaN, but there are tricks to solve that I can go into if you care.

The output looks something like,

0 2.24459 9 9.0 8 8.0 4 4.0 4 4.0 8 8.0 9 9.0 6 6.0 9 8.99995 1 1.0 

So you can see that it messes up in some spots, but often gets the right answer. Depending on your algorithm, this might be fine.



回答2:

As aidan suggested, it's just a softargmax stretched to the limits by beta. We can use tf.nn.softmax to get around the numerical issues:

def softargmax(x, beta=1e10):   x = tf.convert_to_tensor(x)   x_range = tf.range(x.shape.as_list()[-1], dtype=x.dtype)   return tf.reduce_sum(tf.nn.softmax(x*beta) * x_range, axis=-1) 


回答3:

In the case that the value range of your input is positive and you do not need the exact index of the maximum value but it's one-hot form is enough, you can use the sign function as such:

import tensorflow as tf import numpy as np  sess = tf.Session() x = tf.placeholder(dtype=tf.float32, shape=(None,))  y = tf.sign(tf.reduce_max(x,axis=-1,keepdims=True)-x) y = (y-1)*(-1)  print("I can compute the gradient", tf.gradients(y, x))  for run in range(10):     data = np.random.random(10)     print(data.argmax(), sess.run(y, feed_dict={x:data})) 


回答4:

tf.argmax is not differentiable because it returns an integer index. tf.reduce_max and tf.maximum are differentiable



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!