How to get the dimensions of a tensor (in TensorFlow) at graph construction time?

后端 未结 6 887
你的背包
你的背包 2021-01-30 20:54

I am trying an Op that is not behaving as expected.

graph = tf.Graph()
with graph.as_default():
  train_dataset = tf.placeholder(tf.int32, shape=[128, 2])
  embe         


        
6条回答
  •  野的像风
    2021-01-30 21:48

    Let's make it simple as hell. If you want a single number for the number of dimensions like 2, 3, 4, etc., then just use tf.rank(). But, if you want the exact shape of the tensor then use tensor.get_shape()

    with tf.Session() as sess:
       arr = tf.random_normal(shape=(10, 32, 32, 128))
       a = tf.random_gamma(shape=(3, 3, 1), alpha=0.1)
       print(sess.run([tf.rank(arr), tf.rank(a)]))
       print(arr.get_shape(), ", ", a.get_shape())     
    
    
    # for tf.rank()    
    [4, 3]
    
    # for tf.get_shape()
    Output: (10, 32, 32, 128) , (3, 3, 1)
    

提交回复
热议问题