问题
I have used keras to use pre-trained word embeddings but I am not quite sure how to do it on scikit-learn model.
I need to do this in sklearn as well because I am using vecstack
to ensemble both keras sequential model and sklearn model.
This is what I have done for keras model:
glove_dir = '/home/Documents/Glove'
embeddings_index = {}
f = open(os.path.join(glove_dir, 'glove.6B.200d.txt'), 'r', encoding='utf-8')
for line in f:
values = line.split()
word = values[0]
coefs = np.asarray(values[1:], dtype='float32')
embeddings_index[word] = coefs
f.close()
embedding_dim = 200
embedding_matrix = np.zeros((max_words, embedding_dim))
for word, i in word_index.items():
if i < max_words:
embedding_vector = embeddings_index.get(word)
if embedding_vector is not None:
embedding_matrix[i] = embedding_vector
model = Sequential()
model.add(Embedding(max_words, embedding_dim, input_length=maxlen))
.
.
model.layers[0].set_weights([embedding_matrix])
model.layers[0].trainable = False
model.compile(----)
model.fit(-----)
I am very new to scikit-learn, from what I have seen to make an model in sklearn you do:
lr = LogisticRegression()
lr.fit(X_train, y_train)
lr.predict(x_test)
So, my question is how do I use pre-trained Glove with this model? where do I pass the pre-trained glove embedding_matrix
Thank you very much and I really appreciate your help.
回答1:
You can simply use the Zeugma library.
You can install it with pip install zeugma
, then create and train your model with the following lines of code (assuming corpus_train
and corpus_test
are lists of strings):
from sklearn.linear_model import LogisticRegresion
from zeugma.embeddings import EmbeddingTransformer
glove = EmbeddingTransformer('glove')
x_train = glove.transform(corpus_train)
model = LogisticRegression()
model.fit(x_train, y_train)
x_test = glove.transform(corpus_test)
model.predict(x_test)
You can also use different pre-trained embeddings (complete list here) or train your own (see Zeugma's documentation for how to do this).
来源:https://stackoverflow.com/questions/55198750/using-pretrained-glove-word-embedding-with-scikit-learn