Plotting 3D Decision Boundary From Linear SVM

前端 未结 2 1124
隐瞒了意图╮
隐瞒了意图╮ 2021-01-02 19:40

I\'ve fit a 3 feature data set using sklearn.svm.svc(). I can plot the point for each observation using matplotlib and Axes3D. I want to plot the decision boundary to see th

2条回答
  •  孤城傲影
    2021-01-02 20:04

    Here is an example on a toy dataset. Note that plotting in 3D is funky with matplotlib. Sometimes points that are behind the plane might appear as though they are in front of it, so you may have to fiddle with rotating the plot to ascertain what's going on.

    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    from sklearn.svm import SVC
    
    rs = np.random.RandomState(1234)
    
    # Generate some fake data.
    n_samples = 200
    # X is the input features by row.
    X = np.zeros((200,3))
    X[:n_samples/2] = rs.multivariate_normal( np.ones(3), np.eye(3), size=n_samples/2)
    X[n_samples/2:] = rs.multivariate_normal(-np.ones(3), np.eye(3), size=n_samples/2)
    # Y is the class labels for each row of X.
    Y = np.zeros(n_samples); Y[n_samples/2:] = 1
    
    # Fit the data with an svm
    svc = SVC(kernel='linear')
    svc.fit(X,Y)
    
    # The equation of the separating plane is given by all x in R^3 such that:
    # np.dot(svc.coef_[0], x) + b = 0. We should solve for the last coordinate
    # to plot the plane in terms of x and y.
    
    z = lambda x,y: (-svc.intercept_[0]-svc.coef_[0][0]*x-svc.coef_[0][1]*y) / svc.coef_[0][2]
    
    tmp = np.linspace(-2,2,51)
    x,y = np.meshgrid(tmp,tmp)
    
    # Plot stuff.
    fig = plt.figure()
    ax  = fig.add_subplot(111, projection='3d')
    ax.plot_surface(x, y, z(x,y))
    ax.plot3D(X[Y==0,0], X[Y==0,1], X[Y==0,2],'ob')
    ax.plot3D(X[Y==1,0], X[Y==1,1], X[Y==1,2],'sr')
    plt.show()
    

    Output:

    EDIT (Key Mathematical Linear Algebra Statement In Comment Above):

    # The equation of the separating plane is given by all x in R^3 such that:
    # np.dot(coefficients, x_vector) + intercept_value = 0. 
    # We should solve for the last coordinate: x_vector[2] == z
    # to plot the plane in terms of x and y.
    

提交回复
热议问题