the asterisk in tf.gather_nd in python2.7 rise syntax error

后端 未结 2 1176
小蘑菇
小蘑菇 2021-01-21 08:35

I am using Python2.7, and I can\'t update it, and I have this line of code, which raise an error at the asterisk, and I don\'t know why? And how to fix!

inp = tf.         


        
相关标签:
2条回答
  • 2021-01-21 08:53

    That asterisk syntax is not available in Python 2. It was added in Python 3.5 (PEP 448) which was 7 years ago.

    The Python 2 equivalent was

    o = tf.gather_nd(inp, [(i,j) for (i,j) in enumerate(am)])
    

    But you really should not be using Python 2 or investing time in learning it. You don't have to "update" your existing Python 2 installation, if you need it to run legacy code. You can have Python 3.8 running side-by-side with Python 2 if you want. For work reasons I have 3.8, 3.7, 3.6 and 2.7 side-by-side on my machine without problems.

    0 讨论(0)
  • 2021-01-21 08:54

    The syntax error in your question has been explained by BoarGules. With respect to the problem that you are trying to solve, you can get the result you want with something like this:

    import tensorflow as tf
    
    with tf.Graph().as_default(), tf.Session() as sess:
        # In TF 2.x: tf.random.set_seed
        tf.random.set_random_seed(0)
        # Input data
        inp = tf.random.uniform(shape=[4, 6, 2], maxval=100, dtype=tf.int32)
    
        # Find index of greatest value in last two dimensions
        s = tf.shape(inp)
        inp_res = tf.reshape(inp, [s[0], -1])
        max_idx = tf.math.argmax(inp_res, axis=1, output_type=s.dtype)
        # Get row index dividing by number of columns
        max_row_idx = max_idx // s[2]
        # Get rows with max values
        res = tf.gather_nd(inp, tf.expand_dims(max_row_idx, axis=1), batch_dims=1)
        # Print input and result
        print(*sess.run((inp, res)), sep='\n')
    

    Output:

    [[[22 78]
      [75 70]
      [31 10]
      [67  9]
      [70 45]
      [ 5 33]]
    
     [[82 83]
      [82 81]
      [73 58]
      [18 18]
      [57 11]
      [50 71]]
    
     [[84 55]
      [80 72]
      [93  1]
      [98 27]
      [36  6]
      [10 95]]
    
     [[83 24]
      [19  9]
      [46 48]
      [90 87]
      [50 26]
      [55 62]]]
    [[22 78]
     [82 83]
     [98 27]
     [90 87]]
    
    0 讨论(0)
提交回复
热议问题