Loading pre-trained word2vec to initialise embedding_lookup in the Estimator model_fn

倖福魔咒の 提交于 2019-12-21 05:05:13

问题


I am solving a text classification problem. I defined my classifier using the Estimator class with my own model_fn. I would like to use Google's pre-trained word2vec embedding as initial values and then further optimise it for the task at hand.

I saw this post: Using a pre-trained word embedding (word2vec or Glove) in TensorFlow
which explains how to go about it in 'raw' TensorFlow code. However, I would really like to use the Estimator class.

As an extension, I would like to then train this code on Cloud ML Engine, is there a good way of passing in the fairly large file with initial values?

Let's say we have something like:

def build_model_fn():
    def _model_fn(features, labels, mode, params):
        input_layer = features['feat'] #shape=[-1, params["sequence_length"]]
        #... what goes here to initialize W

        embedded = tf.nn.embedding_lookup(W, input_layer)
        ...
        return predictions

estimator = tf.contrib.learn.Estimator(
    model_fn=build_model_fn(),
    model_dir=MODEL_DIR,
    params=params)
estimator.fit(input_fn=read_data, max_steps=2500)

回答1:


Embeddings are typically large enough that the only viable approach is using them to initialize a tf.Variable in your graph. This will allow you to take advantage of param servers in distributed, etc.

For this (and anything else), I would recommend you use the new "core" estimator, tf.estimator.Estimator as this will make things much easier.

From the answer in the link you provided, and knowing that we want a variable not a constant, we can either take approach:

(2) Initialize the variable using a feed dict, or (3) Load the variable from a checkpoint


I'll cover option (3) first since it's much easier, and better:

In your model_fn, simply initialize a variable using the Tensor returned by a tf.contrib.framework.load_variable call. This requires:

  1. That you have a valid TF checkpoint with your embeddings
  2. You know the fully qualified name of the embeddings variable within the checkpoint.

The code is pretty simple:

def model_fn(mode, features, labels, hparams):
  embeddings = tf.Variable(tf.contrib.framework.load_variable(
      'gs://my-bucket/word2vec_checkpoints/',
      'a/fully/qualified/scope/embeddings'
  ))
  ....
  return tf.estimator.EstimatorSpec(...)

However this approach won't work for you if your embeddings weren't produced by another TF model, hence option (2).


For (2), we need to use tf.train.Scaffold which is essentially a configuration object that holds all the options for starting a tf.Session (which estimator intentionally hides for lots of reasons).

You may specify a Scaffold in the tf.train.EstimatorSpec you return in your model_fn.

We create a placeholder in our model_fn, and make it the initializer operation for our embedding variable, then pass an init_feed_dict via the Scaffold. e.g.

def model_fn(mode, features, labels, hparams):
  embed_ph = tf.placeholder(
      shape=[hparams.vocab_size, hparams.embedding_size], 
      dtype=tf.float32)
  embeddings = tf.Variable(embed_ph)
  # Define your model
  return tf.estimator.EstimatorSpec(
      ..., # normal EstimatorSpec args
      scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array})
  )

What's happening here is the init_feed_dict will populate the values of the embed_ph placeholder at runtime, which will then allow the embeddings.initialization_op (assignment of the placeholder), to run.




来源:https://stackoverflow.com/questions/44680769/loading-pre-trained-word2vec-to-initialise-embedding-lookup-in-the-estimator-mod

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!