本项目参考:
https://www.bilibili.com/video/av31500120?t=4657
训练代码
1 # coding: utf-8
2 # Learning from Mofan and Mike G
3 # Recreated by Paprikatree
4 # Convolution NN Train
5
6 import numpy as np
7 from keras.datasets import mnist
8 from keras.utils import np_utils
9 from keras.models import Sequential
10 from keras.layers import Convolution2D, Activation, MaxPool2D, Flatten, Dense
11 from keras.optimizers import Adam
12 from keras.models import load_model
13
14
15 nb_class = 10
16 nb_epoch = 4
17 batchsize = 128
18
19 '''
20 1st,准备参数
21 X_train: (0,255) --> (0,1) CNN中似乎没有必要?cnn自动转了吗?
22 设置时间函数测试一下两者对比。
23 小技巧:X_train /= 255.0 就可不用转换成浮点了???
24 '''
25 # Preparing your data mnist. MAC /.keras/datasets linux home ./keras/datasets
26 (X_train, Y_train), (X_test, Y_test) = mnist.load_data()
27
28
29 # setup data shape
30 # (-1, 28, 28, 1) -1表示有默认个数据集,28*28是像素,1是1个通道
31 X_train = X_train.reshape(-1, 28, 28, 1) # tensorflow-channel last,while theano-channel first
32 X_test = X_test.reshape(-1, 28, 28, 1)
33
34 X_train = X_train/255.000
35 X_test = X_test/255.000
36
37 # One-hot 6 --> [0,0,0,0,0,1,0,0,0]
38 Y_train = np_utils.to_categorical(Y_train, nb_class)
39 Y_test = np_utils.to_categorical(Y_test, nb_class)
40
41 '''
42 2nd,设置模型
43 '''
44
45 # setup model
46 model = Sequential()
47
48 # 1st convolution layer # 滤波器要在28x28的图上横着走32次
49 model.add(Convolution2D(
50 filters=32, # 此处把filters写成了filter,找了半天。囧
51 kernel_size=[5, 5], # 滤波器是5x5大小的,可以是list列表,也可以是tuple元祖
52 padding='same', # padding也是一个窗口模式
53 input_shape=(28, 28, 1) # 定义输入的数据,必须是元组
54 ))
55 model.add(Activation('relu'))
56 model.add(MaxPool2D(
57 pool_size=(2, 2), # 按照规则抓取特征,此处为在pool_size的2*2窗口下,strides = 2*2 跳两格再抓取。如 1 2 3 4 5 6...27 28 抓取1 2 ,跳过 3 4 抓取 5 6。
58 strides=(2, 2), # 相当于把图片缩小了。
59 padding="same",
60 ))
61
62 # 2nd Conv2D layer
63 model.add(Convolution2D(
64 filters=64,
65 kernel_size=(5, 5),
66 padding='same',
67 ))
68 model.add(Activation('relu'))
69 model.add(MaxPool2D(
70 pool_size=(2, 2), # 按照规则抓取特征,此处为在pool_size的2*2窗口下,strides = 2*2 跳两格再抓取。如 1 2 3 4 5 6...27 28 抓取1 2 ,跳过 3 4 抓取 5 6。
71 strides=(2, 2), # 相当于把图片缩小了。
72 padding="same",
73 )) # 讨论,卷积层数和最终结果关系。
74
75 # 1st Fully connected Dense,Dense 全连接层是hello world里面的内容
76 model.add(Flatten()) # 把卷积层里面的全部转换层一维数组
77 model.add(Dense(1024)) # Dense is output
78 model.add(Activation('relu'))
79
80 # 1st Fully connected Dense,Dense 全连接层是hello world里面的内容
81 # 把卷积层里面的全部转换层一维数组
82 model.add(Dense(256)) # Dense is output
83 model.add(Activation('tanh'))
84
85 # 2nd Fully connected Dense
86 model.add(Dense(10))
87 model.add(Activation('softmax'))
88
89 '''
90 3rd 定义参数
91 '''
92 # Define Optimizer and setup Param
93 adam = Adam(lr=0.0001) # Adam实例化
94
95 # compile model
96 model.compile(
97 optimizer=adam, # optimizer='Adam'也是可以的,且默认lr=0.001,此处已经实例化为adam
98 loss='categorical_crossentropy',
99 metrics=['accuracy'],
100 )
101
102 # Run network
103 model.fit(x=X_train, # 更多参数可以查看fit函数,alt+鼠标左键单击fit
104 y=Y_train,
105 epochs=nb_epoch,
106 batch_size=batchsize, # p=parameter, batch_size; v=var, batch size
107 verbose=1, # 显示模式
108 validation_data=(X_test, Y_test)
109 )
110 model.save('model_name.h5')
111 # evaluation = model.evaluate(X_test, Y_test) 现在用model.fit(validation_data)
112 # print(evaluation) 效果一样
测试代码:
1 # coding: utf-8
2 # Learning from Mofan and Mike G
3 # Recreated by Paprikatree
4 # Convolution NN Predict
5
6 import numpy as np
7 from keras.models import load_model # ??
8 import matplotlib.pyplot as plt
9 import matplotlib.image as processimage
10
11
12 # load trained model
13 model = load_model('model_name.h5') # 已经训练好了的模型,在根目录下,默认为model_name.h5
14
15
16 # 写一个来预测的类
17 class MainPredictImg(object):
18
19 def __init__(self):
20 pass
21
22 def pred(self, filename):
23 pred_img = processimage.imread(filename)
24 pred_img = np.array(pred_img)
25 pred_img = pred_img.reshape(-1, 28, 28, 1)
26 prediction = model.predict(pred_img)
27 final_prediction = [result.argmax() for result in prediction][0]
28 a = 0
29 for i in prediction[0]:
30 print(a)
31 print('Percent:{:.30%}'.format(i))
32 a = a+1
33 return final_prediction
34
35
36 def main():
37 predict = MainPredictImg()
38 res = predict.pred('4.png')
39 print("your number is:-->", res)
40
41
42 if __name__ == '__main__':
43 main()
来源:oschina
链接:https://my.oschina.net/u/4365836/blog/3705733