Tensorflow Java API set Placeholder for categorical columns

我的梦境 提交于 2019-12-11 05:13:40

问题


I want to predict on my trained Model from Python Tensorflow API with the Java API, but have problems to feed in the features to predict in Java.

My Python Code is like this:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from six.moves.urllib.request import urlopen
import numpy as np
import tensorflow as tf

feature_names = [
'Attribute1',
'Attribute2',
'Attribute3',
'Attribute4',
'Attribute5',
'Attribute6',
'Attribute7',
'Attribute8',
'Attribute9',
'Attribute10',
'Attribute11',
'Attribute12',
'Attribute13',
'Attribute14',
'Attribute15',
'Attribute16',
'Attribute17',
'Attribute18',
'Attribute19',
'Attribute20']

#prediction_input = np.array([['A11', 6, 'A34', 'A43', 1169, 'A65', 'A75',     4, 'A93', 'A101', 4, 'A121', 67, 'A143', 'A152', 2, 'A173', 1, 'A192', 'A201'],
#                               ['A12', 18, 'A34', 'A43', 1795, 'A61', 'A75', 3, 'A92', 'A103', 4, 'A121', 48, 'A141', 'A151', 2, 'A173', 1, 'A192', 'A201']])
prediction_input = [["A12 12 A32 A40 7472 A65 A71 1 A92 A101 2 A121 24 A143 A151 1 A171 1 A191 A201"],
                ["A11 36 A32 A40 9271 A61 A74 2 A93 A101 1 A123 24 A143 A152 1 A173 1 A192 A201"],
                ["A12 15 A30 A40 1778 A61 A72 2 A92 A101 1 A121 26 A143 A151 2 A171 1 A191 A201"]]

def predict_input_fn():
def decode(zeile):
    parsed_line = tf.decode_csv(zeile, [[''], [0], [''], [''], [0], [''], [''], [0], [''], [''], [0], [''], [0], [''], [''], [0], [''], [0], [''], ['']], field_delim=' ')
    #x = tf.split(x, 20) # Need to split into our 20 features
    # When predicting, we don't need (or have) any labels
    return dict(zip(feature_names, parsed_line)) # Then build a dict from them

# The from_tensor_slices function will use a memory structure as input
dataset = tf.data.Dataset.from_tensor_slices(prediction_input)
dataset = dataset.map(decode)
dataset = dataset.batch(1)
iterator = dataset.make_one_shot_iterator()
next_feature_batch = iterator.get_next()
return next_feature_batch, None # In prediction, we have no labels  

# Data sets
def train_test_input_fn(dateipfad, mit_shuffle=False, anzahl_wiederholungen=1):
def parser(zeile):
    parsed_line = tf.decode_csv(zeile, [[''], [0], [''], [''], [0], [''], [''], [0], [''], [''], [0], [''], [0], [''], [''], [0], [''], [0], [''], [''], [0]], field_delim=' ')
    label = parsed_line[-1:] # Last element is the label
    del parsed_line[-1] # Delete last element
    features = parsed_line # Everything (but last element) are the features
    d = dict(zip(feature_names, features)), label
    return d

dataset = tf.data.TextLineDataset(dateipfad)
dataset = dataset.map(parser)
if mit_shuffle:
    dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(1)
dataset = dataset.repeat(anzahl_wiederholungen)
iterator = dataset.make_one_shot_iterator()

# `features` is a dictionary in which each value is a batch of values for
# that feature; `labels` is a batch of labels.
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels

def main():
feature_columns = [tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute1', ['A11', 'A12', 'A13', 'A14'])),
                    tf.feature_column.numeric_column('Attribute2', shape=[1]),
                            tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute3', ['A30', 'A31', 'A32', 'A33'])),
                        tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute4', ['A40', 'A41', 'A42', 'A43', 'A44', 'A45', 'A46', 'A47', 'A48', 'A49', 'A410'])),
                    tf.feature_column.numeric_column('Attribute5', shape=[1]),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute6', ['A61', 'A62', 'A63', 'A64', 'A65'])),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute7', ['A71', 'A72', 'A73', 'A74', 'A75'])),
                    tf.feature_column.numeric_column('Attribute8', shape=[1]),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute9', ['A91', 'A92', 'A93', 'A94', 'A95'])),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute10', ['A101', 'A102', 'A103'])),
                    tf.feature_column.numeric_column('Attribute11', shape=[1]),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute12', ['A121', 'A122', 'A123', 'A124'])),
                    tf.feature_column.numeric_column('Attribute13', shape=[1]),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute14', ['A141', 'A142', 'A143'])),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute15', ['A151', 'A152', 'A153'])),
                    tf.feature_column.numeric_column('Attribute16', shape=[1]),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute17', ['A171', 'A172', 'A173', 'A174'])),
                    tf.feature_column.numeric_column('Attribute18', shape=[1]),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute19', ['A191', 'A192'])),
                    tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute20', ['A201', 'A202']))]

classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                          hidden_units=[100],
                                          n_classes=2,
                                          model_dir="./summaries")                                                    

# Trainieren des Models
classifier.train(input_fn=lambda: train_test_input_fn("german.data.train.txt", True, 10))

# Errechne die Genauigkeit ("accuracy").
accuracy_score = classifier.evaluate(input_fn=lambda: train_test_input_fn("german.data.test.txt", False, 4))["accuracy"]
print("\nTest Genauigkeit: {0:f}\n".format(accuracy_score))

feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
classifier.export_savedmodel("./export" , serving_input_receiver_fn, as_text=True)

predict_results = classifier.predict(input_fn=predict_input_fn)
for idx, prediction in enumerate(predict_results):
   type = prediction["class_ids"][0] # Get the predicted class (index)
   if type == 0:
       print("Ich denke: {}, ist nicht kreditwürdig".format(prediction_input[idx]))
   elif type == 1:
       print("Ich denke: {}, ist kreditwürdig".format(prediction_input[idx]))

if __name__ == "__main__":
main()

But I found nothing, how I can feed such categorical columns in Java Clients? Can you please provide a sample how I can make this?

My current state is something like this, but without any idea which Tensor I have to create, to predict on the trained model in Java:

public static void main(String[] args) throws Exception {
    String pfad = System.getProperty("user.dir") + "\\1511523781";
    Session session = SavedModelBundle.load(pfad, "serve").session();
    String example = "A12 12 A32 A40 7472 A65 A71 1 A92 A101 2 A121 24 A143 A151 1 A171 1 A191 A201";

    final String xName = "input_example_tensor";
    final String scoresName = "dnn/head/predictions/probabilities:0";

    List<Tensor<?>> outputs = session.runner()
        .feed(xName, example)
        .fetch(scoresName)
        .run();

    // Outer dimension is batch size; inner dimension is number of classes
    float[][] scores = new float[2][3];
    outputs.get(0).copyTo(scores);
    System.out.println(Arrays.deepToString(scores));
  }

Thanks!


回答1:


Since you're using tf.estimator.export.build_parsing_serving_input_receiver_fn, the exported saved model you've created expects a serialized tf.Example protocol buffer as input.

You can use the tf.Example protocol buffer in Java (maven, javadoc), using something like this:

import com.google.protobuf.ByteString;
import java.util.Arrays;
import org.tensorflow.*;
import org.tensorflow.example.*;

public class Main {
  // Returns a Feature containing a BytesList, where each element of the list
  // is the UTF-8 encoded bytes of the Java string.
  public static Feature feature(String... strings) {
    BytesList.Builder b = BytesList.newBuilder();
    for (String s : strings) {
      b.addValue(ByteString.copyFromUtf8(s));
    }
    return Feature.newBuilder().setBytesList(b).build();
  }

  public static Feature feature(float... values) {
    FloatList.Builder b = FloatList.newBuilder();
    for (float v : values) {
      b.addValue(v);
    }
    return Feature.newBuilder().setFloatList(b).build();
  }

  public static void main(String[] args) throws Exception {
    Features features =
        Features.newBuilder()
            .putFeature("Attribute1", feature("A12"))
            .putFeature("Attribute2", feature(12))
            .putFeature("Attribute3", feature("A32"))
            .putFeature("Attribute4", feature("A40"))
            .putFeature("Attribute5", feature(7472))
            .putFeature("Attribute6", feature("A65"))
            .putFeature("Attribute7", feature("A71"))
            .putFeature("Attribute8", feature(1))
            .putFeature("Attribute9", feature("A92"))
            .putFeature("Attribute10", feature("A101"))
            .putFeature("Attribute11", feature(2))
            .putFeature("Attribute12", feature("A121"))
            .putFeature("Attribute13", feature(24))
            .putFeature("Attribute14", feature("A143"))
            .putFeature("Attribute15", feature("A151"))
            .putFeature("Attribute16", feature(1))
            .putFeature("Attribute17", feature("A171"))
            .putFeature("Attribute18", feature(1))
            .putFeature("Attribute19", feature("A191"))
            .putFeature("Attribute20", feature("A201"))
            .build();
    Example example = Example.newBuilder().setFeatures(features).build();

    String pfad = System.getProperty("user.dir") + "\\1511523781";
    try (SavedModelBundle model = SavedModelBundle.load(pfad, "serve")) {
      Session session = model.session();
      final String xName = "input_example_tensor";
      final String scoresName = "dnn/head/predictions/probabilities:0";

      try (Tensor<String> inputBatch = Tensors.create(new byte[][] {example.toByteArray()});
          Tensor<Float> output =
              session
                  .runner()
                  .feed(xName, inputBatch)
                  .fetch(scoresName)
                  .run()
                  .get(0)
                  .expect(Float.class)) {
        System.out.println(Arrays.deepToString(output.copyTo(new float[1][2])));
      }
    }
  }
}

Much of the boilerplate here is to construct the protocol buffer example. Alternatively, you could use something other than build_arsing_serving_input_receiver_fn to setup the exported model to accept input in a different format.

Side note: You can use the saved_model_cli command-line tool that is included with TensorFlow Python installation to inspect the saved model. For example, something like:

saved_model_cli show  \
  --dir ./export/1511523781 \
  --tag_set serve \
  --signature_def predict

will show something like:

The given SavedModel SignatureDef contains the following input(s):
inputs['examples'] tensor_info:
    dtype: DT_STRING
    shape: (-1)
    name: input_example_tensor:0
The given SavedModel SignatureDef contains the following output(s):
...
outputs['probabilities'] tensor_info:
    dtype: DT_FLOAT
    shape: (-1, 2)
    name: dnn/head/predictions/probabilities:0

Suggesting that the saved model takes a single input - a batch of DT_STRING elements and the output probabilities are a batch of 2-dimensional float vectors.

Hope that helps.



来源:https://stackoverflow.com/questions/47477314/tensorflow-java-api-set-placeholder-for-categorical-columns

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!