How to classify a QuickDraw doodle using TensorFlow's sketch RNN tutorial?

前端 未结 1 702
半阙折子戏
半阙折子戏 2021-01-17 04:46

Clarifications:

  1. This question is regarding this QuickDraw RNN Drawing classification tensorflow tutorial, not the text RNN tensorflow tutorial
相关标签:
1条回答
  • 2021-01-17 05:38

    python train_model.py \ --training_data=rnn_tutorial_data/training.tfrecord-?????-of-????? \ --eval_data=rnn_tutorial_data/eval.tfrecord-?????-of-????? \ --classes_file=rnn_tutorial_data/training.tfrecord.classes

    AFAIK using the above command works as well, it will simply read all the files in the folder where you have download the data files one by one.

    create_tfrecord_for_prediction is certainly not my own creation, this code was mostly picked from another file from tensorflow guys create_dataset.py

    Below I have pasted the almost all of the new code i added including the my modifications to the main() function

    def create_tfrecord_for_prediction(batch_size, stoke_data, tfrecord_file):
        def parse_line(stoke_data):
            """Parse provided stroke data and ink (as np array) and classname."""
            inkarray = json.loads(stoke_data)
            stroke_lengths = [len(stroke[0]) for stroke in inkarray]
            total_points = sum(stroke_lengths)
            np_ink = np.zeros((total_points, 3), dtype=np.float32)
            current_t = 0
            for stroke in inkarray:
                if len(stroke[0]) != len(stroke[1]):
                    print("Inconsistent number of x and y coordinates.")
                    return None
                for i in [0, 1]:
                    np_ink[current_t:(current_t + len(stroke[0])), i] = stroke[i]
                current_t += len(stroke[0])
                np_ink[current_t - 1, 2] = 1  # stroke_end
            # Preprocessing.
            # 1. Size normalization.
            lower = np.min(np_ink[:, 0:2], axis=0)
            upper = np.max(np_ink[:, 0:2], axis=0)
            scale = upper - lower
            scale[scale == 0] = 1
            np_ink[:, 0:2] = (np_ink[:, 0:2] - lower) / scale
            # 2. Compute deltas.
            #np_ink = np_ink[1:, 0:2] - np_ink[0:-1, 0:2]
            #np_ink = np_ink[1:, :]
            np_ink[1:, 0:2] -= np_ink[0:-1, 0:2]
            np_ink = np_ink[1:, :]
    
            features = {}
            features["ink"] = tf.train.Feature(float_list=tf.train.FloatList(value=np_ink.flatten()))
            features["shape"] = tf.train.Feature(int64_list=tf.train.Int64List(value=np_ink.shape))
            f = tf.train.Features(feature=features)
            ex = tf.train.Example(features=f)
            return ex
    
        if stoke_data is None:
            print("Error: Stroke data cannot be none")
            return
    
        example = parse_line(stoke_data)
    
        #Remove the file if it already exists
        if tf.gfile.Exists(tfrecord_file):
            tf.gfile.Remove(tfrecord_file)
    
        writer = tf.python_io.TFRecordWriter(tfrecord_file)
        for i in range(batch_size):
            writer.write(example.SerializeToString())
        writer.flush()
        writer.close()
    
    def get_classes():
      classes = []
      with tf.gfile.GFile(FLAGS.classes_file, "r") as f:
        classes = [x.rstrip() for x in f]
      return classes
    
    def main(unused_args):
      print("%s: I Starting application" % (datetime.now()))
    
      estimator, train_spec, eval_spec = create_estimator_and_specs(
          run_config=tf.estimator.RunConfig(
              model_dir=FLAGS.model_dir,
              save_checkpoints_secs=300,
              save_summary_steps=100))
      tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
    
      if FLAGS.predict_for_data != None:
          print("%s: I Starting prediction" % (datetime.now()))
          class_names = get_classes()
          create_tfrecord_for_prediction(FLAGS.batch_size, FLAGS.predict_for_data, FLAGS.predict_temp_file)
          predict_results = estimator.predict(input_fn=get_input_fn(
              mode=tf.estimator.ModeKeys.PREDICT,
              tfrecord_pattern=FLAGS.predict_temp_file,
              batch_size=FLAGS.batch_size))
    
          #predict_results = estimator.predict(input_fn=predict_input_fn)
          for idx, prediction in enumerate(predict_results):
              index = prediction["class_index"]  # Get the predicted class (index)
              probability = prediction["probabilities"][index]
              class_name = class_names[index]
              print("%s: Predicted Class is: %s with a probability of %f" % (datetime.now(), class_name, probability))
              break #We care for only the first prediction, rest are all duplicates just to meet the batch size
    

    FLAGS.predict_for_data this is the command line argument that holds the strokes data FLAGS.predict_temp_file is just a name of file i use to create the temporary input data tfrecord file

    Note1 : Along with this I also modified some code in get_input_fn() you can find this code change in this PR: https://github.com/tensorflow/models/pull/3440 (has not been merged yet)

    Note2: I also had to modify model_fn() and add the below few lines my additions are after the comment #Compute current predictions

      # Build the model.
      inks, lengths, labels = _get_input_tensors(features, labels)
      convolved, lengths = _add_conv_layers(inks, lengths)
      final_state = _add_rnn_layers(convolved, lengths)
      logits = _add_fc_layers(final_state)
    
      # Compute current predictions.
      predictions = tf.argmax(logits, axis=1)
    
      if mode == tf.estimator.ModeKeys.PREDICT:
          preds = {
              "class_index": predictions,
              #"class_index": predictions[:, tf.newaxis],
              "probabilities": tf.nn.softmax(logits),
              "logits": logits
          }
          #preds = {"logits": logits, "predictions": predictions}
    
          return tf.estimator.EstimatorSpec(mode, predictions=preds)
          # Add the loss.
      cross_entropy = tf.reduce_mean(
          tf.nn.sparse_softmax_cross_entropy_with_logits(
              labels=labels, logits=logits))
    
      # Add the optimizer.
      train_op = tf.contrib.layers.optimize_loss(
          loss=cross_entropy,
          global_step=tf.train.get_global_step(),
          learning_rate=params.learning_rate,
          optimizer="Adam",
          # some gradient clipping stabilizes training in the beginning.
          clip_gradients=params.gradient_clipping_norm,
          summaries=["learning_rate", "loss", "gradients", "gradient_norm"])
    
      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions={"logits": logits, "predictions": predictions},
          loss=cross_entropy,
          train_op=train_op,
          eval_metric_ops={"accuracy": tf.metrics.accuracy(labels, predictions)})
    

    The only thing left is then to figure out generating strokes data. For this you can take one of the existing tfrecord file read it and extract a stroke from that read operation or you could write some javascript webpage to generate strokes

    0 讨论(0)
提交回复
热议问题