Binary search and interpolation in tensorflow

前端 未结 2 983
日久生厌
日久生厌 2021-01-07 07:09

I\'m trying to interpolate a 1D tensor in tensorflow (I effectively want the equivalent of np.interp). Since I couldn\'t find a similar tensorflow op, I had to perform the i

相关标签:
2条回答
  • 2021-01-07 07:37

    Found the issue - tensorflow does not like python integers as a parameter to cond - it needs to be wrapped in a constant first. This code works:

    with tf.name_scope("binsearch"):
        m_one = tf.constant(-1, dtype=tf.int32, name='minus_one')
        up    = tf.Variable(0, dtype=tf.int32, name='up')
        mid   = tf.Variable(0, dtype=tf.int32, name='mid')
        down  = tf.Variable(0, dtype=tf.int32, name='down')
        done  = tf.Variable(-1, dtype=tf.int32, name='done')
    
        def cond(up, down, mid, done):
            return tf.logical_and(done<0,up-down>1)
    
        def body(up, down, mid, done):
    
            def fn1():
                return mid, down, (mid+down)//2, tf.cond(tf.gather(xaxis, mid-1)<query, lambda:mid-1, lambda: m_one)
    
            def fn2():
                return up, mid, (up+mid)//2, tf.cond(tf.gather(xaxis, mid+1)>query, lambda:mid, lambda: m_one)
    
            return tf.cond(tf.gather(xaxis, mid)>query, fn1, fn2 )
    
        up, down, mid, done = tf.while_loop(cond, body, (xaxis.shape[0]-1, 0, (xaxis.shape[0]-1)//2, -1))
    
    0 讨论(0)
  • 2021-01-07 07:45

    I don't know the source of your error, but I can tell you that tf.while_loop is very likely to be very slow. You can implement linear interpolation without loops like this:

    import numpy as np
    import tensorflow as tf
    
    xaxis = tf.placeholder(tf.float32, shape=100, name='xaxis')
    yaxis = tf.placeholder(tf.float32, shape=100, name='yaxis')
    query = tf.placeholder(tf.float32, name='query')
    
    # Add additional elements at the beginning and end for extrapolation
    xaxis_pad = tf.concat([[tf.minimum(query - 1, xaxis[0])], xaxis, [tf.maximum(query + 1, xaxis[-1])]], axis=0)
    yaxis_pad = tf.concat([yaxis[:1], yaxis, yaxis[-1:]], axis=0)
    
    # Find the index of the interval containing query
    cmp = tf.cast(query >= xaxis_pad, dtype=tf.int32)
    diff = cmp[1:] - cmp[:-1]
    idx = tf.argmin(diff)
    
    # Interpolate
    alpha = (query - xaxis_pad[idx]) / (xaxis_pad[idx + 1] - xaxis_pad[idx])
    res = alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx]
    
    # Test with f(x) = 2 * x
    q = 5.4
    x = np.arange(100)
    y = 2 * x
    with tf.Session() as sess:
        q_interp = sess.run(res, feed_dict={xaxis: x, yaxis: y, query: q})
    print(q_interp)
    >>> 10.8
    

    The padding part is just to avoid trouble if you pass values out of the range, but otherwise it is just a matter of comparing and finding where the values start to be bigger than query.

    0 讨论(0)
提交回复
热议问题