Tensorflow: When use tf.expand_dims?

后端 未结 2 1761
甜味超标
甜味超标 2020-12-24 13:10

Tensorflow tutorials include the use of tf.expand_dims to add a \"batch dimension\" to a tensor. I have read the docs for this function but it still is rather m

2条回答
  •  孤城傲影
    2020-12-24 13:44

    To add to Da Tong's answer, you may want to expand more than one dimension at the same time. For instance, if you are performing TensorFlow's conv1d operation on vectors of rank 1, you need to feed them with rank three.

    Performing expand_dims several times is readable, but might introduce some overhead into the computational graph. You can get the same functionality in a one-liner with reshape:

    import tensorflow as tf
    
    # having some tensor of rank 1, it could be an audio signal, a word vector...
    tensor = tf.ones(100)
    print(tensor.get_shape()) # => (100,)
    
    # expand its dimensionality to fit into conv2d
    tensor_expand = tf.expand_dims(tensor, 0)
    tensor_expand = tf.expand_dims(tensor_expand, 0)
    tensor_expand = tf.expand_dims(tensor_expand, -1)
    print(tensor_expand.get_shape()) # => (1, 1, 100, 1)
    
    # do the same in one line with reshape
    tensor_reshape = tf.reshape(tensor, [1, 1, tensor.get_shape().as_list()[0],1])
    print(tensor_reshape.get_shape()) # => (1, 1, 100, 1)
    

    NOTE: In case you get the error TypeError: Failed to convert object of type to Tensor., try to pass tf.shape(x)[0] instead of x.get_shape()[0] as suggested here.

    Hope it helps!
    Cheers,
    Andres

提交回复
热议问题