Binary search and interpolation in tensorflow

前端 未结 2 979
日久生厌
日久生厌 2021-01-07 07:09

I\'m trying to interpolate a 1D tensor in tensorflow (I effectively want the equivalent of np.interp). Since I couldn\'t find a similar tensorflow op, I had to perform the i

2条回答
  •  囚心锁ツ
    2021-01-07 07:37

    Found the issue - tensorflow does not like python integers as a parameter to cond - it needs to be wrapped in a constant first. This code works:

    with tf.name_scope("binsearch"):
        m_one = tf.constant(-1, dtype=tf.int32, name='minus_one')
        up    = tf.Variable(0, dtype=tf.int32, name='up')
        mid   = tf.Variable(0, dtype=tf.int32, name='mid')
        down  = tf.Variable(0, dtype=tf.int32, name='down')
        done  = tf.Variable(-1, dtype=tf.int32, name='done')
    
        def cond(up, down, mid, done):
            return tf.logical_and(done<0,up-down>1)
    
        def body(up, down, mid, done):
    
            def fn1():
                return mid, down, (mid+down)//2, tf.cond(tf.gather(xaxis, mid-1)query, lambda:mid, lambda: m_one)
    
            return tf.cond(tf.gather(xaxis, mid)>query, fn1, fn2 )
    
        up, down, mid, done = tf.while_loop(cond, body, (xaxis.shape[0]-1, 0, (xaxis.shape[0]-1)//2, -1))
    

提交回复
热议问题