Clarification on tf.Tensor.set_shape()

后端 未结 1 1112
生来不讨喜
生来不讨喜 2020-12-13 13:27

I have an image that is 478 x 717 x 3 = 1028178 pixels, with a rank of 1. I verified it by calling tf.shape and tf.rank.

When I call image.set_shape([478, 717, 3]),

相关标签:
1条回答
  • 2020-12-13 14:12

    As far as I know (and I wrote that code), there isn't a bug in Tensor.set_shape(). I think the misunderstanding stems from the confusing name of that method.

    To elaborate on the FAQ entry you quoted, Tensor.set_shape() is a pure-Python function that improves the shape information for a given tf.Tensor object. By "improves", I mean "makes more specific".

    Therefore, when you have a Tensor object t with shape (?,), that is a one-dimensional tensor of unknown length. You can call t.set_shape((1028178,)), and then t will have shape (1028178,) when you call t.get_shape(). This doesn't affect the underlying storage, or indeed anything on the backend: it merely means that subsequent shape inference using t can rely on the assertion that it is a vector of length 1028178.

    If t has shape (?,), a call to t.set_shape((478, 717, 3)) will fail, because TensorFlow already knows that t is a vector, so it cannot have shape (478, 717, 3). If you want to make a new Tensor with that shape from the contents of t, you can use reshaped_t = tf.reshape(t, (478, 717, 3)). This creates a new tf.Tensor object in Python; the actual implementation of tf.reshape() does this using a shallow copy of the tensor buffer, so it is inexpensive in practice.

    One analogy is that Tensor.set_shape() is like a run-time cast in an object-oriented language like Java. For example, if you have a pointer to an Object but know that, in fact, it is a String, you might do the cast (String) obj in order to pass obj to a method that expects a String argument. However, if you have a String s and try to cast it to a java.util.Vector, the compiler will give you an error, because these two types are unrelated.

    0 讨论(0)
提交回复
热议问题