Plot normal distribution in 3D

后端 未结 3 570
无人及你
无人及你 2020-12-29 15:38

I am trying to plot the comun distribution of two normal distributed variables.

The code below plots one normal distributed variable. What would the code be for plo

相关标签:
3条回答
  • 2020-12-29 16:04

    The following adaption to @Ianhi's code above returns a contour plot version of the 3D plot above.

    import matplotlib.pyplot as plt
    from matplotlib import style
    style.use('fivethirtyeight')
    import numpy as np
    from scipy.stats import multivariate_normal
    
    
    
    
    #Parameters to set
    mu_x = 0
    variance_x = 3
    
    mu_y = 0
    variance_y = 15
    
    x = np.linspace(-10,10,500)
    y = np.linspace(-10,10,500)
    X,Y = np.meshgrid(x,y)
    
    pos = np.array([X.flatten(),Y.flatten()]).T
    
    
    
    rv = multivariate_normal([mu_x, mu_y], [[variance_x, 0], [0, variance_y]])
    
    
    fig = plt.figure(figsize=(10,10))
    ax0 = fig.add_subplot(111)
    ax0.contour(rv.pdf(pos).reshape(500,500))
    
    
    
    plt.show()
    

    0 讨论(0)
  • 2020-12-29 16:07

    It sounds like what you're looking for is a Multivariate Normal Distribution. This is implemented in scipy as scipy.stats.multivariate_normal. It's important to remember that you are passing a covariance matrix to the function. So to keep things simple keep the off diagonal elements as zero:

    [X variance ,     0    ]
    [     0     ,Y Variance]
    

    Here is an example using this function and generating a 3D plot of the resulting distribution. I add the colormap to make seeing the curves easier but feel free to remove it.

    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.stats import multivariate_normal
    from mpl_toolkits.mplot3d import Axes3D
    
    #Parameters to set
    mu_x = 0
    variance_x = 3
    
    mu_y = 0
    variance_y = 15
    
    #Create grid and multivariate normal
    x = np.linspace(-10,10,500)
    y = np.linspace(-10,10,500)
    X, Y = np.meshgrid(x,y)
    pos = np.empty(X.shape + (2,))
    pos[:, :, 0] = X; pos[:, :, 1] = Y
    rv = multivariate_normal([mu_x, mu_y], [[variance_x, 0], [0, variance_y]])
    
    #Make a 3D plot
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.plot_surface(X, Y, rv.pdf(pos),cmap='viridis',linewidth=0)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('Z axis')
    plt.show()
    

    Giving you this plot:

    Edit the method used below was deprecated in Matplotlib v2.2 and removed in v3.1

    A simpler version is available through matplotlib.mlab.bivariate_normal It takes the following arguments so you don't need to worry about matrices matplotlib.mlab.bivariate_normal(X, Y, sigmax=1.0, sigmay=1.0, mux=0.0, muy=0.0, sigmaxy=0.0) Here X, and Y are again the result of a meshgrid so using this to recreate the above plot:

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.mlab import bivariate_normal
    from mpl_toolkits.mplot3d import Axes3D
    
    #Parameters to set
    mu_x = 0
    sigma_x = np.sqrt(3)
    
    mu_y = 0
    sigma_y = np.sqrt(15)
    
    #Create grid and multivariate normal
    x = np.linspace(-10,10,500)
    y = np.linspace(-10,10,500)
    X, Y = np.meshgrid(x,y)
    Z = bivariate_normal(X,Y,sigma_x,sigma_y,mu_x,mu_y)
    
    #Make a 3D plot
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.plot_surface(X, Y, Z,cmap='viridis',linewidth=0)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('Z axis')
    plt.show()
    

    Giving:

    0 讨论(0)
  • 2020-12-29 16:15

    While the other answers are great, I wanted to achieve similar results while also illustrating the distribution with a scatter plot of the sample.

    More details can be found here: Python 3d plot of multivariate gaussian distribution

    The results looks like:

    And is generated using the following code:

    from mpl_toolkits.mplot3d import Axes3D
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib import cm
    from scipy.stats import multivariate_normal
    
    
    # Sample parameters
    mu = np.array([0, 0])
    sigma = np.array([[0.7, 0.2], [0.2, 0.3]])
    rv = multivariate_normal(mu, sigma)
    sample = rv.rvs(500)
    
    # Bounds parameters
    x_abs = 2.5
    y_abs = 2.5
    x_grid, y_grid = np.mgrid[-x_abs:x_abs:.02, -y_abs:y_abs:.02]
    
    pos = np.empty(x_grid.shape + (2,))
    pos[:, :, 0] = x_grid
    pos[:, :, 1] = y_grid
    
    levels = np.linspace(0, 1, 40)
    
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    
    # Removes the grey panes in 3d plots
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    
    # The heatmap
    ax.contourf(x_grid, y_grid, 0.1 * rv.pdf(pos),
                zdir='z', levels=0.1 * levels, alpha=0.9)
    
    # The wireframe
    ax.plot_wireframe(x_grid, y_grid, rv.pdf(
        pos), rstride=10, cstride=10, color='k')
    
    # The scatter. Note that the altitude is defined based on the pdf of the
    # random variable
    ax.scatter(sample[:, 0], sample[:, 1], 1.05 * rv.pdf(sample), c='k')
    
    ax.legend()
    ax.set_title("Gaussian sample and pdf")
    ax.set_xlim3d(-x_abs, x_abs)
    ax.set_ylim3d(-y_abs, y_abs)
    ax.set_zlim3d(0, 1)
    
    plt.show()
    
    0 讨论(0)
提交回复
热议问题