Define custom LSTM Cell in Keras?

后端 未结 1 1306
予麋鹿
予麋鹿 2020-12-10 19:23

I use Keras with TensorFlow as back-end. If I want to make a modification to an LSTM cell, such as \"removing\" the output gate, how can I do it? It is a multiplicative gate

1条回答
  •  时光说笑
    2020-12-10 20:17

    First of all, you should define your own custom layer. If you need some intuition how to implement your own cell see LSTMCell in Keras repository. E.g. your custom cell will be:

    class MinimalRNNCell(keras.layers.Layer):
    
        def __init__(self, units, **kwargs):
            self.units = units
            self.state_size = units
            super(MinimalRNNCell, self).__init__(**kwargs)
    
        def build(self, input_shape):
            self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                          initializer='uniform',
                                          name='kernel')
            self.recurrent_kernel = self.add_weight(
                shape=(self.units, self.units),
                initializer='uniform',
                name='recurrent_kernel')
            self.built = True
    
        def call(self, inputs, states):
            prev_output = states[0]
            h = K.dot(inputs, self.kernel)
            output = h + K.dot(prev_output, self.recurrent_kernel)
            return output, [output]
    

    Then, use tf.keras.layers.RNN to use your cell:

    cell = MinimalRNNCell(32)
    x = keras.Input((None, 5))
    layer = RNN(cell)
    y = layer(x)
    
    # Here's how to use the cell to build a stacked RNN:
    
    cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
    x = keras.Input((None, 5))
    layer = RNN(cells)
    y = layer(x)
    

    0 讨论(0)
提交回复
热议问题