Understanding input shape to PyTorch LSTM

后端 未结 1 1968
猫巷女王i
猫巷女王i 2021-01-14 06:48

This seems to be one of the most common questions about LSTMs in PyTorch, but I am still unable to figure out what should be the input shape to PyTorch LSTM.

Even af

相关标签:
1条回答
  • 2021-01-14 07:08

    You have explained the structure of your input, but you haven't made the connection between your input dimensions and the LSTM's expected input dimensions.

    Let's break down your input (assigning names to the dimensions):

    • batch_size: 12
    • seq_len: 384
    • input_size / num_features: 768

    That means the input_size of the LSTM needs to be 768.

    The hidden_size is not dependent on your input, but rather how many features the LSTM should create, which is then used for the hidden state as well as the output, since that is the last hidden state. You have to decide how many features you want to use for the LSTM.

    Finally, for the input shape, setting batch_first=True requires the input to have the shape [batch_size, seq_len, input_size], in your case that would be [12, 384, 768].

    import torch
    import torch.nn as nn
    
    # Size: [batch_size, seq_len, input_size]
    input = torch.randn(12, 384, 768)
    
    lstm = nn.LSTM(input_size=768, hidden_size=512, batch_first=True)
    
    output, _ = lstm(input)
    output.size()  # => torch.Size([12, 384, 512])
    
    0 讨论(0)
提交回复
热议问题