问题
I was implementing Sharpness Aware Minimization (SAM) using Tensorflow. The algorithm is simplified as follows
- Compute gradient using current weight W
- Compute ε according to the equation in the paper
- Compute gradient using the weights W + ε
- Update model using gradient from step 3
I have implement step 1 and 2 already, but having trouble implementing step 3 according to the code below
def train_step(self, data, rho=0.05, p=2, q=2):
if (1 / p) + (1 / q) != 1:
raise tf.python.framework.errors_impl.InvalidArgumentError('p, q must be specified so that 1/p + 1/q = 1')
x, y = data
# compute first backprop
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# compute neighborhoods (epsilon_hat) from first backprop
trainable_w_plus_epsilon_hat = [
w + (rho * tf.sign(loss) * (tf.pow(tf.abs(g), q-1) / tf.math.pow(tf.norm(g, ord=q), q / p)))
for w, g in zip(trainable_vars, gradients)
]
### HOW TO SET TRAINABLE WEIGHTS TO `w_plus_epsilon_hat`?
#
# TODO:
# 1. compute gradient using trainable weights from `trainable_w_plus_epsilon_hat`
# 2. update `trainable_vars` using gradient from step 1
#
#########################################################
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
Is there anyway to compute gradient using trainable weights from trainable_w_plus_epsilon_hat
?
来源:https://stackoverflow.com/questions/65381773/computing-gradient-of-the-model-with-modified-weights