How to plot heatmap for high-dimensional dataset?

后端 未结 2 763
北海茫月
北海茫月 2021-01-13 17:34

I would greatly appreciate if you could let me know how to plot high-resolution heatmap for a large dataset with approximately 150 features.

My code is as follows:

相关标签:
2条回答
  • 2021-01-13 18:16

    If I understand your problem correctly, I think all you have to do is increase you figure size:

    f, ax = plt.subplots(figsize=(20, 20))
    

    instead of

    f, ax = plt.subplots(figsize=(9, 9))
    
    0 讨论(0)
  • 2021-01-13 18:18

    Adjusting the figsize and dpi worked for me.

    I adapted your code and doubled the size of the heatmap to 165 x 165. The rendering takes a while, but the png looks fine. My backend is "module://ipykernel.pylab.backend_inline."

    As noted in my original answer, I'm pretty sure you forgot close the figure object before creating a new one. Try plt.close("all") before fig, ax = plt.subplots() if you get wierd effects.

    import pandas as pd
    import matplotlib.pyplot as plt
    import numpy as np
    import seaborn as sns
    
    print(plt.get_backend())
    
    # close any existing plots
    plt.close("all")
    
    df = pd.read_csv("Financial Distress.csv")
    # select out the desired columns
    df = df.iloc[:, 3:].select_dtypes(include=['float64','int64'])
    
    # copy columns to double size of dataframe
    df2 = df.copy()
    df2.columns = "c_" + df2.columns
    df3 = pd.concat([df, df2], axis=1)
    
    # get the correlation coefficient between the different columns
    corr = df3.iloc[:, 1:].corr()
    arr_corr = corr.as_matrix()
    # mask out the top triangle
    arr_corr[np.triu_indices_from(arr_corr)] = np.nan
    
    fig, ax = plt.subplots(figsize=(24, 18))
    
    hm = sns.heatmap(arr_corr, cbar=True, vmin=-0.5, vmax=0.5,
                     fmt='.2f', annot_kws={'size': 3}, annot=True, 
                     square=True, cmap=plt.cm.Blues)
    
    ticks = np.arange(corr.shape[0]) + 0.5
    ax.set_xticks(ticks)
    ax.set_xticklabels(corr.columns, rotation=90, fontsize=8)
    ax.set_yticks(ticks)
    ax.set_yticklabels(corr.index, rotation=360, fontsize=8)
    
    ax.set_title('correlation matrix')
    plt.tight_layout()
    plt.savefig("corr_matrix_incl_anno_double.png", dpi=300)
    

    full figure: zoom of top left section:

    0 讨论(0)
提交回复
热议问题