I\'m trying to interpolate a 1D tensor in tensorflow (I effectively want the equivalent of np.interp). Since I couldn\'t find a similar tensorflow op, I had to perform the i
I don't know the source of your error, but I can tell you that tf.while_loop
is very likely to be very slow. You can implement linear interpolation without loops like this:
import numpy as np
import tensorflow as tf
xaxis = tf.placeholder(tf.float32, shape=100, name='xaxis')
yaxis = tf.placeholder(tf.float32, shape=100, name='yaxis')
query = tf.placeholder(tf.float32, name='query')
# Add additional elements at the beginning and end for extrapolation
xaxis_pad = tf.concat([[tf.minimum(query - 1, xaxis[0])], xaxis, [tf.maximum(query + 1, xaxis[-1])]], axis=0)
yaxis_pad = tf.concat([yaxis[:1], yaxis, yaxis[-1:]], axis=0)
# Find the index of the interval containing query
cmp = tf.cast(query >= xaxis_pad, dtype=tf.int32)
diff = cmp[1:] - cmp[:-1]
idx = tf.argmin(diff)
# Interpolate
alpha = (query - xaxis_pad[idx]) / (xaxis_pad[idx + 1] - xaxis_pad[idx])
res = alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx]
# Test with f(x) = 2 * x
q = 5.4
x = np.arange(100)
y = 2 * x
with tf.Session() as sess:
q_interp = sess.run(res, feed_dict={xaxis: x, yaxis: y, query: q})
print(q_interp)
>>> 10.8
The padding part is just to avoid trouble if you pass values out of the range, but otherwise it is just a matter of comparing and finding where the values start to be bigger than query
.