问题
I follow the tutorial https://www.tensorflow.org/tutorials/layers and I want use it to use my own dataset.
def train_input_fn_custom(filenames_array, labels_array, batch_size):
# Reads an image from a file, decodes it into a dense tensor, and resizes it to a fixed shape.
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_png(image_string, channels=1)
image_resized = tf.image.resize_images(image_decoded, [40, 40])
return image_resized, label
filenames = tf.constant(filenames_array)
labels = tf.constant(labels_array)
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
def main(self):
tf.logging.set_verbosity(tf.logging.INFO)
# Get data
filenames_train = ['blackcorner-data/1.png', 'blackcorner-data/2.png']
labels_train = [0, 1]
# Create the Estimator
classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, model_dir="/tmp/test_convnet_model")
# Set up logging for predictions
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=50)
# Train the model
cust_train_input_fn = train_input_fn_custom(
filenames_array=filenames_train,
labels_array=labels_train,
batch_size=3)
classifier.train(
input_fn=cust_train_input_fn,
steps=2000,
hooks=[logging_hook])
if __name__ == "__main__":
tf.app.run()
But I have this error :
Traceback (most recent call last):
File "/usr/lib/python3.6/inspect.py", line 1119, in getfullargspec
sigcls=Signature)
File "/usr/lib/python3.6/inspect.py", line 2186, in _signature_from_callable
raise TypeError('{!r} is not a callable object'.format(obj))
TypeError: (<tf.Tensor 'IteratorGetNext:0' shape=(?, 40, 40, ?) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(?,) dtype=int32>) is not a callable object
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "cnn_mnist_for_stackoverflow.py", line 139, in <module>
tf.app.run()
File "/home/geo/Projet/ML/cnn_mnist/venv/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 126, in run
_sys.exit(main(argv))
File "cnn_mnist_for_stackoverflow.py", line 135, in main
hooks=[logging_hook])
...
raise TypeError('unsupported callable') from ex
TypeError: unsupported callable
I don't understand this error, I just know it came from the train_input_fn_custom. The tensorflow version is 1.6
if anyone has an idea.. Thanks !
回答1:
The input_fn
argument to classifier.train()
must be a callable object (with no arguments), such as a function or a lambda
. In your code, you are passing the results of calling train_input_fn_custom()
, rather than a callable object that invokes train_input_fn_custom()
. To fix this issue, replace the definition of cust_train_input_fn
as follows:
# The `lambda:` creates a callable object with no arguments.
cust_train_input_fn = lambda: train_input_fn_custom(
filenames_array=filenames_train, labels_array=labels_train, batch_size=3)
来源:https://stackoverflow.com/questions/49140164/tensorflow-error-unsupported-callable