conditional assignment of tf.variable in Tensorflow 2

ぃ、小莉子 提交于 2021-01-05 07:21:24

问题


For numpy we have

threshold = 3
a = np.array([1,2,3,4,5,6])
a[a>=3] = 199

# a is [1, 2, 199, 199, 199, 199]

How to write a similar code in tensorflow 2

b = tf.Variable(a)

Thanks.


回答1:


Sure, you can use tf.where to conditionally set values:

b = tf.Variable(a)
tf.where(b >= 3, 199, b)
# <tf.Tensor: shape=(6,), dtype=int64, numpy=array([  1,   2, 199, 199, 199, 199])>


来源:https://stackoverflow.com/questions/65449671/conditional-assignment-of-tf-variable-in-tensorflow-2

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!