tensorflow set block within 2d tensor to constant value

后端 未结 1 1951
遥遥无期
遥遥无期 2021-01-22 08:04

Here\'s a minimal example of what I\'m trying to do:


import numpy as np
import tensorflow as tf

map = tf.placeholder(tf.float32)
xmin = tf.placeholder(tf.int3         


        
1条回答
  •  -上瘾入骨i
    2021-01-22 08:30

    This should do the trick:

    import numpy as np
    import tensorflow as tf
    
    map = tf.placeholder(tf.float32)
    xmin = tf.placeholder(tf.int32)
    xmax = tf.placeholder(tf.int32)
    ymin = tf.placeholder(tf.int32)
    ymax = tf.placeholder(tf.int32)
    
    post_operation_map = 2.0 * map + 1.0
    
    # Fill block with nan
    shape = tf.shape(post_operation_map)
    dtype = post_operation_map.dtype
    shape_x, shape_y = shape[0], shape[1]
    x_range = tf.range(shape_x)[:, tf.newaxis]
    y_range = tf.range(shape_y)[tf.newaxis, :]
    mask = (xmin <= x_range) & (x_range < xmax) & (ymin <= y_range) & (y_range < ymax)
    post_operation_map = tf.where(
        mask, tf.fill(shape, tf.constant(np.nan, dtype)), post_operation_map)
    
    with tf.Session() as sess:
        feed = {map:np.random.rand(8, 6),
                xmin: 1,
                xmax: 4,
                ymin: 2,
                ymax: 5}
        print(sess.run(post_operation_map, feed_dict=feed))
    

    Output:

    [[ 2.50152206  1.01042879  2.88725328  1.27295971  2.99401283  1.84210801]
     [ 2.98338175  2.26357031         nan         nan         nan  2.68635511]
     [ 1.00461781  2.00605297         nan         nan         nan  2.16447353]
     [ 2.15073347  1.64699006         nan         nan         nan  1.97648919]
     [ 1.7709868   1.65353572  1.6698066   2.26957846  2.75840473  1.23831809]
     [ 1.51848006  1.45277226  1.46150732  1.08112144  2.87904882  2.62266874]
     [ 1.86656547  1.5177052   1.36731267  2.70582867  1.57994771  2.48001719]
     [ 1.89354372  2.88848639  1.49879098  1.36527407  1.47415829  2.95422626]]
    

    0 讨论(0)
提交回复
热议问题