Regularization for LSTM in tensorflow

前端 未结 3 1724
庸人自扰
庸人自扰 2021-02-05 15:10

Tensorflow offers a nice LSTM wrapper.

rnn_cell.BasicLSTM(num_units, forget_bias=1.0, input_size=None,
           state_is_tuple=False, activation=tanh)
<         


        
3条回答
  •  孤街浪徒
    2021-02-05 15:37

    Tensorflow has some built-in and helper functions that let you apply L2 norms to your model such as tf.clip_by_global_norm:

        # ^^^ define your LSTM above here ^^^
    
        params = tf.trainable_variables()
    
        gradients = tf.gradients(self.losses, params)
    
        clipped_gradients, norm = tf.clip_by_global_norm(gradients,max_gradient_norm)
        self.gradient_norms = norm
    
        opt = tf.train.GradientDescentOptimizer(self.learning_rate)
        self.updates = opt.apply_gradients(
                        zip(clipped_gradients, params), global_step=self.global_step)
    

    in your training step run:

        outputs = session.run([self.updates, self.gradient_norms, self.losses], input_feed)
    

提交回复
热议问题