How to put multiple colormap patches in a matplotlib legend?

后端 未结 1 567
深忆病人
深忆病人 2021-01-20 22:59

Situation at hand:

I have multiple groups of lines, where the lines within the same group vary according to some group specific parameter. I assign each

相关标签:
1条回答
  • 2021-01-20 23:30

    I adapted the solution of the answer by ImportanceOfBeingErnest to "Create a matplotlib mpatches with a rectangle bi-colored for figure legend" to this case. As linked there, the instructions in the section on Implementing a custom legend handler in the matplotlib legend guide were particularly helpful.

    Result:

    stackoverflow answer: How to put multiple colormap patches in a matplotlib legend?

    Solution:

    I created the class HandlerColormap derived from matplotlib's base class for legend handlers HandlerBase. HandlerColormap takes a colormap and a number of stripes as arguments.

    For the argument cmap a matplotlib colormap instance should be given.

    The argument num_stripes determines how (non-)continuous the color gradient in the legend patch will be.

    As instructed in HandlerBase I override its create_artist method using the given dimension such that the code should be (automatically) scaled by fontsize. In this new create_artist method I create multiple stripes (slim matplotlib Rectangles) colored according to the input colormap.

    Code:

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle
    from matplotlib.legend_handler import HandlerBase
    
    class HandlerColormap(HandlerBase):
        def __init__(self, cmap, num_stripes=8, **kw):
            HandlerBase.__init__(self, **kw)
            self.cmap = cmap
            self.num_stripes = num_stripes
        def create_artists(self, legend, orig_handle, 
                           xdescent, ydescent, width, height, fontsize, trans):
            stripes = []
            for i in range(self.num_stripes):
                s = Rectangle([xdescent + i * width / self.num_stripes, ydescent], 
                              width / self.num_stripes, 
                              height, 
                              fc=self.cmap((2 * i + 1) / (2 * self.num_stripes)), 
                              transform=trans)
                stripes.append(s)
            return stripes
    
    x_array = np.linspace(1, 10, 10)
    y_array = x_array
    param_max = x_array.size
    cmaps = [plt.cm.spring, plt.cm.winter]  # set of colormaps 
                                            # (as many as there are groups of lines)
    plt.figure()
    for param, (x, y) in enumerate(zip(x_array, y_array)):  
        x_line1 = np.linspace(x, 1.5 * x, 10)
        y_line1 = np.linspace(y**2, y**2 - x, 10)
        x_line2 = np.linspace(1.2 * x, 1.5 * x, 10)
        y_line2 = np.linspace(2 * y, 2 * y - x, 10)
        # plot lines with color depending on param using different colormaps:
        plt.plot(x_line1, y_line1, c=cmaps[0](param / param_max))
        plt.plot(x_line2, y_line2, c=cmaps[1](param / param_max))
    
    cmap_labels = ["parameter 1 $\in$ [0, 10]", "parameter 2 $\in$ [-1, 1]"]
    # create proxy artists as handles:
    cmap_handles = [Rectangle((0, 0), 1, 1) for _ in cmaps]
    handler_map = dict(zip(cmap_handles, 
                           [HandlerColormap(cm, num_stripes=8) for cm in cmaps]))
    plt.legend(handles=cmap_handles, 
               labels=cmap_labels, 
               handler_map=handler_map, 
               fontsize=12)
    plt.show()
    
    0 讨论(0)
提交回复
热议问题