How to plot non-linear decision boundaries with a grid in R?

前端 未结 2 978
陌清茗
陌清茗 2021-02-06 17:36

I find this particular graph in ISLR (Figure 2.13) or ESL very well done. I can\'t guess how the authors would have made this in R. I know how to get the orange and blue points

相关标签:
2条回答
  • 2021-02-06 17:55

    As I indicated in my comment, a solution was provided by @chl here on stats.stackexchange.com. Here it is, applied to your data set.

    library(class)
    set.seed(pi)
    X <- t(replicate(1000, runif(2)))
    g <- ifelse(apply(X, 1, sum) <= 1, 0, 1)
    xnew <- cbind(rep(seq(0, 1, length.out=50), 50),
                  rep(seq(0, 1, length.out=50), each=50))
    m <- knn(X, xnew, g, k=15, prob=TRUE)
    prob <- attr(m, "prob")
    prob <- ifelse(m=="1", prob, 1-prob)
    prob15 <- matrix(prob, 50)
    par(mar=rep(3, 4))
    contour(unique(xnew[, 1]), unique(xnew[, 2]), prob15, levels=0.5, 
            labels="", xlab='', ylab='', axes=FALSE, lwd=2.5, asp=1)
    title(xlab=expression(italic('X')[1]), ylab=expression(italic('X')[2]), 
          line=1, family='serif', cex.lab=1.5)
    points(X, bg=ifelse(g==1, "#CA002070", "#0571B070"), pch=21)
    gd <- expand.grid(x=unique(xnew[, 1]), y=unique(xnew[, 2]))
    points(gd, pch=20, cex=0.4, col=ifelse(prob15 > 0.5, "#CA0020", "#0571B0"))
    box()
    

    decision boundary

    (UPDATE: I changed the colour palette because the blue/yellow/purple thing was pretty hideous.)

    0 讨论(0)
  • 2021-02-06 18:04

    This was my silly attempt at approximation. Clearly the issues raised by @StephenKolassa are valid and not handled by this approximation.

    myCurve1 = function (x)
      abs(x[[1]] * sin(x[[1]]) + x[[2]] * sin(x[[2]]))
    myCurve2 = function (x)
      abs(x[[1]] * cos(x[[1]]) + x[[2]] * cos(x[[2]]))
    myCurve3 = function (x)
      abs(x[[1]] * tan(x[[1]]) + x[[2]] * tan(x[[2]]))
    
    tmp = function (myCurve, seed=99) {
      set.seed(seed)
      points = replicate(100, runif(2))
      colors = ifelse(apply(points, 2, myCurve) > 0.5, "orange", "blue")
      # Confound some
      swapInts = sample.int(length(colors), 6)
      for (i in swapInts) {
        if (colors[[i]] == "orange") {
          colors[[i]] = "blue"
        } else {
          colors[[i]] = "orange"
        }
      }
      gridPoints = seq(0, 1, 0.005)
      gridPoints = as.matrix(expand.grid(gridPoints, gridPoints))
      gridColors = vector("character", nrow(gridPoints))
      gridPch = vector("character", nrow(gridPoints))
      for (i in 1:nrow(gridPoints)) {
        val = myCurve(gridPoints[i, ])
        if (val > 0.505) {
          gridColors[[i]] = "orange"
          gridPch[[i]] = "."
        } else if (val < 0.495) {
          gridColors[[i]] = "blue"
          gridPch[[i]] = "."
        } else {
          gridColors[[i]] = "purple"
          gridPch[[i]] = "*"
        }
      }
      plot(x=gridPoints[ , 1], y=gridPoints[ , 2], col=gridColors, pch=gridPch)
      points(x=points[1, ], y=points[2, ], col=colors, lwd=2)
    }
    
    par(mfrow=c(1, 3))
    tmp(myCurve1)
    tmp(myCurve2)
    tmp(myCurve3)
    

    enter image description here

    0 讨论(0)
提交回复
热议问题