问题
I am trying to use the convolution part of ResNet50() model, as this:
#generate batches
def get_batches(dirname, gen=image.ImageDataGenerator(), shuffle=True, batch_size=4, class_mode='categorical',
target_size=(224,224)):
return gen.flow_from_directory(dirname, target_size=target_size,
class_mode=class_mode, shuffle=shuffle, batch_size=batch_size)
trn_batches = get_batches("path_to_dirctory", shuffle=False,batch_size=4)
#create model
rn_mean = np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape((1,1,3))
inp_resnet = Input((224,224,3))
preproc = Lambda(lambda x: (x - rn_mean)[:, :, :, ::-1])(inp_resnet)
resnet_model = ResNet50(include_top=False, input_tensor=preproc)
res5b_branch2a = resnet_model.get_layer('res5b_branch2a')
last_conv_layer = resnet_model.layers[resnet_model.layers.index(res5b_branch2a)-1].output
resnet_model_conv = Model(inp_resnet, Flatten()(AveragePooling2D((7,7))(last_conv_layer)))
#feed batches to model
trn_conv_features_resnet = resnet_model_conv.predict_generator(trn_batches, trn_batches.samples)
The model summary is pretty long but the last part shows that the output should have a shape of (None, 2048).
So I assume that if I throw in 200 images to this model, I should have a output with shape of (200, 2048). Am I correct?
But in fact, with 200 images in, I got an output with shape of (800, 2048). I am wondering why this happens.
I checked another topic but it seems to be a different issue here. Please kindly help! BTW, this is done in Keras 2.
Update:
I realized that if I set batch_size=4
, I got (800, 2048) output with 200 images input, and if I change batch_size=2
, I got (400, 2018) output with the same 200 images input. Is this how the batch_size
setting works? Should I use just batch_size=1
? I thought that batch_size is the number of pictures that was feed to the model at a time, and no matter what the batch_size
is, the total number of images should be 200, right? For example, if 4, then 50 batches will be fed to model, and if 2, then 100 batches will be fed to model.
来源:https://stackoverflow.com/questions/44475827/keras-cnn-model-output-shape-doesnt-match-model-summary