load multiple models in Tensorflow

后端 未结 5 1855
北荒
北荒 2021-02-04 09:08

I have written the following convolutional neural network (CNN) class in Tensorflow [I have tried to omit some lines of code for clarity.]

class CNN:
de         


        
相关标签:
5条回答
  • 2021-02-04 09:38

    This should be a comment to the most up-voted answer. But I do not have enough reputation to do that.

    Anyway. If you(anyone searched and got to this point) still having trouble with the solution provided by lpp AND you are using Keras, check following quote from github.

    This is because the keras share a global session if no default tf session provided

    When the model1 created, it is on graph1 When the model1 loads weight, the weight is on a keras global session which is associated with graph1

    When the model2 created, it is on graph2 When the model2 loads weight, the global session does not know the graph2

    A solution below may help,

    graph1 = Graph()
    with graph1.as_default():
        session1 = Session()
        with session1.as_default():
            with open('model1_arch.json') as arch_file:
                model1 = model_from_json(arch_file.read())
            model1.load_weights('model1_weights.h5')
            # K.get_session() is session1
    
    # do the same for graph2, session2, model2
    
    0 讨论(0)
  • 2021-02-04 09:40

    I encountered the same problem and could not solve the problem (without retraining) with any solution i found on the internet. So what I did is load each model in two separate threads which communicate with the main thread. It is simple enough to write the code, you just have to be careful when you synchronize the threads. In my case each thread received the input for its problem and returned to the main thread the output. It works without any observable overhead.

    0 讨论(0)
  • 2021-02-04 09:46

    One way is to clear your session if you want to train or load multiple models in succession. You can easily do this using

    from keras import backend as K 
    
    # load and use model 1
    
    K.clear_session()
    
    # load and use  model 2
    
    K.clear_session()`
    

    K.clear_session() destroys the current TF graph and creates a new one. Useful to avoid clutter from old models / layers.

    0 讨论(0)
  • 2021-02-04 09:51

    Yes there is. Use separate graphs.

    g1 = tf.Graph()
    g2 = tf.Graph()
    
    with g1.as_default():
        cnn1 = CNN(..., restore_file='snapshot-model1-10000',..........) 
    with g2.as_default():
        cnn2 = CNN(..., restore_file='snapshot-model2-10000',..........)
    

    EDIT:

    If you want them into same graph. You'll have to rename some variables. One idea is have each CNN in separate scope and let saver handle variables in that scope e.g.:

    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), scope='model1')
    

    and in cnn wrap all your construction in scope:

    with tf.variable_scope('model1'):
        ...
    

    EDIT2:

    Other idea is renaming variables which saver manages (since I assume you want to use your saved checkpoints without retraining everything. Saving allows different variable names in graph and in checkpoint, have a look at documentation for initialization.

    0 讨论(0)
  • 2021-02-04 09:54

    You need to create 2 sessions and restore the 2 models separately. In order for this to work you need to do the following:

    1a. When you're saving the models you need to add scopes to the variable names. That way you will know which variables belong to which model:

    # The first model
    tf.Variable(tf.zeros([self.batch_size]), name="model_1/Weights")
    ...
    
    # The second model 
    tf.Variable(tf.zeros([self.batch_size]), name="model_2/Weights")
    ...
    

    1b. Alternatively, if you already saved the models you can rename the variables by adding scope with this script.

    2.. When you restore the different models you need to filter by variable name like this:

    # The first model
    sess_1 = tf.Session()
    sess_1.run(tf.initialize_all_variables())
    saver_1 = tf.train.Saver([v for v in tf.all_variables() if 'model_1' in v.name])
    saver_1.restore(sess_1, weights_1_file)
    sess_1.run(pred, feed_dict={image: X})
    
    # The second model
    sess_2 = tf.Session()
    sess_2.run(tf.initialize_all_variables())
    saver_2 = tf.train.Saver([v for v in tf.all_variables() if 'model_2' in v.name])
    saver_2.restore(sess_2, weights_2_file)
    sess_2.run(pred, feed_dict={image: X})
    
    0 讨论(0)
提交回复
热议问题