Tensorflow Java Multi-GPU inference

↘锁芯ラ 提交于 2019-12-03 11:49:54

In short: There is a workaround, where you end up with one session per GPU.

Details:

The general flow is that the TensorFlow runtime respects the devices specified for operations in the graph. If no device is specified for an operation, then it "places" it based on some heuristics. Those heuristics currently result in "place operation on GPU:0 if GPUs are available and there is a GPU kernel for the operation" (Placer::Run in case you're interested).

What you ask for I think is a reasonable feature request for TensorFlow - the ability to treat devices in the serialized graph as "virtual" ones to be mapped to a set of "phyiscal" devices at run time, or alternatively setting the "default device". This feature does not currently exist. Adding such an option to ConfigProto is something you may want to file a feature request for.

I can suggest a workaround in the interim. First, some commentary on your proposed solutions.

  1. Your first idea will surely work, but as you pointed out, is cumbersome.

  2. Setting using visible_device_list in the ConfigProto doesn't quite work out since that is actually a per-process setting and is ignored after the first session is created in the process. This is certainly not documented as well as it should be (and somewhat unfortunate that this appears in the per-Session configuration). However, this explains why your suggestion here doesn't work and why you still see a single GPU being used.

  3. This could work.

Another option is to end up with different graphs (with operations explicitly placed on different GPUs), resulting in one session per GPU. Something like this can be used to edit the graph and explicitly assign a device to each operation:

public static byte[] modifyGraphDef(byte[] graphDef, String device) throws Exception {
  GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder();
  for (int i = 0; i < builder.getNodeCount(); ++i) {
    builder.getNodeBuilder(i).setDevice(device);
  }
  return builder.build().toByteArray();
} 

After which you could create a Graph and Session per GPU using something like:

final int NUM_GPUS = 8;
// setAllowSoftPlacement: Just in case our device modifications were too aggressive
// (e.g., setting a GPU device on an operation that only has CPU kernels)
// setLogDevicePlacment: So we can see what happens.
byte[] config =
    ConfigProto.newBuilder()
        .setLogDevicePlacement(true)
        .setAllowSoftPlacement(true)
        .build()
        .toByteArray();
Graph graphs[] = new Graph[NUM_GPUS];
Session sessions[] = new Session[NUM_GPUS];
for (int i = 0; i < NUM_GPUS; ++i) {
  graphs[i] = new Graph();
  graphs[i].importGraphDef(modifyGraphDef(graphDef, String.format("/gpu:%d", i)));
  sessions[i] = new Session(graphs[i], config);    
}

Then use sessions[i] to execute the graph on GPU #i.

Hope that helps.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!