Tensorflow Java Multi-GPU inference

后端 未结 2 2150
伪装坚强ぢ
伪装坚强ぢ 2021-02-14 06:43

I have a server with multiple GPUs and want to make full use of them during model inference inside a java app. By default tensorflow seizes all available GPUs, but uses only th

2条回答
  •  南旧
    南旧 (楼主)
    2021-02-14 07:02

    In short: There is a workaround, where you end up with one session per GPU.

    Details:

    The general flow is that the TensorFlow runtime respects the devices specified for operations in the graph. If no device is specified for an operation, then it "places" it based on some heuristics. Those heuristics currently result in "place operation on GPU:0 if GPUs are available and there is a GPU kernel for the operation" (Placer::Run in case you're interested).

    What you ask for I think is a reasonable feature request for TensorFlow - the ability to treat devices in the serialized graph as "virtual" ones to be mapped to a set of "phyiscal" devices at run time, or alternatively setting the "default device". This feature does not currently exist. Adding such an option to ConfigProto is something you may want to file a feature request for.

    I can suggest a workaround in the interim. First, some commentary on your proposed solutions.

    1. Your first idea will surely work, but as you pointed out, is cumbersome.

    2. Setting using visible_device_list in the ConfigProto doesn't quite work out since that is actually a per-process setting and is ignored after the first session is created in the process. This is certainly not documented as well as it should be (and somewhat unfortunate that this appears in the per-Session configuration). However, this explains why your suggestion here doesn't work and why you still see a single GPU being used.

    3. This could work.

    Another option is to end up with different graphs (with operations explicitly placed on different GPUs), resulting in one session per GPU. Something like this can be used to edit the graph and explicitly assign a device to each operation:

    public static byte[] modifyGraphDef(byte[] graphDef, String device) throws Exception {
      GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder();
      for (int i = 0; i < builder.getNodeCount(); ++i) {
        builder.getNodeBuilder(i).setDevice(device);
      }
      return builder.build().toByteArray();
    } 
    

    After which you could create a Graph and Session per GPU using something like:

    final int NUM_GPUS = 8;
    // setAllowSoftPlacement: Just in case our device modifications were too aggressive
    // (e.g., setting a GPU device on an operation that only has CPU kernels)
    // setLogDevicePlacment: So we can see what happens.
    byte[] config =
        ConfigProto.newBuilder()
            .setLogDevicePlacement(true)
            .setAllowSoftPlacement(true)
            .build()
            .toByteArray();
    Graph graphs[] = new Graph[NUM_GPUS];
    Session sessions[] = new Session[NUM_GPUS];
    for (int i = 0; i < NUM_GPUS; ++i) {
      graphs[i] = new Graph();
      graphs[i].importGraphDef(modifyGraphDef(graphDef, String.format("/gpu:%d", i)));
      sessions[i] = new Session(graphs[i], config);    
    }
    

    Then use sessions[i] to execute the graph on GPU #i.

    Hope that helps.

提交回复
热议问题