Generate a heatmap in MatPlotLib using a scatter data set

后端 未结 12 2122
南方客
南方客 2020-11-22 09:22

I have a set of X,Y data points (about 10k) that are easy to plot as a scatter plot but that I would like to represent as a heatmap.

I looked through the examples in

相关标签:
12条回答
  • 2020-11-22 10:12

    Instead of using np.hist2d, which in general produces quite ugly histograms, I would like to recycle py-sphviewer, a python package for rendering particle simulations using an adaptive smoothing kernel and that can be easily installed from pip (see webpage documentation). Consider the following code, which is based on the example:

    import numpy as np
    import numpy.random
    import matplotlib.pyplot as plt
    import sphviewer as sph
    
    def myplot(x, y, nb=32, xsize=500, ysize=500):   
        xmin = np.min(x)
        xmax = np.max(x)
        ymin = np.min(y)
        ymax = np.max(y)
    
        x0 = (xmin+xmax)/2.
        y0 = (ymin+ymax)/2.
    
        pos = np.zeros([3, len(x)])
        pos[0,:] = x
        pos[1,:] = y
        w = np.ones(len(x))
    
        P = sph.Particles(pos, w, nb=nb)
        S = sph.Scene(P)
        S.update_camera(r='infinity', x=x0, y=y0, z=0, 
                        xsize=xsize, ysize=ysize)
        R = sph.Render(S)
        R.set_logscale()
        img = R.get_image()
        extent = R.get_extent()
        for i, j in zip(xrange(4), [x0,x0,y0,y0]):
            extent[i] += j
        print extent
        return img, extent
    
    fig = plt.figure(1, figsize=(10,10))
    ax1 = fig.add_subplot(221)
    ax2 = fig.add_subplot(222)
    ax3 = fig.add_subplot(223)
    ax4 = fig.add_subplot(224)
    
    
    # Generate some test data
    x = np.random.randn(1000)
    y = np.random.randn(1000)
    
    #Plotting a regular scatter plot
    ax1.plot(x,y,'k.', markersize=5)
    ax1.set_xlim(-3,3)
    ax1.set_ylim(-3,3)
    
    heatmap_16, extent_16 = myplot(x,y, nb=16)
    heatmap_32, extent_32 = myplot(x,y, nb=32)
    heatmap_64, extent_64 = myplot(x,y, nb=64)
    
    ax2.imshow(heatmap_16, extent=extent_16, origin='lower', aspect='auto')
    ax2.set_title("Smoothing over 16 neighbors")
    
    ax3.imshow(heatmap_32, extent=extent_32, origin='lower', aspect='auto')
    ax3.set_title("Smoothing over 32 neighbors")
    
    #Make the heatmap using a smoothing over 64 neighbors
    ax4.imshow(heatmap_64, extent=extent_64, origin='lower', aspect='auto')
    ax4.set_title("Smoothing over 64 neighbors")
    
    plt.show()
    

    which produces the following image:

    As you see, the images look pretty nice, and we are able to identify different substructures on it. These images are constructed spreading a given weight for every point within a certain domain, defined by the smoothing length, which in turns is given by the distance to the closer nb neighbor (I've chosen 16, 32 and 64 for the examples). So, higher density regions typically are spread over smaller regions compared to lower density regions.

    The function myplot is just a very simple function that I've written in order to give the x,y data to py-sphviewer to do the magic.

    0 讨论(0)
  • 2020-11-22 10:13

    If you are using 1.2.x

    import numpy as np
    import matplotlib.pyplot as plt
    
    x = np.random.randn(100000)
    y = np.random.randn(100000)
    plt.hist2d(x,y,bins=100)
    plt.show()
    

    gaussian_2d_heat_map

    0 讨论(0)
  • 2020-11-22 10:15

    Edit: For a better approximation of Alejandro's answer, see below.

    I know this is an old question, but wanted to add something to Alejandro's anwser: If you want a nice smoothed image without using py-sphviewer you can instead use np.histogram2d and apply a gaussian filter (from scipy.ndimage.filters) to the heatmap:

    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    from scipy.ndimage.filters import gaussian_filter
    
    
    def myplot(x, y, s, bins=1000):
        heatmap, xedges, yedges = np.histogram2d(x, y, bins=bins)
        heatmap = gaussian_filter(heatmap, sigma=s)
    
        extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
        return heatmap.T, extent
    
    
    fig, axs = plt.subplots(2, 2)
    
    # Generate some test data
    x = np.random.randn(1000)
    y = np.random.randn(1000)
    
    sigmas = [0, 16, 32, 64]
    
    for ax, s in zip(axs.flatten(), sigmas):
        if s == 0:
            ax.plot(x, y, 'k.', markersize=5)
            ax.set_title("Scatter plot")
        else:
            img, extent = myplot(x, y, s)
            ax.imshow(img, extent=extent, origin='lower', cmap=cm.jet)
            ax.set_title("Smoothing with  $\sigma$ = %d" % s)
    
    plt.show()
    

    Produces:

    The scatter plot and s=16 plotted on top of eachother for Agape Gal'lo (click for better view):


    One difference I noticed with my gaussian filter approach and Alejandro's approach was that his method shows local structures much better than mine. Therefore I implemented a simple nearest neighbour method at pixel level. This method calculates for each pixel the inverse sum of the distances of the n closest points in the data. This method is at a high resolution pretty computationally expensive and I think there's a quicker way, so let me know if you have any improvements.

    Update: As I suspected, there's a much faster method using Scipy's scipy.cKDTree. See Gabriel's answer for the implementation.

    Anyway, here's my code:

    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    
    
    def data_coord2view_coord(p, vlen, pmin, pmax):
        dp = pmax - pmin
        dv = (p - pmin) / dp * vlen
        return dv
    
    
    def nearest_neighbours(xs, ys, reso, n_neighbours):
        im = np.zeros([reso, reso])
        extent = [np.min(xs), np.max(xs), np.min(ys), np.max(ys)]
    
        xv = data_coord2view_coord(xs, reso, extent[0], extent[1])
        yv = data_coord2view_coord(ys, reso, extent[2], extent[3])
        for x in range(reso):
            for y in range(reso):
                xp = (xv - x)
                yp = (yv - y)
    
                d = np.sqrt(xp**2 + yp**2)
    
                im[y][x] = 1 / np.sum(d[np.argpartition(d.ravel(), n_neighbours)[:n_neighbours]])
    
        return im, extent
    
    
    n = 1000
    xs = np.random.randn(n)
    ys = np.random.randn(n)
    resolution = 250
    
    fig, axes = plt.subplots(2, 2)
    
    for ax, neighbours in zip(axes.flatten(), [0, 16, 32, 64]):
        if neighbours == 0:
            ax.plot(xs, ys, 'k.', markersize=2)
            ax.set_aspect('equal')
            ax.set_title("Scatter Plot")
        else:
            im, extent = nearest_neighbours(xs, ys, resolution, neighbours)
            ax.imshow(im, origin='lower', extent=extent, cmap=cm.jet)
            ax.set_title("Smoothing over %d neighbours" % neighbours)
            ax.set_xlim(extent[0], extent[1])
            ax.set_ylim(extent[2], extent[3])
    plt.show()
    

    Result:

    0 讨论(0)
  • 2020-11-22 10:16

    Make a 2-dimensional array that corresponds to the cells in your final image, called say heatmap_cells and instantiate it as all zeroes.

    Choose two scaling factors that define the difference between each array element in real units, for each dimension, say x_scale and y_scale. Choose these such that all your datapoints will fall within the bounds of the heatmap array.

    For each raw datapoint with x_value and y_value:

    heatmap_cells[floor(x_value/x_scale),floor(y_value/y_scale)]+=1

    0 讨论(0)
  • 2020-11-22 10:19

    If you don't want hexagons, you can use numpy's histogram2d function:

    import numpy as np
    import numpy.random
    import matplotlib.pyplot as plt
    
    # Generate some test data
    x = np.random.randn(8873)
    y = np.random.randn(8873)
    
    heatmap, xedges, yedges = np.histogram2d(x, y, bins=50)
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    
    plt.clf()
    plt.imshow(heatmap.T, extent=extent, origin='lower')
    plt.show()
    

    This makes a 50x50 heatmap. If you want, say, 512x384, you can put bins=(512, 384) in the call to histogram2d.

    Example: Matplotlib heat map example

    0 讨论(0)
  • 2020-11-22 10:23

    In Matplotlib lexicon, i think you want a hexbin plot.

    If you're not familiar with this type of plot, it's just a bivariate histogram in which the xy-plane is tessellated by a regular grid of hexagons.

    So from a histogram, you can just count the number of points falling in each hexagon, discretiize the plotting region as a set of windows, assign each point to one of these windows; finally, map the windows onto a color array, and you've got a hexbin diagram.

    Though less commonly used than e.g., circles, or squares, that hexagons are a better choice for the geometry of the binning container is intuitive:

    • hexagons have nearest-neighbor symmetry (e.g., square bins don't, e.g., the distance from a point on a square's border to a point inside that square is not everywhere equal) and

    • hexagon is the highest n-polygon that gives regular plane tessellation (i.e., you can safely re-model your kitchen floor with hexagonal-shaped tiles because you won't have any void space between the tiles when you are finished--not true for all other higher-n, n >= 7, polygons).

    (Matplotlib uses the term hexbin plot; so do (AFAIK) all of the plotting libraries for R; still i don't know if this is the generally accepted term for plots of this type, though i suspect it's likely given that hexbin is short for hexagonal binning, which is describes the essential step in preparing the data for display.)


    from matplotlib import pyplot as PLT
    from matplotlib import cm as CM
    from matplotlib import mlab as ML
    import numpy as NP
    
    n = 1e5
    x = y = NP.linspace(-5, 5, 100)
    X, Y = NP.meshgrid(x, y)
    Z1 = ML.bivariate_normal(X, Y, 2, 2, 0, 0)
    Z2 = ML.bivariate_normal(X, Y, 4, 1, 1, 1)
    ZD = Z2 - Z1
    x = X.ravel()
    y = Y.ravel()
    z = ZD.ravel()
    gridsize=30
    PLT.subplot(111)
    
    # if 'bins=None', then color of each hexagon corresponds directly to its count
    # 'C' is optional--it maps values to x-y coordinates; if 'C' is None (default) then 
    # the result is a pure 2D histogram 
    
    PLT.hexbin(x, y, C=z, gridsize=gridsize, cmap=CM.jet, bins=None)
    PLT.axis([x.min(), x.max(), y.min(), y.max()])
    
    cb = PLT.colorbar()
    cb.set_label('mean value')
    PLT.show()   
    

    enter image description here

    0 讨论(0)
提交回复
热议问题