Conditional assignment of tensor values in TensorFlow

前端 未结 2 1225
栀梦
栀梦 2020-11-29 07:57

I want to replicate the following numpy code in tensorflow. For example, I want to assign a 0 to all tensor indices that previously ha

相关标签:
2条回答
  • 2020-11-29 08:05

    I'm also just starting to use tensorflow Maybe some one will fill my approach more intuitive

    import tensorflow as tf
    
    conditionVal = 1
    init_a = tf.constant([1, 2, 3, 1], dtype=tf.int32, name='init_a')
    a = tf.Variable(init_a, dtype=tf.int32, name='a')
    target = tf.fill(a.get_shape(), conditionVal, name='target')
    
    init = tf.initialize_all_variables()
    condition = tf.not_equal(a, target)
    defaultValues = tf.zeros(a.get_shape(), dtype=a.dtype)
    calculate = tf.select(condition, a, defaultValues)
    
    with tf.Session() as session:
        session.run(init)
        session.run(calculate)
        print(calculate.eval())
    

    main trouble is that it is difficult to implement "custom logic". if you could not explain your logic within linear math terms you need to write "custom op" library for tensorflow (more details here)

    0 讨论(0)
  • 2020-11-29 08:12

    Several comparison operators are available within TensorFlow API.

    However, there is nothing equivalent to the concise NumPy syntax when it comes to manipulating the tensors directly. You have to make use of individual comparison, where and assign operators to perform the same action.

    Equivalent code to your NumPy example is this:

    import tensorflow as tf
    
    a = tf.Variable( [1,2,3,1] )    
    start_op = tf.global_variables_initializer()    
    comparison = tf.equal( a, tf.constant( 1 ) )    
    conditional_assignment_op = a.assign( tf.where (comparison, tf.zeros_like(a), a) )
    
    with tf.Session() as session:
        # Equivalent to: a = np.array( [1, 2, 3, 1] )
        session.run( start_op )
        print( a.eval() )    
        # Equivalent to: a[a==1] = 0
        session.run( conditional_assignment_op )
        print( a.eval() )
    
    # Output is:
    # [1 2 3 1]
    # [0 2 3 0]
    

    The print statements are of course optional, they are just there to demonstrate the code is performing correctly.

    0 讨论(0)
提交回复
热议问题