Loading sklearn model in Java. Model created with DNNClassifier in python

前端 未结 3 2088
我在风中等你
我在风中等你 2021-02-10 02:48

The goal is to open in Java a model created/trained in python with tensorflow.contrib.learn.learn.DNNClassifier.

At the moment the main issue is to know th

相关标签:
3条回答
  • 2021-02-10 03:24

    I got an error without feed("input_example_tensor", inputTensor) on Tensorflow 1.1.

    But I found that example.proto can be fed as "input_example_tensor", although it took a lot of time to figure out how to create string tensors for serialized protocol buffer.

    This is how I created inputTensor.

    org.tensorflow.example.Example.Builder example = org.tensorflow.example.Example.newBuilder();   
    /* set some features to example... */
    
    Tensor exampleTensor = Tensor.create(example.build().toByteArray());
    // Here, the shape of exampleTensor is not specified yet.
    
    // Set the shape to feed this as "input_example_tensor"
    Graph g = bundle.graph(); 
    Output examplePlaceholder =
                      g.opBuilder("Placeholder", "example")
                      .setAttr("dtype", exampleTensor.dataType())                        
                          .build().output(0);
    Tensor shapeTensor = Tensor.create(new long[]{1}, IntBuffer.wrap(new int[]{1}));                      
    Output shapeConst = g.opBuilder("Const", "shape")
                          .setAttr("dtype", shapeTensor.dataType())
                          .setAttr("value", shapeTensor)
                          .build().output(0);
    Output shaped = g.opBuilder("Reshape", "output").addInput(examplePlaceholder).addInput(shapeConst).build().output(0);
    
    
    Tensor inputTensor = s.runner().feed(examplePlaceholder, exampleTensor).fetch(shaped).run().get(0);                   
    // Now, inputTensor has shape of [1] and ready to feed.     
    
    0 讨论(0)
  • 2021-02-10 03:35

    Ok I finally Solve : the main problem was the name of the input to use in java that is ""dnn/input_from_feature_columns/input_from_feature_columns/concat" and not "input_example_tensor".

    I have discover this using the graph navigation with: tensorboard --logdir=D:\python\Workspace\Autoencoder\src\dnn\ModelSave

    here is the java code :

    public class HelloTF {
    public static void main(String[] args) throws Exception {
        SavedModelBundle bundle=SavedModelBundle.load("/java/workspace/APIJavaSampleCode/tfModels/dnn/ModelSave","serve");
        Session s = bundle.session();
    
        double[] inputDouble = {1.0,0.7982741870963959,1.0,-0.46270838239235024,0.040320274521029376,0.443451913224413,-1.0,1.0,1.0,-1.0,0.36689718911339564,-0.13577379160035796,-0.5162916256414466,-0.03373651520104648,1.0,1.0,1.0,1.0,0.786999801054777,-0.43856035121103853,-0.8199093927945158,1.0,-1.0,-1.0,-0.1134921695894473,-1.0,0.6420892436196663,0.7871737734493178,1.0,0.6501788845358409,1.0,1.0,1.0,-0.17586627413625022,0.8817194210401085};
        float [] inputfloat=new float[inputDouble.length];
        for(int i=0;i<inputfloat.length;i++)
        {
            inputfloat[i]=(float)inputDouble[i];
        }
    FloatBuffer.wrap(inputfloat) );
        float[][] data= new float[1][35];
        data[0]=inputfloat;
        Tensor inputTensor=Tensor.create(data);
    
    
        Tensor result = s.runner()
                .feed("dnn/input_from_feature_columns/input_from_feature_columns/concat", inputTensor)
                //.feed("input_example_tensor", inputTensor)
                //.fetch("tensorflow/serving/classify")
                .fetch("dnn/multi_class_head/predictions/probabilities")
                //.fetch("dnn/zero_fraction_3/Cast")
                .run().get(0);
    
    
         float[][] m = new float[1][5];
         float[][] vector = result.copyTo(m);
         float maxVal = 0;
         int inc = 0;
         int predict = -1;
         for(float val : vector[0]) 
         {
             System.out.println(val+"  ");
             if(val > maxVal) {
                 predict = inc;
                 maxVal = val;
             }
             inc++;
         }
         System.out.println(predict);
    
    
    
    }
    

    }

    I have tested the output :

    phyton side :

    Prediction for sample_2 is:[3] 
    Prediction for sample_2 is:[array([ 0.17157166,  0.24475774,  0.16158019,  0.24648622,  0.17560424], dtype=float32)] 
    

    Java Side :

    0.17157166  
    0.24475774  
    0.16158019   
    0.24648622  
    0.17560424  
    3
    
    0 讨论(0)
  • 2021-02-10 03:36

    The error message offers a clue: the tensor named "input_example_tensor" in the model expects to have string contents, whereas you provided float values.

    Judging by the name of the tensor and your code, I'd guess that the tensor you're feeding is defined in input_fn_utils.py. This tensor is passed to the tf.parse_example() op, which expects a vector of tf.train.Example protocol buffers, serialized as strings.

    0 讨论(0)
提交回复
热议问题