How do I connect two keras models into one model?

て烟熏妆下的殇ゞ 提交于 2021-02-11 05:55:39

问题


Let's say I have a ResNet50 model and I wish to connect the output layer of this model to the input layer of a VGG model.

This is the ResNet model and the output tensor of ResNet50:

img_shape = (164, 164, 3)
resnet50_model = ResNet50(include_top=False, input_shape=img_shape, weights = None)

print(resnet50_model.output.shape)

I get the output:

TensorShape([Dimension(None), Dimension(6), Dimension(6), Dimension(2048)])

Now I want a new layer where I reshape this output tensor to (64,64,18)

Then I have a VGG16 model:

VGG_model = VGG_model = VGG16(include_top=False, weights=None)

I want the output of the ResNet50 to reshape into the desired tensor and fed in as an input to the VGG model. So essentially I want to concatenate two models. Can someone help me do that? Thank you!


回答1:


There are multiple ways you can do this. Here is one way of using Sequential model API to do it.

import tensorflow as tf
from tensorflow.keras.applications import ResNet50, VGG16

model = tf.keras.Sequential()
img_shape = (164, 164, 3)
model.add(ResNet50(include_top=False, input_shape=img_shape, weights = None))

model.add(tf.keras.layers.Reshape(target_shape=(64,64,18)))
model.add(tf.keras.layers.Conv2D(3,kernel_size=(3,3),name='Conv2d'))

VGG_model = VGG16(include_top=False, weights=None)
model.add(VGG_model)

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.summary()

Model summary is as follows

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
resnet50 (Model)             (None, 6, 6, 2048)        23587712  
_________________________________________________________________
reshape (Reshape)            (None, 64, 64, 18)        0         
_________________________________________________________________
Conv2d (Conv2D)              (None, 62, 62, 3)         489       
_________________________________________________________________
vgg16 (Model)                multiple                  14714688  
=================================================================
Total params: 38,302,889
Trainable params: 38,249,769
Non-trainable params: 53,120
_________________________________________________________________

Full code is here.



来源:https://stackoverflow.com/questions/61446429/how-do-i-connect-two-keras-models-into-one-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!