Error while drawing animation of seaborn heatmap for 3D volume

前端 未结 1 406
眼角桃花
眼角桃花 2021-01-15 21:40

Trying to visualize the cross-correlation between two volumes, img_3D, and mask_3D, using Seaborn heatmap, and animation f

相关标签:
1条回答
  • 2021-01-15 22:10

    Check this code:

    import numpy as np
    np.random.seed(0)
    import matplotlib.pyplot as plt
    from matplotlib.animation import FuncAnimation
    from scipy.signal import correlate
    import seaborn as sns
    sns.set()
    
    img = np.load('img.npy')
    act = np.load('act.npy')
    
    result = correlate(img, act, mode = 'same')
    
    def updatefig(sl):
        ax.cla()
        print(sl + 1, ' / ', result.shape[2])
        sns.heatmap(result[..., sl], cbar = False)
        ax.set_title("frame {}".format(sl + 1))
        ax.axis('off')
    
    fig, ax = plt.subplots()
    ani = FuncAnimation(fig, updatefig, frames = result.shape[2], interval = 5)
    
    plt.show()
    

    which gives me this animation (I halved the animation reported below to reduce the file size under 2 MB, the code above reproduce all 40 frames):


    EDIT

    In order to add a fixed colorbar to the heatmap, check this code:

    import numpy as np
    np.random.seed(0)
    import matplotlib.pyplot as plt
    from matplotlib.animation import FuncAnimation
    from scipy.signal import correlate
    import seaborn as sns
    sns.set()
    
    img = np.load('img.npy')
    act = np.load('act.npy')
    
    result = correlate(img, act, mode = 'same')
    
    def updatefig(sl):
        ax.cla()
        print(sl + 1, ' / ', result.shape[2])
        sns.heatmap(result[..., sl],
                    ax = ax,
                    cbar = True,
                    cbar_ax = cbar_ax,
                    vmin = result.min(),
                    vmax = result.max())
        ax.set_title("frame {}".format(sl + 1))
        ax.axis('off')
    
    grid_kws = {'width_ratios': (0.9, 0.05), 'wspace': 0.2}
    fig, (ax, cbar_ax) = plt.subplots(1, 2, gridspec_kw = grid_kws, figsize = (10, 8))
    ani = FuncAnimation(fig, updatefig, frames = result.shape[2], interval = 5)
    
    plt.show()
    

    which produces this animation (cut as the previous one):

    0 讨论(0)
提交回复
热议问题