【CV】ckpt文件转为pb文件(fasterrcnn)

落爺英雄遲暮 提交于 2020-01-31 03:04:00
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow

def freeze_graph(input_checkpoint,output_graph):
    #指定输出的节点名称,该节点名称必须是原模型中存在的节点。直接用最后输出的节点,可以在tensorboard中查找到,tensorboard只能在linux中使用
    output_node_names = "SCORE/resnet_v1_101_5/cls_prob/cls_prob/scores,SCORE/resnet_v1_101_5/bbox_pred/BiasAdd/bbox_pred/scores,SCORE/resnet_v1_101_5/cls_pred/cls_pred/scores"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #通过 import_meta_graph 导入模型中的图----1
    graph = tf.get_default_graph() #获得默认的图
    input_graph_def = graph.as_graph_def() #返回一个序列化的图代表当前的图
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #通过 saver.restore 从模型中恢复图中各个变量的数据----2
        output_graph_def = graph_util.convert_variables_to_constants(  #通过 graph_util.convert_variables_to_constants 将模型持久化----3
            sess=sess,
            input_graph_def=input_graph_def, #等于:sess.graph_def
            output_node_names=output_node_names.split(","))  #如果有多个输出节点,以逗号隔开
        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

input_checkpoint='./checkpoints/res101_faster_rcnn_iter_70000.ckpt'
out_pb_path='./checkpoints/frozen_model.pb'

reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:  # Print tensor name and values
    print("tensor_name: ", key)
    #print(reader.get_tensor(key))

freeze_graph(input_checkpoint, out_pb_path)
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!