Visualise word2vec generated from gensim

后端 未结 2 1982
清歌不尽
清歌不尽 2021-01-31 10:54

I have trained a doc2vec and corresponding word2vec on my own corpus using gensim. I want to visualise the word2vec using t-sne with the words. As in, each dot in the figure has

相关标签:
2条回答
  • 2021-01-31 11:10

    With the following, you can convert your model to a TSV and then use this page for visualization.

    with open(self.word_tensors_TSV, 'bw') as file_vector, open(self.word_meta_TSV, 'bw') as file_metadata:
        for word in model.wv.vocab:
            file_metadata.write((word + '\n').encode('utf-8', errors='replace'))
            vector_row = '\t'.join(str(x) for x in model[word])
            file_vector.write((vector_row + '\n').encode('utf-8', errors='replace'))
    

    :)

    0 讨论(0)
  • 2021-01-31 11:11

    Two parts to the answer: how to get the word labels, and how to plot the labels on a scatterplot.

    Word labels in gensim's word2vec

    model.wv.vocab is a dict of {word: object of numeric vector}. To load the data into X for t-SNE, I made one change.

    vocab = list(model.wv.vocab)
    X = model[vocab]
    

    This accomplishes two things: (1) it gets you a standalone vocab list for the final dataframe to plot, and (2) when you index model, you can be sure that you know the order of the words.

    Proceed as before with

    tsne = TSNE(n_components=2)
    X_tsne = tsne.fit_transform(X)
    

    Now let's put X_tsne together with the vocab list. This is easy with pandas, so import pandas as pd if you don't have that yet.

    df = pd.DataFrame(X_tsne, index=vocab, columns=['x', 'y'])
    

    The vocab words are the indices of the dataframe now.

    I don't have your dataset, but in the other SO you mentioned, an example df that uses sklearn's newsgroups would look something like

                            x             y
    politics    -1.524653e+20 -1.113538e+20
    worry        2.065890e+19  1.403432e+20
    mu          -1.333273e+21 -5.648459e+20
    format      -4.780181e+19  2.397271e+19
    recommended  8.694375e+20  1.358602e+21
    arguing     -4.903531e+19  4.734511e+20
    or          -3.658189e+19 -1.088200e+20
    above        1.126082e+19 -4.933230e+19
    

    Scatterplot

    I like the object-oriented approach to matplotlib, so this starts out a little different.

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    
    ax.scatter(df['x'], df['y'])
    

    Lastly, the annotate method will label coordinates. The first two arguments are the text label and the 2-tuple. Using iterrows(), this can be very succinct:

    for word, pos in df.iterrows():
        ax.annotate(word, pos)
    

    [Thanks to Ricardo in the comments for this suggestion.]

    Then do plt.show() or fig.savefig(). Depending on your data, you'll probably have to mess with ax.set_xlim and ax.set_ylim to see into a dense cloud. This is the newsgroup example without any tweaking:

    You can modify dot size, color, etc., too. Happy fine-tuning!

    0 讨论(0)
提交回复
热议问题