这个baselines项目设计的比较灵活,结构有点复杂。由于项目庞大,各个函数之间又是相互调用,有时候从一个函数追溯下去,可以追溯6,7层,每个函数的超参数又特别多,很容易把人搞晕。
接下来只看DQN部分的源码,其他无关的先不看,沿着一条线分解它!接下来进行一个递归游戏,一层一层的深入探索,探索到尽头再返回,当然中途适当剪剪枝,跟网络图无关的部分先不访问!
首先,我们找递归入口,在deepq下有个experiments,这下面全是实例,pong就是一个Atari游戏的实验。
以下是trian_pong的代码
1.
from baselines import deepq
from baselines import bench
from baselines import logger
from baselines.common.atari_wrappers import make_atari
import numpy as np
np.seterr(invalid='ignore')
def main():
logger.configure()
env = make_atari('PongNoFrameskip-v4')
env = bench.Monitor(env, logger.get_dir())
env = deepq.wrap_atari_dqn(env)
model = deepq.learn(
env,
"conv_only",
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256],
dueling=True,
lr=1e-4,
total_timesteps=int(1e7),
buffer_size=10000,
exploration_fraction=0.1,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
)
deepq.learn()
model.save('pong_model.pkl')
env.close()
if __name__ == '__main__':
main()
以上代码主要调用了deepq文件。并且仔细看下deepq.learn()方法,参数很多,我们先关注它的网络结构,明显卷积层conv和全连接层hiddens是属于网络结构部分。按住ctrl点击learn,会跳转到learn()方法位置。以下是learn()的超参数。别怕,看起来多,其实没想象的复杂,因为它的复杂是你想象不到的
2.
def learn(env,
network,
seed=None,
lr=5e-4,
total_timesteps=100000,
buffer_size=50000,
exploration_fraction=0.1,
exploration_final_eps=0.02,
train_freq=1,
batch_size=32,
print_freq=100,
checkpoint_freq=10000,
checkpoint_path=None,
learning_starts=1000,
gamma=1.0,
target_network_update_freq=500,
prioritized_replay=False,
prioritized_replay_alpha=0.6,
prioritized_replay_beta0=0.4,
prioritized_replay_beta_iters=None,
prioritized_replay_eps=1e-6,
param_noise=False,
callback=None,
load_path=None,
**network_kwargs
):
可是这里面没有出现参数convs等,那就说明这些参数是在 **network_kwargs(一个字典)里面经过仔细对比可知:
**network_kwargs里有{convs, hiddens,dueling}这3个
再看以下哪个函数用到了 **network_kwargs
3.
q_func = build_q_func(network, **network_kwargs)
只有这一条用到了,nice!点进去,发现新的天地
4.
def build_q_func(network, hiddens=[256], dueling=True, layer_norm=False, **network_kwargs):
if isinstance(network, str):
from baselines.common.models import get_network_builder
# print('network:',network)
network = get_network_builder(network)(**network_kwargs)
# print('network:', network)
def q_func_builder(input_placeholder, num_actions, scope, reuse=False):
with tf.variable_scope(scope, reuse=reuse):
latent = network(input_placeholder)
if isinstance(latent, tuple):
if latent[1] is not None:
raise NotImplementedError("DQN is not compatible with recurrent policies yet")
latent = latent[0]
latent = layers.flatten(latent)
with tf.variable_scope("action_value"):
action_out = latent
for hidden in hiddens:
action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)
if layer_norm:
action_out = layers.layer_norm(action_out, center=True, scale=True)
action_out = tf.nn.relu(action_out)
action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)
if dueling:
with tf.variable_scope("state_value"):
state_out = latent
for hidden in hiddens:
state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=None)
if layer_norm:
state_out = layers.layer_norm(state_out, center=True, scale=True)
state_out = tf.nn.relu(state_out)
state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None)
action_scores_mean = tf.reduce_mean(action_scores, 1)
action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1)
q_out = state_score + action_scores_centered
else:
q_out = action_scores
return q_out
return q_func_builder
很接近了,这个明显就是在构建网络结构了, **network_kwargs里面使用了两个参数,还剩下‘conv’ 没用,看3可以知道它返回的是一个函数,也就是这里面的q_func_builder,看这个函数前,还要进入下一层探索,先看下开头的这行
5.
network = get_network_builder(network)(**network_kwargs)
从刚刚的传参我们知道network = “conv_only” ,是一个字符串。通过一个字符串,返回了一个函数,并且再把**network_kwargs(也就是卷积层的参数)参数传递给返回的这个函数, 目测最终返回一个卷积层的网络结构。先点进去看一下:
6.
def get_network_builder(name):
if callable(name): #函数用于检查一个对象是否是可调用的
# print('name',name)
return name
elif name in mapping:
# print('mapping',mapping)
return mapping[name]
else:
raise ValueError('Unknown network type: {}'.format(name))
这段代码很简洁:先检测下name是否是可调用的参数,再检测是否在mapping字典里,然后返回函数名。显然这里的“conv_only”是从mapping里面去取。
我们把mappling输出来看一下
7.
print(mapping)
output:
mapping {'mlp': <function mlp at 0x000001C90316B5E8>, 'cnn': <function cnn at 0x000001C90316B678>, 'impala_cnn': <function impala_cnn at 0x000001C90316B708>, 'cnn_small': <function cnn_small at 0x000001C90316B798>, 'lstm': <function lstm at 0x000001C90316B828>, 'cnn_lstm': <function cnn_lstm at 0x000001C90316B8B8>, 'impala_cnn_lstm': <function impala_cnn_lstm at 0x000001C90316B948>, 'cnn_lnlstm': <function cnn_lnlstm at 0x000001C90316B9D8>, 'conv_only': <function conv_only at 0x000001C90316BA68>}
network: <function conv_only.<locals>.network_fn at 0x000001C9030AF678>
发现有那么多函数,那么这些函数是什么时候加入mapping的呢?
在该文件搜索下mapping,发现register()函数:
8.
mapping = {}
def register(name):
def _thunk(func):
mapping[name] = func
return func
return _thunk
就是简单的将name和函数名存入mapping
再搜索register,发现了很多register的注解
9.
@register("mlp")
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):...
@register("cnn")
def cnn(**conv_kwargs):...
@register("conv_only")
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):....
其实加上@register(“xxx”)后就会调用上面8函数直接加入mapping,这就正好跟mapping里的函数对应上了。注解是个好东西
到此为止,我也不知道跳了多少层了。只要脑袋里的线路清晰,就不会混乱。那么直接关注conv_only函数吧
10.
@register("conv_only")
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
def network_fn(X):
out = tf.cast(X, tf.float32) / 255.
with tf.variable_scope("convnet"):
for num_outputs, kernel_size, stride in convs:
out = tf.contrib.layers.convolution2d(out,
num_outputs=num_outputs,
kernel_size=kernel_size,
stride=stride,
activation_fn=tf.nn.relu,
**conv_kwargs)
return out
return network_fn
先看下整体结构,果然没错,跟我们在5的时候说的一样之前传入的**network_kwargs就变成这个函数的convs了。返回了一个network_fn函数 ,再传入 图像的输入参数----------就是这里的X参数。现在仔细看下这个函数!
- convs是一个列表,列表里是一个元组。X是输入的图像参数,是一个shape=[batch,84,84,4]的Tensor。
- 先量化成0-1的浮点数,再遍历convs里面的元素,搭建卷积网络
- num_outputs是通道数,也就是卷积核个数, kernel_size卷积核大小, stride 是上下的跨度大小 这些是卷积网络里常用的
- 输出out是一个【batch,size,size,64】的Tensor
接下来返回4在来看
def build_q_func(network, hiddens=[256], dueling=True, layer_norm=False, **network_kwargs):
if isinstance(network, str):
from baselines.common.models import get_network_builder
# print('network:',network)
network = get_network_builder(network)(**network_kwargs)
# print('network:', network)
def q_func_builder(input_placeholder, num_actions, scope, reuse=False):
with tf.variable_scope(scope, reuse=reuse):
latent = network(input_placeholder)
if isinstance(latent, tuple):
if latent[1] is not None:
raise NotImplementedError("DQN is not compatible with recurrent policies yet")
latent = latent[0]
latent = layers.flatten(latent)
with tf.variable_scope("action_value"):
action_out = latent
for hidden in hiddens:
action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)
if layer_norm:
action_out = layers.layer_norm(action_out, center=True, scale=True)
action_out = tf.nn.relu(action_out)
action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)
if dueling:
with tf.variable_scope("state_value"):
state_out = latent
for hidden in hiddens:
state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=None)
if layer_norm:
state_out = layers.layer_norm(state_out, center=True, scale=True)
state_out = tf.nn.relu(state_out)
state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None)
action_scores_mean = tf.reduce_mean(action_scores, 1)
action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1)
q_out = state_score + action_scores_centered
else:
q_out = action_scores
return q_out
return q_func_builder
根据上面的分析,在执行network = get_network_builder(network)(**network_kwargs)
语句后,network相当于是10里的network_fn(X)函数
接着分析q_func_builder():
latent = network(input_placeholder)
首先将输入传给network,返回一个latent是【batch,,,64】的Tensor- 进行判断和错误提示后,
latent = layers.flatten(latent)
通过这个把四维的卷积层,拉伸成二维,准备做全连接层 action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)
这个进行循环遍历做全连接action_out = layers.layer_norm(action_out, center=True, scale=True)
这个是把全连接层进行标准化,这个操作来自于这篇论文action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)
然后接上输出层- 下面是关于dueling优化的,原理就不想讲了,简单的说dueling网络就是在原来的网络上增加一路,这一路输出只有一个结点表示当前状态下的价值。这个价值要减去上面那一路的平均值,然后将两个输出合并成一个out
- 最后还是把整个函数返回出去
那么到此为止,这个网络的构造已经分析完了。现在再回到3来看下
q_func = build_q_func(network, **network_kwargs)
它只需要调用这一条语句,就可以得到一个构建网络图的函数了。q_func只是一个函数,还未进行传参和网络图的构建。
来源:CSDN
作者:橘子JUZI
链接:https://blog.csdn.net/qq_41832757/article/details/104439740