How to update a subset of 2D tensor in Tensorflow?

后端 未结 2 1096
梦谈多话
梦谈多话 2021-01-19 11:59

I want to update an index in a 2D tensor with value 0. So data is a 2D tensor whose 2nd row 2nd column index value is to be replaced by 0. However, I am getting a type error

2条回答
  •  小蘑菇
    小蘑菇 (楼主)
    2021-01-19 12:29

    tf.scatter_update could only be applied to Variable type. data in your code IS a Variable, while data2 IS NOT, because the return type of tf.reshape is Tensor.

    Solution:

    for tensorflow after v1.0

    data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
    row = tf.gather(data, 2)
    new_row = tf.concat([row[:2], tf.constant([0]), row[3:]], axis=0)
    sparse_update = tf.scatter_update(data, tf.constant(2), new_row)
    

    for tensorflow before v1.0

    data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
    row = tf.gather(data, 2)
    new_row = tf.concat(0, [row[:2], tf.constant([0]), row[3:]])
    sparse_update = tf.scatter_update(data, tf.constant(2), new_row)
    

提交回复
热议问题