问题
I customized a layer, merged the batch_size and the first dimension, the other dimensions remained unchanged, but compute_output_shape seemed to have no effect, resulting in the subsequent layer could not get accurate shape information, resulting in an error. How do I make compute_output_shape work?
import keras
from keras import backend as K
class BatchMergeReshape(keras.layers.Layer):
def __init__(self, **kwargs):
super(BatchMergeReshape, self).__init__(**kwargs)
def build(self, input_shape):
super(BatchMergeReshape, self).build(input_shape)
def call(self, x):
input_shape = K.shape(x)
batch_size, seq_len = input_shape[0], input_shape[1]
r = K.reshape(x, (batch_size*seq_len,)+input_shape[2:])
print("call_shape:",r.shape)
return r
def compute_output_shape(self, input_shape):
if input_shape[0] is None:
r = (None,)+input_shape[2:]
print("compute_output_shape:",r)
return r
else:
r = (input_shape[0]*input_shape[1],)+input_shape[2:]
return r
a = keras.layers.Input(shape=(3,4,5))
b = BatchMergeReshape()(a)
print(b.shape)
# call_shape: (?, ?)
# compute_output_shape: (None, 4, 5)
# (?, ?)
I need to get (None,4,5) but get (None,None),why compute_output_shape didn't work. My keras version is 2.2.4
回答1:
The problem is probably that K.shape
returns a tensor not a tuple. You can't do (batch_size*seq_len,) + input_shape[2:]
. This is mixing a lot of things, tensors and tuples, the result will certainly be wrong.
Now the good thing is that, if you know the other dimensions and just not the batch size, you just need this layer:
Lambda(lambda x: K.reshape(x, (-1,) + other_dimensions_tuple))
If you don't then:
input_shape = K.shape(x)
new_batch_size = input_shape[0:1] * input_shape[1:2] #needs to keep a shape of an array
#new_batch_size.shape = (1,)
new_shape = K.concatenate([new_batch_size, input_shape[2:]]) #this is a tensor
#result of concatenating 2 tensors
r = K.reshape(x, new_shape)
Notice that this works in Tensorflow, but may not work in Theano.
Notice also that Keras will demand that the batch size of the model's output is equal to the batch size of the model's inputs. This means that you will need to restore the original batch size before the end of the model.
来源:https://stackoverflow.com/questions/58072362/keras-compute-output-shape-not-working-for-custom-layer