Soft actor critic with discrete action space

[亡魂溺海] 提交于 2021-02-08 10:27:22

问题


I'm trying to implement the soft actor critic algorithm for discrete action space and I have trouble with the loss function.

Here is the link from SAC with continuous action space: https://spinningup.openai.com/en/latest/algorithms/sac.html

I do not know what I'm doing wrong.

The problem is the network do not learn anything on the cartpole environment.

The full code on github: https://github.com/tk2232/sac_discrete/blob/master/sac_discrete.py

Here is my idea how to calculate the loss for discrete actions.

Value Network

class ValueNet:
    def __init__(self, sess, state_size, hidden_dim, name):
        self.sess = sess

        with tf.variable_scope(name):
            self.states = tf.placeholder(dtype=tf.float32, shape=[None, state_size], name='value_states')
            self.targets = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='value_targets')
            x = Dense(units=hidden_dim, activation='relu')(self.states)
            x = Dense(units=hidden_dim, activation='relu')(x)
            self.values = Dense(units=1, activation=None)(x)

            optimizer = tf.train.AdamOptimizer(0.001)

            loss = 0.5 * tf.reduce_mean((self.values - tf.stop_gradient(self.targets)) ** 2)
            self.train_op = optimizer.minimize(loss, var_list=_params(name))

    def get_value(self, s):
        return self.sess.run(self.values, feed_dict={self.states: s})

    def update(self, s, targets):
        self.sess.run(self.train_op, feed_dict={self.states: s, self.targets: targets})

In the Q_Network I'm gather the values with the collected actions

Example

q_out = [[0.5533, 0.4444], [0.2222, 0.6666]]
collected_actions = [0, 1]
gather = [[0.5533], [0.6666]]

gather function

def gather_tensor(params, idx):
    idx = tf.stack([tf.range(tf.shape(idx)[0]), idx[:, 0]], axis=-1)
    params = tf.gather_nd(params, idx)
    return params

Q Network

class SoftQNetwork:
    def __init__(self, sess, state_size, action_size, hidden_dim, name):
        self.sess = sess

        with tf.variable_scope(name):
            self.states = tf.placeholder(dtype=tf.float32, shape=[None, state_size], name='q_states')
            self.targets = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='q_targets')
            self.actions = tf.placeholder(dtype=tf.int32, shape=[None, 1], name='q_actions')

            x = Dense(units=hidden_dim, activation='relu')(self.states)
            x = Dense(units=hidden_dim, activation='relu')(x)
            x = Dense(units=action_size, activation=None)(x)
            self.q = tf.reshape(gather_tensor(x, self.actions), shape=(-1, 1))

            optimizer = tf.train.AdamOptimizer(0.001)

            loss = 0.5 * tf.reduce_mean((self.q - tf.stop_gradient(self.targets)) ** 2)
            self.train_op = optimizer.minimize(loss, var_list=_params(name))

    def update(self, s, a, target):
        self.sess.run(self.train_op, feed_dict={self.states: s, self.actions: a, self.targets: target})

    def get_q(self, s, a):
        return self.sess.run(self.q, feed_dict={self.states: s, self.actions: a})

Policy Net

class PolicyNet:
    def __init__(self, sess, state_size, action_size, hidden_dim):
        self.sess = sess

        with tf.variable_scope('policy_net'):
            self.states = tf.placeholder(dtype=tf.float32, shape=[None, state_size], name='policy_states')
            self.targets = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='policy_targets')
            self.actions = tf.placeholder(dtype=tf.int32, shape=[None, 1], name='policy_actions')

            x = Dense(units=hidden_dim, activation='relu')(self.states)
            x = Dense(units=hidden_dim, activation='relu')(x)
            self.logits = Dense(units=action_size, activation=None)(x)
            dist = Categorical(logits=self.logits)

            optimizer = tf.train.AdamOptimizer(0.001)

            # Get action
            self.new_action = dist.sample()
            self.new_log_prob = dist.log_prob(self.new_action)

            # Calc loss
            log_prob = dist.log_prob(tf.squeeze(self.actions))
            loss = tf.reduce_mean(tf.squeeze(self.targets) - 0.2 * log_prob)
            self.train_op = optimizer.minimize(loss, var_list=_params('policy_net'))

    def get_action(self, s):
        action = self.sess.run(self.new_action, feed_dict={self.states: s[np.newaxis, :]})
        return action[0]

    def get_next_action(self, s):
        next_action, next_log_prob = self.sess.run([self.new_action, self.new_log_prob], feed_dict={self.states: s})
        return next_action.reshape((-1, 1)), next_log_prob.reshape((-1, 1))

    def update(self, s, a, target):
        self.sess.run(self.train_op, feed_dict={self.states: s, self.actions: a, self.targets: target})

Update function

def soft_q_update(batch_size, frame_idx):
    gamma = 0.99
    alpha = 0.2

    state, action, reward, next_state, done = replay_buffer.sample(batch_size)
    action = action.reshape((-1, 1))
    reward = reward.reshape((-1, 1))
    done = done.reshape((-1, 1))

Q_target

v_ = value_net_target.get_value(next_state)
q_target = reward + (1 - done) * gamma * v_

V_target

next_action, next_log_prob = policy_net.get_next_action(state)
q1 = soft_q_net_1.get_q(state, next_action)
q2 = soft_q_net_2.get_q(state, next_action)
q = np.minimum(q1, q2)
v_target = q - alpha * next_log_prob

Policy_target

q1 = soft_q_net_1.get_q(state, action)
q2 = soft_q_net_2.get_q(state, action)
policy_target = np.minimum(q1, q2)

回答1:


Since the algorithm is generic to both discrete and continuous policy, the key idea is that we need a discrete distribution that is reparametrizable. Then the extension should involve minimal code modification from the continuous SAC --- by just changing the policy distribution class.

There is one such distribution — the GumbelSoftmax distribution. PyTorch does not have this built-in, so I simply extend it from a close cousin which has the right rsample() and add a correct log prob calculation method. With the ability to calculate a reparametrized action and its log prob, SAC works beautifully for discrete actions with minimal extra code, as seen below.

    def calc_log_prob_action(self, action_pd, reparam=False):
        '''Calculate log_probs and actions with option to reparametrize from paper eq. 11'''
        samples = action_pd.rsample() if reparam else action_pd.sample()
        if self.body.is_discrete:  # this is straightforward using GumbelSoftmax
            actions = samples
            log_probs = action_pd.log_prob(actions)
        else:
            mus = samples
            actions = self.scale_action(torch.tanh(mus))
            # paper Appendix C. Enforcing Action Bounds for continuous actions
            log_probs = (action_pd.log_prob(mus) - torch.log(1 - actions.pow(2) + 1e-6).sum(1))
        return log_probs, actions


# ... for discrete action, GumbelSoftmax distribution

class GumbelSoftmax(distributions.RelaxedOneHotCategorical):
    '''
    A differentiable Categorical distribution using reparametrization trick with Gumbel-Softmax
    Explanation http://amid.fish/assets/gumbel.html
    NOTE: use this in place PyTorch's RelaxedOneHotCategorical distribution since its log_prob is not working right (returns positive values)
    Papers:
    [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al, 2017)
    [2] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017)
    '''

    def sample(self, sample_shape=torch.Size()):
        '''Gumbel-softmax sampling. Note rsample is inherited from RelaxedOneHotCategorical'''
        u = torch.empty(self.logits.size(), device=self.logits.device, dtype=self.logits.dtype).uniform_(0, 1)
        noisy_logits = self.logits - torch.log(-torch.log(u))
        return torch.argmax(noisy_logits, dim=-1)

    def log_prob(self, value):
        '''value is one-hot or relaxed'''
        if value.shape != self.logits.shape:
            value = F.one_hot(value.long(), self.logits.shape[-1]).float()
            assert value.shape == self.logits.shape
        return - torch.sum(- value * F.log_softmax(self.logits, -1), -1)

And here's the LunarLander results. SAC learns to solve it really fast.

The full implementation code is in SLM Lab at https://github.com/kengz/SLM-Lab/blob/master/slm_lab/agent/algorithm/sac.py

The SAC benchmark results on Roboschool (continuous) and LunarLander (discrete) are shown here: https://github.com/kengz/SLM-Lab/pull/399




回答2:


Probably this repo may be helpful. Description says, that repo contains an implementation of SAC for discrete action space on PyTorch. There is file with SAC algorithm for continuous action space and file with SAC adapted for discrete action space.




回答3:


There is a paper about SAC with discrete action spaces. It says SAC for discrete action spaces doesn't need re-parametrization tricks like Gumbel softmax. Instead, SAC needs some modifications. please refer to the paper for more details.

Paper / Author's implementation (without codes for atari) / Reproduction (with codes for atari)

I hope it helps you.



来源:https://stackoverflow.com/questions/56226133/soft-actor-critic-with-discrete-action-space

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!