How to overlay a Seaborn jointplot with a “marginal” (distribution histogram) from a different dataset

后端 未结 3 1028
死守一世寂寞
死守一世寂寞 2021-02-06 07:22

I have plotted a Seaborn JointPlot from a set of \"observed counts vs concentration\" which are stored in a pandas DataFrame. I would like to overlay (

3条回答
  •  深忆病人
    2021-02-06 07:41

    Wrote a function to plot it, very loosly based on @blue_chip's idea. You might still need to tweak it a bit for your specific needs.

    Here is an example usage:

    Example data:

    import seaborn as sns, numpy as np, matplotlib.pyplot as plt, pandas as pd
    n=1000
    m1=-3
    m2=3
    
    df1 = pd.DataFrame((np.random.randn(n)+m1).reshape(-1,2), columns=['x','y'])
    df2 = pd.DataFrame((np.random.randn(n)+m2).reshape(-1,2), columns=['x','y'])
    df3 = pd.DataFrame(df1.values+df2.values, columns=['x','y'])
    df1['kind'] = 'dist1'
    df2['kind'] = 'dist2'
    df3['kind'] = 'dist1+dist2'
    df=pd.concat([df1,df2,df3])
    

    Function definition:

    def multivariateGrid(col_x, col_y, col_k, df, k_is_color=False, scatter_alpha=.5):
        def colored_scatter(x, y, c=None):
            def scatter(*args, **kwargs):
                args = (x, y)
                if c is not None:
                    kwargs['c'] = c
                kwargs['alpha'] = scatter_alpha
                plt.scatter(*args, **kwargs)
    
            return scatter
    
        g = sns.JointGrid(
            x=col_x,
            y=col_y,
            data=df
        )
        color = None
        legends=[]
        for name, df_group in df.groupby(col_k):
            legends.append(name)
            if k_is_color:
                color=name
            g.plot_joint(
                colored_scatter(df_group[col_x],df_group[col_y],color),
            )
            sns.distplot(
                df_group[col_x].values,
                ax=g.ax_marg_x,
                color=color,
            )
            sns.distplot(
                df_group[col_y].values,
                ax=g.ax_marg_y,
                color=color,            
                vertical=True
            )
        # Do also global Hist:
        sns.distplot(
            df[col_x].values,
            ax=g.ax_marg_x,
            color='grey'
        )
        sns.distplot(
            df[col_y].values.ravel(),
            ax=g.ax_marg_y,
            color='grey',
            vertical=True
        )
        plt.legend(legends)
        
    

    Usage:

    multivariateGrid('x', 'y', 'kind', df=df)
    

提交回复
热议问题