Why doesn't the color of the points in a scatter plot match the color of the points in the corresponding legend?

前端 未结 1 962
旧巷少年郎
旧巷少年郎 2020-12-07 04:11

I have a sample scatterplot via matplotlib via the code below.

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0, 100, 501)
y = np.sin(x)         


        
相关标签:
1条回答
  • 2020-12-07 04:53

    A colorbar can be achieved via plt.colorbar(). This would allow to directly see the values corresponding to the colors.

    Having the points in the legend show different colors is of course also nice, although it would not allow to give any quantitative information.

    Unfortunately matplotlib does not provide any inbuilt way to achieve this. So one way would be to subclass the legend handler used to create the legend handle and implement this feature.

    Here we create a ScatterHandler with a custom create_collection method, in which we create the desired PathCollection and use this by specifying it in the legend_map dictionary of the legend.

    handler_map={ type(sc) : ScatterHandler()}
    

    The following code seems a bit complicated at first sight, however you may simply copy the class without understanding it completely and use it in your code.

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.legend_handler import HandlerRegularPolyCollection
    
    class ScatterHandler(HandlerRegularPolyCollection):
        def update_prop(self, legend_handle, orig_handle, legend):
            legend._set_artist_props(legend_handle)
            legend_handle.set_clip_box(None)
            legend_handle.set_clip_path(None)
    
        def create_collection(self, orig_handle, sizes, offsets, transOffset):
            p = type(orig_handle)([orig_handle.get_paths()[0]],
                                  sizes=sizes, offsets=offsets,
                                  transOffset=transOffset,
                                  cmap=orig_handle.get_cmap(),
                                  norm=orig_handle.norm )
    
            a = orig_handle.get_array()
            if type(a) != type(None):
                p.set_array(np.linspace(a.min(),a.max(),len(offsets)))
            else:
                self._update_prop(p, orig_handle)
            return p
    
    
    x = np.linspace(0, 100, 501)
    y = np.sin(x)*np.cos(x/50.)
    
    sc = plt.scatter(x, y, cmap='plasma', c=x, label='xy data sample')
    
    legend_dict = dict(ncol=1, loc='best', scatterpoints=4, fancybox=True, shadow=True)
    plt.legend(handler_map={type(sc) : ScatterHandler()}, **legend_dict)
    
    plt.show()
    

    0 讨论(0)
自定义标题
段落格式
字体
字号
代码语言
提交回复
热议问题