Update only part of the word embedding matrix in Tensorflow

前端 未结 2 867
小蘑菇
小蘑菇 2020-12-01 02:09

Assuming that I want to update a pre-trained word-embedding matrix during training, is there a way to update only a subset of the word embedding matrix?

I have looke

相关标签:
2条回答
  • 2020-12-01 03:02

    Since you just want to select the elements to be updated (and not to change the gradients), you can do as follows.

    Let indices_to_update be a boolean tensor that indicates the indices you wish to update, and entry_stop_gradients is defined in the link, Then:

    gather_emb = entry_stop_gradients(gather_emb, indices_to_update)
    

    (Source)

    0 讨论(0)
  • 2020-12-01 03:10

    TL;DR: The default implementation of opt.minimize(loss), TensorFlow will generate a sparse update for word_emb that modifies only the rows of word_emb that participated in the forward pass.

    The gradient of the tf.gather(word_emb, indices) op with respect to word_emb is a tf.IndexedSlices object (see the implementation for more details). This object represents a sparse tensor that is zero everywhere, except for the rows selected by indices. A call to opt.minimize(loss) calls AdamOptimizer._apply_sparse(word_emb_grad, word_emb), which makes a call to tf.scatter_sub(word_emb, ...)* that updates only the rows of word_emb that were selected by indices.

    If on the other hand you want to modify the tf.IndexedSlices that is returned by opt.compute_gradients(loss, word_emb), you can perform arbitrary TensorFlow operations on its indices and values properties, and create a new tf.IndexedSlices that can be passed to opt.apply_gradients([(word_emb, ...)]). For example, you could cap the gradients using MyCapper() (as in the example) using the following calls:

    grad, = opt.compute_gradients(loss, word_emb)
    train_op = opt.apply_gradients(
        [tf.IndexedSlices(MyCapper(grad.values), grad.indices)])
    

    Similarly, you could change the set of indices that will be modified by creating a new tf.IndexedSlices with a different indices.


    * In general, if you want to update only part of a variable in TensorFlow, you can use the tf.scatter_update(), tf.scatter_add(), or tf.scatter_sub() operators, which respectively set, add to (+=) or subtract from (-=) the value previously stored in a variable.

    0 讨论(0)
提交回复
热议问题