Getting Colorbar instance of scatter plot in pandas/matplotlib

后端 未结 2 1874
盖世英雄少女心
盖世英雄少女心 2020-12-06 02:40

How do I get the internally created colorbar instance of a plot created by pandas.DataFrame.plot?

Here is an example for generating a colored scatter plot:



        
相关标签:
2条回答
  • 2020-12-06 03:19

    pandas does not return the axis for the colorbar, therefore we have to locate it:

    1st, let's get the figure instance: i.e., use plt.gcf()

    In [61]:
    
    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    import itertools as it
    
    # [ (0,0), (0,1), ..., (9,9) ]
    xy_positions = list( it.product( range(10), range(10) ) )
    
    df = pd.DataFrame( xy_positions, columns=['x','y'] )
    
    # draw 100 floats
    df['score'] = np.random.random( 100 )
    
    ax = df.plot( kind='scatter',
                  x='x',
                  y='y',
                  c='score',
                  s=500)
    ax.set_xlim( [-0.5,9.5] )
    ax.set_ylim( [-0.5,9.5] )
    
    f = plt.gcf()
    

    2, how many axes does this figure have?

    In [62]:
    
    f.get_axes()
    Out[62]:
    [<matplotlib.axes._subplots.AxesSubplot at 0x120a4d450>,
     <matplotlib.axes._subplots.AxesSubplot at 0x120ad0050>]
    

    3, The first axes (that is, the first one created), contains the plot

    In [63]:
    
    ax
    Out[63]:
    <matplotlib.axes._subplots.AxesSubplot at 0x120a4d450>
    

    4, Therefore, the second axis is the colorbar axes

    In [64]:
    
    cax = f.get_axes()[1]
    #and we can modify it, i.e.:
    cax.set_ylabel('test')
    
    0 讨论(0)
  • 2020-12-06 03:20

    It's not quite the same but you could just plot using matplotlib:

    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    import itertools as it
    
    # [ (0,0), (0,1), ..., (9,9) ]
    xy_positions = list( it.product( range(10), range(10) ) )
    
    df = pd.DataFrame( xy_positions, columns=['x','y'] )
    
    # draw 100 floats
    df['score'] = np.random.random( 100 )
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    
    s = ax.scatter(df.x, df.y, c=df.score, s=500)
    cb = plt.colorbar(s)
    cb.set_label('desired_label')
    
    ax.set_xlim( [-0.5,9.5] )
    ax.set_ylim( [-0.5,9.5] )
    
    plt.show()
    
    0 讨论(0)
提交回复
热议问题