问题
For numpy we have
threshold = 3
a = np.array([1,2,3,4,5,6])
a[a>=3] = 199
# a is [1, 2, 199, 199, 199, 199]
How to write a similar code in tensorflow 2
b = tf.Variable(a)
Thanks.
回答1:
Sure, you can use tf.where to conditionally set values:
b = tf.Variable(a)
tf.where(b >= 3, 199, b)
# <tf.Tensor: shape=(6,), dtype=int64, numpy=array([ 1, 2, 199, 199, 199, 199])>
来源:https://stackoverflow.com/questions/65449671/conditional-assignment-of-tf-variable-in-tensorflow-2