The goal is to open in Java a model created/trained in python with tensorflow.contrib.learn.learn.DNNClassifier
.
At the moment the main issue is to know th
I got an error without feed("input_example_tensor", inputTensor)
on Tensorflow 1.1.
But I found that example.proto
can be fed as "input_example_tensor", although it took a lot of time to figure out how to create string tensors for serialized protocol buffer.
This is how I created inputTensor
.
org.tensorflow.example.Example.Builder example = org.tensorflow.example.Example.newBuilder();
/* set some features to example... */
Tensor exampleTensor = Tensor.create(example.build().toByteArray());
// Here, the shape of exampleTensor is not specified yet.
// Set the shape to feed this as "input_example_tensor"
Graph g = bundle.graph();
Output examplePlaceholder =
g.opBuilder("Placeholder", "example")
.setAttr("dtype", exampleTensor.dataType())
.build().output(0);
Tensor shapeTensor = Tensor.create(new long[]{1}, IntBuffer.wrap(new int[]{1}));
Output shapeConst = g.opBuilder("Const", "shape")
.setAttr("dtype", shapeTensor.dataType())
.setAttr("value", shapeTensor)
.build().output(0);
Output shaped = g.opBuilder("Reshape", "output").addInput(examplePlaceholder).addInput(shapeConst).build().output(0);
Tensor inputTensor = s.runner().feed(examplePlaceholder, exampleTensor).fetch(shaped).run().get(0);
// Now, inputTensor has shape of [1] and ready to feed.