Understanding input shape to PyTorch LSTM

心不动则不痛 提交于 2021-01-19 06:21:32

问题


This seems to be one of the most common questions about LSTMs in PyTorch, but I am still unable to figure out what should be the input shape to PyTorch LSTM.

Even after following several posts (1, 2, 3) and trying out the solutions, it doesn't seem to work.

Background: I have encoded text sequences (variable length) in a batch of size 12 and the sequences are padded and packed using pad_packed_sequence functionality. MAX_LEN for each sequence is 384 and each token (or word) in the sequence has a dimension of 768. Hence my batch tensor could have one of the following shapes: [12, 384, 768] or [384, 12, 768].

The batch will be my input to the PyTorch rnn module (lstm here).

According to the PyTorch documentation for LSTMs, its input dimensions are (seq_len, batch, input_size) which I understand as following.
seq_len - the number of time steps in each input stream (feature vector length).
batch - the size of each batch of input sequences.
input_size - the dimension for each input token or time step.

lstm = nn.LSTM(input_size=?, hidden_size=?, batch_first=True)

What should be the exact input_size and hidden_size values here?


回答1:


You have explained the structure of your input, but you haven't made the connection between your input dimensions and the LSTM's expected input dimensions.

Let's break down your input (assigning names to the dimensions):

  • batch_size: 12
  • seq_len: 384
  • input_size / num_features: 768

That means the input_size of the LSTM needs to be 768.

The hidden_size is not dependent on your input, but rather how many features the LSTM should create, which is then used for the hidden state as well as the output, since that is the last hidden state. You have to decide how many features you want to use for the LSTM.

Finally, for the input shape, setting batch_first=True requires the input to have the shape [batch_size, seq_len, input_size], in your case that would be [12, 384, 768].

import torch
import torch.nn as nn

# Size: [batch_size, seq_len, input_size]
input = torch.randn(12, 384, 768)

lstm = nn.LSTM(input_size=768, hidden_size=512, batch_first=True)

output, _ = lstm(input)
output.size()  # => torch.Size([12, 384, 512])


来源:https://stackoverflow.com/questions/61632584/understanding-input-shape-to-pytorch-lstm

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!