Plotting two distributions in seaborn.jointplot

前端 未结 3 1706
忘了有多久
忘了有多久 2020-12-09 19:36

I have two pandas dataframes I would like to plot in the same seaborn jointplot. It looks something like this (commands are don in an IPython shell; ipyth

相关标签:
3条回答
  • 2020-12-09 20:05

    Here is one way to do it by modifying the underlying data of sns.JointGrid.

    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # simulate some artificial data
    # ========================================
    np.random.seed(0)
    data1 = np.random.multivariate_normal([0,0], [[1,0.5],[0.5,1]], size=200)
    data2 = np.random.multivariate_normal([0,0], [[1,-0.8],[-0.8,1]], size=100)
    
    # both df1 and df2 have bivaraite normals, df1.size=200, df2.size=100
    df1 = pd.DataFrame(data1, columns=['x1', 'y1'])
    df2 = pd.DataFrame(data2, columns=['x2', 'y2'])
    
    
    # plot
    # ========================================   
    graph = sns.jointplot(x=df1.x1, y=df1.y1, color='r')
    
    graph.x = df2.x2
    graph.y = df2.y2
    graph.plot_joint(plt.scatter, marker='x', c='b', s=50)
    

    enter image description here

    0 讨论(0)
  • 2020-12-09 20:15

    A better solution, in my opinion, is to use the axes handles for the joint and marginal distributions that sns.joinplot returns. Using those (the names are ax_joint, ax_marg_x and ax_marg_y) is also possible to draw on the marginal distributions plots.

    import seaborn as sns
    import numpy as np
    
    data1 = np.random.randn(100)
    data2 = np.random.randn(100)
    data3 = np.random.randn(100)
    data4 = np.random.randn(100)
    
    df1 = pd.DataFrame({'col1': data1, 'col2':data2})
    df2 = pd.DataFrame({'col1': data3, 'col2':data4})
    
    axs = sns.jointplot('col1', 'col2', data=df1)
    axs.ax_joint.scatter('col1', 'col2', data=df2, c='r', marker='x')
    
    # drawing pdf instead of histograms on the marginal axes
    axs.ax_marg_x.cla()
    axs.ax_marg_y.cla()
    sns.distplot(df1.col1, ax=axs.ax_marg_x)
    sns.distplot(df1.col2, ax=axs.ax_marg_y, vertical=True)
    

    0 讨论(0)
  • 2020-12-09 20:18

    It might be easier after drawing the jointplot, change to the axis on which you want to draw something and use then normal pyplot or axis based seaborn plots:

    g=sns.jointplot(...)
    plt.sca("axis_name")
    plt.plot/plt.scatter/.../sns.kde(ax="axis_name")
    

    The axis name is either ax_joint for the 2d-Plot or ax_marg_x or ax_marg_y for the 1d Plots on the side.

    Furthermore, if you want to use the jointplot structure but plot all plots by pyplot, use the cla function, e.g. for clearing the 2d-Plot:

    g.ax_joint.cla()
    
    0 讨论(0)
提交回复
热议问题