How to set the variables of LSTMCell as input instead of letting it create it in Tensorflow?

南笙酒味 提交于 2020-06-09 05:25:07

问题


When I create a tf.contrib.rnn.LSTMCell, it creates its kernel and bias trainable variables during initialisation.

How the code looks now:

cell_fw = tf.contrib.rnn.LSTMCell(hidden_size_char,
                        state_is_tuple=True)

What I want it to look:

kernel = tf.get_variable(...)
bias = tf.get_variable(...)
cell_fw = tf.contrib.rnn.LSTMCell(kernel, bias, hidden_size,
                        state_is_tuple=True)

What I want to do is to create those variables myself, and give it to the LSTMCell class when instantiating it as input to its init.

Is there an easy way to do this? I looked at the class source code but it seems that it is within a complex hierarchy of classes.


回答1:


I subclassed the LSTMCell class, and changed its init and build methods so that they accept given variables. If variables are given in init within build, we wouldn't use get_variable anymore, and would use the given kernel and bias variables.

There might be cleaner ways to do it though.

_BIAS_VARIABLE_NAME = "bias"
_WEIGHTS_VARIABLE_NAME = "kernel"

class MyLSTMCell(tf.contrib.rnn.LSTMCell):
    def __init__(self, num_units,
                 use_peepholes=False, cell_clip=None,
                 initializer=None, num_proj=None, proj_clip=None,
                 num_unit_shards=None, num_proj_shards=None,
                 forget_bias=1.0, state_is_tuple=True,
                 activation=None, reuse=None, name=None, var_given=False, kernel=None, bias=None):

        super(MyLSTMCell, self).__init__(num_units,
                 use_peepholes=use_peepholes, cell_clip=cell_clip,
                 initializer=initializer, num_proj=num_proj, proj_clip=proj_clip,
                 num_unit_shards=num_unit_shards, num_proj_shards=num_proj_shards,
                 forget_bias=forget_bias, state_is_tuple=state_is_tuple,
                 activation=activation, reuse=reuse, name=name)

        self.var_given = var_given
        if self.var_given:
            self._kernel = kernel
            self._bias = bias


    def build(self, inputs_shape):
        if inputs_shape[1].value is None:
            raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
                             % inputs_shape)

        input_depth = inputs_shape[1].value
        h_depth = self._num_units if self._num_proj is None else self._num_proj
        maybe_partitioner = (
            partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
            if self._num_unit_shards is not None
            else None)
        if self.var_given:
            # self._kernel and self._bais are already added in init
            pass
        else:
            self._kernel = self.add_variable(
                _WEIGHTS_VARIABLE_NAME,
                shape=[input_depth + h_depth, 4 * self._num_units],
                initializer=self._initializer,
                partitioner=maybe_partitioner)
            self._bias = self.add_variable(
                _BIAS_VARIABLE_NAME,
                shape=[4 * self._num_units],
                initializer=init_ops.zeros_initializer(dtype=self.dtype))
        if self._use_peepholes:
            self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units],
                                               initializer=self._initializer)
            self._w_i_diag = self.add_variable("w_i_diag", shape=[self._num_units],
                                               initializer=self._initializer)
            self._w_o_diag = self.add_variable("w_o_diag", shape=[self._num_units],
                                               initializer=self._initializer)

        if self._num_proj is not None:
            maybe_proj_partitioner = (
                partitioned_variables.fixed_size_partitioner(self._num_proj_shards)
                if self._num_proj_shards is not None
                else None)
            self._proj_kernel = self.add_variable(
                "projection/%s" % _WEIGHTS_VARIABLE_NAME,
                shape=[self._num_units, self._num_proj],
                initializer=self._initializer,
                partitioner=maybe_proj_partitioner)

        self.built = True

So the code will be like this:

kernel = get_variable(...)
bias = get_variable(...)
lstm_fw = MyLSTMCell(....., var_given=True, kernel=kernel, bias=bias)


来源:https://stackoverflow.com/questions/61389657/how-to-manually-initialize-a-tf-1-x-lstmcell-and-dynamic-rnn

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!