How to get the dimensions of a tensor (in TensorFlow) at graph construction time?

后端 未结 6 888
你的背包
你的背包 2021-01-30 20:54

I am trying an Op that is not behaving as expected.

graph = tf.Graph()
with graph.as_default():
  train_dataset = tf.placeholder(tf.int32, shape=[128, 2])
  embe         


        
6条回答
  •  粉色の甜心
    2021-01-30 21:33

    A function to access the values:

    def shape(tensor):
        s = tensor.get_shape()
        return tuple([s[i].value for i in range(0, len(s))])
    

    Example:

    batch_size, num_feats = shape(logits)
    

提交回复
热议问题