setting a legend matching the colours in pyplot.scatter

后端 未结 3 1563
迷失自我
迷失自我 2021-01-07 07:03

Suppose my data is organized in the following way:

x_values = [6.2, 3.6, 7.3, 3.2, 2.7]
y_values = [1.5, 3.2, 5.4, 3.1, 2.8]
colours = [1, 1, 0, 1, -1]
label         


        
相关标签:
3条回答
  • 2021-01-07 07:23

    Just a remark, not exactly answering the question:

    If use "seaborn" it would be EXACTLY ONE LINE:

    import seaborn as sns 
    x_values = [6.2, 3.6, 7.3, 3.2, 2.7]
    y_values = [1.5, 3.2, 5.4, 3.1, 2.8]
    #colors = [1, 1, 0, 1, -1]
    labels = ["a", "a", "b", "a", "c"]
    ax = sns.scatterplot(x=x_values, y=y_values, hue=labels)
    

    PS

    But the question is about matplotlib, so. We have answers above, also one might look at: https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/scatter_with_legend.html Subsection: "Automated legend creation".

    However I feel not easy to modify those examples to what you need.

    0 讨论(0)
  • 2021-01-07 07:24

    You can always make your own legend as follows:

    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    
    x_values = [6.2, 3.6, 7.3, 3.2, 2.7]
    y_values = [1.5, 3.2, 5.4, 3.1, 2.8]
    
    a = 'red'
    b = 'blue'
    c = 'yellow'
    
    colours = [a, a, b, a, c]
    labels = ["a", "a", "b", "a", "c"]
    
    axis = plt.gca()
    axis.scatter(x_values, y_values, c=colours)
    
    # Create a legend
    handles = [mpatches.Patch(color=colour, label=label) for label, colour in [('a', a), ('b', b), ('c', c)]]
    plt.legend(handles=handles, loc=2, frameon=True)
    
    plt.show()
    

    Which would look like:

    0 讨论(0)
  • 2021-01-07 07:26

    If you want to use a colormap you can create a legend entry for each unique entry in the colors list as shown below. This approach works well for any number of values. The legend handles are the markers of a plot, such that they match with the scatter points.

    import matplotlib.pyplot as plt
    
    x_values = [6.2, 3.6, 7.3, 3.2, 2.7]
    y_values = [1.5, 3.2, 5.4, 3.1, 2.8]
    colors = [1, 1, 0, 1, -1]
    labels = ["a", "a", "b", "a", "c"]
    clset = set(zip(colors, labels))
    
    ax = plt.gca()
    sc = ax.scatter(x_values, y_values, c=colors, cmap="brg")
    
    handles = [plt.plot([],color=sc.get_cmap()(sc.norm(c)),ls="", marker="o")[0] for c,l in clset ]
    labels = [l for c,l in clset]
    ax.legend(handles, labels)
    
    plt.show()
    

    0 讨论(0)
提交回复
热议问题