No shape error in tensorflow graph construction but getting shape mismatch error during graph computation

和自甴很熟 提交于 2019-12-11 05:32:31

问题


There occurs no error in tensorflow graph construction, but I get a shape mismatch error during graph computation in tf.gradients (I guess that the error is in back propagation).

This is the error I get:

InvalidArgumentError (see above for traceback):
Input to reshape is a tensor with 16777216 values, but the requested shape has 4096
[[Node: gradients/truediv_grad/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0 /device:GPU:0"](gradients/truediv_grad/Sum, gradients/truediv_grad/Shape)]]


回答1:


I solved the issue , using two techniques:

1.Apparently if you are creating custom ops and gradients , you need to be very explicit in providing the shape information to tensorflow using set_shape or tf.reshape

2.Also when you are registering your gradient using tf.register_gradient which takes op and grad as inputs, you need to be careful while chaining the gradients i.e dy/dx = dy/dz*dz/dx. Say dy/dz is the custom gradient we have created and dz/dxis the gradient of the previous ops as per the chain rule of differentiation.

tf.register_gradient(Mygrad)
def Mygrad(op,grad):
*****do stuff with op.inputs and calculate custom grads say cust_grad or dy/dz****
return cust_grad*grad

I changed this to following:

tf.register_gradient(Mygrad)
def Mygrad(op,grad):
*****do stuff with op.inputs and calculate custom grads say cust_grad or dy/dz****
return tf.matmul(tf.reshape(cust_grad,[calculated_shape]),tf.reshape(grad,expeced_shape))


来源:https://stackoverflow.com/questions/49335990/no-shape-error-in-tensorflow-graph-construction-but-getting-shape-mismatch-error

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!