the asterisk in tf.gather_nd in python2.7 rise syntax error

后端 未结 2 1177
小蘑菇
小蘑菇 2021-01-21 08:35

I am using Python2.7, and I can\'t update it, and I have this line of code, which raise an error at the asterisk, and I don\'t know why? And how to fix!

inp = tf.         


        
2条回答
  •  孤街浪徒
    2021-01-21 08:54

    The syntax error in your question has been explained by BoarGules. With respect to the problem that you are trying to solve, you can get the result you want with something like this:

    import tensorflow as tf
    
    with tf.Graph().as_default(), tf.Session() as sess:
        # In TF 2.x: tf.random.set_seed
        tf.random.set_random_seed(0)
        # Input data
        inp = tf.random.uniform(shape=[4, 6, 2], maxval=100, dtype=tf.int32)
    
        # Find index of greatest value in last two dimensions
        s = tf.shape(inp)
        inp_res = tf.reshape(inp, [s[0], -1])
        max_idx = tf.math.argmax(inp_res, axis=1, output_type=s.dtype)
        # Get row index dividing by number of columns
        max_row_idx = max_idx // s[2]
        # Get rows with max values
        res = tf.gather_nd(inp, tf.expand_dims(max_row_idx, axis=1), batch_dims=1)
        # Print input and result
        print(*sess.run((inp, res)), sep='\n')
    

    Output:

    [[[22 78]
      [75 70]
      [31 10]
      [67  9]
      [70 45]
      [ 5 33]]
    
     [[82 83]
      [82 81]
      [73 58]
      [18 18]
      [57 11]
      [50 71]]
    
     [[84 55]
      [80 72]
      [93  1]
      [98 27]
      [36  6]
      [10 95]]
    
     [[83 24]
      [19  9]
      [46 48]
      [90 87]
      [50 26]
      [55 62]]]
    [[22 78]
     [82 83]
     [98 27]
     [90 87]]
    

提交回复
热议问题