matplotlib: plotting histogram plot just above scatter plot

后端 未结 2 494
独厮守ぢ
独厮守ぢ 2021-01-05 05:31

I would like to make beautiful scatter plots with histograms above and right of the scatter plot, as it is possible in seaborn with jointplot:

I am looking

相关标签:
2条回答
  • 2021-01-05 05:40

    Here's an example of how to do it, using gridspec.GridSpec:

    import matplotlib.pyplot as plt
    from matplotlib.gridspec import GridSpec
    import numpy as np
    
    x = np.random.rand(50)
    y = np.random.rand(50)
    
    fig = plt.figure()
    
    gs = GridSpec(4,4)
    
    ax_joint = fig.add_subplot(gs[1:4,0:3])
    ax_marg_x = fig.add_subplot(gs[0,0:3])
    ax_marg_y = fig.add_subplot(gs[1:4,3])
    
    ax_joint.scatter(x,y)
    ax_marg_x.hist(x)
    ax_marg_y.hist(y,orientation="horizontal")
    
    # Turn off tick labels on marginals
    plt.setp(ax_marg_x.get_xticklabels(), visible=False)
    plt.setp(ax_marg_y.get_yticklabels(), visible=False)
    
    # Set labels on joint
    ax_joint.set_xlabel('Joint x label')
    ax_joint.set_ylabel('Joint y label')
    
    # Set labels on marginals
    ax_marg_y.set_xlabel('Marginal x label')
    ax_marg_x.set_ylabel('Marginal y label')
    plt.show()
    

    0 讨论(0)
  • 2021-01-05 05:56

    I encountered the same problem today. Additionally I wanted a CDF for the marginals.

    Code:

    import matplotlib.pyplot as plt
    import matplotlib.gridspec as gridspec
    import numpy as np
    
    x = np.random.beta(2,5,size=int(1e4))
    y = np.random.randn(int(1e4))
    
    fig = plt.figure(figsize=(8,8))
    gs = gridspec.GridSpec(3, 3)
    ax_main = plt.subplot(gs[1:3, :2])
    ax_xDist = plt.subplot(gs[0, :2],sharex=ax_main)
    ax_yDist = plt.subplot(gs[1:3, 2],sharey=ax_main)
    
    ax_main.scatter(x,y,marker='.')
    ax_main.set(xlabel="x data", ylabel="y data")
    
    ax_xDist.hist(x,bins=100,align='mid')
    ax_xDist.set(ylabel='count')
    ax_xCumDist = ax_xDist.twinx()
    ax_xCumDist.hist(x,bins=100,cumulative=True,histtype='step',normed=True,color='r',align='mid')
    ax_xCumDist.tick_params('y', colors='r')
    ax_xCumDist.set_ylabel('cumulative',color='r')
    
    ax_yDist.hist(y,bins=100,orientation='horizontal',align='mid')
    ax_yDist.set(xlabel='count')
    ax_yCumDist = ax_yDist.twiny()
    ax_yCumDist.hist(y,bins=100,cumulative=True,histtype='step',normed=True,color='r',align='mid',orientation='horizontal')
    ax_yCumDist.tick_params('x', colors='r')
    ax_yCumDist.set_xlabel('cumulative',color='r')
    
    plt.show()
    

    Hope it helps the next person searching for scatter-plot with marginal distribution.

    0 讨论(0)
提交回复
热议问题