Generate a heatmap in MatPlotLib using a scatter data set

后端 未结 12 2133
南方客
南方客 2020-11-22 09:22

I have a set of X,Y data points (about 10k) that are easy to plot as a scatter plot but that I would like to represent as a heatmap.

I looked through the examples in

相关标签:
12条回答
  • 2020-11-22 09:59

    and the initial question was... how to convert scatter values to grid values, right? histogram2d does count the frequency per cell, however, if you have other data per cell than just the frequency, you'd need some additional work to do.

    x = data_x # between -10 and 4, log-gamma of an svc
    y = data_y # between -4 and 11, log-C of an svc
    z = data_z #between 0 and 0.78, f1-values from a difficult dataset
    

    So, I have a dataset with Z-results for X and Y coordinates. However, I was calculating few points outside the area of interest (large gaps), and heaps of points in a small area of interest.

    Yes here it becomes more difficult but also more fun. Some libraries (sorry):

    from matplotlib import pyplot as plt
    from matplotlib import cm
    import numpy as np
    from scipy.interpolate import griddata
    

    pyplot is my graphic engine today, cm is a range of color maps with some initeresting choice. numpy for the calculations, and griddata for attaching values to a fixed grid.

    The last one is important especially because the frequency of xy points is not equally distributed in my data. First, let's start with some boundaries fitting to my data and an arbitrary grid size. The original data has datapoints also outside those x and y boundaries.

    #determine grid boundaries
    gridsize = 500
    x_min = -8
    x_max = 2.5
    y_min = -2
    y_max = 7
    

    So we have defined a grid with 500 pixels between the min and max values of x and y.

    In my data, there are lots more than the 500 values available in the area of high interest; whereas in the low-interest-area, there are not even 200 values in the total grid; between the graphic boundaries of x_min and x_max there are even less.

    So for getting a nice picture, the task is to get an average for the high interest values and to fill the gaps elsewhere.

    I define my grid now. For each xx-yy pair, i want to have a color.

    xx = np.linspace(x_min, x_max, gridsize) # array of x values
    yy = np.linspace(y_min, y_max, gridsize) # array of y values
    grid = np.array(np.meshgrid(xx, yy.T))
    grid = grid.reshape(2, grid.shape[1]*grid.shape[2]).T
    

    Why the strange shape? scipy.griddata wants a shape of (n, D).

    Griddata calculates one value per point in the grid, by a predefined method. I choose "nearest" - empty grid points will be filled with values from the nearest neighbor. This looks as if the areas with less information have bigger cells (even if it is not the case). One could choose to interpolate "linear", then areas with less information look less sharp. Matter of taste, really.

    points = np.array([x, y]).T # because griddata wants it that way
    z_grid2 = griddata(points, z, grid, method='nearest')
    # you get a 1D vector as result. Reshape to picture format!
    z_grid2 = z_grid2.reshape(xx.shape[0], yy.shape[0])
    

    And hop, we hand over to matplotlib to display the plot

    fig = plt.figure(1, figsize=(10, 10))
    ax1 = fig.add_subplot(111)
    ax1.imshow(z_grid2, extent=[x_min, x_max,y_min, y_max,  ],
                origin='lower', cmap=cm.magma)
    ax1.set_title("SVC: empty spots filled by nearest neighbours")
    ax1.set_xlabel('log gamma')
    ax1.set_ylabel('log C')
    plt.show()
    

    Around the pointy part of the V-Shape, you see I did a lot of calculations during my search for the sweet spot, whereas the less interesting parts almost everywhere else have a lower resolution.

    0 讨论(0)
  • 2020-11-22 10:04

    I'm afraid I'm a little late to the party but I had a similar question a while ago. The accepted answer (by @ptomato) helped me out but I'd also want to post this in case it's of use to someone.

    
    ''' I wanted to create a heatmap resembling a football pitch which would show the different actions performed '''
    
    import numpy as np
    import matplotlib.pyplot as plt
    import random
    
    #fixing random state for reproducibility
    np.random.seed(1234324)
    
    fig = plt.figure(12)
    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)
    
    #Ratio of the pitch with respect to UEFA standards 
    hmap= np.full((6, 10), 0)
    #print(hmap)
    
    xlist = np.random.uniform(low=0.0, high=100.0, size=(20))
    ylist = np.random.uniform(low=0.0, high =100.0, size =(20))
    
    #UEFA Pitch Standards are 105m x 68m
    xlist = (xlist/100)*10.5
    ylist = (ylist/100)*6.5
    
    ax1.scatter(xlist,ylist)
    
    #int of the co-ordinates to populate the array
    xlist_int = xlist.astype (int)
    ylist_int = ylist.astype (int)
    
    #print(xlist_int, ylist_int)
    
    for i, j in zip(xlist_int, ylist_int):
        #this populates the array according to the x,y co-ordinate values it encounters 
        hmap[j][i]= hmap[j][i] + 1   
    
    #Reversing the rows is necessary 
    hmap = hmap[::-1]
    
    #print(hmap)
    im = ax2.imshow(hmap)
    
    
    

    Here's the result

    0 讨论(0)
  • 2020-11-22 10:05

    Seaborn now has the jointplot function which should work nicely here:

    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    # Generate some test data
    x = np.random.randn(8873)
    y = np.random.randn(8873)
    
    sns.jointplot(x=x, y=y, kind='hex')
    plt.show()
    

    0 讨论(0)
  • 2020-11-22 10:06

    Very similar to @Piti's answer, but using 1 call instead of 2 to generate the points:

    import numpy as np
    import matplotlib.pyplot as plt
    
    pts = 1000000
    mean = [0.0, 0.0]
    cov = [[1.0,0.0],[0.0,1.0]]
    
    x,y = np.random.multivariate_normal(mean, cov, pts).T
    plt.hist2d(x, y, bins=50, cmap=plt.cm.jet)
    plt.show()
    

    Output:

    0 讨论(0)
  • 2020-11-22 10:08

    Here's Jurgy's great nearest neighbour approach but implemented using scipy.cKDTree. In my tests it's about 100x faster.

    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    from scipy.spatial import cKDTree
    
    
    def data_coord2view_coord(p, resolution, pmin, pmax):
        dp = pmax - pmin
        dv = (p - pmin) / dp * resolution
        return dv
    
    
    n = 1000
    xs = np.random.randn(n)
    ys = np.random.randn(n)
    
    resolution = 250
    
    extent = [np.min(xs), np.max(xs), np.min(ys), np.max(ys)]
    xv = data_coord2view_coord(xs, resolution, extent[0], extent[1])
    yv = data_coord2view_coord(ys, resolution, extent[2], extent[3])
    
    
    def kNN2DDens(xv, yv, resolution, neighbours, dim=2):
        """
        """
        # Create the tree
        tree = cKDTree(np.array([xv, yv]).T)
        # Find the closest nnmax-1 neighbors (first entry is the point itself)
        grid = np.mgrid[0:resolution, 0:resolution].T.reshape(resolution**2, dim)
        dists = tree.query(grid, neighbours)
        # Inverse of the sum of distances to each grid point.
        inv_sum_dists = 1. / dists[0].sum(1)
    
        # Reshape
        im = inv_sum_dists.reshape(resolution, resolution)
        return im
    
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 15))
    for ax, neighbours in zip(axes.flatten(), [0, 16, 32, 63]):
    
        if neighbours == 0:
            ax.plot(xs, ys, 'k.', markersize=5)
            ax.set_aspect('equal')
            ax.set_title("Scatter Plot")
        else:
    
            im = kNN2DDens(xv, yv, resolution, neighbours)
    
            ax.imshow(im, origin='lower', extent=extent, cmap=cm.Blues)
            ax.set_title("Smoothing over %d neighbours" % neighbours)
            ax.set_xlim(extent[0], extent[1])
            ax.set_ylim(extent[2], extent[3])
    
    plt.savefig('new.png', dpi=150, bbox_inches='tight')
    
    0 讨论(0)
  • 2020-11-22 10:11

    Here's one I made on a 1 Million point set with 3 categories (colored Red, Green, and Blue). Here's a link to the repository if you'd like to try the function. Github Repo

    histplot(
        X,
        Y,
        labels,
        bins=2000,
        range=((-3,3),(-3,3)),
        normalize_each_label=True,
        colors = [
            [1,0,0],
            [0,1,0],
            [0,0,1]],
        gain=50)
    
    0 讨论(0)
提交回复
热议问题