How to correctly use scikit-learn's Gaussian Process for a 2D-inputs, 1D-output regression?

后端 未结 2 1963
情深已故
情深已故 2020-12-28 14:54

Prior to posting I did a lot of searches and found this question which might be exactly my problem. However, I tried what is proposed in the answer but unfortunately this di

2条回答
  •  一生所求
    2020-12-28 15:31

    You're using two features to predict a third. Rather than a 3D plot like plot_surface, it's usually clearer if you use a 2D plot that's able to show information about a third dimension, like hist2d or pcolormesh. Here's a complete example using data/code similar to that in the question:

    from itertools import product
    import numpy as np
    from matplotlib import pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    
    from sklearn.gaussian_process import GaussianProcessRegressor
    from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
    
    X = np.array([[0,0],[2,0],[4,0],[6,0],[8,0],[10,0],[12,0],[14,0],[16,0],[0,2],
                        [2,2],[4,2],[6,2],[8,2],[10,2],[12,2],[14,2],[16,2]])
    
    y = np.array([-54,-60,-62,-64,-66,-68,-70,-72,-74,-60,-62,-64,-66,
                        -68,-70,-72,-74,-76])
    
    # Input space
    x1 = np.linspace(X[:,0].min(), X[:,0].max()) #p
    x2 = np.linspace(X[:,1].min(), X[:,1].max()) #q
    x = (np.array([x1, x2])).T
    
    kernel = C(1.0, (1e-3, 1e3)) * RBF([5,5], (1e-2, 1e2))
    gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=15)
    
    gp.fit(X, y)
    
    x1x2 = np.array(list(product(x1, x2)))
    y_pred, MSE = gp.predict(x1x2, return_std=True)
    
    X0p, X1p = x1x2[:,0].reshape(50,50), x1x2[:,1].reshape(50,50)
    Zp = np.reshape(y_pred,(50,50))
    
    # alternative way to generate equivalent X0p, X1p, Zp
    # X0p, X1p = np.meshgrid(x1, x2)
    # Zp = [gp.predict([(X0p[i, j], X1p[i, j]) for i in range(X0p.shape[0])]) for j in range(X0p.shape[1])]
    # Zp = np.array(Zp).T
    
    fig = plt.figure(figsize=(10,8))
    ax = fig.add_subplot(111)
    ax.pcolormesh(X0p, X1p, Zp)
    
    plt.show()
    

    Output:

    Kinda plain looking, but so was my example data. In general, you shouldn't expect to get particular interesting resulting with this few data points.

    Also, if you do want the surface plot, you can just replace the pcolormesh line with what you originally had (more or less):

    ax = fig.add_subplot(111, projection='3d')            
    surf = ax.plot_surface(X0p, X1p, Zp, rstride=1, cstride=1, cmap='jet', linewidth=0, antialiased=False)
    

    Output:

提交回复
热议问题