Removing Data Below A Line In A Scatterplot (Python)

前端 未结 1 979
春和景丽
春和景丽 2021-01-23 15:29

So I had code that graphed a 2dhistogram of my dataset. I plotted it like so:

histogram = plt.hist2d(fehsc, ofesc, bins=nbins, range=[[-1,.5],[0.225,0.4]])


        
相关标签:
1条回答
  • 2021-01-23 16:15

    You could define a mask for your data before you plot and then just plot the data points that actually meet your criteria. Below an example, where all data points above a certain line are plotted in green and all data points below the line are plotted in black.

    from matplotlib import pyplot as plt
    import numpy as np
    
    #the scatterplot data
    xvals = np.random.rand(100)
    yvals = np.random.rand(100)
    
    #the line
    b  = 0.1
    m = 1
    x = np.linspace(0,1,num=100)
    y = m*x+b
    
    mask = yvals > m*xvals+b
    
    plt.scatter(xvals[mask],yvals[mask],color='g')
    plt.scatter(xvals[~mask],yvals[~mask],color='k')
    plt.plot(x,y,'r')
    plt.show()
    

    The result looks like this

    Hope this helps.

    EDIT:

    If you want to create a 2D histogram, where the portion below the line is set to zero, you can do that by first generating the histogram using numpy (as an array) and then setting the values inside that array to zero, if the bins fall below the line. After that, you can plot the matrix using plt.pcolormesh:

    from matplotlib import pyplot as plt
    import numpy as np
    
    #the scatterplot data
    xvals = np.random.rand(1000)
    yvals = np.random.rand(1000)
    histogram,xbins,ybins = np.histogram2d(xvals,yvals,bins=50)
    
    #computing the bin centers from the bin edges:
    xcenters = 0.5*(xbins[:-1]+xbins[1:])
    ycenters = 0.5*(ybins[:-1]+ybins[1:])
    
    #the line
    b  = 0.1
    m = 1
    x = np.linspace(0,1,num=100)
    y = m*x+b
    
    #hiding the part of the histogram below the line
    xmesh,ymesh = np.meshgrid(xcenters,ycenters)
    mask = m*xmesh+b > ymesh
    histogram[mask] = 0
    
    #making the plot
    mat = plt.pcolormesh(xcenters,ycenters,histogram)
    line = plt.plot(x,y,'r')
    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.show()
    

    The result would be something like this:

    0 讨论(0)
提交回复
热议问题