Computing gradient of the model with modified weights

别等时光非礼了梦想. 提交于 2021-01-28 11:16:47

问题


I was implementing Sharpness Aware Minimization (SAM) using Tensorflow. The algorithm is simplified as follows

  1. Compute gradient using current weight W
  2. Compute ε according to the equation in the paper
  3. Compute gradient using the weights W + ε
  4. Update model using gradient from step 3

I have implement step 1 and 2 already, but having trouble implementing step 3 according to the code below

def train_step(self, data, rho=0.05, p=2, q=2):
    if (1 / p) + (1 / q) != 1:
        raise tf.python.framework.errors_impl.InvalidArgumentError('p, q must be specified so that 1/p + 1/q = 1')
    x, y = data
        
    # compute first backprop
    with tf.GradientTape() as tape:
        y_pred = self(x, training=True)
        loss = self.compiled_loss(y, y_pred)
    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)
        
    # compute neighborhoods (epsilon_hat) from first backprop
    trainable_w_plus_epsilon_hat = [
        w + (rho * tf.sign(loss) * (tf.pow(tf.abs(g), q-1) / tf.math.pow(tf.norm(g, ord=q), q / p)))
        for w, g in zip(trainable_vars, gradients)
    ]
        
    ### HOW TO SET TRAINABLE WEIGHTS TO `w_plus_epsilon_hat`?
    #
    # TODO:
    #     1. compute gradient using trainable weights from `trainable_w_plus_epsilon_hat`
    #     2. update `trainable_vars` using gradient from step 1
    #
    #########################################################

    self.compiled_metrics.update_state(y, y_pred)
    return {m.name: m.result() for m in self.metrics}

Is there anyway to compute gradient using trainable weights from trainable_w_plus_epsilon_hat?

来源:https://stackoverflow.com/questions/65381773/computing-gradient-of-the-model-with-modified-weights

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!