Multicolor scatter plot legend in Python

前端 未结 1 1034
谎友^
谎友^ 2020-12-21 14:05

I have some basic car engine size, horsepower and body type data (sample shown below)

         body-style  engine-size  horsepower
0   convertible          1         


        
相关标签:
1条回答
  • 2020-12-21 14:49

    In matplotlib, you can easily generate custom legends. In your example, just retrieve the color-label combinations from your dictionary and create custom patches for the legend:

    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D
    import matplotlib.patches as mpatches
    import pandas as pd
    
    #this part just recreates your dataset
    wtf =  pd.read_csv("test.csv", delim_whitespace=True)
    col_dict = {'convertible':'red' ,  'hatchback':'blue' , 'sedan':'purple' , 'wagon':'yellow' , 'hardtop':'green'}
    wtf["colour_column"] = wtf["body-style"].map(col_dict)
    wtf["comp_ratio_size"] = np.square(wtf["horsepower"] - wtf["engine-size"])
    
    fig = plt.figure(figsize=(8,8),dpi=75)
    ax = fig.gca()
    ax.scatter(wtf['engine-size'],wtf['horsepower'],c=wtf["colour_column"],s=wtf['comp_ratio_size'],alpha=0.4)
    ax.set_xlabel('horsepower')
    ax.set_ylabel("engine size")
    
    #retrieve values from color dictionary and attribute it to corresponding labels
    leg_el = [mpatches.Patch(facecolor = value, edgecolor = "black", label = key, alpha = 0.4) for key, value in col_dict.items()]
    ax.legend(handles = leg_el)
    
    plt.show()
    

    Output:

    0 讨论(0)
提交回复
热议问题