问题
There occurs no error in tensorflow graph construction, but I get a shape mismatch error during graph computation in tf.gradients
(I guess that the error is in back propagation).
This is the error I get:
InvalidArgumentError (see above for traceback):
Input to reshape is a tensor with 16777216 values, but the requested shape has 4096
[[Node: gradients/truediv_grad/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0 /device:GPU:0"](gradients/truediv_grad/Sum, gradients/truediv_grad/Shape)]]
回答1:
I solved the issue , using two techniques:
1.Apparently if you are creating custom ops and gradients , you need to be very explicit in providing the shape information to tensorflow using set_shape
or tf.reshape
2.Also when you are registering your gradient using tf.register_gradient
which takes op and grad as inputs, you need to be careful while chaining the gradients i.e dy/dx = dy/dz*dz/dx
.
Say dy/dz
is the custom gradient we have created and dz/dx
is the gradient of the previous ops as per the chain rule of differentiation.
tf.register_gradient(Mygrad)
def Mygrad(op,grad):
*****do stuff with op.inputs and calculate custom grads say cust_grad or dy/dz****
return cust_grad*grad
I changed this to following:
tf.register_gradient(Mygrad)
def Mygrad(op,grad):
*****do stuff with op.inputs and calculate custom grads say cust_grad or dy/dz****
return tf.matmul(tf.reshape(cust_grad,[calculated_shape]),tf.reshape(grad,expeced_shape))
来源:https://stackoverflow.com/questions/49335990/no-shape-error-in-tensorflow-graph-construction-but-getting-shape-mismatch-error