Matplotlib adding legend based on existing color series

前端 未结 2 1230
半阙折子戏
半阙折子戏 2020-12-01 19:54

I plotted some data using scatter plot and specified it as such:

plt.scatter(rna.data[\'x\'], rna.data[\'y\'], s=size,
                    c=rna.data[\'colors         


        
相关标签:
2条回答
  • 2020-12-01 20:47

    You can create the legend handles using an empty plot with the color based on the colormap and normalization of the scatter plot.

    import pandas as pd
    import numpy as np; np.random.seed(1)
    import matplotlib.pyplot as plt
    
    x = [np.random.normal(5,2, size=20), np.random.normal(10,1, size=20),
         np.random.normal(5,1, size=20), np.random.normal(10,1, size=20)]
    y = [np.random.normal(5,1, size=20), np.random.normal(5,1, size=20),
         np.random.normal(10,2, size=20), np.random.normal(10,2, size=20)]
    c = [np.ones(20)*(i+1) for i in range(4)]
    
    df = pd.DataFrame({"x":np.array(x).flatten(), 
                       "y":np.array(y).flatten(), 
                       "colors":np.array(c).flatten()})
    
    size=81
    sc = plt.scatter(df['x'], df['y'], s=size, c=df['colors'], edgecolors='none')
    
    lp = lambda i: plt.plot([],color=sc.cmap(sc.norm(i)), ms=np.sqrt(size), mec="none",
                            label="Feature {:g}".format(i), ls="", marker="o")[0]
    handles = [lp(i) for i in np.unique(df["colors"])]
    plt.legend(handles=handles)
    plt.show()
    

    Alternatively you may filter your dataframe by the values in the colors column, e.g. using groubpy, and plot one scatter plot for each feature:

    import pandas as pd
    import numpy as np; np.random.seed(1)
    import matplotlib.pyplot as plt
    
    x = [np.random.normal(5,2, size=20), np.random.normal(10,1, size=20),
         np.random.normal(5,1, size=20), np.random.normal(10,1, size=20)]
    y = [np.random.normal(5,1, size=20), np.random.normal(5,1, size=20),
         np.random.normal(10,2, size=20), np.random.normal(10,2, size=20)]
    c = [np.ones(20)*(i+1) for i in range(4)]
    
    df = pd.DataFrame({"x":np.array(x).flatten(), 
                       "y":np.array(y).flatten(), 
                       "colors":np.array(c).flatten()})
    
    size=81
    cmap = plt.cm.viridis
    norm = plt.Normalize(df['colors'].values.min(), df['colors'].values.max())
    
    for i, dff in df.groupby("colors"):
        plt.scatter(dff['x'], dff['y'], s=size, c=cmap(norm(dff['colors'])), 
                    edgecolors='none', label="Feature {:g}".format(i))
    
    plt.legend()
    plt.show()
    

    Both methods produce the same plot:

    0 讨论(0)
  • 2020-12-01 20:51

    Altair can be a great choice here.

    Continuous classes

    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    
    df = pd.DataFrame(40*np.random.randn(10, 3), columns=['A', 'B','C'])
    
    from altair import *
    Chart(df).mark_circle().encode(x='A',y='B', color='C').configure_cell(width=200, height=150)
    

    Discrete classes

    df = pd.DataFrame(10*np.random.randn(40, 2), columns=['A', 'B'])
    df['C'] = np.random.choice(['alpha','beta','gamma','delta'], size=40)
    
    from altair import *
    Chart(df).mark_circle().encode(x='A',y='B', color='C').configure_cell(width=200, height=150)
    

    0 讨论(0)
提交回复
热议问题