Generate a heatmap in MatPlotLib using a scatter data set

后端 未结 12 2120
南方客
南方客 2020-11-22 09:22

I have a set of X,Y data points (about 10k) that are easy to plot as a scatter plot but that I would like to represent as a heatmap.

I looked through the examples in

12条回答
  •  伪装坚强ぢ
    2020-11-22 10:12

    Instead of using np.hist2d, which in general produces quite ugly histograms, I would like to recycle py-sphviewer, a python package for rendering particle simulations using an adaptive smoothing kernel and that can be easily installed from pip (see webpage documentation). Consider the following code, which is based on the example:

    import numpy as np
    import numpy.random
    import matplotlib.pyplot as plt
    import sphviewer as sph
    
    def myplot(x, y, nb=32, xsize=500, ysize=500):   
        xmin = np.min(x)
        xmax = np.max(x)
        ymin = np.min(y)
        ymax = np.max(y)
    
        x0 = (xmin+xmax)/2.
        y0 = (ymin+ymax)/2.
    
        pos = np.zeros([3, len(x)])
        pos[0,:] = x
        pos[1,:] = y
        w = np.ones(len(x))
    
        P = sph.Particles(pos, w, nb=nb)
        S = sph.Scene(P)
        S.update_camera(r='infinity', x=x0, y=y0, z=0, 
                        xsize=xsize, ysize=ysize)
        R = sph.Render(S)
        R.set_logscale()
        img = R.get_image()
        extent = R.get_extent()
        for i, j in zip(xrange(4), [x0,x0,y0,y0]):
            extent[i] += j
        print extent
        return img, extent
    
    fig = plt.figure(1, figsize=(10,10))
    ax1 = fig.add_subplot(221)
    ax2 = fig.add_subplot(222)
    ax3 = fig.add_subplot(223)
    ax4 = fig.add_subplot(224)
    
    
    # Generate some test data
    x = np.random.randn(1000)
    y = np.random.randn(1000)
    
    #Plotting a regular scatter plot
    ax1.plot(x,y,'k.', markersize=5)
    ax1.set_xlim(-3,3)
    ax1.set_ylim(-3,3)
    
    heatmap_16, extent_16 = myplot(x,y, nb=16)
    heatmap_32, extent_32 = myplot(x,y, nb=32)
    heatmap_64, extent_64 = myplot(x,y, nb=64)
    
    ax2.imshow(heatmap_16, extent=extent_16, origin='lower', aspect='auto')
    ax2.set_title("Smoothing over 16 neighbors")
    
    ax3.imshow(heatmap_32, extent=extent_32, origin='lower', aspect='auto')
    ax3.set_title("Smoothing over 32 neighbors")
    
    #Make the heatmap using a smoothing over 64 neighbors
    ax4.imshow(heatmap_64, extent=extent_64, origin='lower', aspect='auto')
    ax4.set_title("Smoothing over 64 neighbors")
    
    plt.show()
    

    which produces the following image:

    As you see, the images look pretty nice, and we are able to identify different substructures on it. These images are constructed spreading a given weight for every point within a certain domain, defined by the smoothing length, which in turns is given by the distance to the closer nb neighbor (I've chosen 16, 32 and 64 for the examples). So, higher density regions typically are spread over smaller regions compared to lower density regions.

    The function myplot is just a very simple function that I've written in order to give the x,y data to py-sphviewer to do the magic.

提交回复
热议问题