问题
I have dataset that looks like this:
Category Weekly_Date a b
<chr> <date> <dbl> <dbl>
1 aa 2018-07-01 36.6 1.4
2 aa 2018-07-02 5.30 0
3 bb 2018-07-01 4.62 1.2
4 bb 2018-07-02 3.71 1.5
5 cc 2018-07-01 3.41 12
... ... ... ... ...
I fitted linear regression for each group separately:
fit_linreg <- train %>%
group_by(Category) %>%
do(model = lm(Target ~ Unit_price + Unit_discount, data = .))
Now I have different models for each category:
aa model1
bb model2
cc model3
So, I need to apply each model to the appropriate category. How to achieve that? (dplyr is preferable)
回答1:
if you nest the data of your test data, join it with the models, then you can use map2 to make predictions on the test data with the trained models. See below example with mtcars.
library(tidyverse)
x <- mtcars %>%
group_by(gear) %>%
do(model = lm(mpg ~ hp + wt, data = .))
x
Source: local data frame [3 x 2]
Groups: <by row>
# A tibble: 3 x 2
gear model
* <dbl> <list>
1 3 <S3: lm>
2 4 <S3: lm>
3 5 <S3: lm>
mtcars %>%
group_by(gear) %>%
nest %>%
inner_join(x) %>%
mutate(preds = map2(model, data, predict)) %>%
unnest(preds)
Joining, by = "gear"
# A tibble: 32 x 2
gear preds
<dbl> <dbl>
1 4 22.0
2 4 21.2
3 4 25.1
4 4 26.0
5 4 22.2
6 4 17.8
7 4 17.8
8 4 28.7
9 4 32.3
10 4 30.0
# ... with 22 more rows
回答2:
Here's one approach, I'm using data.table
to filter but you can use dplyr
instead as well, I just prefer the data.table
syntax.
d <- as.data.table(mtcars)
cats <- unique(d$cyl)
m <- lapply(cats, function(z){
return(lm(formula = mpg ~ wt + hp + disp,
data = d[cyl == z, ] ))
})
names(m) <- cats
OUTPUT
> summary(m)
Length Class Mode
6 12 lm list
4 12 lm list
8 12 lm list
# Checking first model
> m[[1]]
Call:
lm(formula = mpg ~ wt + hp + disp, data = d[cyl == z, ])
Coefficients:
(Intercept) wt hp disp
30.27791 -3.89618 -0.01097 0.01610
> sapply(1:length(m), function(z) return(summary(m[[z]])$adj.r.squared))
[1] 0.4434228 0.5829574 0.3461900
I named the list because it might be easier to refer to models by name aa
or bb
in your case. Hope this helps!
回答3:
I find the nesting and un-nesting very unnatural, so here's my attempt.
Let's say you want the quality of the model's fit.
library(dplyr)
mtcars %>%
group_by(cyl) %>%
do(data.frame(r2 = summary(lm(mpg ~ wt, data = .))$r.squared))
#> # A tibble: 3 x 2
#> # Groups: cyl [3]
#> cyl r2
#> <dbl> <dbl>
#> 1 4 0.509
#> 2 6 0.465
#> 3 8 0.423
Let's say you want the residuals:
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
mtcars %>%
group_by(cyl) %>%
do(data.frame(resid = residuals(lm(mpg ~ wt, data = .))))
#> # A tibble: 32 x 2
#> # Groups: cyl [3]
#> cyl resid
#> <dbl> <dbl>
#> 1 4 -3.67
#> 2 4 2.84
#> 3 4 1.02
#> 4 4 5.25
#> 5 4 -0.0513
#> 6 4 4.69
#> 7 4 -4.15
#> 8 4 -1.34
#> 9 4 -1.49
#> 10 4 -0.627
#> # ... with 22 more rows
See ?do
for why you need the embedded data.frame()
. You'll probably want to include other columns in the result. Not just the grouping variable and the residuals. I can't find a neat way to do this, other than listing them!
library(dplyr)
mtcars %>%
group_by(cyl) %>%
do(data.frame(disp = .$disp,
qsec = .$qsec,
resid = residuals(lm(mpg ~ wt, data = .))))
#> # A tibble: 32 x 4
#> # Groups: cyl [3]
#> cyl disp qsec resid
#> <dbl> <dbl> <dbl> <dbl>
#> 1 4 108 18.6 -3.67
#> 2 4 147. 20 2.84
#> 3 4 141. 22.9 1.02
#> 4 4 78.7 19.5 5.25
#> 5 4 75.7 18.5 -0.0513
#> 6 4 71.1 19.9 4.69
#> 7 4 120. 20.0 -4.15
#> 8 4 79 18.9 -1.34
#> 9 4 120. 16.7 -1.49
#> 10 4 95.1 16.9 -0.627
#> # ... with 22 more rows
Something that doesn't work
For the first example, I thought the following would work:
library(dplyr)
mtcars %>%
group_by(cyl) %>%
summarise(r2 = summary(lm(mpg ~ wt, data = .))$r.squared)
#> # A tibble: 3 x 2
#> cyl r2
#> <dbl> <dbl>
#> 1 4 0.753
#> 2 6 0.753
#> 3 8 0.753
But you can see all models have the same r2. It's because the model is being fit to all the data, not per cyl
. Looking at the authors' code, I believe this is because they've optimised the evaluation of mutate()
and summarise()
using Rcpp, and the optimisation doesn't work in this case. But do()
works as expected. It subsets the data by group before passing it to the expression to be evaluated. I see they are pondering this, see Hyrbid Folding
来源:https://stackoverflow.com/questions/52168341/make-prediction-for-each-group-differently