pandas - multi index plotting

后端 未结 2 1580
被撕碎了的回忆
被撕碎了的回忆 2020-11-28 09:09

I have some data where I\'ve manipulated the dataframe using the following code:

import pandas as pd
import numpy as np

data = pd.DataFrame([[0,0,0,3,6,5,6,         


        
相关标签:
2条回答
  • 2020-11-28 09:26

    I would use a factor plot from seaborn.

    Say you have data like this:

    import numpy as np
    import pandas
    
    import seaborn
    seaborn.set(style='ticks') 
    np.random.seed(0)
    
    groups = ('Group 1', 'Group 2')
    sexes = ('Male', 'Female')
    means = ('Low', 'High')
    index = pandas.MultiIndex.from_product(
        [groups, sexes, means], 
       names=['Group', 'Sex', 'Mean']
    )
    
    values = np.random.randint(low=20, high=100, size=len(index))
    data = pandas.DataFrame(data={'val': values}, index=index).reset_index()
    print(data)
    
         Group     Sex  Mean  val
    0  Group 1    Male   Low   64
    1  Group 1    Male  High   67
    2  Group 1  Female   Low   84
    3  Group 1  Female  High   87
    4  Group 2    Male   Low   87
    5  Group 2    Male  High   29
    6  Group 2  Female   Low   41
    7  Group 2  Female  High   56
    

    You can then create the factor plot with one command + plus an extra line to remove some redundant (for your data) x-labels:

    fg = seaborn.factorplot(x='Group', y='val', hue='Mean', 
                            col='Sex', data=data, kind='bar')
    fg.set_xlabels('')
    

    Which gives me:

    0 讨论(0)
  • 2020-11-28 09:32

    In a related question I found an alternative solution by @Stein that codes the multiindex levels as different labels. Here is how it looks like for your example:

    import pandas as pd
    import matplotlib.pyplot as plt
    from itertools import groupby
    import numpy as np 
    %matplotlib inline
    
    groups = ('Group 1', 'Group 2')
    sexes = ('Male', 'Female')
    means = ('Low', 'High')
    index = pd.MultiIndex.from_product(
        [groups, sexes, means], 
       names=['Group', 'Sex', 'Mean']
    )
    
    values = np.random.randint(low=20, high=100, size=len(index))
    data = pd.DataFrame(data={'val': values}, index=index)
    # unstack last level to plot two separate columns
    data = data.unstack(level=-1)
    
    def add_line(ax, xpos, ypos):
        line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                          transform=ax.transAxes, color='gray')
        line.set_clip_on(False)
        ax.add_line(line)
    
    def label_len(my_index,level):
        labels = my_index.get_level_values(level)
        return [(k, sum(1 for i in g)) for k,g in groupby(labels)]
    
    def label_group_bar_table(ax, df):
        ypos = -.1
        scale = 1./df.index.size
        for level in range(df.index.nlevels)[::-1]:
            pos = 0
            for label, rpos in label_len(df.index,level):
                lxpos = (pos + .5 * rpos)*scale
                ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
                add_line(ax, pos*scale, ypos)
                pos += rpos
            add_line(ax, pos*scale , ypos)
            ypos -= .1
    
    ax = data['val'].plot(kind='bar')
    #Below 2 lines remove default labels
    ax.set_xticklabels('')
    ax.set_xlabel('')
    label_group_bar_table(ax, data)
    

    This gives:

    0 讨论(0)
提交回复
热议问题