问题
I'm using seaborn for plotting data. Everything is fine until my mentor asked me how the plot is made in the following code for example.
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
x = np.random.normal(size=100)
sns.distplot(x)
plt.show()
The result of this code is:
My questions:
1- How does distplot manage to plot this?
2- Why starts the plot at -3
and ends at 4
?
3- Is there any parametric function or any specific mathematical function that distplot uses to plot the data like this?
I use distplot and kde to plot my data, but I would like to know what is the maths behind those functions.
回答1:
Here is some code trying to illustrate how the kde curve is drawn.
The code starts with a random sample of 100 xs.
These xs are shown in a histogram. With density=True
the histogram is normalized so that it's full area would be 1. (Standard, the bars of the histogram grow with the number of points. Internally, the complete area is calculated and each bar's height is divided by that area.)
To draw the kde, a gaussian "bell" curve is drawn around each of the N samples. These curves are summed, and normalized by dividing by N.
The sigma
of these curves is a free parameter. Default it is calculated by Scott's rule (N ** (-1/5)
or 0.4
for 100 points, the green curve in the example plot).
The code below shows the result for different choices of sigma
. Smaller sigma
s enclose the given data stronger, larger sigma
s appear more smooth. There is no perfect choice for sigma
, it depends strongly on the data and what is known (or guessed) about the underlying distribution.
import matplotlib.pyplot as plt
import numpy as np
def gauss(x, mu, sigma):
return np.exp(-((x - mu) / sigma) ** 2 / 2) / (sigma * np.sqrt(2 * np.pi))
N = 100
xs = np.random.normal(0, 1, N)
plt.hist(xs, density=True, label='Histogram', alpha=.4, ec='w')
x = np.linspace(xs.min() - 1, xs.max() + 1, 100)
for sigma in np.arange(.2, 1.2, .2):
plt.plot(x, sum(gauss(x, xi, sigma) for xi in xs) / N, label=f'$\\sigma = {sigma:.1f}$')
plt.xlim(x[0], x[-1])
plt.legend()
plt.show()
PS: Instead of a histogram or a kde, other ways to visualize 100 random numbers are a set of short lines:
plt.plot(np.repeat(xs, 3), np.tile((0, -0.05, np.nan), N), lw=1, c='k', alpha=0.5)
plt.ylim(ymin=-0.05)
or dots (jittered, so they don't overlap):
plt.scatter(xs, -np.random.rand(N)/10, s=1, color='crimson')
plt.ylim(ymin=-0.099)
来源:https://stackoverflow.com/questions/61228160/how-does-distplot-calculate-the-kde-curve