In what order are weights saved in a LSTM kernel in Tensorflow

▼魔方 西西 提交于 2019-12-22 01:36:58

问题


I looked into the saved weights for a LSTMCell in Tensorflow. It has one big kernel and bias weights.

The dimensions of the kernel are

(input_size + hidden_size)*(hidden_size*4)

Now from what I understand this is encapsulating 4 input to hidden layer affine transforms as well as 4 hidden to hidden layer transforms.

So there should be 4 matrices of size

input_size*hidden_size

and 4 of size

hidden_size*hidden_size

Can someone tell me or point me to the code where TF saves these, so I can break the kernel matrix into smaller matrices.


回答1:


The weights are combined as mentioned in the other answer, but the order is: where c is the context and h is the history.

input_c,      input_h
new_input_c,  new_input_h
forget_c,     forget_h
output_c,     output_h

The relevant code is here

if self._state_is_tuple:
  c, h = state
else:
  c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)

gate_inputs = math_ops.matmul(
    array_ops.concat([inputs, h], 1), self._kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)

# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(
    value=gate_inputs, num_or_size_splits=4, axis=one)



回答2:


In tensorflow 1.5, LSTM variables are defined in LSTMCell.build method. The source code can be found in rnn_cell_impl.py:

self._kernel = self.add_variable(
    _WEIGHTS_VARIABLE_NAME,
    shape=[input_depth + h_depth, 4 * self._num_units],
    initializer=self._initializer,
    partitioner=maybe_partitioner)
self._bias = self.add_variable(
    _BIAS_VARIABLE_NAME,
    shape=[4 * self._num_units],
    initializer=init_ops.zeros_initializer(dtype=self.dtype))

As you can see, there's just one [input_depth + h_depth, 4 * self._num_units] variable, not 8 different matrices, and all of them are multiplied simultaneously in a batch.

The gates are defined this way:

i, j, f, o = array_ops.split(value=gate_inputs, num_or_size_splits=4, axis=one)


来源:https://stackoverflow.com/questions/48212694/in-what-order-are-weights-saved-in-a-lstm-kernel-in-tensorflow

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!