Tensorflow Java Multi-GPU inference

后端 未结 2 2142
伪装坚强ぢ
伪装坚强ぢ 2021-02-14 06:43

I have a server with multiple GPUs and want to make full use of them during model inference inside a java app. By default tensorflow seizes all available GPUs, but uses only th

2条回答
  •  广开言路
    2021-02-14 07:04

    In python it can be done as follows:

    def get_frozen_graph(graph_file):
        """Read Frozen Graph file from disk."""
        with tf.gfile.GFile(graph_file, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        return graph_def
    
    trt_graph1 = get_frozen_graph('/home/ved/ved_1/frozen_inference_graph.pb')
    
    with tf.device('/gpu:1'):
        [tf_input_l1, tf_scores_l1, tf_boxes_l1, tf_classes_l1, tf_num_detections_l1, tf_masks_l1] = tf.import_graph_def(trt_graph1, 
                        return_elements=['image_tensor:0', 'detection_scores:0', 
                        'detection_boxes:0', 'detection_classes:0','num_detections:0', 'detection_masks:0'])
        
    tf_sess1 = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    
    trt_graph2 = get_frozen_graph('/home/ved/ved_2/frozen_inference_graph.pb')
    
    with tf.device('/gpu:0'):
        [tf_input_l2, tf_scores_l2, tf_boxes_l2, tf_classes_l2, tf_num_detections_l2] = tf.import_graph_def(trt_graph2, 
                        return_elements=['image_tensor:0', 'detection_scores:0', 
                        'detection_boxes:0', 'detection_classes:0','num_detections:0'])
        
    tf_sess2 = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    

提交回复
热议问题