问题
My script used dataset API to realize input pipeline. When I use tensorflow 1.10, everything is ok, but when I upgrade to tensorflow 1.11 or 1.12. I got the following error message at the beginning of training:
Attempted create an iterator on device "/job:localhost/replica:0/task:0/device:GPU:0" from handle defined on device "/job:localhost/replica:0/task:0/device:CPU:0"
Here are piece of my code:
def build_all_dataset(self):
self.build_dataset(ROUTE_TRAIN)
self.build_dataset(ROUTE_VALIDATION)
self.build_dataset(ROUTE_TEST)
def build_dataset(self, p_route):
# Create dataset instance.
# ......
# ......
self.dataset[p_route] = dataset
self.iterators[p_route] = dataset.make_initializable_iterator()
self.handles[p_route] = (self.iterators[p_route].string_handle()).eval()
def build_model(self):
with tf.device("/GPU:0"):
# Build dataset for training/validation/test.
self.build_all_dataset()
self.ph_dataset_handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(self.ph_dataset_handle, self.dataset[ROUTE_TRAIN].output_types,
self.dataset[ROUTE_TRAIN].output_shapes)
xxx, xxx, xxx = iterator.get_next()
# Build the following graph.
# ......
# ......
while True:
try:
session.run(xxx, {self.ph_dataset_handle: self.handles[ROUTE_TRAIN]})
except tf.errors.OutOfRangeError:
break
I checked that the placeholder "self.ph_dataset_handle" is placed on /GPU:0, why tensorflow said "handle defined on device"/job:localhost/replica:0/task:0/device:CPU:0"?
Could you please give any insight? Thanks!
来源:https://stackoverflow.com/questions/53295253/when-use-dataset-api-got-device-placement-error-with-tensorflow-1-11