I\'m trying to interpolate a 1D tensor in tensorflow (I effectively want the equivalent of np.interp). Since I couldn\'t find a similar tensorflow op, I had to perform the i
Found the issue - tensorflow does not like python integers as a parameter to cond - it needs to be wrapped in a constant first. This code works:
with tf.name_scope("binsearch"):
m_one = tf.constant(-1, dtype=tf.int32, name='minus_one')
up = tf.Variable(0, dtype=tf.int32, name='up')
mid = tf.Variable(0, dtype=tf.int32, name='mid')
down = tf.Variable(0, dtype=tf.int32, name='down')
done = tf.Variable(-1, dtype=tf.int32, name='done')
def cond(up, down, mid, done):
return tf.logical_and(done<0,up-down>1)
def body(up, down, mid, done):
def fn1():
return mid, down, (mid+down)//2, tf.cond(tf.gather(xaxis, mid-1)<query, lambda:mid-1, lambda: m_one)
def fn2():
return up, mid, (up+mid)//2, tf.cond(tf.gather(xaxis, mid+1)>query, lambda:mid, lambda: m_one)
return tf.cond(tf.gather(xaxis, mid)>query, fn1, fn2 )
up, down, mid, done = tf.while_loop(cond, body, (xaxis.shape[0]-1, 0, (xaxis.shape[0]-1)//2, -1))
I don't know the source of your error, but I can tell you that tf.while_loop
is very likely to be very slow. You can implement linear interpolation without loops like this:
import numpy as np
import tensorflow as tf
xaxis = tf.placeholder(tf.float32, shape=100, name='xaxis')
yaxis = tf.placeholder(tf.float32, shape=100, name='yaxis')
query = tf.placeholder(tf.float32, name='query')
# Add additional elements at the beginning and end for extrapolation
xaxis_pad = tf.concat([[tf.minimum(query - 1, xaxis[0])], xaxis, [tf.maximum(query + 1, xaxis[-1])]], axis=0)
yaxis_pad = tf.concat([yaxis[:1], yaxis, yaxis[-1:]], axis=0)
# Find the index of the interval containing query
cmp = tf.cast(query >= xaxis_pad, dtype=tf.int32)
diff = cmp[1:] - cmp[:-1]
idx = tf.argmin(diff)
# Interpolate
alpha = (query - xaxis_pad[idx]) / (xaxis_pad[idx + 1] - xaxis_pad[idx])
res = alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx]
# Test with f(x) = 2 * x
q = 5.4
x = np.arange(100)
y = 2 * x
with tf.Session() as sess:
q_interp = sess.run(res, feed_dict={xaxis: x, yaxis: y, query: q})
print(q_interp)
>>> 10.8
The padding part is just to avoid trouble if you pass values out of the range, but otherwise it is just a matter of comparing and finding where the values start to be bigger than query
.