Loading sklearn model in Java. Model created with DNNClassifier in python

前端 未结 3 1220
轻奢々
轻奢々 2021-02-10 02:45

The goal is to open in Java a model created/trained in python with tensorflow.contrib.learn.learn.DNNClassifier.

At the moment the main issue is to know th

3条回答
  •  长发绾君心
    2021-02-10 03:13

    I got an error without feed("input_example_tensor", inputTensor) on Tensorflow 1.1.

    But I found that example.proto can be fed as "input_example_tensor", although it took a lot of time to figure out how to create string tensors for serialized protocol buffer.

    This is how I created inputTensor.

    org.tensorflow.example.Example.Builder example = org.tensorflow.example.Example.newBuilder();   
    /* set some features to example... */
    
    Tensor exampleTensor = Tensor.create(example.build().toByteArray());
    // Here, the shape of exampleTensor is not specified yet.
    
    // Set the shape to feed this as "input_example_tensor"
    Graph g = bundle.graph(); 
    Output examplePlaceholder =
                      g.opBuilder("Placeholder", "example")
                      .setAttr("dtype", exampleTensor.dataType())                        
                          .build().output(0);
    Tensor shapeTensor = Tensor.create(new long[]{1}, IntBuffer.wrap(new int[]{1}));                      
    Output shapeConst = g.opBuilder("Const", "shape")
                          .setAttr("dtype", shapeTensor.dataType())
                          .setAttr("value", shapeTensor)
                          .build().output(0);
    Output shaped = g.opBuilder("Reshape", "output").addInput(examplePlaceholder).addInput(shapeConst).build().output(0);
    
    
    Tensor inputTensor = s.runner().feed(examplePlaceholder, exampleTensor).fetch(shaped).run().get(0);                   
    // Now, inputTensor has shape of [1] and ready to feed.     
    

提交回复
热议问题