What exactly does tf.expand_dims do to a vector and why can results be added together even when matrix shapes are different?

老子叫甜甜 提交于 2019-12-19 20:47:52

问题


I'm adding two vectors that I thought were 'reshaped' together and getting a 2d matrix as a result. I expect some type of error here, but didn't get it. I think I understand what's happening, it treated them as if there were two more sets of each vector horizontally and vertically, but I don't understand why the results of a and b aren't different. And if they're not meant to be, why does this work at all?

import tensorflow as tf
import numpy as np

start_vec = np.array((83,69,45))
a = tf.expand_dims(start_vec, 0)
b = tf.expand_dims(start_vec, 1)
ab_sum = a + b
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    a = sess.run(a)
    b = sess.run(b)
    ab_sum = sess.run(ab_sum)

print(a)
print(b)
print(ab_sum)

=================================================

[[83 69 45]]

[[83]
 [69]
 [45]]

[[166 152 128]
 [152 138 114]
 [128 114  90]]

回答1:


In fact, this question makes more use of broadcasting characteristics of tensorflow, which the same as numpy (Broadcasting). Broadcasting gets rid of the requirement that the operation shape between tensors must be the same. Of course, it must also meet certain conditions.

General Broadcasting Rules:

When operating on two arrays, NumPy compares their shapes element-wise. It starts with the trailing dimensions, and works its way forward. Two dimensions are compatible when

1.they are equal, or

2.one of them is 1

A simple example is one-dimensional tensors multiplied by scalars.

import tensorflow as tf

start_vec = tf.constant((83,69,45))
b = start_vec * 2

with tf.Session() as sess:
    print(sess.run(b))

[166 138  90]

Back to the question, the function of tf.expand_dims() is to insert a dimension into a tensor’s shape at the specified axis position. Your original data shape is (3,). You will get the shape of a=tf.expand_dims(start_vec, 0) is (1,3) when your set axis=0.You will get the shape of b=tf.expand_dims(start_vec, 1) is (3,1) when your set axis=1.

By comparing the rules of broadcasting, you can see that they satisfy the second condition. So their actual operation is

83,83,83     83,69,45
69,69,69  +  83,69,45
45,45,45     83,69,45


来源:https://stackoverflow.com/questions/54165937/what-exactly-does-tf-expand-dims-do-to-a-vector-and-why-can-results-be-added-tog

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!