I\'m trying to port this line of Python code:
my_var = tf.Variable(3, name=\"input_a\")
to Java. I was able to do this with tf.consta
Having the same need as you, I used the assign
node of tensorflow to assign the value to my variable. So first you need to define your node the way you did and then you need to add this node with the corresponding value. Then I refer to this new assigned node later in my graph so it does not raise the error java.lang.IllegalStateException: Attempting to use uninitialized value
.
I expanded the Graph feature with a GraphBuilder class and added this required classes:
class GraphBuilder(g: Graph ) {
def variable(name: String, dataType: DataType, shape: Shape): Output = {
g.opBuilder("Variable", name)
.setAttr("dtype", dataType)
.setAttr("shape", shape)
.build()
.output(0)
}
def assign(value: Output, variable: Output): Output = {
graph.opBuilder("Assign", "Assign/" + variable.op().name()).addInput(variable).addInput(value).build().output(0)
}
}
val WValue = Array.fill(numFeatures)(Array.fill(hiddenDim)(0.0))
val W = builder.variable("W", DataType.DOUBLE, Shape.make(numFeatures, hiddenDim))
val W_init = builder.assign(builder.constant("Wval", WValue), W)
The assign nodes will assign your variables with pre-set value at each forward pass so it's not suited for training either. But anyway, from this post it seems that you need to add dependencies as by default the JAVA API does not provide the training nodes: https://github.com/tensorflow/tensorflow/issues/5518.