How to get the dimensions of a tensor (in TensorFlow) at graph construction time?

后端 未结 6 886
你的背包
你的背包 2021-01-30 20:54

I am trying an Op that is not behaving as expected.

graph = tf.Graph()
with graph.as_default():
  train_dataset = tf.placeholder(tf.int32, shape=[128, 2])
  embe         


        
6条回答
  •  轻奢々
    轻奢々 (楼主)
    2021-01-30 21:31

    I see most people confused about tf.shape(tensor) and tensor.get_shape() Let's make it clear:

    1. tf.shape

    tf.shape is used for dynamic shape. If your tensor's shape is changable, use it. An example: a input is an image with changable width and height, we want resize it to half of its size, then we can write something like:
    new_height = tf.shape(image)[0] / 2

    1. tensor.get_shape

    tensor.get_shape is used for fixed shapes, which means the tensor's shape can be deduced in the graph.

    Conclusion: tf.shape can be used almost anywhere, but t.get_shape only for shapes can be deduced from graph.

提交回复
热议问题