对于最简单的RNN,我们可以使用以下两个方法调用,分别是 torch.nn.RNNCell()
和 torch.nn.RNN()
,这两种方式的区别在于 RNNCell()
只能接受序列中单步的输入,且必须传入隐藏状态,而 RNN()
可以接受一个序列的输入,默认会传入全 0 的隐藏状态,也可以自己申明隐藏状态传入。
RNN()的参数:
input_size 表示输入特征的维度;
hidden_size表示输出特征的维度;
num_layers表示网络的层数;
nonlinearity表示选用的是非线性激活函数,默认是‘tanh’;
bias表示是否使用偏置,默认是使用;
batch_first 表示输入数据的形式,默认是 False,就是这样形式,(seq, batch, feature),也就是将序列长度放在第一位,batch 放在第二位
dropout 表示是否在输出层应用 dropout;
bidirectional 表示是否使用双向的 rnn,默认是 False;
对于 RNNCell()
,里面的参数就少很多,只有 input_size,hidden_size,bias 以及 nonlinearity;
一般情况下我们都是用 nn.RNN()
而不是 nn.RNNCell()
,因为 nn.RNN()
能够避免我们手动写循环,非常方便,同时如果不特别说明,我们也会选择使用默认的全 0 初始化隐藏状态。
LSTM 和基本的 RNN 是一样的,他的参数也是相同的,同时他也有 nn.LSTMCell()
和 nn.LSTM()
两种形式。
来源:CSDN
作者:xckkcxxck
链接:https://blog.csdn.net/xckkcxxck/article/details/82976940