Tensorflow: how to use pretrained weights in new graph?

前端 未结 2 868
感动是毒
感动是毒 2021-01-05 14:24

I\'m trying to build an object detector with CNN using tensorflow with python framework. I would like to train my model to do just object recognition (classification) at fir

相关标签:
2条回答
  • 2021-01-05 14:38

    Although I agree with Aechlys to restore variables. The problem is harder when we want to fix these variables. For example, we trained these variables and we want to use them in another model, but this time without training them (training new variables like in transfer-learning). You can see the answer I posted here.

    Quick example:

     with tf.session() as sess:
        new_saver = tf.train.import_meta_graph(pathToMeta)
        new_saver.restore(sess, pathToNonMeta) 
    
        weight1 = sess.run(sess.graph.get_tensor_by_name("w1:0")) 
    
    
     tf.reset_default_graph() #this will eliminate the variables we restored
    
    
     with tf.session() as sess:
        weights = 
           {
           '1': tf.Variable(weight1 , name='w1-bis', trainable=False)
           }
    ...
    

    We are now sure the restored variables are not a part of the graph.

    0 讨论(0)
  • 2021-01-05 14:40

    Use saver with no arguments to save the entire model.

    tf.reset_default_graph()
    v1 = tf.get_variable("v1", [3], initializer = tf.initializers.random_normal)
    v2 = tf.get_variable("v2", [5], initializer = tf.initializers.random_normal)
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.save(sess, save_path='./test-case.ckpt')
    
        print(v1.eval())
        print(v2.eval())
    saver = None
    
    v1 = [ 2.1882825   1.159807   -0.26564872]
    v2 = [0.11437789 0.5742971 ]
    

    Then in the model you want to restore to certain values, pass a list of variable names you want to restore or a dictionary of {"variable name": variable} to the Saver.

    tf.reset_default_graph()
    b1 = tf.get_variable("b1", [3], initializer= tf.initializers.random_normal)
    b2 = tf.get_variable("b2", [3], initializer= tf.initializers.random_normal)
    saver = tf.train.Saver(var_list={'v1': b1})
    
    with tf.Session() as sess:
      saver.restore(sess, "./test-case.ckpt")
      print(b1.eval())
      print(b2.eval())
    
    INFO:tensorflow:Restoring parameters from ./test-case.ckpt
    b1 = [ 2.1882825   1.159807   -0.26564872]
    b2 = FailedPreconditionError: Attempting to use uninitialized value b2
    
    0 讨论(0)
提交回复
热议问题