强化学习 ---baselines项目之 Atari游戏的网络结构解析

二次信任 提交于 2020-02-26 03:11:25

这个baselines项目设计的比较灵活,结构有点复杂。由于项目庞大,各个函数之间又是相互调用,有时候从一个函数追溯下去,可以追溯6,7层,每个函数的超参数又特别多,很容易把人搞晕。
      接下来只看DQN部分的源码,其他无关的先不看,沿着一条线分解它!接下来进行一个递归游戏,一层一层的深入探索,探索到尽头再返回,当然中途适当剪剪枝,跟网络图无关的部分先不访问!
     首先,我们找递归入口,在deepq下有个experiments,这下面全是实例,pong就是一个Atari游戏的实验。
在这里插入图片描述
以下是trian_pong的代码


1.

from baselines import deepq
from baselines import bench
from baselines import logger
from baselines.common.atari_wrappers import make_atari
import numpy as np
np.seterr(invalid='ignore')

def main():
    logger.configure()
    env = make_atari('PongNoFrameskip-v4')
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    model = deepq.learn(
        env,
        "conv_only",
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=True,
        lr=1e-4,
        total_timesteps=int(1e7),
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
    )
    deepq.learn()

    model.save('pong_model.pkl')
    env.close()

if __name__ == '__main__':
    main()

以上代码主要调用了deepq文件。并且仔细看下deepq.learn()方法,参数很多,我们先关注它的网络结构,明显卷积层conv和全连接层hiddens是属于网络结构部分。按住ctrl点击learn,会跳转到learn()方法位置。以下是learn()的超参数。别怕,看起来多,其实没想象的复杂,因为它的复杂是你想象不到的


2.

def learn(env,
          network,
          seed=None,
          lr=5e-4,
          total_timesteps=100000,
          buffer_size=50000,
          exploration_fraction=0.1,
          exploration_final_eps=0.02,
          train_freq=1,
          batch_size=32,
          print_freq=100,
          checkpoint_freq=10000,
          checkpoint_path=None,
          learning_starts=1000,
          gamma=1.0,
          target_network_update_freq=500,
          prioritized_replay=False,
          prioritized_replay_alpha=0.6,
          prioritized_replay_beta0=0.4,
          prioritized_replay_beta_iters=None,
          prioritized_replay_eps=1e-6,
          param_noise=False,
          callback=None,
          load_path=None,
          **network_kwargs
          ):

可是这里面没有出现参数convs等,那就说明这些参数是在 **network_kwargs(一个字典)里面经过仔细对比可知:

**network_kwargs里有{convs, hiddens,dueling}这3个
再看以下哪个函数用到了 **network_kwargs


3.

q_func = build_q_func(network, **network_kwargs)

只有这一条用到了,nice!点进去,发现新的天地


4.

def build_q_func(network, hiddens=[256], dueling=True, layer_norm=False, **network_kwargs):
    if isinstance(network, str):
        from baselines.common.models import get_network_builder
       # print('network:',network)
        network = get_network_builder(network)(**network_kwargs)
       # print('network:', network)
    def q_func_builder(input_placeholder, num_actions, scope, reuse=False):
        with tf.variable_scope(scope, reuse=reuse):
            latent = network(input_placeholder)
            if isinstance(latent, tuple):
                if latent[1] is not None:
                    raise NotImplementedError("DQN is not compatible with recurrent policies yet")
                latent = latent[0]

            latent = layers.flatten(latent)

            with tf.variable_scope("action_value"):
                action_out = latent
                for hidden in hiddens:
                    action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)
                    if layer_norm:
                        action_out = layers.layer_norm(action_out, center=True, scale=True)
                    action_out = tf.nn.relu(action_out)
                action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)

            if dueling:
                with tf.variable_scope("state_value"):
                    state_out = latent
                    for hidden in hiddens:
                        state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=None)
                        if layer_norm:
                            state_out = layers.layer_norm(state_out, center=True, scale=True)
                        state_out = tf.nn.relu(state_out)
                    state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None)
                action_scores_mean = tf.reduce_mean(action_scores, 1)
                action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1)
                q_out = state_score + action_scores_centered
            else:
                q_out = action_scores
            return q_out

    return q_func_builder

很接近了,这个明显就是在构建网络结构了, **network_kwargs里面使用了两个参数,还剩下‘conv’ 没用,看3可以知道它返回的是一个函数,也就是这里面的q_func_builder,看这个函数前,还要进入下一层探索,先看下开头的这行


5.

network = get_network_builder(network)(**network_kwargs)

从刚刚的传参我们知道network = “conv_only” ,是一个字符串。通过一个字符串,返回了一个函数,并且再把**network_kwargs(也就是卷积层的参数)参数传递给返回的这个函数, 目测最终返回一个卷积层的网络结构。先点进去看一下:


6.

def get_network_builder(name):

    if callable(name):  #函数用于检查一个对象是否是可调用的
        # print('name',name)
        return name
    elif name in mapping:
       # print('mapping',mapping)
        return mapping[name]

    else:
        raise ValueError('Unknown network type: {}'.format(name))

这段代码很简洁:先检测下name是否是可调用的参数,再检测是否在mapping字典里,然后返回函数名。显然这里的“conv_only”是从mapping里面去取。
我们把mappling输出来看一下


7.

print(mapping)
output:
mapping {'mlp': <function mlp at 0x000001C90316B5E8>, 'cnn': <function cnn at 0x000001C90316B678>, 'impala_cnn': <function impala_cnn at 0x000001C90316B708>, 'cnn_small': <function cnn_small at 0x000001C90316B798>, 'lstm': <function lstm at 0x000001C90316B828>, 'cnn_lstm': <function cnn_lstm at 0x000001C90316B8B8>, 'impala_cnn_lstm': <function impala_cnn_lstm at 0x000001C90316B948>, 'cnn_lnlstm': <function cnn_lnlstm at 0x000001C90316B9D8>, 'conv_only': <function conv_only at 0x000001C90316BA68>}
network: <function conv_only.<locals>.network_fn at 0x000001C9030AF678>

发现有那么多函数,那么这些函数是什么时候加入mapping的呢?
在该文件搜索下mapping,发现register()函数:


8.

mapping = {}

def register(name):
    def _thunk(func):
        mapping[name] = func
        return func
    return _thunk

就是简单的将name和函数名存入mapping
再搜索register,发现了很多register的注解


9.

@register("mlp")
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):...

@register("cnn")
def cnn(**conv_kwargs):...

@register("conv_only")
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):....

其实加上@register(“xxx”)后就会调用上面8函数直接加入mapping,这就正好跟mapping里的函数对应上了。注解是个好东西
到此为止,我也不知道跳了多少层了。只要脑袋里的线路清晰,就不会混乱。那么直接关注conv_only函数吧


10.

@register("conv_only")
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
    def network_fn(X):
        out = tf.cast(X, tf.float32) / 255.
        with tf.variable_scope("convnet"):
            for num_outputs, kernel_size, stride in convs:
                out = tf.contrib.layers.convolution2d(out,
                                           num_outputs=num_outputs,
                                           kernel_size=kernel_size,
                                           stride=stride,
                                           activation_fn=tf.nn.relu,
                                           **conv_kwargs)

        return out
    return network_fn

先看下整体结构,果然没错,跟我们在5的时候说的一样之前传入的**network_kwargs就变成这个函数的convs了。返回了一个network_fn函数 ,再传入 图像的输入参数----------就是这里的X参数。现在仔细看下这个函数!

  • convs是一个列表,列表里是一个元组。X是输入的图像参数,是一个shape=[batch,84,84,4]的Tensor。
  • 先量化成0-1的浮点数,再遍历convs里面的元素,搭建卷积网络
  • num_outputs是通道数,也就是卷积核个数, kernel_size卷积核大小, stride 是上下的跨度大小 这些是卷积网络里常用的
  • 输出out是一个【batch,size,size,64】的Tensor


    接下来返回4在来看
def build_q_func(network, hiddens=[256], dueling=True, layer_norm=False, **network_kwargs):
    if isinstance(network, str):
        from baselines.common.models import get_network_builder
       # print('network:',network)
        network = get_network_builder(network)(**network_kwargs)
       # print('network:', network)
    def q_func_builder(input_placeholder, num_actions, scope, reuse=False):
        with tf.variable_scope(scope, reuse=reuse):
            latent = network(input_placeholder)
            if isinstance(latent, tuple):
                if latent[1] is not None:
                    raise NotImplementedError("DQN is not compatible with recurrent policies yet")
                latent = latent[0]

            latent = layers.flatten(latent)

            with tf.variable_scope("action_value"):
                action_out = latent
                for hidden in hiddens:
                    action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)
                    if layer_norm:
                        action_out = layers.layer_norm(action_out, center=True, scale=True)
                    action_out = tf.nn.relu(action_out)
                action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)

            if dueling:
                with tf.variable_scope("state_value"):
                    state_out = latent
                    for hidden in hiddens:
                        state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=None)
                        if layer_norm:
                            state_out = layers.layer_norm(state_out, center=True, scale=True)
                        state_out = tf.nn.relu(state_out)
                    state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None)
                action_scores_mean = tf.reduce_mean(action_scores, 1)
                action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1)
                q_out = state_score + action_scores_centered
            else:
                q_out = action_scores
            return q_out

    return q_func_builder

根据上面的分析,在执行network = get_network_builder(network)(**network_kwargs)语句后,network相当于是10里的network_fn(X)函数

接着分析q_func_builder():

  • latent = network(input_placeholder) 首先将输入传给network,返回一个latent是【batch,,,64】的Tensor
  • 进行判断和错误提示后,latent = layers.flatten(latent)通过这个把四维的卷积层,拉伸成二维,准备做全连接层
  • action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=None)这个进行循环遍历做全连接
  • action_out = layers.layer_norm(action_out, center=True, scale=True)这个是把全连接层进行标准化,这个操作来自于这篇论文
  • action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)然后接上输出层
  • 下面是关于dueling优化的,原理就不想讲了,简单的说dueling网络就是在原来的网络上增加一路,这一路输出只有一个结点表示当前状态下的价值。这个价值要减去上面那一路的平均值,然后将两个输出合并成一个out
  • 最后还是把整个函数返回出去

那么到此为止,这个网络的构造已经分析完了。现在再回到3来看下

q_func = build_q_func(network, **network_kwargs)

它只需要调用这一条语句,就可以得到一个构建网络图的函数了。q_func只是一个函数,还未进行传参和网络图的构建。

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!