Keras【Deep Learning With Python】更优模型探索Keras实现RNN

不想你离开。 提交于 2020-02-26 10:36:45

RNN简介

1.RNN的应用

RNN主要有两个应用,一是评测一个句子出现的可能性,二是文本自动生成。\

2.什么是RNN?

RNN之所以叫RNN是因为它循环处理相同的任务,就是预测句子接下来的单词是什么。RNN认为在这循环的任务中,各个步骤之间不是独立,于是它循环记录前面所有文本的信息(也叫记忆),作为预测当前词的一个输入。
在这里插入图片描述
在RNN中,每个词作为一层,对其进行预测。
在这里插入图片描述

在这里插入图片描述

F函数一般是tanh或者ReLU

是在t时刻词典中所有词出现的概率,也就是||=|vocabulary|

并且所有层共享U和W
在这里插入图片描述

3.RNN用来做什么?

RNNs在NLP中得到了巨大的成功,LSTM是被广泛使用的RNN。LSTM与典型的RNN基本框架一致,只是使用了同的方式来计算隐藏状态。

3.1语言模型和文本生成

语言模型中,输入时经过编码的词向量序列,输出是一系列预测的词。在训练模型的时候,令 在这里插入图片描述,也就是让输出等于下一时刻真实的输入,因为在文本生成中,这一时刻的输出对应的是下一时刻的输入。

     3.2机器翻译

              机器翻译与语言模型的不同是,机器翻译必须等待所有输入结束后才输出,因为这个时候才能得到翻译句子的所有信息。

在这里插入图片描述
3.3语音识别。

              输入一系列的声波信息,然后预测一段语音。

     3.4生成图像描述

              RNNs和 CNN一起,可以用来为未标记的图像生成描述。

4. 训练RNNs

训练RNN和训练传统的神经网络一样,都是使用反向传播算法,但是又有些不同,这里所有步骤都共享同一个参数,每一个步骤的回归输出不仅仅依赖于当前时刻,还依赖前面时刻的步骤,这就叫BPTT算法(时间反向传播)。

Keras代码实现(Mnist)

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.layers.recurrent import SimpleRNN
from keras.optimizers import Adam
# 数据长度-一行有28个像素
input_size = 28
# 序列长度-一共有28行
time_steps = 28
# 隐藏层cell个数
cell_size = 50 
 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
x_train = x_train/255.0
x_test = x_test/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)#one hot
 
# 创建模型
model = Sequential()
 
# 循环神经网络
model.add(SimpleRNN(
    units = cell_size, # 输出
    input_shape = (time_steps,input_size), #输入
))
 
# 输出层
model.add(Dense(10,activation='softmax'))
 
# 定义优化器
adam = Adam(lr=1e-4)
 
# 定义优化器,loss function,训练过程中计算准确率
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=10)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('test loss',loss)
print('test accuracy',accuracy)
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!