Taking the last state from BiLSTM (BiGRU) in PyTorch

我的未来我决定 提交于 2021-02-07 07:52:49

问题


After reading several articles, I am still quite confused about correctness of my implementation of getting last hidden states from BiLSTM.

  1. Understanding Bidirectional RNN in PyTorch (TowardsDataScience)
  2. PackedSequence for seq2seq model (PyTorch forums)
  3. What's the difference between “hidden” and “output” in PyTorch LSTM? (StackOverflow)
  4. Select tensor in a batch of sequences (Pytorch formums)

The approach from the last source (4) seems to be the cleanest for me, but I am still uncertain if I understood the thread correctly. Am I using the right final hidden states from LSTM and reversed LSTM? This is my implementation

# pos contains indices of words in embedding matrix
# seqlengths contains info about sequence lengths
# so for instance, if batch_size is 2 and pos=[4,6,9,3,1] and 
# seqlengths contains [3,2], we have batch with samples
# of variable length [4,6,9] and [3,1]

all_in_embs = self.in_embeddings(pos)
in_emb_seqs = pack_sequence(torch.split(all_in_embs, seqlengths, dim=0))
output,lasthidden = self.rnn(in_emb_seqs)
if not self.data_processor.use_gru:
    lasthidden = lasthidden[0]
# u_emb_batch has shape batch_size x embedding_dimension
# sum last state from forward and backward  direction
u_emb_batch = lasthidden[-1,:,:] + lasthidden[-2,:,:]

Is it correct?


回答1:


In a general case if you want to create your own BiLSTM network, you need to create two regular LSTMs, and feed one with the regular input sequence, and the other with inverted input sequence. After you finish feeding both sequences, you just take the last states from both nets and somehow tie them together (sum or concatenate).

As I understand, you are using built-in BiLSTM as in this example (setting bidirectional=True in nn.LSTM constructor). Then you get the concatenated output after feeding the batch, as PyTorch handles all the hassle for you.

If it is the case, and you want to sum the hidden states, then you have to

u_emb_batch = (lasthidden[0, :, :] + lasthidden[1, :, :])

assuming you have only one layer. If you have more layers, your variant seem better.

This is because the result is structured (see documentation):

h_n of shape (num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t = seq_len

By the way,

u_emb_batch_2 = output[-1, :, :HIDDEN_DIM] + output[-1, :, HIDDEN_DIM:]

should provide the same result.




回答2:


Here's a detailed explanation for those working with unpacked sequences:

output is of shape (seq_len, batch, num_directions * hidden_size) (see documentation). This means that the output of the forward and backward passes of your GRU are concatenated along the 3rd dimension.

Assuming batch=2 and hidden_size=256 in your example, you can easily separate the outputs of both forward and backward passes by doing:

output = output.view(-1, 2, 2, 256)   # (seq_len, batch_size, num_directions, hidden_size)
output_forward = output[:, :, 0, :]   # (seq_len, batch_size, hidden_size)
output_backward = output[:, :, 1, :]  # (seq_len, batch_size, hidden_size)

(Note: the -1 tells pytorch to infer that dimension from the others. See this question.)

Equivalently, you can use the torch.chunk function on the original output of shape (seq_len, batch, num_directions * hidden_size):

# Split in 2 tensors along dimension 2 (num_directions)
output_forward, output_backward = torch.chunk(output, 2, 2)

Now you can torch.gather the last hidden state of the forward pass using seqlengths (after reshaping it), and the last hidden state of the backward pass by selecting the element at position 0

# First we unsqueeze seqlengths two times so it has the same number of
# of dimensions as output_forward
# (batch_size) -> (1, batch_size, 1)
lengths = seqlengths.unsqueeze(0).unsqueeze(2)

# Then we expand it accordingly
# (1, batch_size, 1) -> (1, batch_size, hidden_size) 
lengths = lengths.expand((1, -1, output_forward.size(2)))

last_forward = torch.gather(output_forward, 0, lengths - 1).squeeze(0)
last_backward = output_backward[0, :, :]

Note that I subtracted 1 from lengths because of the 0-based indexing

A this point both last_forward and last_backward are of shape (batch_size, hidden_dim)




回答3:


I tested the biLSTM output and h_n:

# shape of x is size(batch_size, time_steps, input_size)
# shape of output (batch_size, time_steps, hidden_size * num_directions)
# shape of h_n is size(num_directions, batch_size, hidden_size)
output, (h_n, _c_n) = biLSTM(x) 

print('step 0 of output from reverse == h_n from reverse?', 
    output[:, 0, hidden_size:] == h_n[1])
print('step -1 of output from reverse == h_n from reverse?', 
    output[:, -1, hidden_size:] == h_n[1])

output

step 0 of output from reverse == h_n from reverse? True
step -1 of output from reverse == h_n from reverse? False

This confirmed that the h_n of the reverse direction is the hidden state of the first time step.

So, if you really need the hidden state of the last time step from both forward and reverse direction, you should use:

sum_lasthidden = output[:, -1, :hidden_size] + output[:, -1, hidden_size:]

not

h_n[0,:,:] + h_n[1,:,:]

As h_n[1,:,:] is the hidden state of the first time step from the reverse direction.

So the answer from @igrinis

u_emb_batch = (lasthidden[0, :, :] + lasthidden[1, :, :])

is not correct.

But in theory, last time step hidden state from the reverse direction only contains information from the last time step of the sequence.



来源:https://stackoverflow.com/questions/50856936/taking-the-last-state-from-bilstm-bigru-in-pytorch

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!