How to import word2vec into TensorFlow Seq2Seq model?

后端 未结 2 771
太阳男子
太阳男子 2021-02-11 05:23

I am playing with Tensorflow sequence to sequence translation model. I was wondering if I could import my own word2vec into this model? Rather than using its original \'dense re

2条回答
  •  抹茶落季
    2021-02-11 06:16

    The seq2seq embedding_* functions indeed create embedding matrices very similar to those from word2vec. They are a variable named sth like this:

    EMBEDDING_KEY = "embedding_attention_seq2seq/RNN/EmbeddingWrapper/embedding"

    Knowing this, you can just modify this variable. I mean -- get your word2vec vectors in some format, say a text file. Assuming you have your vocabulary in model.vocab you can then assign the read vectors in a way illustrated by the snippet below (it's just a snippet, you'll have to change it to make it work, but I hope it shows the idea).

       vectors_variable = [v for v in tf.trainable_variables()
                            if EMBEDDING_KEY in v.name]
        if len(vectors_variable) != 1:
          print("Word vector variable not found or too many.")
          sys.exit(1)
        vectors_variable = vectors_variable[0]
        vectors = vectors_variable.eval()
        print("Setting word vectors from %s" % FLAGS.word_vector_file)
        with gfile.GFile(FLAGS.word_vector_file, mode="r") as f:
          # Lines have format: dog 0.045123 -0.61323 0.413667 ...
          for line in f:
            line_parts = line.split()
            # The first part is the word.
            word = line_parts[0]
            if word in model.vocab:
              # Remaining parts are components of the vector.
              word_vector = np.array(map(float, line_parts[1:]))
              if len(word_vector) != vec_size:
                print("Warn: Word '%s', Expecting vector size %d, found %d"
                         % (word, vec_size, len(word_vector)))
              else:
                vectors[model.vocab[word]] = word_vector
        # Assign the modified vectors to the vectors_variable in the graph.
        session.run([vectors_variable.initializer],
                    {vectors_variable.initializer.inputs[1]: vectors})
    

提交回复
热议问题