How to reproduce predict.svm in R? [closed]

|▌冷眼眸甩不掉的悲伤 提交于 2019-12-09 01:02:10

问题


I want to train an SVM classifier in R and be able to use it in other software by exporting the relevant parameters. To do so, I first want to be able to reproduce the behavior of predict.svm() in R (using the e1071 package).

I trained the model based on the iris data.

data(iris)

# simplify the data by removing the third label
ir <- iris[1:100,]
ir$Species <- as.factor(as.integer(ir$Species))

# train the model
m <- svm(Species ~ ., data=ir, cost=8)

# the model internally uses a scaled version of the data, example:
m$x.scale
# # # # # 
# $`scaled:center`
# Sepal.Length  Sepal.Width Petal.Length  Petal.Width 
#        5.471        3.099        2.861        0.786 
#
# $`scaled:scale`
# Sepal.Length  Sepal.Width Petal.Length  Petal.Width 
#    0.6416983    0.4787389    1.4495485    0.5651531 
# # # # #

# because the model uses scaled data, make a scaled data frame
irs<-ir;
sc<-data.frame(m$x.scale);
for(col in row.names(sc)){
      irs[[col]]<-(ir[[col]]-sc[[col,1]])/sc[[col,2]]
}

# a radial kernel function
k<-function(x,x1,gamma){
    return(exp(-gamma*sum((x-x1)^2)))
}

According to Hastie, Tibshirani, Friedman (2001), equation 12.24, the prediction function of x can be written as the sum over the support vectors of the coefficient times the kernel function of the SV and x, which corresponds to a matrix product, plus the intercept.

$\hat{f}(x)= \sum^N_{i=1} \hat{\alpha}_i y_i K(x,x_i)+\hat{\beta}_0 $, where $y_i$ is already contained in m$coefs.

# m$coefs contains the coefficients of the support vectors, m$SV 
# the support vectors, and m$rho the *negative* intercept
f<-function(x,m){
    return(t(m$coefs) %*% as.matrix(apply(m$SV,1,k,x,m$gamma)) - m$rho)
}

# a prediction function based on the sign of the prediction function
my.predict<-function(m,x){
    apply(x,1,function(y) sign(f(y,m)))
}

# applying my prediction function to the scaled data frame should
# yield the same result as applying predict.svm() to the original data
# example, thus the table should show one-to-one correspondence:
table(my.predict(m,irs[,1:4]),predict(m,ir[,1:4]))

# the unexpected result:    
# # # # #
#      1  2
#  -1  4 24
#  1  46 26
# # # # #

Who can explain where this is going wrong?

Edit: there was a minor error in my code, it now gives the following, expected result:

      1  2
  -1  0 50
  1  50  0

I hope to be of help to anyone facing the same problem.


回答1:


There was a minor error in one of my functions. The edited version works.



来源:https://stackoverflow.com/questions/14375072/how-to-reproduce-predict-svm-in-r

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!