Tensorflow : Predict in Recurrent Neural Networks for Drawing Classification tutorial

三世轮回 提交于 2019-12-06 12:01:06

I finally solved the problem by taking my input keystrokes, writing them to disk as a TFRecord. I also had to write the same inputstrokes batch_size times to same TFRecord, else i got the shape mismatch errors. And then invoking predict worked.

The main addition for prediction was the following function

def create_tfrecord_for_prediction(batch_size, stoke_data, tfrecord_file):
    def parse_line(stoke_data):
        """Parse provided stroke data and ink (as np array) and classname."""
        inkarray = json.loads(stoke_data)
        stroke_lengths = [len(stroke[0]) for stroke in inkarray]
        total_points = sum(stroke_lengths)
        np_ink = np.zeros((total_points, 3), dtype=np.float32)
        current_t = 0
        for stroke in inkarray:
            if len(stroke[0]) != len(stroke[1]):
                print("Inconsistent number of x and y coordinates.")
                return None
            for i in [0, 1]:
                np_ink[current_t:(current_t + len(stroke[0])), i] = stroke[i]
            current_t += len(stroke[0])
            np_ink[current_t - 1, 2] = 1  # stroke_end
        # Preprocessing.
        # 1. Size normalization.
        lower = np.min(np_ink[:, 0:2], axis=0)
        upper = np.max(np_ink[:, 0:2], axis=0)
        scale = upper - lower
        scale[scale == 0] = 1
        np_ink[:, 0:2] = (np_ink[:, 0:2] - lower) / scale
        # 2. Compute deltas.
        #np_ink = np_ink[1:, 0:2] - np_ink[0:-1, 0:2]
        #np_ink = np_ink[1:, :]
        np_ink[1:, 0:2] -= np_ink[0:-1, 0:2]
        np_ink = np_ink[1:, :]

        features = {}
        features["ink"] = tf.train.Feature(float_list=tf.train.FloatList(value=np_ink.flatten()))
        features["shape"] = tf.train.Feature(int64_list=tf.train.Int64List(value=np_ink.shape))
        f = tf.train.Features(feature=features)
        ex = tf.train.Example(features=f)
        return ex

    if stoke_data is None:
        print("Error: Stroke data cannot be none")
        return

    example = parse_line(stoke_data)

    #Remove the file if it already exists
    if tf.gfile.Exists(tfrecord_file):
        tf.gfile.Remove(tfrecord_file)

    writer = tf.python_io.TFRecordWriter(tfrecord_file)
    for i in range(batch_size):
        writer.write(example.SerializeToString())
    writer.flush()
    writer.close()

Then in the main function you just have to invoke estimator.predict() reusing the same input_fn=get_input_fn(...)argument except point it to the temporary created tfrecord_file

Hope this helps

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!