Tensorflow: delete nodes from graph

前端 未结 2 675
一整个雨季
一整个雨季 2021-01-28 04:27

I\'m trying to delete some nodes from graph and save it in .pb

Only needed nodes can be added to new mod_graph_def graph, but the problem that graph still h

相关标签:
2条回答
  • 2021-01-28 04:54

    Previous answer is good but I would suggest to tie removed node input to the next node input. Like if we have a chain A-input b->B-input c->C-input d->D and going to remove say node B then we should not just remove input c but replace it with input b. Look at code below:

    #  remove node and connect its input to follower
    def remove_node(graph_def, node_name, input_name):
        nodes = []
        for node in graph_def.node:
            if node.name == node_name:
                assert(input_name in node.input or len(node.input) == 0),\
                    "Node input to use is not among inputs of node to remove"
                input_of_removed_node = input_name if len(node.input) else ''
                print("Removing {} and using its input {}".format(node.name, 
                       input_of_removed_node))
                continue
            nodes.append(node)
        
        # modify inputs where required
        # removed name must be replaced with input from removed node
        for node in nodes:
            inp_names = []
            replace = False
            for inp in node.input:
                if inp == node_name:
                    inp_names.append(input_of_removed_node)
                    print("For node {} replacing input {} 
                           with {}".format(node.name, inp, input_of_removed_node))
                    replace = True
                else:
                    inp_names.append(inp)
            if replace:
                del node.input[:]
                node.input.extend(inp_names)
        mod_graph_def = tf.GraphDef()
        mod_graph_def.node.extend(nodes)
        return mod_graph_def
    
    0 讨论(0)
  • 2021-01-28 05:00
    def delete_ops_from_graph():
        with open(input_model_filepath, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
    
        # Delete nodes
        nodes = []
        for node in graph_def.node:
            if 'Neg' in node.name:
                print('Drop', node.name)
            else:
                nodes.append(node)
    
        mod_graph_def = tf.GraphDef()
        mod_graph_def.node.extend(nodes)
    
        # Delete references to deleted nodes
        for node in mod_graph_def.node:
            inp_names = []
            for inp in node.input:
                if 'Neg' in inp:
                    pass
                else:
                    inp_names.append(inp)
    
            del node.input[:]
            node.input.extend(inp_names)
    
        with open(output_model_filepath, 'wb') as f:
            f.write(mod_graph_def.SerializeToString())
    
    0 讨论(0)
提交回复
热议问题