pyplot combine multiple line labels in legend

前端 未结 6 2101
鱼传尺愫
鱼传尺愫 2020-12-29 07:00

I have data that results in multiple lines being plotted, I want to give these lines a single label in my legend. I think this can be better demonstrated using the example b

相关标签:
6条回答
  • 2020-12-29 07:21

    I'd make a small helper function personally, if i planned on doing it often;

    from matplotlib import pyplot
    import numpy
    
    
    a = numpy.array([[ 3.57,  1.76,  7.42,  6.52],
                     [ 1.57,  1.2 ,  3.02,  6.88],
                     [ 2.23,  4.86,  5.12,  2.81],
                     [ 4.48,  1.38,  2.14,  0.86],
                     [ 6.68,  1.72,  8.56,  3.23]])
    
    
    def plotCollection(ax, xs, ys, *args, **kwargs):
    
      ax.plot(xs,ys, *args, **kwargs)
    
      if "label" in kwargs.keys():
    
        #remove duplicates
        handles, labels = pyplot.gca().get_legend_handles_labels()
        newLabels, newHandles = [], []
        for handle, label in zip(handles, labels):
          if label not in newLabels:
            newLabels.append(label)
            newHandles.append(handle)
    
        pyplot.legend(newHandles, newLabels)
    
    ax = pyplot.subplot(1,1,1)  
    plotCollection(ax, a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')
    plotCollection(ax, a[:,1::2].T, a[:, ::2].T, 'b', label='data_b')
    pyplot.show()
    

    An easier (and IMO clearer) way to remove duplicates (than what you have) from the handles and labels of the legend is this:

    handles, labels = pyplot.gca().get_legend_handles_labels()
    newLabels, newHandles = [], []
    for handle, label in zip(handles, labels):
      if label not in newLabels:
        newLabels.append(label)
        newHandles.append(handle)
    pyplot.legend(newHandles, newLabels)
    
    0 讨论(0)
  • 2020-12-29 07:33

    Matplotlib gives you a nice interface to collections of lines, LineCollection. The code is straight forward

    import numpy
    import matplotlib.pyplot as plt
    from matplotlib.collections import LineCollection
    
    a = numpy.array([[ 3.57,  1.76,  7.42,  6.52],
                     [ 1.57,  1.2 ,  3.02,  6.88],
                     [ 2.23,  4.86,  5.12,  2.81],
                     [ 4.48,  1.38,  2.14,  0.86],
                     [ 6.68,  1.72,  8.56,  3.23]])
    
    xs = a[:,::2]
    ys = a[:, 1::2]
    lines = LineCollection([list(zip(x,y)) for x,y in zip(xs, ys)], label='data_a')
    f, ax = plt.subplots(1, 1)
    ax.add_collection(lines)
    ax.legend()
    ax.set_xlim([xs.min(), xs.max()]) # have to set manually
    ax.set_ylim([ys.min(), ys.max()])
    plt.show()
    

    This results in the output below:

    0 讨论(0)
  • 2020-12-29 07:35

    Numpy solution based on will's response above.

    import numpy as np
    import matplotlib.pylab as plt
    a = np.array([[3.57, 1.76, 7.42, 6.52],
                  [1.57, 1.20, 3.02, 6.88],
                  [2.23, 4.86, 5.12, 2.81],
                  [4.48, 1.38, 2.14, 0.86],
                  [6.68, 1.72, 8.56, 3.23]])
    
    plt.plot(a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')
    handles, labels = plt.gca().get_legend_handles_labels()
    

    Assuming that equal labels have equal handles, get unique labels and their respective indices, which correspond to handle indices.

    labels, ids = np.unique(labels, return_index=True)
    handles = [handles[i] for i in ids]
    plt.legend(handles, labels, loc='best')
    plt.show()
    
    0 讨论(0)
  • 2020-12-29 07:37

    I would do this trick:

    for i in range(len(a)):
      plt.plot(a[i,::2].T, a[i, 1::2].T, 'r', label='data_a' if i==0 else None)
    
    0 讨论(0)
  • 2020-12-29 07:38

    So using will's suggestion and another question here, I am leaving my remedy here

    handles, labels = plt.gca().get_legend_handles_labels()
    i =1
    while i<len(labels):
        if labels[i] in labels[:i]:
            del(labels[i])
            del(handles[i])
        else:
            i +=1
    
    plt.legend(handles, labels)
    

    And the new plot looks like, modified multiple line plot legend

    0 讨论(0)
  • 2020-12-29 07:41

    A low tech solution is to make two plot calls. One that plots your data and a second one that plots nothing but carries the handle:

    a = np.array([[ 3.57,  1.76,  7.42,  6.52],
                  [ 1.57,  1.2 ,  3.02,  6.88],
                  [ 2.23,  4.86,  5.12,  2.81],
                  [ 4.48,  1.38,  2.14,  0.86],
                  [ 6.68,  1.72,  8.56,  3.23]])
    
    plt.plot(a[:,::2].T, a[:, 1::2].T, 'r')
    plt.plot([],[], 'r', label='data_a')
    
    plt.legend(loc='best')
    

    Here's the result:

    result

    0 讨论(0)
提交回复
热议问题