Add legend to scatter plot (PCA)

依然范特西╮ 提交于 2019-11-26 04:01:36

问题


I am a newbie with python and found this excellent PCA biplot suggestion (Plot PCA loadings and loading in biplot in sklearn (like R's autoplot)). Now I tried to add a legend to the plot for the different targets. But the command plt.legend() doesn\'t work.

Is there an easy way to do it? As an example, the iris data with the biplot code from the link above.

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.decomposition import PCA
import pandas as pd
from sklearn.preprocessing import StandardScaler

iris = datasets.load_iris()
X = iris.data
y = iris.target
#In general a good idea is to scale the data
scaler = StandardScaler()
scaler.fit(X)
X=scaler.transform(X)    

pca = PCA()
x_new = pca.fit_transform(X)

def myplot(score,coeff,labels=None):
    xs = score[:,0]
    ys = score[:,1]
    n = coeff.shape[0]
    scalex = 1.0/(xs.max() - xs.min())
    scaley = 1.0/(ys.max() - ys.min())
    plt.scatter(xs * scalex,ys * scaley, c = y)
    for i in range(n):
        plt.arrow(0, 0, coeff[i,0], coeff[i,1],color = \'r\',alpha = 0.5)
        if labels is None:
            plt.text(coeff[i,0]* 1.15, coeff[i,1] * 1.15, \"Var\"+str(i+1), color = \'g\', ha = \'center\', va = \'center\')
        else:
            plt.text(coeff[i,0]* 1.15, coeff[i,1] * 1.15, labels[i], color = \'g\', ha = \'center\', va = \'center\')
plt.xlim(-1,1)
plt.ylim(-1,1)
plt.xlabel(\"PC{}\".format(1))
plt.ylabel(\"PC{}\".format(2))
plt.grid()

#Call the function. Use only the 2 PCs.
myplot(x_new[:,0:2],np.transpose(pca.components_[0:2, :]))
plt.show()

Any suggestions for PCA biplots are welcome! Also other codes, if the adding of the legend is easier in another way!


回答1:


I recently proposed an easy way to add a legend to a scatter, see GitHub PR. This is still being discussed.

In the meantime you need to manually create your legend from the unique labels in y. For each of them you'd create a Line2D object with the same marker as is used in the scatter plot and supply them as argument to plt.legend.

scatter = plt.scatter(xs * scalex,ys * scaley, c = y)
labels = np.unique(y)
handles = [plt.Line2D([],[],marker="o", ls="", 
                      color=scatter.cmap(scatter.norm(yi))) for yi in labels]
plt.legend(handles, labels)



来源:https://stackoverflow.com/questions/50654620/add-legend-to-scatter-plot-pca

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!