I want to predict on my trained Model from Python Tensorflow API with the Java API, but have problems to feed in the features to predict in Java.
My Python Code is like this:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from six.moves.urllib.request import urlopen
import numpy as np
import tensorflow as tf
feature_names = [
#prediction_input = np.array([['A11', 6, 'A34', 'A43', 1169, 'A65', 'A75', 4, 'A93', 'A101', 4, 'A121', 67, 'A143', 'A152', 2, 'A173', 1, 'A192', 'A201'],
# ['A12', 18, 'A34', 'A43', 1795, 'A61', 'A75', 3, 'A92', 'A103', 4, 'A121', 48, 'A141', 'A151', 2, 'A173', 1, 'A192', 'A201']])
prediction_input = [["A12 12 A32 A40 7472 A65 A71 1 A92 A101 2 A121 24 A143 A151 1 A171 1 A191 A201"],
["A11 36 A32 A40 9271 A61 A74 2 A93 A101 1 A123 24 A143 A152 1 A173 1 A192 A201"],
["A12 15 A30 A40 1778 A61 A72 2 A92 A101 1 A121 26 A143 A151 2 A171 1 A191 A201"]]
def predict_input_fn():
def decode(zeile):
parsed_line = tf.decode_csv(zeile, [[''], [0], [''], [''], [0], [''], [''], [0], [''], [''], [0], [''], [0], [''], [''], [0], [''], [0], [''], ['']], field_delim=' ')
#x = tf.split(x, 20) # Need to split into our 20 features
# When predicting, we don't need (or have) any labels
return dict(zip(feature_names, parsed_line)) # Then build a dict from them
# The from_tensor_slices function will use a memory structure as input
dataset = tf.data.Dataset.from_tensor_slices(prediction_input)
dataset = dataset.map(decode)
dataset = dataset.batch(1)
iterator = dataset.make_one_shot_iterator()
next_feature_batch = iterator.get_next()
return next_feature_batch, None # In prediction, we have no labels
# Data sets
def train_test_input_fn(dateipfad, mit_shuffle=False, anzahl_wiederholungen=1):
def parser(zeile):
parsed_line = tf.decode_csv(zeile, [[''], [0], [''], [''], [0], [''], [''], [0], [''], [''], [0], [''], [0], [''], [''], [0], [''], [0], [''], [''], [0]], field_delim=' ')
label = parsed_line[-1:] # Last element is the label
del parsed_line[-1] # Delete last element
features = parsed_line # Everything (but last element) are the features
d = dict(zip(feature_names, features)), label
return d
dataset = tf.data.TextLineDataset(dateipfad)
dataset = dataset.map(parser)
if mit_shuffle:
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(1)
dataset = dataset.repeat(anzahl_wiederholungen)
iterator = dataset.make_one_shot_iterator()
# `features` is a dictionary in which each value is a batch of values for
# that feature; `labels` is a batch of labels.
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
def main():
feature_columns = [tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute1', ['A11', 'A12', 'A13', 'A14'])),
tf.feature_column.numeric_column('Attribute2', shape=[1]),
tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute3', ['A30', 'A31', 'A32', 'A33'])),
tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute4', ['A40', 'A41', 'A42', 'A43', 'A44', 'A45', 'A46', 'A47', 'A48', 'A49', 'A410'])),
tf.feature_column.numeric_column('Attribute5', shape=[1]),
tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute6', ['A61', 'A62', 'A63', 'A64', 'A65'])),
tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute7', ['A71', 'A72', 'A73', 'A74', 'A75'])),
tf.feature_column.numeric_column('Attribute8', shape=[1]),
tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute9', ['A91', 'A92', 'A93', 'A94', 'A95'])),
tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute10', ['A101', 'A102', 'A103'])),
tf.feature_column.numeric_column('Attribute11', shape=[1]),
tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute12', ['A121', 'A122', 'A123', 'A124'])),
tf.feature_column.numeric_column('Attribute13', shape=[1]),
tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute14', ['A141', 'A142', 'A143'])),
tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute15', ['A151', 'A152', 'A153'])),
tf.feature_column.numeric_column('Attribute16', shape=[1]),
tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute17', ['A171', 'A172', 'A173', 'A174'])),
tf.feature_column.numeric_column('Attribute18', shape=[1]),
tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute19', ['A191', 'A192'])),
tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list('Attribute20', ['A201', 'A202']))]
classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
# Trainieren des Models
classifier.train(input_fn=lambda: train_test_input_fn("german.data.train.txt", True, 10))
# Errechne die Genauigkeit ("accuracy").
accuracy_score = classifier.evaluate(input_fn=lambda: train_test_input_fn("german.data.test.txt", False, 4))["accuracy"]
print("\nTest Genauigkeit: {0:f}\n".format(accuracy_score))
feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
classifier.export_savedmodel("./export" , serving_input_receiver_fn, as_text=True)
predict_results = classifier.predict(input_fn=predict_input_fn)
for idx, prediction in enumerate(predict_results):
type = prediction["class_ids"][0] # Get the predicted class (index)
if type == 0:
print("Ich denke: {}, ist nicht kreditwürdig".format(prediction_input[idx]))
elif type == 1:
print("Ich denke: {}, ist kreditwürdig".format(prediction_input[idx]))
if __name__ == "__main__":
But I found nothing, how I can feed such categorical columns in Java Clients? Can you please provide a sample how I can make this?
My current state is something like this, but without any idea which Tensor I have to create, to predict on the trained model in Java:
public static void main(String[] args) throws Exception {
String pfad = System.getProperty("user.dir") + "\\1511523781";
Session session = SavedModelBundle.load(pfad, "serve").session();
String example = "A12 12 A32 A40 7472 A65 A71 1 A92 A101 2 A121 24 A143 A151 1 A171 1 A191 A201";
final String xName = "input_example_tensor";
final String scoresName = "dnn/head/predictions/probabilities:0";
List<Tensor<?>> outputs = session.runner()
.feed(xName, example)
// Outer dimension is batch size; inner dimension is number of classes
float[][] scores = new float[2][3];
Since you're using tf.estimator.export.build_parsing_serving_input_receiver_fn, the exported saved model you've created expects a serialized tf.Example protocol buffer as input.
You can use the tf.Example
protocol buffer in Java (maven, javadoc), using something like this:
import com.google.protobuf.ByteString;
import java.util.Arrays;
import org.tensorflow.*;
import org.tensorflow.example.*;
public class Main {
// Returns a Feature containing a BytesList, where each element of the list
// is the UTF-8 encoded bytes of the Java string.
public static Feature feature(String... strings) {
BytesList.Builder b = BytesList.newBuilder();
for (String s : strings) {
return Feature.newBuilder().setBytesList(b).build();
public static Feature feature(float... values) {
FloatList.Builder b = FloatList.newBuilder();
for (float v : values) {
return Feature.newBuilder().setFloatList(b).build();
public static void main(String[] args) throws Exception {
Features features =
.putFeature("Attribute1", feature("A12"))
.putFeature("Attribute2", feature(12))
.putFeature("Attribute3", feature("A32"))
.putFeature("Attribute4", feature("A40"))
.putFeature("Attribute5", feature(7472))
.putFeature("Attribute6", feature("A65"))
.putFeature("Attribute7", feature("A71"))
.putFeature("Attribute8", feature(1))
.putFeature("Attribute9", feature("A92"))
.putFeature("Attribute10", feature("A101"))
.putFeature("Attribute11", feature(2))
.putFeature("Attribute12", feature("A121"))
.putFeature("Attribute13", feature(24))
.putFeature("Attribute14", feature("A143"))
.putFeature("Attribute15", feature("A151"))
.putFeature("Attribute16", feature(1))
.putFeature("Attribute17", feature("A171"))
.putFeature("Attribute18", feature(1))
.putFeature("Attribute19", feature("A191"))
.putFeature("Attribute20", feature("A201"))
Example example = Example.newBuilder().setFeatures(features).build();
String pfad = System.getProperty("user.dir") + "\\1511523781";
try (SavedModelBundle model = SavedModelBundle.load(pfad, "serve")) {
Session session = model.session();
final String xName = "input_example_tensor";
final String scoresName = "dnn/head/predictions/probabilities:0";
try (Tensor<String> inputBatch = Tensors.create(new byte[][] {example.toByteArray()});
Tensor<Float> output =
.feed(xName, inputBatch)
.expect(Float.class)) {
System.out.println(Arrays.deepToString(output.copyTo(new float[1][2])));
Much of the boilerplate here is to construct the protocol buffer example. Alternatively, you could use something other than build_arsing_serving_input_receiver_fn
to setup the exported model to accept input in a different format.
Side note: You can use the saved_model_cli
command-line tool that is included with TensorFlow Python installation to inspect the saved model. For example, something like:
saved_model_cli show \
--dir ./export/1511523781 \
--tag_set serve \
--signature_def predict
will show something like:
The given SavedModel SignatureDef contains the following input(s):
inputs['examples'] tensor_info:
dtype: DT_STRING
shape: (-1)
name: input_example_tensor:0
The given SavedModel SignatureDef contains the following output(s):
outputs['probabilities'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 2)
name: dnn/head/predictions/probabilities:0
Suggesting that the saved model takes a single input - a batch of DT_STRING
elements and the output probabilities are a batch of 2-dimensional float vectors.
Hope that helps.