问题
Let's say I have a ResNet50 model and I wish to connect the output layer of this model to the input layer of a VGG model.
This is the ResNet model and the output tensor of ResNet50:
img_shape = (164, 164, 3)
resnet50_model = ResNet50(include_top=False, input_shape=img_shape, weights = None)
print(resnet50_model.output.shape)
I get the output:
TensorShape([Dimension(None), Dimension(6), Dimension(6), Dimension(2048)])
Now I want a new layer where I reshape this output tensor to (64,64,18)
Then I have a VGG16 model:
VGG_model = VGG_model = VGG16(include_top=False, weights=None)
I want the output of the ResNet50 to reshape into the desired tensor and fed in as an input to the VGG model. So essentially I want to concatenate two models. Can someone help me do that? Thank you!
回答1:
There are multiple ways you can do this. Here is one way of using Sequential model API to do it.
import tensorflow as tf
from tensorflow.keras.applications import ResNet50, VGG16
model = tf.keras.Sequential()
img_shape = (164, 164, 3)
model.add(ResNet50(include_top=False, input_shape=img_shape, weights = None))
model.add(tf.keras.layers.Reshape(target_shape=(64,64,18)))
model.add(tf.keras.layers.Conv2D(3,kernel_size=(3,3),name='Conv2d'))
VGG_model = VGG16(include_top=False, weights=None)
model.add(VGG_model)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()
Model summary is as follows
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
resnet50 (Model) (None, 6, 6, 2048) 23587712
_________________________________________________________________
reshape (Reshape) (None, 64, 64, 18) 0
_________________________________________________________________
Conv2d (Conv2D) (None, 62, 62, 3) 489
_________________________________________________________________
vgg16 (Model) multiple 14714688
=================================================================
Total params: 38,302,889
Trainable params: 38,249,769
Non-trainable params: 53,120
_________________________________________________________________
Full code is here.
来源:https://stackoverflow.com/questions/61446429/how-do-i-connect-two-keras-models-into-one-model