Binary search and interpolation in tensorflow

前端 未结 2 984
日久生厌
日久生厌 2021-01-07 07:09

I\'m trying to interpolate a 1D tensor in tensorflow (I effectively want the equivalent of np.interp). Since I couldn\'t find a similar tensorflow op, I had to perform the i

2条回答
  •  悲哀的现实
    2021-01-07 07:45

    I don't know the source of your error, but I can tell you that tf.while_loop is very likely to be very slow. You can implement linear interpolation without loops like this:

    import numpy as np
    import tensorflow as tf
    
    xaxis = tf.placeholder(tf.float32, shape=100, name='xaxis')
    yaxis = tf.placeholder(tf.float32, shape=100, name='yaxis')
    query = tf.placeholder(tf.float32, name='query')
    
    # Add additional elements at the beginning and end for extrapolation
    xaxis_pad = tf.concat([[tf.minimum(query - 1, xaxis[0])], xaxis, [tf.maximum(query + 1, xaxis[-1])]], axis=0)
    yaxis_pad = tf.concat([yaxis[:1], yaxis, yaxis[-1:]], axis=0)
    
    # Find the index of the interval containing query
    cmp = tf.cast(query >= xaxis_pad, dtype=tf.int32)
    diff = cmp[1:] - cmp[:-1]
    idx = tf.argmin(diff)
    
    # Interpolate
    alpha = (query - xaxis_pad[idx]) / (xaxis_pad[idx + 1] - xaxis_pad[idx])
    res = alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx]
    
    # Test with f(x) = 2 * x
    q = 5.4
    x = np.arange(100)
    y = 2 * x
    with tf.Session() as sess:
        q_interp = sess.run(res, feed_dict={xaxis: x, yaxis: y, query: q})
    print(q_interp)
    >>> 10.8
    

    The padding part is just to avoid trouble if you pass values out of the range, but otherwise it is just a matter of comparing and finding where the values start to be bigger than query.

提交回复
热议问题