问题
TF Object Detection API grabs all GPU memory by default, so it's difficult to tell how much I can further increase my batch size. Typically I just continue to increase it until I get a CUDA OOM error.
PyTorch on the other hand doesn't grab all GPU memory by default, so it's easy to see what percentage I have left to work with, without all the trial and error.
Is there a better way to determine batch size with the TF Object Detection API that I'm missing? Something like an allow-growth
flag for model_main.py
?
回答1:
I have been looking in the source code and I have found no FLAG related to this.
But, in the file model_main.py
of https://github.com/tensorflow/models/blob/master/research/object_detection/model_main.py
you can find the following main function definition:
def main(unused_argv):
flags.mark_flag_as_required('model_dir')
flags.mark_flag_as_required('pipeline_config_path')
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)
train_and_eval_dict = model_lib.create_estimator_and_inputs(
run_config=config,
...
the idea would be to modify it in a similar way such as the following manner:
config_proto = tf.ConfigProto()
config_proto.gpu_options.allow_growth = True
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir, session_config=config_proto)
So, adding config_proto
and changing config
but maintaining all other things equal.
Also, allow_growth
makes the program use as much GPU memory as it needs. So, depending on you GPU you might end up with all memory eaten. In this case you may want to use
config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
which defines the fraction of memory to use.
Hope this has helped.
If you do not want to modify the file it seems that a issue should be open because I do not see any FLAG. unless the FLAG
flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
'file.')
Means something related to this. But I do not think so becuase from what It seems in model_lib.py
it is related to train, eval and infer configurations not GPU usage configuration.
来源:https://stackoverflow.com/questions/55529094/determining-max-batch-size-with-tensorflow-object-detection-api