I want a piece of code that creates a variable within a scope if it doesn\'t exist, and access the variable if it already exists. I need it to be the same c
New AUTO_REUSE option does the trick.
From the tf.variable_scope API docs: if reuse=tf.AUTO_REUSE
, we create variables if they do not exist, and return them otherwise.
Basic example of sharing a variable AUTO_REUSE:
def foo():
with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
v = tf.get_variable("v", [1])
return v
v1 = foo() # Creates v.
v2 = foo() # Gets the same, existing v.
assert v1 == v2
Although using "try...except..." clause works, I think a more elegant and maintainable way would be separate the variable initialization process with the "reuse" process.
def initialize_variable(scope_name, var_name, shape):
with tf.variable_scope(scope_name) as scope:
v = tf.get_variable(var_name, shape)
scope.reuse_variable()
def get_scope_variable(scope_name, var_name):
with tf.variable_scope(scope_name, reuse=True):
v = tf.get_variable(var_name)
return v
Since often we only need to initialize variable ones, but reuse/share it for many times, separating the two processes make the code cleaner. Also this way, we won't need to go through the "try" clause every time to check if the variable has been created already or not.
We can write our abstraction over tf.varaible_scope
than uses reuse=None
on the first call and uses reuse=True
on the consequent calls:
def variable_scope(name_or_scope, *args, **kwargs):
if isinstance(name_or_scope, str):
scope_name = tf.get_variable_scope().name + '/' + name_or_scope
elif isinstance(name_or_scope, tf.Variable):
scope_name = name_or_scope.name
if scope_name in variable_scope.scopes:
kwargs['reuse'] = True
else:
variable_scope.scopes.add(scope_name)
return tf.variable_scope(name_or_scope, *args, **kwargs)
variable_scope.scopes = set()
Usage:
with variable_scope("foo"): #create the first time
v = tf.get_variable("v", [1])
with variable_scope("foo"): #reuse the second time
v = tf.get_variable("v", [1])
A ValueError
is raised in get_variable()
when creating a new variable and shape is not declared, or when violating reuse during variable creation. Therefore, you can try this:
def get_scope_variable(scope_name, var, shape=None):
with tf.variable_scope(scope_name) as scope:
try:
v = tf.get_variable(var, shape)
except ValueError:
scope.reuse_variables()
v = tf.get_variable(var)
return v
v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v')
assert v1 == v2
Note that the following also works:
v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v', [1])
assert v1 == v2
UPDATE. The new API supports auto-reusing now:
def get_scope_variable(scope, var, shape=None):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
v = tf.get_variable(var, shape)
return v