Tensorflow Datasets with string inputs do not preserve data type

只愿长相守 提交于 2020-04-14 06:17:15

问题


All reproducible code below is run at Google Colab with TF 2.2.0-rc2.

Adapting the simple example from the documentation for creating a dataset from a simple Python list:

import numpy as np
import tensorflow as tf
tf.__version__
# '2.2.0-rc2'
np.version.version
# '1.18.2'

dataset1 = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 
for element in dataset1: 
  print(element) 
  print(type(element.numpy()))

we get the result

tf.Tensor(1, shape=(), dtype=int32)
<class 'numpy.int32'>
tf.Tensor(2, shape=(), dtype=int32)
<class 'numpy.int32'>
tf.Tensor(3, shape=(), dtype=int32)
<class 'numpy.int32'>

where all data types are int32, as expected.

But changing this simple example to feed a list of strings instead of integers:

dataset2 = tf.data.Dataset.from_tensor_slices(['1', '2', '3']) 
for element in dataset2: 
  print(element) 
  print(type(element.numpy()))

gives the result

tf.Tensor(b'1', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'2', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'3', shape=(), dtype=string)
<class 'bytes'>

where, surprisingly, and despite the tensors themselves being of dtype=string, their evaluations are of type bytes.

This behavior is not confined to the .from_tensor_slices method; here is the situation with .list_files (the following snippet runs straightforward in a fresh Colab notebook):

disc_data = tf.data.Dataset.list_files('sample_data/*.csv') # 4 csv files
for element in disc_data: 
  print(element) 
  print(type(element.numpy()))

the result being:

tf.Tensor(b'sample_data/california_housing_test.csv', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'sample_data/mnist_train_small.csv', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'sample_data/california_housing_train.csv', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'sample_data/mnist_test.csv', shape=(), dtype=string)
<class 'bytes'>

where again, the file names in the evaluated tensors are returned as bytes, instead of string, despite that the tensors themselves are of dtype=string.

Similar behavior is observed also with the .from_generator method (not shown here).

A final demonstration: as shown in the .as_numpy_iterator method documentation, the following equality condition is evaluated as True:

dataset3 = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]), 
                                               'b': [5, 6]}) 

list(dataset3.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5}, 
                                       {'a': (2, 4), 'b': 6}] 
# True

but if we change the elements of b to be strings, the equality condition is now surprisingly evaluated as False!

dataset4 = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]), 
                                               'b': ['5', '6']})   # change elements of b to strings

list(dataset4.as_numpy_iterator()) == [{'a': (1, 3), 'b': '5'},   # here
                                       {'a': (2, 4), 'b': '6'}]   # also
# False

probably due to the different data types, since the values themselves are evidently identical.


I didn't stumble upon this behavior by academic experimentation; I am trying to pass my data to TF Datasets using custom functions that read pairs of files from the disk of the form

f = ['filename1', 'filename2']

which custom functions work perfectly well on their own, but mapped through TF Datasets give

RuntimeError: not a string

which, after this digging, seems at least not unexplained, if the returned data types are indeed bytes and not string.

So, is this a bug (as it seems), or am I missing something here?


回答1:


This is a known behavior:

From: https://github.com/tensorflow/tensorflow/issues/5552#issuecomment-260455136

TensorFlow converts str to bytes in most places, including sess.run, and this is unlikely to change. The user is free to convert back, but unfortunately it's too large a change to add a unicode dtype to the core. Closing as won't fix for now.

I guess nothing changed with TensorFlow 2.x - there are still places in which strings are converted to bytes and you have to take care of this manually.



来源:https://stackoverflow.com/questions/61131730/tensorflow-datasets-with-string-inputs-do-not-preserve-data-type

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!