how to load a tensorflow model and continue training

前端 未结 1 1871
慢半拍i
慢半拍i 2021-02-02 03:38

I want to load a pretrained model and continue training with this model.
Standard code snippet to save a model (pretrain.py):

tf.reset_default         


        
相关标签:
1条回答
  • 2021-02-02 04:08

    I think I found the answer. The key is that it doesn't need to call tf.train.import_meta_graph() if it has already uses saver.restore(sess, tf.train.latest_checkpoint('./')). Here is my code.

    # tf Graph input
    X = tf.placeholder("float", [None, n_input])
    Y = tf.placeholder("float", [None, n_classes])
    mlp_layer_name = ['h1', 'b1', 'h2', 'b2', 'h3', 'b3', 'w_o', 'b_o']
    logits = multilayer_perceptron(X, n_input, n_classes, mlp_layer_name)
    loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y), name='loss_op')
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.minimize(loss_op, name='train_op')
    
    with tf.Session() as sess:
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint('./')) # search for checkpoint file
    
        graph = tf.get_default_graph()
    
        for epoch in range(training_epochs):
            avg_cost = 0.
    
            # Loop over all batches
            for i in range(total_batch):
                batch_x, batch_y = next(train_generator)
    
                # Run optimization op (backprop) and cost op (to get loss value)
                _, c = sess.run([train_op, loss_op], feed_dict={X: batch_x,
                                                                Y: batch_y})
                # Compute average loss
                avg_cost += c / total_batch
    
            print("Epoch: {:3d}, cost = {:.6f}".format(epoch+1, avg_cost))
    
    0 讨论(0)
提交回复
热议问题