Inferencing with Tensorflow Serving using Java

删除回忆录丶 提交于 2020-01-05 03:33:31


We are transitioning an existing Java production code to use Tensorflow Serving (TFS) for inferencing. We have already retrained our models and saved them using the new SavedModel format (no more frozen graphs!!).
From the documentation that I have read, TFS does not directly support Java. However it does provide a gRPC interface, and that does provide a Java interface.

My question, what are the steps involved in bringing up a Java application to use TFS.

[Edit: moved steps to a solution]


It took four days to piece this together, as documentation and examples are still limited.
I'm sure there are better ways to do this, but this is what I found so far:

  • I cloned the tensorflow/tensorflow, tensorflow/serving and google/protobuf repos on github.
  • I compiled the following protobuf files using the protoc protobuf compiler with the grpc-java plugin. I hate the fact that there are so many scattered .proto files to be compiled, but I wanted the minimal set to include and there are so many unneeded .proto files in the various directories that would have been drawn in. Here is the minimal set I needed to compile our Java app:
    • serving_repo/tensorflow_serving/apis/*.proto
    • serving_repo/tensorflow_serving/config/model_server_config.proto
    • serving_repo/tensorflow_serving/core/logging.proto
    • serving_repo/tensorflow_serving/core/logging_config.proto
    • serving_repo/tensorflow_serving/util/status.proto
    • serving_repo/tensorflow_serving/sources/storage_path/file_system_storage_path_source.proto
    • serving_repo/tensorflow_serving/config/log_collector_config.proto
    • tensorflow_repo/tensorflow/core/framework/tensor.proto
    • tensorflow_repo/tensorflow/core/framework/tensor_shape.proto
    • tensorflow_repo/tensorflow/core/framework/types.proto
    • tensorflow_repo/tensorflow/core/framework/resource_handle.proto
    • tensorflow_repo/tensorflow/core/example/example.proto
    • tensorflow_repo/tensorflow/core/protobuf/tensorflow_server.proto
    • tensorflow_repo/tensorflow/core/example/feature.proto
    • tensorflow_repo/tensorflow/core/protobuf/named_tensor.proto
    • tensorflow_repo/tensorflow/core/protobuf/config.proto
  • Note that protoc will compile even withOUT grpc-java present, however most of the critical entrypoints will be mysteriously missing. If is missing then grpc-java is not being executed.
  • Command line example(with linebreaks inserted for readability):
$ ./protoc -I=/Users/foobar/protobuf_repo/src \
   -I=/Users/foobar/tensorflow_repo \   
   -I=/Users/foobar/tfserving_repo \  
   -plugin=protoc-gen-grpc-java=/Users/foobar/protoc-gen-grpc-java-1.20.0-osx-x86_64.exe \
   --java_out=src \
   --grpc-java_out=src \
  • Following the gRPC documentation, I created a Channel and a stub:
ManagedChannel mChannel;
PredictionServiceGrpc.PredictionServiceBlockingStub mBlockingstub;
mChannel = ManagedChannelBuilder.forAddress(host,port).usePlaintext().build();
mBlockingstub = PredictionServiceGrpc.newBlockingStub(mChannel);
  • I followed several documents to piece together the steps that follow:
    • The gRPC documents discuss stubs (Blocking and Asynch)
    • This article overview the process, but with Python
    • This sample code was critical for examples of the NewBuilder syntax.
  • Maven imports are:
    • io.grpc:grpc-all
    • org.tensorflow:libtensorflow
    • org.tensorflow:proto
  • Here is sample code:
// Generate features TensorProto
TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();

TensorShapeProto.Dim featuresDim1  = TensorShapeProto.Dim.newBuilder().setSize(1).build();
TensorShapeProto     featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).build();
TensorProto featuresTensorProto =;

// Now prepare for the inference request over gRPC to the TF Serving server version =;

Model.ModelSpec.Builder model = Model.ModelSpec
                                     .setVersion(version);  // type = Int64Value
Model.ModelSpec     modelSpec =;

Predict.PredictRequest request;
request = Predict.PredictRequest.newBuilder()
                                .putInputs("image", featuresTensorProto)

Predict.PredictResponse response;

try {
    response = mBlockingstub.predict(request);
    // Refer to

    java.util.Map<java.lang.String, org.tensorflow.framework.TensorProto> outputs = response.getOutputsOrDefault();
    for (java.util.Map.Entry<java.lang.String, org.tensorflow.framework.TensorProto> entry : outputs.entrySet()) {
        System.out.println("Response with the key: " + entry.getKey() + ", value: " + entry.getValue());
} catch (StatusRuntimeException e) {
    logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus());
    success = false;

