How to label the group in grouped barplot in pyplot?

后端 未结 1 1071
[愿得一人]
[愿得一人] 2021-01-25 04:12

I want to do something similar to the following in pyplot where I can label the bars and the group as well. I was able to label the bars but don\'t know how to label the groups.

相关标签:
1条回答
  • 2021-01-25 05:15

    A default seaborn barplot would create a legend to show the years:

    import numpy as np
    import pandas as pd
    import seaborn as sns
    
    labels = ['apples', 'bananas', 'coconuts', 'dates', 'elderberries', 'figs', 'grapes']
    years = [2017, 2018, 2019]
    df = pd.DataFrame({'Fruit': np.tile(labels, len(years)),
                       'Year': np.tile(years, len(labels)),
                       'Amount': np.random.uniform(1.5, 5, len(labels)*len(years))})
    ax = sns.barplot(x='Fruit', y='Amount', hue='Year', data=df)
    

    Placing text() using a BlendedTransform can place new labels at the desired positions. clip_on=False is needed to draw outside the main axes area. Vertical lines can be added via axvline (these use the BlendedTransform by default). New year labels can be put at the position of each bar. The legend and the default xlabel can be removed.

    from matplotlib import pyplot as plt
    from matplotlib import transforms
    import numpy as np
    import pandas as pd
    import seaborn as sns
    
    labels = ['apples', 'bananas', 'coconuts', 'dates', 'elderberries', 'figs', 'grapes']
    years = [2017, 2018, 2019]
    df = pd.DataFrame({'Fruit': np.tile(labels, len(years)),
                       'Year': np.tile(years, len(labels)),
                       'Amount': np.random.uniform(1.5, 5, len(labels)*len(years))})
    fig, ax = plt.subplots(figsize=(12, 4))
    ax = sns.barplot(x='Fruit', y='Amount', hue='Year', palette='Reds', data=df, ax=ax)
    year_pos = np.sort( [p.get_x() + p.get_width()/2  for p in ax.patches])
    ax.set_xticks(year_pos)
    ax.set_xticklabels(np.tile(years, len(labels)), rotation=30)
    ax.get_legend().remove()
    ax.set_xlabel('') # remove default xlabel
    fruit_pos = year_pos.reshape(-1, len(years)).mean(axis=1)
    trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
    
    for pos, label in zip(fruit_pos, labels):
        ax.text(pos, -0.25,  label, transform=trans, ha='center', va='bottom', color='steelblue', fontsize=14)
    for pos in (fruit_pos[:-1] + fruit_pos[1:]) / 2:
        ax.axvline(pos, 0, -0.25, color='steelblue', ls=':' , clip_on=False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.tight_layout()
    plt.show()
    

    0 讨论(0)
提交回复
热议问题