How can I subsample an array according to its density? (Remove frequent values, keep rare ones)

后端 未结 2 670
遇见更好的自我
遇见更好的自我 2020-12-10 00:32

I have this problem that I want to plot a data distribution where some values occur frequently while others are quite rare. The number of points in total is around 30.000. R

2条回答
  •  囚心锁ツ
    2020-12-10 00:43

    Consider the following function. It will bin the data in equal bins along the axis and

    • if there are one or two points in a bin, take over those points,
    • if there are more points in a bin, take over the minimum and maximum value.
    • append the first and last point to make sure the same data range is used.

    This allows to keep the original data in regions of low density, but significantly reduce the amount of data to plot in regions of high density. At the same time all the features are preserved with a sufficiently dense binning.

    import numpy as np; np.random.seed(42)
    
    def filt(x,y, bins):
        d = np.digitize(x, bins)
        xfilt = []
        yfilt = []
        for i in np.unique(d):
            xi = x[d == i]
            yi = y[d == i]
            if len(xi) <= 2:
                xfilt.extend(list(xi))
                yfilt.extend(list(yi))
            else:
                xfilt.extend([xi[np.argmax(yi)], xi[np.argmin(yi)]])
                yfilt.extend([yi.max(), yi.min()])
        # prepend/append first/last point if necessary
        if x[0] != xfilt[0]:
            xfilt = [x[0]] + xfilt
            yfilt = [y[0]] + yfilt
        if x[-1] != xfilt[-1]:
            xfilt.append(x[-1])
            yfilt.append(y[-1])
        sort = np.argsort(xfilt)
        return np.array(xfilt)[sort], np.array(yfilt)[sort]
    

    To illustrate the concept let's use some toy data

    x = np.array([1,2,3,4, 6,7,8,9, 11,14, 17, 26,28,29])
    y = np.array([4,2,5,3, 7,3,5,5, 2, 4,  5,  2,5,3])
    bins = np.linspace(0,30,7)
    

    Then calling xf, yf = filt(x,y,bins) and plotting both the original data and the filtered data gives:

    The usecase of the question with some 30000 datapoints would be shown in the following. Using the presented technique would allow to reduce the number of plotted points from 30000 to some 500. This number will of course depend on the binning in use - here 300 bins. In this case the function takes ~10 ms to compute. This is not super-fast, but still a large improvement compared to plotting all the points.

    import matplotlib.pyplot as plt
    
    # Generate some data
    x = np.sort(np.random.rayleigh(3, size=30000))
    y = np.cumsum(np.random.randn(len(x)))+250
    # Decide for a number of bins
    bins = np.linspace(x.min(),x.max(),301)
    # Filter data
    xf, yf = filt(x,y,bins) 
    
    # Plot results
    fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, figsize=(7,8), 
                                        gridspec_kw=dict(height_ratios=[1,2,2]))
    
    ax1.hist(x, bins=bins)
    ax1.set_yscale("log")
    ax1.set_yticks([1,10,100,1000])
    
    ax2.plot(x,y, linewidth=1, label="original data, {} points".format(len(x)))
    
    ax3.plot(xf, yf, linewidth=1, label="binned min/max, {} points".format(len(xf)))
    
    for ax in [ax2, ax3]:
        ax.legend()
    plt.show()
    

提交回复
热议问题