TensorFlow - Saver.restore not restoring all parameters

≡放荡痞女 提交于 2020-01-04 11:05:12

问题


I was training Bidirectional LSTM type RNN for nearly 24 hours, and due to oscillation in the error I decided to decrease the learning before allowing it to continue training. Since the model is saved using Saver.save(sess,file) at every epoch, I terminated the training with the CTC Loss having minimised to approximately 115.

Now after restoring the model, the initial error rate I am getting is somewhere around 162, which is inconsistent with the flow of error rate I was getting in 7th epoch, and is also what I got in the first epoch. So it is my impression that either "restore" function is not working or if it is working, then there must be something else that is not allowing it to take effect.

Here is my code:

    graph = tf.Graph()
    with graph.as_default():
        # Graph creation
        graph_start = time.time()
        seq_inputs = tf.placeholder(tf.float32, shape=     [None,batch_size,frame_length], name="sequence_inputs")
        seq_lens = tf.placeholder(shape=[batch_size],dtype=tf.int32)
        seq_inputs = seq_bn(seq_inputs,seq_lens)

        initializer = tf.truncated_normal_initializer(mean=0,stddev=0.1)
        forward = tf.nn.rnn_cell.LSTMCell(num_units=num_units,
                                          num_proj = hidden_size,
                                          use_peepholes=use_peephole,
                                          initializer=initializer,
                                          state_is_tuple=True)

        forward = tf.nn.rnn_cell.MultiRNNCell([forward] * n_layers, state_is_tuple=True)

        backward = tf.nn.rnn_cell.LSTMCell(num_units=num_units,
                                           num_proj= hidden_size,
                                           use_peepholes=use_peephole,
                                           initializer=initializer,
                                           state_is_tuple=True)

        backward = tf.nn.rnn_cell.MultiRNNCell([backward] * n_layers, state_is_tuple=True)

        [fw_out,bw_out], _ = tf.nn.bidirectional_dynamic_rnn(cell_fw=forward, cell_bw=backward, inputs=seq_inputs,time_major=True, dtype=tf.float32,                                               sequence_length=tf.cast(seq_lens,tf.int64))


        # Batch normalize forward output
        mew,var_ = tf.nn.moments(fw_out,axes=[0])
        fw_out = tf.nn.batch_normalization(fw_out,mew,var_,0.1,1,1e-6)
        # fw_out = seq_bn(fw_out,seq_lens)

        # Batch normalize backward output
        mew,var_ = tf.nn.moments(bw_out,axes=[0])
        bw_out = tf.nn.batch_normalization(bw_out,mew,var_,0.1,1,1e-6)
        # bw_out = seq_bn(bw_out,seq_lens)

        # Reshaping forward, and backward outputs for affine transformation
        fw_out = tf.reshape(fw_out,[-1,hidden_size])
        bw_out = tf.reshape(bw_out,[-1,hidden_size])

        # Linear Layer params
        W_fw = tf.Variable(tf.truncated_normal(shape=[hidden_size,n_chars],stddev=np.sqrt(2.0 / (hidden_size))))
        W_bw = tf.Variable(tf.truncated_normal(shape=[hidden_size,n_chars],stddev=np.sqrt(2.0 / (hidden_size))))
        b_out = tf.constant(0.1,shape=[n_chars])

        # Perform an affine transformation
        logits =  tf.add(tf.add(tf.matmul(fw_out,W_fw),tf.matmul(bw_out,W_bw)),b_out)
        logits = tf.reshape(logits,[-1,batch_size,n_chars])

        # Use CTC Beam Search Decoder to decode pred string from the prob map
        decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_lens)

        # Target params
        indices = tf.placeholder(dtype=tf.int64, shape=[None,2])
        values = tf.placeholder(dtype=tf.int32, shape=[None])
        shape = tf.placeholder(dtype=tf.int64,shape=[2])
        # Make targets
        targets = tf.SparseTensor(indices,values,shape)

        # Compute Loss
        loss = tf.reduce_mean(tf.nn.ctc_loss(logits, targets, seq_lens))
        # Compute error rate based on edit distance
        predicted = tf.to_int32(decoded[0])
        error_rate = tf.reduce_sum(tf.edit_distance(predicted,targets,normalize=False))/ \
         tf.to_float(tf.size(targets.values))    

        tvars = tf.trainable_variables()
        grad, _ = tf.clip_by_global_norm(tf.gradients(loss,tvars),max_grad_norm)
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr,momentum=momentum)
        train_step = optimizer.apply_gradients(zip(grad,tvars))
        graph_end = time.time()
        print("Time elapsed for creating graph: %.3f"%(round(graph_end-graph_start,3)))
        # steps per epoch
        start_time = 0
        steps = int(np.ceil(len(data_train.files)/batch_size))

        loss_tr = []
        log_tr = []
        loss_vl = []
        log_vl = []
        err_tr = []
        err_vl = []
        saver = tf.train.Saver()
        with tf.Session(config=config) as sess:
            #sess.run(tf.initialize_all_variables())
            checkpt_path = tf.train.latest_checkpoint(checkpoint_dir)
            print(saver.restore(sess,checkpt_path))
            print("Model restore from 7th epoch 188th step")
            feed = None
            epoch = None
            step = None
            try:
                for epoch in range(7,epochs+1):
                    if epoch==7:
                       initial_step = 189
                    else:
                       initial_step = 0
                    transcript = []
                    loss_val = 0
                    l_pr = 0
                    start_time = time.time()
                    for step in range(initial_step,steps):
                        train_data, transcript, \
                        targ_indices, targ_values, \
                        targ_shape, n_frames = data_train.next_batch()
                        n_frames = np.reshape(n_frames,[-1])
                        feed = {seq_inputs: train_data, indices:targ_indices, values:targ_values, shape:targ_shape, seq_lens:n_frames}
                        del train_data,targ_indices,targ_values,targ_shape,n_frames

                        # Evaluate loss value, decoded transcript, and log probability
                        _,loss_val,deco,l_pr,err_rt_tr = sess.run([train_step,loss,decoded,log_prob,error_rate],
                                                            feed_dict=feed)
                        del feed
                        loss_tr.append(loss_val)
                        log_tr.append(l_pr)
                        err_tr.append(err_rt_tr)

                        # On validation set
                        val_data, val_transcript, \
                        targ_indices, targ_values, \
                        targ_shape, n_frames = data_val.next_batch()
                        n_frames = np.reshape(n_frames, [-1])
                        feed = {seq_inputs: val_data, indices: targ_indices,values: targ_values, shape: targ_shape, seq_lens: n_frames}
                        del val_data, val_transcript,targ_indices,targ_values,targ_shape,n_frames
                    vl_loss, l_val_pr, err_rt_vl = sess.run([loss, log_prob, error_rate], feed_dict=feed)
                        del feed
                        loss_vl.append(vl_loss)
                        log_vl.append(l_val_pr)
                        err_vl.append(err_rt_vl)
                        print("epoch %d, step: %d, tr_loss: %.2f, vl_loss: %.2f, tr_err: %.2f, vl_err: %.2f"
                          % (epoch, step, np.mean(loss_tr), np.mean(loss_vl), err_rt_tr, err_rt_vl))

                    end_time = time.time()
                    elapsed = round(end_time - start_time, 3)

                    # On training set
                    # Select a random index within batch_size
                    sample_index = np.random.randint(0, batch_size)

                    # Fetch the target transcript
                    actual_str = [data_train.reverse_map[i] for i in transcript[sample_index]]

                    # Fetch the decoded path from probability map
                    pred_sparse = tf.SparseTensor(deco[0].indices, deco[0].values, deco[0].shape)
                    pred_dense = tf.sparse_tensor_to_dense(pred_sparse)
                    ans = pred_dense.eval()
                    #pred = [data_train.reverse_map[i] for i in ans[sample_index, :]]
                    pred = []
                    for i in ans[sample_index,:]:
                        if i == n_chars-1:
                            pred.append(data_train.reverse_map[0])
                        else:
                            pred.append(data_train.reverse_map[i])
                    print("time_elapsed for 200 steps: %.3f, " % (elapsed))
                    if epoch%2 == 0:
                        print("Sample mini-batch results: \n" \
                              "predicted string: ", np.array(pred))
                        print("actual string: ", np.array(actual_str))
                    print("On training set, the loss: %.2f, log_pr: %.3f, error rate %.3f:"% (loss_val, np.mean(l_pr), err_rt_tr))
                    print("On validation set, the loss: %.2f, log_pr: %.3f, error rate: %.3f" % (vl_loss, np.mean(l_val_pr), err_rt_vl))

                    # Save the trainable parameters after the end of an epoch
                    if epoch > 7:
                        path = saver.save(sess, 'model_%d' % epoch)
                    print("Session saved at: %s" % path)
                    np.save(results_fn, np.array([loss_tr, log_tr, loss_vl, log_vl, err_tr, err_vl], dtype=np.object))
            except (KeyboardInterrupt, SystemExit, Exception), e:
                print("Error/Interruption: %s" % str(e))
                exc_type, exc_obj, exc_tb = sys.exc_info()
                print("Line no: %d" % exc_tb.tb_lineno)
                if epoch > 7:
                    print("Saving model: %s" % saver.save(sess, 'Last.cpkt'))
                print("Current batch: %d" % data_train.b_id)
                print("Current epoch: %d" % epoch)
                print("Current step: %d"%step)
                np.save(results_fn, np.array([loss_tr, log_tr, loss_vl, log_vl, err_tr, err_vl], dtype=np.object))
                print("Clossing TF Session...")
                sess.close()
                print("Terminating Program...")
                sys.exit(0)

回答1:


I think you need to re-initialize your accumulators for each epoch.

So these ones must be put at the beginning, inside the loop.

loss_tr = []
log_tr = []
loss_vl = []
log_vl = []
err_tr = []
err_vl = []


来源:https://stackoverflow.com/questions/38880176/tensorflow-saver-restore-not-restoring-all-parameters

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!