Inference with a model trained with tf.Dataset

青春壹個敷衍的年華 提交于 2019-12-22 06:24:54

问题


I have trained a model using the tf.data.Dataset API, so my training code looks something like this

with graph.as_default():
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    dataset = dataset.map(scale_features, num_parallel_calls=n_workers)
    dataset = dataset.shuffle(10000)
    dataset = dataset.padded_batch(batch_size, padded_shapes={...})
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle,
                                                   train_dataset.output_types,
                                                   train_dataset.output_shapes)
    batch = iterator.get_next()
    ... 
    # Model code
    ...
    iterator = dataset.make_initializable_iterator()

with tf.Session(graph=graph) as sess:
    train_handle = sess.run(iterator.string_handle())
    sess.run(tf.global_variables_initializer())
    for epoch in range(n_epochs):
        sess.run(train_iterator.initializer)
        while True:
            try:
                sess.run(optimizer, feed_dict={handle: train_handle})
            except tf.errors.OutOfRangeError:
               break

Now after the model is trained I want to infer on examples that are not in the datasets and I am not sure how to go about doing it.

Just to be clear, I know how to use another dataset, for example I just pass a handle to my test set upon testing.

The question is about given the scaling scheme and the fact that the network expects a handle, if I want to make a prediction to a new example which is not written to a TFRecord, how would I go about doing that?

If I'd modify the batch I'd be responsible for the scaling beforehand which is something I would like to avoid if possible.

So how should I infer single examples from a model traiend the tf.data.Dataset way? (This is not for production purposes it is for evaluating what will happen if I change specific features)


回答1:


actually there is a tensor name called "IteratorGetNext:0" in the graph when you use dataset api, so you can using following way to directly set input:

#get a tensor from a graph 
input tensor : input = graph.get_tensor_by_name("IteratorGetNext:0")
# difine the target tensor you want evaluate for your prediction
prediction tensor: predictions=...
# finally call session to run 
then sess.run(predictions, feed_dict={input: np.asanyarray(images), ...})


来源:https://stackoverflow.com/questions/50940667/inference-with-a-model-trained-with-tf-dataset

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!