TensorFlow: getting variable by name

后端 未结 4 440
后悔当初
后悔当初 2020-12-08 09:36

When using the TensorFlow Python API, I created a variable (without specifying its name in the constructor), and its name property had the value

相关标签:
4条回答
  • 2020-12-08 09:56

    Based on @mrry 's answer, I think it would be better to create and use the following function, since there's also local variables, and other variables that are not in global variables (they are in different collections):

    def get_var_by_name(query_name, var_list):
        """
        Get Variable by name
    
        e.g.
        local_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES)
        the_var = get_var_by_name(local_vars, 'accuracy/total:0')
        """
        target_var = None
        for var in var_list:
            if var.name==query_name:
                target_var = var
                break
        return target_var
    
    0 讨论(0)
  • 2020-12-08 09:57

    If you want to get any stored variables from a model, usetf.train.load_variable("model_folder_name","Variable name")

    0 讨论(0)
  • 2020-12-08 10:12

    The get_variable() function creates a new variable or returns one created earlier by get_variable(). It won't return a variable created using tf.Variable(). Here's a quick example:

    >>> with tf.variable_scope("foo"):
    ...   bar1 = tf.get_variable("bar", (2,3)) # create
    ... 
    >>> with tf.variable_scope("foo", reuse=True):
    ...   bar2 = tf.get_variable("bar")  # reuse
    ... 
    
    >>> with tf.variable_scope("", reuse=True): # root variable scope
    ...   bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
    ... 
    >>> (bar1 is bar2) and (bar2 is bar3)
    True
    

    If you did not create the variable using tf.get_variable(), you have a couple options. First, you can use tf.global_variables() (as @mrry suggests):

    >>> bar1 = tf.Variable(0.0, name="bar")
    >>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]
    >>> bar1 is bar2
    True
    

    Or you can use tf.get_collection() like so:

    >>> bar1 = tf.Variable(0.0, name="bar")
    >>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]
    >>> bar1 is bar2
    True
    

    Edit

    You can also use get_tensor_by_name():

    >>> bar1 = tf.Variable(0.0, name="bar")
    >>> graph = tf.get_default_graph()
    >>> bar2 = graph.get_tensor_by_name("bar:0")
    >>> bar1 is bar2
    False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal 
    bar2 in value.
    

    Recall that a tensor is the output of an operation. It has the same name as the operation, plus :0. If the operation has multiple outputs, they have the same name as the operation plus :0, :1, :2, and so on.

    0 讨论(0)
  • 2020-12-08 10:20

    The easiest way to get a variable by name is to search for it in the tf.global_variables() collection:

    var_23 = [v for v in tf.global_variables() if v.name == "Variable_23:0"][0]
    

    This works well for ad hoc reuse of existing variables. A more structured approach—for when you want to share variables between multiple parts of a model—is covered in the Sharing Variables tutorial.

    0 讨论(0)
提交回复
热议问题