Predict with step_naomit and retain ID using tidymodels

百般思念 提交于 2020-04-30 06:40:10

问题


I am trying to retain an ID on the row when predicting using a Random Forest model to merge back on to the original dataframe. I am using step_naomit in the recipe that removes the rows with missing data when I bake the training data, but also removes the records with missing data on the testing data. Unfortunately, I don't have an ID to easily know which records were removed so I can accurately merge back on the predictions.

I have tried to add an ID column to the original data, but bake will remove any variable not included in the formula (and I don't want to include ID in the formula). I also thought I may be able to retain the row.names from the original table to merge on, but it appears the row.name is reset upon baking as well.

I realize I can remove the NA values prior to the recipe to solve this problem, but then what is the point of step_naomit in the recipe? I also tried skip=TRUE in the step_naomit, but then I get an error for missing data when fitting the model (only for random forest). I feel I am missing something here in tidymodels that would allow me to retain all the rows prior to baking?

See example:


## R 3.6.1 ON WINDOWS 10 MACHINE

require(tidyverse)
require(tidymodels)
require(ranger)

set.seed(123)

temp <- iris %>%
    dplyr::mutate(Petal.Width = case_when(
        round(Sepal.Width) %% 2 == 0 ~ NA_real_, ## INTRODUCE NA VALUES
        TRUE ~ Petal.Width))

mySplit <- rsample::initial_split(temp, prop = 0.8)

myRecipe <- function(dataFrame) {
    recipes::recipe(Petal.Width ~ ., data = dataFrame) %>%
        step_naomit(all_numeric()) %>%
        prep(data = dataFrame)
}

myPred <- function(mySplit,myRecipe) {

    train_set <- training(mySplit)
    test_set <- testing(mySplit)

    train_prep <- myRecipe(train_set)

    analysis_processed <- bake(train_prep, new_data = train_set)

    model <- rand_forest(
            mode = "regression",
            mtry = 3,
            trees = 50) %>%
        set_engine("ranger", importance = 'impurity') %>%
        fit(Sepal.Width ~ ., data=analysis_processed)

    test_processed <- bake(train_prep, new_data = test_set)

    test_processed %>%
        bind_cols(myPrediction = unlist(predict(model,new_data=test_processed))) 

}

getPredictions <- myPred(mySplit,myRecipe)

nrow(getPredictions)

##  21 ROWS

max(as.numeric(row.names(getPredictions)))

##  21

nrow(testing(mySplit))

##  29 ROWS

max(as.numeric(row.names(testing(mySplit))))

##  150

回答1:


To be able to keep track of which observations were removed we need to give the original dataset an id variable.

temp <- iris %>%
    dplyr::mutate(Petal.Width = case_when(
        round(Sepal.Width) %% 2 == 0 ~ NA_real_, ## INTRODUCE NA VALUES
        TRUE ~ Petal.Width),
        id = row_number()) #<<<<

Then we use update_role() to first designate it as an "id variable", then remove it as a predictor so it doesn't become part of the modeling process. And that is it. Everything else should work like before. Below is the fully updated code with #<<<< to denote my changes.

require(tidyverse)
#> Loading required package: tidyverse
require(tidymodels)
#> Loading required package: tidymodels
#> Registered S3 method overwritten by 'xts':
#>   method     from
#>   as.zoo.xts zoo
#> ── Attaching packages ───────────────────── tidymodels 0.0.3 ──
#> ✔ broom     0.5.2     ✔ recipes   0.1.7
#> ✔ dials     0.0.3     ✔ rsample   0.0.5
#> ✔ infer     0.5.0     ✔ yardstick 0.0.4
#> ✔ parsnip   0.0.4
#> ── Conflicts ──────────────────────── tidymodels_conflicts() ──
#> ✖ scales::discard() masks purrr::discard()
#> ✖ dplyr::filter()   masks stats::filter()
#> ✖ recipes::fixed()  masks stringr::fixed()
#> ✖ dplyr::lag()      masks stats::lag()
#> ✖ dials::margin()   masks ggplot2::margin()
#> ✖ dials::offset()   masks stats::offset()
#> ✖ yardstick::spec() masks readr::spec()
#> ✖ recipes::step()   masks stats::step()
require(ranger)
#> Loading required package: ranger

set.seed(1234)

temp <- iris %>%
    dplyr::mutate(Petal.Width = case_when(
        round(Sepal.Width) %% 2 == 0 ~ NA_real_, ## INTRODUCE NA VALUES
        TRUE ~ Petal.Width),
        id = row_number()) #<<<<

mySplit <- rsample::initial_split(temp, prop = 0.8)

myRecipe <- function(dataFrame) {
    recipes::recipe(Petal.Width ~ ., data = dataFrame) %>%
        update_role(id, new_role = "id variable") %>%  #<<<<
        update_role(-id, new_role = 'predictor') %>%   #<<<<
        step_naomit(all_numeric()) %>%
        prep(data = dataFrame)
}

myPred <- function(mySplit,myRecipe) {

    train_set <- training(mySplit)
    test_set <- testing(mySplit)

    train_prep <- myRecipe(train_set)

    analysis_processed <- bake(train_prep, new_data = train_set)

    model <- rand_forest(
            mode = "regression",
            mtry = 3,
            trees = 50) %>%
        set_engine("ranger", importance = 'impurity') %>%
        fit(Sepal.Width ~ ., data=analysis_processed)

    test_processed <- bake(train_prep, new_data = test_set)

    test_processed %>%
        bind_cols(myPrediction = unlist(predict(model,new_data=test_processed))) 

}

getPredictions <- myPred(mySplit, myRecipe)

getPredictions
#> # A tibble: 23 x 7
#>    Sepal.Length Sepal.Width Petal.Length Petal.Width Species     id myPrediction
#>           <dbl>       <dbl>        <dbl>       <dbl> <fct>    <int>        <dbl>
#>  1          4.6         3.1          1.5         0.2 setosa       4         3.24
#>  2          4.3         3            1.1         0.1 setosa      14         3.04
#>  3          5.1         3.4          1.5         0.2 setosa      40         3.22
#>  4          5.9         3            4.2         1.5 versico…    62         2.98
#>  5          6.7         3.1          4.4         1.4 versico…    66         2.92
#>  6          6           2.9          4.5         1.5 versico…    79         3.03
#>  7          5.7         2.6          3.5         1   versico…    80         2.79
#>  8          6           2.7          5.1         1.6 versico…    84         3.12
#>  9          5.8         2.6          4           1.2 versico…    93         2.79
#> 10          6.2         2.9          4.3         1.3 versico…    98         2.88
#> # … with 13 more rows

# removed ids
setdiff(testing(mySplit)$id, getPredictions$id)
#> [1]   5  28  47  70  90 132

Created on 2019-11-26 by the reprex package (v0.3.0)



来源:https://stackoverflow.com/questions/59043199/predict-with-step-naomit-and-retain-id-using-tidymodels

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!