缘由
最近一直在看深度学习的代码,又一次看到了slim.arg_scope()的嵌套使用,具体代码如下:
with slim.arg_scope( [slim.conv2d, slim.separable_conv2d], weights_initializer=tf.truncated_normal_initializer( stddev=weights_initializer_stddev), activation_fn=activation_fn, normalizer_fn=slim.batch_norm if use_batch_norm else None): with slim.arg_scope([slim.batch_norm], **batch_norm_params): with slim.arg_scope( [slim.conv2d], weights_regularizer=slim.l2_regularizer(weight_decay)): with slim.arg_scope( [slim.separable_conv2d], weights_regularizer=depthwise_regularizer) as arg_sc: return arg_sc
由上述代码可以看到,第一层argscope有slim.conv2d参数,第三层也有这个参数,那么不同层的参数是如何相互补充,作用到之后的代码块中,就是这篇博文的出发点。
准备工作
我们先看一下arg_scope的函数声明:
@tf_contextlib.contextmanager def arg_scope(list_ops_or_scope, **kwargs):
有函数修饰符@tf_contextlib.contextmanager修饰arg_scope函数,我们先研究下这个函数修饰符。
@的作用
@之后一般接一个可调用对象(tf_contextlib.contextmanager),一起构成函数修饰符(装饰器),这个可调用对象将被修饰函数(arg_scope)作为参数,执行一系列辅助操作,我们来看一个demo:
import time def my_time(func): print(time.ctime()) return func() @my_time # 从这里可以看出@time 等价于 time(xxx()),但是这种写法你得考虑python代码的执行顺序 def xxx(): print(‘Hello world!‘) 运行结果: Wed Jul 26 23:01:21 2017 Hello world!
在这个例子中,xxx函数实现我们的主要功能,打印Hello world!,但我们想给xxx函数添加一些辅助操作,于是我们用函数修饰符@my_time,使xxx函数先打印时间。整个例子的执行流程为调用my_time可调用对象,它接受xxx函数作为参数,先打印时间,再执行xxx函数。
上下文管理器
既然arg_scope函数存在装饰器,那么我们应该了解一下装饰器提供了什么辅助功能,代码为:
import contextlib as _contextlib from tensorflow.python.util import tf_decorator def contextmanager(target): """A tf_decorator-aware wrapper for `contextlib.contextmanager`. Usage is identical to `contextlib.contextmanager`. Args: target: A callable to be wrapped in a contextmanager. Returns: A callable that can be used inside of a `with` statement. """ context_manager = _contextlib.contextmanager(target) return tf_decorator.make_decorator(target, context_manager, ‘contextmanager‘)
可以看到导入了contextlib库,这个库提供了contextmanager函数,这也是一个装饰器,它使被修饰的函数具有上下文管理器的功能。上下文管理器的功能是在我们执行一段代码块之前做一些准备工作,执行完代码块之后做一些收尾工作,同样先来看一个上下文管理器的例子:
import time class MyTimer(object): def __init__(self, verbose = False): self.verbose = verbose def __enter__(self): self.start = time.time() return self def __exit__(self, *unused): self.end = time.time() self.secs = self.end - self.start self.msecs = self.secs * 1000 if self.verbose: print "elapsed time: %f ms" %self.msecs
with MyTimer(True):
print(‘Hello world!‘)
类MyTimer中的__enter__和__exit__方法分别是准备工作和收尾工作。整个代码的执行过程为:先执行__enter__方法,__enter__方法中的返回值(这个例子中是self)可以用到代码块中,再执行语句块,这个例子中是print函数,最后执行__exit__方法,更多关于上下文管理器的内容可以看这,我的例子也是从那copy的。contextlib中实现上下文管理器稍有不同,一样来看个例子:
from contextlib import contextmanager @contextmanager def tag(name): print "<%s>" % name yield print "</%s>" % name >>> with tag("h1"): ... print "foo"
运行结果: <h1> foo </h1>
tag函数中yield之前的代码相当于__enter__方法,yield产生的生成器相当于__enter__方法的返回值,yield之后的代码相当于__exit__方法。
arg_scope方法
这里我把arg_scope方法中代码稍微做了一些精简,代码如下:
arg_scope = [{}]
@tf_contextlib.contextmanager def arg_scope(list_ops_or_scope, **kwargs):try: current_scope = current_arg_scope().copy() for op in list_ops_or_scope: key = arg_scope_func_key(op) # op的代号 if not has_arg_scope(op): # op是否用@slim.add_arg_scope修饰,这会在下一篇中介绍 raise ValueError(‘%s is not decorated with @add_arg_scope‘, _name_op(op)) if key in current_scope: current_kwargs = current_scope[key].copy() current_kwargs.update(kwargs) current_scope[key] = current_kwargs else: current_scope[key] = kwargs.copy() _get_arg_stack().append(current_scope) yield current_scope finally: _get_arg_stack().pop()
# demo
with slim.arg_scope( [slim.conv2d, slim.separable_conv2d], weights_initializer=tf.truncated_normal_initializer( stddev=weights_initializer_stddev), activation_fn=activation_fn, normalizer_fn=slim.batch_norm if use_batch_norm else None): with slim.arg_scope([slim.batch_norm], **batch_norm_params): with slim.arg_scope( [slim.conv2d], weights_regularizer=slim.l2_regularizer(weight_decay)): with slim.arg_scope( [slim.separable_conv2d], weights_regularizer=depthwise_regularizer) as arg_sc: return arg_sc
我们沿着demo一步步看,其中arg_scope是一个栈。先看第一层,current_arg_scope()函数返回栈中最后一个元素,此时是空字典{},由于字典为空,所以会把conv2d和separable_conv2d加入字典,此时栈为[{‘conv2d‘: kargs, ‘separable_conv2d‘: kargs}],然后执行接下来的代码块,即第二层with,finally中函数要在代码块执行完后再执行;第二层执行完后栈为[{‘conv2d‘: kargs, ‘separable_conv2d‘: kargs},{‘conv2d‘: kargs, ‘separable_conv2d‘: kargs, ‘batch_norm‘: batch_norm_params}],可以看到是将第一层的字典复制之后检查其中是否有与第二层相同的op,相同的op就把参数更新,不同的op就增加键值对,如这里的batch_norm。
回到我们开头提到的问题,不同层的参数是如何互相补充的?现在我们可以看到,参数存储在栈中,每叠加一层,就在原有参数基础上把新参数添加上去。
原文:https://www.cnblogs.com/zzy-tf/p/9356883.html