How to get weights from .pb model in Tensorflow

后端 未结 1 488
野的像风
野的像风 2020-12-15 11:13

I trained one model and then create one .pb file by freeze that model. so, my question is how to get weights from .pb file or i have to do more process for get weights

相关标签:
1条回答
  • 2020-12-15 11:45

    Let us first load the graph from .pb file.

    import tensorflow as tf
    from tensorflow.python.platform import gfile
    
    GRAPH_PB_PATH = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb' #path to your .pb file
    with tf.Session(config=config) as sess:
      print("load graph")
      with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
        graph_nodes=[n for n in graph_def.node]
    

    Now when you freeze a graph to .pb file your variables are converted to Const type and the weights which were trainabe variables would also be stored as Const in .pb file. graph_nodes contains all the nodes in graph. But we are interested in all the Const type nodes.

    wts = [n for n in graph_nodes if n.op=='Const']
    

    Each element of wts is of NodeDef type. It has several atributes such as name, op etc. The values can be extracted as follows -

    from tensorflow.python.framework import tensor_util
    
    for n in wts:
        print "Name of the node - %s" % n.name
        print "Value - " 
        print tensor_util.MakeNdarray(n.attr['value'].tensor)
    

    Hope this solves your concern.

    0 讨论(0)
提交回复
热议问题