How to create/initialize a Variable with Tensorflow 1.0 Java API

后端 未结 1 1714
忘掉有多难
忘掉有多难 2021-01-20 10:20

I\'m trying to port this line of Python code:

my_var = tf.Variable(3, name=\"input_a\")

to Java. I was able to do this with tf.consta

相关标签:
1条回答
  • 2021-01-20 11:05

    Having the same need as you, I used the assign node of tensorflow to assign the value to my variable. So first you need to define your node the way you did and then you need to add this node with the corresponding value. Then I refer to this new assigned node later in my graph so it does not raise the error java.lang.IllegalStateException: Attempting to use uninitialized value.

    I expanded the Graph feature with a GraphBuilder class and added this required classes:

    class GraphBuilder(g: Graph ) {
      def variable(name: String, dataType: DataType, shape: Shape): Output = {
        g.opBuilder("Variable", name)
          .setAttr("dtype", dataType)
          .setAttr("shape", shape)
          .build()
          .output(0)
      }
    
      def assign(value: Output, variable: Output): Output = {
          graph.opBuilder("Assign", "Assign/" + variable.op().name()).addInput(variable).addInput(value).build().output(0)
      }
    }
    
    val WValue = Array.fill(numFeatures)(Array.fill(hiddenDim)(0.0))
    val W = builder.variable("W", DataType.DOUBLE, Shape.make(numFeatures, hiddenDim))
    val W_init = builder.assign(builder.constant("Wval", WValue), W)
    

    The assign nodes will assign your variables with pre-set value at each forward pass so it's not suited for training either. But anyway, from this post it seems that you need to add dependencies as by default the JAVA API does not provide the training nodes: https://github.com/tensorflow/tensorflow/issues/5518.

    0 讨论(0)
提交回复
热议问题