When use Dataset API, got device placement error with tensorflow >= 1.11

99封情书 提交于 2019-12-11 02:27:28

问题


My script used dataset API to realize input pipeline. When I use tensorflow 1.10, everything is ok, but when I upgrade to tensorflow 1.11 or 1.12. I got the following error message at the beginning of training:

Attempted create an iterator on device "/job:localhost/replica:0/task:0/device:GPU:0" from handle defined on device "/job:localhost/replica:0/task:0/device:CPU:0"

Here are piece of my code:

def build_all_dataset(self):
    self.build_dataset(ROUTE_TRAIN)
    self.build_dataset(ROUTE_VALIDATION)
    self.build_dataset(ROUTE_TEST)

def build_dataset(self, p_route):
    # Create dataset instance.
    # ......
    # ......

    self.dataset[p_route] = dataset
    self.iterators[p_route] = dataset.make_initializable_iterator()
    self.handles[p_route] = (self.iterators[p_route].string_handle()).eval()

def build_model(self):
    with tf.device("/GPU:0"):
        # Build dataset for training/validation/test.
        self.build_all_dataset()
        self.ph_dataset_handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(self.ph_dataset_handle, self.dataset[ROUTE_TRAIN].output_types,
                                                       self.dataset[ROUTE_TRAIN].output_shapes)

        xxx, xxx, xxx = iterator.get_next()
        # Build the following graph.
        # ......
        # ......

while True:
    try:
        session.run(xxx, {self.ph_dataset_handle: self.handles[ROUTE_TRAIN]})
    except tf.errors.OutOfRangeError:
        break

I checked that the placeholder "self.ph_dataset_handle" is placed on /GPU:0, why tensorflow said "handle defined on device"/job:localhost/replica:0/task:0/device:CPU:0"?

Could you please give any insight? Thanks!

来源:https://stackoverflow.com/questions/53295253/when-use-dataset-api-got-device-placement-error-with-tensorflow-1-11

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!