当在做使用神经网络做分类(或回归)问题时,我们倾向于选择较复杂的网络来提高准确率,但是网络的复杂会使得训练时间变很长。而如果我们使用其他人已经训练好的模型来给我们的任务做分类,需要训练的参数数量就会大大减少。
下面,使用已经训练好的MobileNetV2模型来给猫狗数据集进行分类。
导入需要的库
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
导入猫狗数据集并处理数据
(raw_train, raw_validation, raw_test), metadata = tfds.load(
'cats_vs_dogs',
split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
shuffle_files=True,
batch_size=None,
with_info=True,
as_supervised=True,
)
IMG_SIZE = 160 # All images will be resized to 160x160
def format_example(image, label):
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)
BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000
train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)
这里直接给出处理数据的代码,具体过程可以参考这篇文章中的第一部分。
导入MobileNetV2模型
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
base_model.trainable = False
在tf.keras.applications.MobileNetV2函数中,include_top=False 表示我们不需要MobileNetV2模型中的最后一层,即分类层(因为原本训练的MobileNetV2模型中包含1000个类别,而这里只有两个)。base_model.trainable = False 表示不对base_model中的参数进行训练。
展示MobileNetV2的模型结构
base_model.summary()
Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 160, 160, 3) 0
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D) (None, 161, 161, 3) 0 input_1[0][0]
__________________________________________________________________________________________________
Conv1 (Conv2D) (None, 80, 80, 32) 864 Conv1_pad[0][0]
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization) (None, 80, 80, 32) 128 Conv1[0][0]
__________________________________________________________________________________________________
Conv1_relu (ReLU) (None, 80, 80, 32) 0 bn_Conv1[0][0]
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 80, 80, 32) 288 Conv1_relu[0][0]
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 80, 80, 32) 128 expanded_conv_depthwise[0][0]
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 80, 80, 32) 0 expanded_conv_depthwise_BN[0][0]
__________________________________________________________________________________________________
expanded_conv_project (Conv2D) (None, 80, 80, 16) 512 expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 80, 80, 16) 64 expanded_conv_project[0][0]
__________________________________________________________________________________________________
block_1_expand (Conv2D) (None, 80, 80, 96) 1536 expanded_conv_project_BN[0][0]
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 80, 80, 96) 384 block_1_expand[0][0]
__________________________________________________________________________________________________
block_1_expand_relu (ReLU) (None, 80, 80, 96) 0 block_1_expand_BN[0][0]
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D) (None, 81, 81, 96) 0 block_1_expand_relu[0][0]
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 40, 40, 96) 864 block_1_pad[0][0]
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 40, 40, 96) 384 block_1_depthwise[0][0]
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU) (None, 40, 40, 96) 0 block_1_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_1_project (Conv2D) (None, 40, 40, 24) 2304 block_1_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 40, 40, 24) 96 block_1_project[0][0]
__________________________________________________________________________________________________
block_2_expand (Conv2D) (None, 40, 40, 144) 3456 block_1_project_BN[0][0]
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 40, 40, 144) 576 block_2_expand[0][0]
__________________________________________________________________________________________________
block_2_expand_relu (ReLU) (None, 40, 40, 144) 0 block_2_expand_BN[0][0]
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 40, 40, 144) 1296 block_2_expand_relu[0][0]
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 40, 40, 144) 576 block_2_depthwise[0][0]
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU) (None, 40, 40, 144) 0 block_2_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_2_project (Conv2D) (None, 40, 40, 24) 3456 block_2_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, 40, 40, 24) 96 block_2_project[0][0]
__________________________________________________________________________________________________
block_2_add (Add) (None, 40, 40, 24) 0 block_1_project_BN[0][0]
block_2_project_BN[0][0]
__________________________________________________________________________________________________
block_3_expand (Conv2D) (None, 40, 40, 144) 3456 block_2_add[0][0]
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, 40, 40, 144) 576 block_3_expand[0][0]
__________________________________________________________________________________________________
block_3_expand_relu (ReLU) (None, 40, 40, 144) 0 block_3_expand_BN[0][0]
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D) (None, 41, 41, 144) 0 block_3_expand_relu[0][0]
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, 20, 20, 144) 1296 block_3_pad[0][0]
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, 20, 20, 144) 576 block_3_depthwise[0][0]
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU) (None, 20, 20, 144) 0 block_3_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_3_project (Conv2D) (None, 20, 20, 32) 4608 block_3_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, 20, 20, 32) 128 block_3_project[0][0]
__________________________________________________________________________________________________
block_4_expand (Conv2D) (None, 20, 20, 192) 6144 block_3_project_BN[0][0]
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_4_expand[0][0]
__________________________________________________________________________________________________
block_4_expand_relu (ReLU) (None, 20, 20, 192) 0 block_4_expand_BN[0][0]
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, 20, 20, 192) 1728 block_4_expand_relu[0][0]
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, 20, 20, 192) 768 block_4_depthwise[0][0]
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU) (None, 20, 20, 192) 0 block_4_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_4_project (Conv2D) (None, 20, 20, 32) 6144 block_4_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, 20, 20, 32) 128 block_4_project[0][0]
__________________________________________________________________________________________________
block_4_add (Add) (None, 20, 20, 32) 0 block_3_project_BN[0][0]
block_4_project_BN[0][0]
__________________________________________________________________________________________________
block_5_expand (Conv2D) (None, 20, 20, 192) 6144 block_4_add[0][0]
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_5_expand[0][0]
__________________________________________________________________________________________________
block_5_expand_relu (ReLU) (None, 20, 20, 192) 0 block_5_expand_BN[0][0]
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, 20, 20, 192) 1728 block_5_expand_relu[0][0]
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, 20, 20, 192) 768 block_5_depthwise[0][0]
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU) (None, 20, 20, 192) 0 block_5_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_5_project (Conv2D) (None, 20, 20, 32) 6144 block_5_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, 20, 20, 32) 128 block_5_project[0][0]
__________________________________________________________________________________________________
block_5_add (Add) (None, 20, 20, 32) 0 block_4_add[0][0]
block_5_project_BN[0][0]
__________________________________________________________________________________________________
block_6_expand (Conv2D) (None, 20, 20, 192) 6144 block_5_add[0][0]
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_6_expand[0][0]
__________________________________________________________________________________________________
block_6_expand_relu (ReLU) (None, 20, 20, 192) 0 block_6_expand_BN[0][0]
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D) (None, 21, 21, 192) 0 block_6_expand_relu[0][0]
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, 10, 10, 192) 1728 block_6_pad[0][0]
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, 10, 10, 192) 768 block_6_depthwise[0][0]
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU) (None, 10, 10, 192) 0 block_6_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_6_project (Conv2D) (None, 10, 10, 64) 12288 block_6_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, 10, 10, 64) 256 block_6_project[0][0]
__________________________________________________________________________________________________
block_7_expand (Conv2D) (None, 10, 10, 384) 24576 block_6_project_BN[0][0]
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_7_expand[0][0]
__________________________________________________________________________________________________
block_7_expand_relu (ReLU) (None, 10, 10, 384) 0 block_7_expand_BN[0][0]
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_7_expand_relu[0][0]
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_7_depthwise[0][0]
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_7_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_7_project (Conv2D) (None, 10, 10, 64) 24576 block_7_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, 10, 10, 64) 256 block_7_project[0][0]
__________________________________________________________________________________________________
block_7_add (Add) (None, 10, 10, 64) 0 block_6_project_BN[0][0]
block_7_project_BN[0][0]
__________________________________________________________________________________________________
block_8_expand (Conv2D) (None, 10, 10, 384) 24576 block_7_add[0][0]
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_8_expand[0][0]
__________________________________________________________________________________________________
block_8_expand_relu (ReLU) (None, 10, 10, 384) 0 block_8_expand_BN[0][0]
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_8_expand_relu[0][0]
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_8_depthwise[0][0]
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_8_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_8_project (Conv2D) (None, 10, 10, 64) 24576 block_8_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, 10, 10, 64) 256 block_8_project[0][0]
__________________________________________________________________________________________________
block_8_add (Add) (None, 10, 10, 64) 0 block_7_add[0][0]
block_8_project_BN[0][0]
__________________________________________________________________________________________________
block_9_expand (Conv2D) (None, 10, 10, 384) 24576 block_8_add[0][0]
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_9_expand[0][0]
__________________________________________________________________________________________________
block_9_expand_relu (ReLU) (None, 10, 10, 384) 0 block_9_expand_BN[0][0]
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_9_expand_relu[0][0]
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_9_depthwise[0][0]
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_9_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_9_project (Conv2D) (None, 10, 10, 64) 24576 block_9_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, 10, 10, 64) 256 block_9_project[0][0]
__________________________________________________________________________________________________
block_9_add (Add) (None, 10, 10, 64) 0 block_8_add[0][0]
block_9_project_BN[0][0]
__________________________________________________________________________________________________
block_10_expand (Conv2D) (None, 10, 10, 384) 24576 block_9_add[0][0]
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, 10, 10, 384) 1536 block_10_expand[0][0]
__________________________________________________________________________________________________
block_10_expand_relu (ReLU) (None, 10, 10, 384) 0 block_10_expand_BN[0][0]
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, 10, 10, 384) 3456 block_10_expand_relu[0][0]
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, 10, 10, 384) 1536 block_10_depthwise[0][0]
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_10_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_10_project (Conv2D) (None, 10, 10, 96) 36864 block_10_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, 10, 10, 96) 384 block_10_project[0][0]
__________________________________________________________________________________________________
block_11_expand (Conv2D) (None, 10, 10, 576) 55296 block_10_project_BN[0][0]
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_11_expand[0][0]
__________________________________________________________________________________________________
block_11_expand_relu (ReLU) (None, 10, 10, 576) 0 block_11_expand_BN[0][0]
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, 10, 10, 576) 5184 block_11_expand_relu[0][0]
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, 10, 10, 576) 2304 block_11_depthwise[0][0]
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU) (None, 10, 10, 576) 0 block_11_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_11_project (Conv2D) (None, 10, 10, 96) 55296 block_11_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, 10, 10, 96) 384 block_11_project[0][0]
__________________________________________________________________________________________________
block_11_add (Add) (None, 10, 10, 96) 0 block_10_project_BN[0][0]
block_11_project_BN[0][0]
__________________________________________________________________________________________________
block_12_expand (Conv2D) (None, 10, 10, 576) 55296 block_11_add[0][0]
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_12_expand[0][0]
__________________________________________________________________________________________________
block_12_expand_relu (ReLU) (None, 10, 10, 576) 0 block_12_expand_BN[0][0]
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, 10, 10, 576) 5184 block_12_expand_relu[0][0]
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, 10, 10, 576) 2304 block_12_depthwise[0][0]
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU) (None, 10, 10, 576) 0 block_12_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_12_project (Conv2D) (None, 10, 10, 96) 55296 block_12_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, 10, 10, 96) 384 block_12_project[0][0]
__________________________________________________________________________________________________
block_12_add (Add) (None, 10, 10, 96) 0 block_11_add[0][0]
block_12_project_BN[0][0]
__________________________________________________________________________________________________
block_13_expand (Conv2D) (None, 10, 10, 576) 55296 block_12_add[0][0]
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_13_expand[0][0]
__________________________________________________________________________________________________
block_13_expand_relu (ReLU) (None, 10, 10, 576) 0 block_13_expand_BN[0][0]
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D) (None, 11, 11, 576) 0 block_13_expand_relu[0][0]
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, 5, 5, 576) 5184 block_13_pad[0][0]
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, 5, 5, 576) 2304 block_13_depthwise[0][0]
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU) (None, 5, 5, 576) 0 block_13_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_13_project (Conv2D) (None, 5, 5, 160) 92160 block_13_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, 5, 5, 160) 640 block_13_project[0][0]
__________________________________________________________________________________________________
block_14_expand (Conv2D) (None, 5, 5, 960) 153600 block_13_project_BN[0][0]
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_14_expand[0][0]
__________________________________________________________________________________________________
block_14_expand_relu (ReLU) (None, 5, 5, 960) 0 block_14_expand_BN[0][0]
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_14_expand_relu[0][0]
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_14_depthwise[0][0]
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_14_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_14_project (Conv2D) (None, 5, 5, 160) 153600 block_14_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, 5, 5, 160) 640 block_14_project[0][0]
__________________________________________________________________________________________________
block_14_add (Add) (None, 5, 5, 160) 0 block_13_project_BN[0][0]
block_14_project_BN[0][0]
__________________________________________________________________________________________________
block_15_expand (Conv2D) (None, 5, 5, 960) 153600 block_14_add[0][0]
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_15_expand[0][0]
__________________________________________________________________________________________________
block_15_expand_relu (ReLU) (None, 5, 5, 960) 0 block_15_expand_BN[0][0]
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_15_expand_relu[0][0]
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_15_depthwise[0][0]
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_15_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_15_project (Conv2D) (None, 5, 5, 160) 153600 block_15_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, 5, 5, 160) 640 block_15_project[0][0]
__________________________________________________________________________________________________
block_15_add (Add) (None, 5, 5, 160) 0 block_14_add[0][0]
block_15_project_BN[0][0]
__________________________________________________________________________________________________
block_16_expand (Conv2D) (None, 5, 5, 960) 153600 block_15_add[0][0]
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_16_expand[0][0]
__________________________________________________________________________________________________
block_16_expand_relu (ReLU) (None, 5, 5, 960) 0 block_16_expand_BN[0][0]
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_16_expand_relu[0][0]
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_16_depthwise[0][0]
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_16_depthwise_BN[0][0]
__________________________________________________________________________________________________
block_16_project (Conv2D) (None, 5, 5, 320) 307200 block_16_depthwise_relu[0][0]
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 5, 5, 320) 1280 block_16_project[0][0]
__________________________________________________________________________________________________
Conv_1 (Conv2D) (None, 5, 5, 1280) 409600 block_16_project_BN[0][0]
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization) (None, 5, 5, 1280) 5120 Conv_1[0][0]
__________________________________________________________________________________________________
out_relu (ReLU) (None, 5, 5, 1280) 0 Conv_1_bn[0][0]
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984
__________________________________________________________________________________________________
在MobileNetV2的基础上添加自己的层
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1)
])
展示最终模型
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
mobilenetv2_1.00_160 (Model) (None, 5, 5, 1280) 2257984
_________________________________________________________________
global_average_pooling2d_2 ( (None, 1280) 0
_________________________________________________________________
dense (Dense) (None, 1) 1281
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________
由此可见,base_model.trainable = False之后,参与训练的参数只有后来加的1281个。
训练模型
history = model.fit(train_batches,
epochs=10,
validation_data=validation_batches)
Epoch 1/10
582/582 [==============================] - 16s 28ms/step - loss: 0.3749 - accuracy: 0.8174 - val_loss: 0.1736 - val_accuracy: 0.9037
Epoch 2/10
582/582 [==============================] - 14s 24ms/step - loss: 0.2002 - accuracy: 0.9138 - val_loss: 0.1243 - val_accuracy: 0.9351
Epoch 3/10
582/582 [==============================] - 15s 26ms/step - loss: 0.1658 - accuracy: 0.9314 - val_loss: 0.1176 - val_accuracy: 0.9355
Epoch 4/10
582/582 [==============================] - 14s 25ms/step - loss: 0.1493 - accuracy: 0.9385 - val_loss: 0.1108 - val_accuracy: 0.9390
Epoch 5/10
582/582 [==============================] - 15s 25ms/step - loss: 0.1420 - accuracy: 0.9392 - val_loss: 0.1068 - val_accuracy: 0.9402
Epoch 6/10
582/582 [==============================] - 14s 24ms/step - loss: 0.1363 - accuracy: 0.9438 - val_loss: 0.1029 - val_accuracy: 0.9450
Epoch 7/10
582/582 [==============================] - 14s 24ms/step - loss: 0.1280 - accuracy: 0.9459 - val_loss: 0.0985 - val_accuracy: 0.9484
Epoch 8/10
582/582 [==============================] - 14s 24ms/step - loss: 0.1270 - accuracy: 0.9480 - val_loss: 0.0944 - val_accuracy: 0.9514
Epoch 9/10
582/582 [==============================] - 14s 24ms/step - loss: 0.1219 - accuracy: 0.9492 - val_loss: 0.0989 - val_accuracy: 0.9484
Epoch 10/10
582/582 [==============================] - 14s 24ms/step - loss: 0.1238 - accuracy: 0.9475 - val_loss: 0.0990 - val_accuracy: 0.9488
PS:也可以选择性地训练MobileNetV2中的部分层
base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))
# Fine-tune from this layer onwards
fine_tune_at = 100
# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
这表示MobileNetV2中的前100层不被训练,其他层会参与训练。
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
metrics=['accuracy'])
model.summary()
最终得到模型为:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
mobilenetv2_1.00_160 (Model) (None, 5, 5, 1280) 2257984
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280) 0
_________________________________________________________________
dense (Dense) (None, 1) 1281
=================================================================
Total params: 2,259,265
Trainable params: 1,863,873
Non-trainable params: 395,392
_________________________________________________________________
来源:CSDN
作者:cofisher
链接:https://blog.csdn.net/qq_36758914/article/details/104738838