Tensorflow variable scope: reuse if variable exists

后端 未结 4 379
醉酒成梦
醉酒成梦 2020-12-02 15:17

I want a piece of code that creates a variable within a scope if it doesn\'t exist, and access the variable if it already exists. I need it to be the same c

相关标签:
4条回答
  • 2020-12-02 15:59

    New AUTO_REUSE option does the trick.

    From the tf.variable_scope API docs: if reuse=tf.AUTO_REUSE, we create variables if they do not exist, and return them otherwise.

    Basic example of sharing a variable AUTO_REUSE:

    def foo():
      with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
        v = tf.get_variable("v", [1])
      return v
    
    v1 = foo()  # Creates v.
    v2 = foo()  # Gets the same, existing v.
    assert v1 == v2
    
    0 讨论(0)
  • 2020-12-02 16:03

    Although using "try...except..." clause works, I think a more elegant and maintainable way would be separate the variable initialization process with the "reuse" process.

    def initialize_variable(scope_name, var_name, shape):
        with tf.variable_scope(scope_name) as scope:
            v = tf.get_variable(var_name, shape)
            scope.reuse_variable()
    
    def get_scope_variable(scope_name, var_name):
        with tf.variable_scope(scope_name, reuse=True):
            v = tf.get_variable(var_name)
        return v
    

    Since often we only need to initialize variable ones, but reuse/share it for many times, separating the two processes make the code cleaner. Also this way, we won't need to go through the "try" clause every time to check if the variable has been created already or not.

    0 讨论(0)
  • 2020-12-02 16:06

    We can write our abstraction over tf.varaible_scope than uses reuse=None on the first call and uses reuse=True on the consequent calls:

    def variable_scope(name_or_scope, *args, **kwargs):
      if isinstance(name_or_scope, str):
        scope_name = tf.get_variable_scope().name + '/' + name_or_scope
      elif isinstance(name_or_scope, tf.Variable):
        scope_name = name_or_scope.name
    
      if scope_name in variable_scope.scopes:
        kwargs['reuse'] = True
      else:
        variable_scope.scopes.add(scope_name)
    
      return tf.variable_scope(name_or_scope, *args, **kwargs)
    variable_scope.scopes = set()
    

    Usage:

    with variable_scope("foo"): #create the first time
        v = tf.get_variable("v", [1])
    
    with variable_scope("foo"): #reuse the second time
        v = tf.get_variable("v", [1])
    
    0 讨论(0)
  • 2020-12-02 16:12

    A ValueError is raised in get_variable() when creating a new variable and shape is not declared, or when violating reuse during variable creation. Therefore, you can try this:

    def get_scope_variable(scope_name, var, shape=None):
        with tf.variable_scope(scope_name) as scope:
            try:
                v = tf.get_variable(var, shape)
            except ValueError:
                scope.reuse_variables()
                v = tf.get_variable(var)
        return v
    
    v1 = get_scope_variable('foo', 'v', [1])
    v2 = get_scope_variable('foo', 'v')
    assert v1 == v2
    

    Note that the following also works:

    v1 = get_scope_variable('foo', 'v', [1])
    v2 = get_scope_variable('foo', 'v', [1])
    assert v1 == v2
    

    UPDATE. The new API supports auto-reusing now:

    def get_scope_variable(scope, var, shape=None):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            v = tf.get_variable(var, shape)
        return v
    
    0 讨论(0)
提交回复
热议问题