Как использовать Rcpp для ускорения цикла for?

Я создал цикл for, и я хотел бы ускорить его, используя библиотеку Rcpp. Я не очень хорошо знаком с С++. Не могли бы вы помочь мне быстрее выполнить мою функцию? Спасибо за помощь!

Я включил свой алгоритм, код вместе с вводом и выходом с sessionInfo.

Вот мой алгоритм:

если текущая цена выше предыдущей цены, отметьте (+1) в столбце TR

если текущая цена ниже предыдущей цены, отметьте (-1) в столбце TR

если текущая цена совпадает с предыдущей ценой,    отметьте то же, что и в предыдущей цене в столбце TR

Вот мой код:

price <- c(71.91, 71.82, 71.81, 71.81, 71.81, 71.82, 71.81, 71.81, 71.81, 
           71.82, 71.81, 71.81, 71.8, 71.81, 71.8, 71.81, 71.8, 71.8, 71.8, 
           71.8, 71.81, 71.81, 71.81, 71.81, 71.81, 71.81, 71.81, 71.81, 
           71.81, 71.82, 71.81, 71.81, 71.81, 71.81, 71.81, 71.81, 71.8, 
           71.8, 71.81, 71.81, 71.81, 71.81, 71.82, 71.82, 71.81, 71.81, 
           71.81, 71.81, 71.81, 71.81, 71.81, 71.82, 71.82, 71.82, 71.82, 
           71.82, 71.82, 71.82, 71.81, 71.82, 71.82, 71.82, 71.82, 71.82, 
           71.82, 71.82, 71.82, 71.82, 71.81, 71.81, 71.81, 71.82, 71.82, 
           71.81, 71.82, 71.82, 71.82, 71.81, 71.82, 71.82, 71.82, 71.81, 
           71.81, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 
           71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 
           71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 
           71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 
           71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 
           71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 71.81, 
           71.82, 71.82, 71.82, 71.82, 71.83, 71.82, 71.82, 71.82, 71.81, 
           71.81, 71.81, 71.81, 71.81, 71.81, 71.81, 71.82, 71.82, 71.82, 
           71.81, 71.81, 71.81, 71.82, 71.82, 71.82, 71.82, 71.82, 71.82, 
           71.82, 71.82, 71.82, 71.82, 71.82, 71.83, 71.83, 71.83, 71.83, 
           71.83, 71.83, 71.83, 71.83, 71.83, 71.83, 71.83, 71.83, 71.83, 
           71.83, 71.83, 71.83, 71.83, 71.83, 71.83, 71.83, 71.83, 71.83, 
           71.83, 71.83, 71.83, 71.83, 71.83, 71.83, 71.83, 71.83, 71.83, 
           71.83)

TR <- numeric(length(price)-1)
TR <- c(NA,TR)

for (i in 1: (length(price)-1)){

  if (price[i] == price[i+1]) {TR[i+1] = TR[i]}

  if (price[i] < price[i+1]) {TR[i+1] = 1}

  if (price[i] > price[i+1]) {TR[i+1] = -1}

}

И вот мой вывод: dput (TR) дает

c(NA, -1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, 1, 
  -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 
  -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, 
  1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 
  1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
  1, 1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 
  -1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1)

и вот мой sessionInfo:

> sessionInfo()
R version 3.1.2 (2014-10-31)
Platform: x86_64-w64-mingw32/x64 (64-bit)

locale:
[1] LC_COLLATE=English_United States.1252  LC_CTYPE=English_United States.1252   
[3] LC_MONETARY=English_United States.1252 LC_NUMERIC=C                          
[5] LC_TIME=English_United States.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] data.table_1.9.4

loaded via a namespace (and not attached):
[1] chron_2.3-45  plyr_1.8.1    Rcpp_0.11.1   reshape2_1.4  stringr_0.6.2 tools_3.1.2  

Ответ 1

Вы можете довольно просто перевести цикл for:

library(Rcpp)
cppFunction(
"IntegerVector proc(NumericVector x) {
  const int n = x.size();
  IntegerVector y(n);
  y[0] = NA_INTEGER;
  for (int i=1; i < n; ++i) {
    if (x[i] == x[i-1]) y[i] = y[i-1];
    else if (x[i] > x[i-1]) y[i] = 1;
    else y[i] = -1;
  }
  return y;
}")

Как обычно, вы можете получить довольно большое ускорение с помощью Rcpp по сравнению с циклом for в базе R:

proc.for <- function(price) {
  TR <- numeric(length(price)-1)
  TR <- c(NA,TR)
  for (i in 1: (length(price)-1)){
    if (price[i] == price[i+1]) {TR[i+1] = TR[i]}
    if (price[i] < price[i+1]) {TR[i+1] = 1}
    if (price[i] > price[i+1]) {TR[i+1] = -1}
  }
  return(TR)
}
proc.aaron <- function(price) {
  change <- sign(diff(price))
  good <- change != 0
  goodval <- change[good]
  c(NA, goodval[cumsum(good)])
}
proc.jbaums <- function(price) {
  TR <- sign(diff(price))
  TR[TR==0] <- TR[which(TR != 0)][findInterval(which(TR == 0), which(TR != 0))]
  TR
}

all.equal(proc(price), proc.for(price), proc.aaron(price), proc.jbaums(price))
# [1] TRUE
library(microbenchmark)
microbenchmark(proc(price), proc.for(price), proc.aaron(price), proc.jbaums(price))
# Unit: microseconds
#                expr     min       lq      mean   median       uq      max neval
#         proc(price)   1.871   2.5380   3.92111   3.1110   4.5880   15.318   100
#     proc.for(price) 408.200 448.2830 542.19766 484.1265 546.3255 1821.104   100
#   proc.aaron(price)  23.916  25.5770  33.53259  31.5420  35.8575  190.372   100
#  proc.jbaums(price)  33.536  38.8995  46.80109  43.4510  49.3555  112.306   100

Мы видим ускорение более чем на 100x по сравнению с циклом for и 10x по сравнению с векторизованными альтернативами предоставленного вектора.

Ускорение еще более значимо с большим вектором (длина 1 миллион тестируется здесь):

price.big <- rep(price, times=5000)
all.equal(proc(price.big), proc.for(price.big), proc.aaron(price.big), proc.jbaums(price.big))
# [1] TRUE
microbenchmark(proc(price.big), proc.for(price.big), proc.aaron(price.big), proc.jbaums(price.big))
# Unit: milliseconds
#                    expr         min          lq        mean      median          uq        max neval
#         proc(price.big)    1.442119    1.818494    5.094274    2.020437    2.771903   56.54321   100
#     proc.for(price.big) 2639.819536 2699.493613 2949.962241 2781.636460 3062.277930 4472.35369   100
#   proc.aaron(price.big)   91.499940   99.859418  132.519296  140.521212  147.462259  207.72813   100
#  proc.jbaums(price.big)  117.242451  138.528214  170.989065  170.606048  180.337074  487.13615   100

Теперь мы имеем ускорение 1000x по сравнению с циклом for и ускорением ~ 70x по сравнению с векторизованными R-функциями. Даже при таком размере неясно, существует ли много преимуществ Rcpp над векторизованными R-решениями, если функция вызывается только один раз, поскольку для компиляции кода Rcpp требуется, как минимум, 100 мс. Ускорение довольно привлекательно, если это фрагмент кода, который неоднократно вызывался в вашем анализе.

Ответ 2

Вы можете выполнить байтовую компиляцию. Также полезно посмотреть на цикл R, который использует ту же логику if-else-if-else, что и код Rcpp. С R 3.1.2 я получаю

f1 <- function(price) {
    TR <- numeric(length(price)-1)
    TR <- c(NA,TR)
    for (i in 1: (length(price)-1)){
        if (price[i] == price[i+1]) {TR[i+1] = TR[i]}
        if (price[i] < price[i+1]) {TR[i+1] = 1}
        if (price[i] > price[i+1]) {TR[i+1] = -1}
    }
    return(TR)
}

f2 <- function(price) {
    TR <- numeric(length(price)-1)
    TR <- c(NA,TR)
    for (i in 1: (length(price)-1)){
        if (price[i] == price[i+1]) {TR[i+1] = TR[i]}
        else if (price[i] < price[i+1]) {TR[i+1] = 1}
        else {TR[i+1] = -1}
    }
    return(TR)
}

library(compiler)
f1c <- cmpfun(f1)
f2c <- cmpfun(f2)

library(microbenchmark)
microbenchmark(f1(price), f2(price), f1c(price), f2c(price), times = 1000)
## Unit: microseconds
##       expr     min       lq     mean   median       uq       max neval  cld
##  f1(price) 536.619 570.3715 667.3520 586.2465 609.9280 45046.462  1000    d
##  f2(price) 328.592 351.2070 386.5895 365.0245 381.4850  1302.497  1000   c 
## f1c(price) 167.570 182.4645 218.9537 192.4780 204.7810  7843.291  1000  b  
## f2c(price)  96.644 107.4465 124.1324 113.5470 121.5365  1019.389  1000 a   

R-devel, который будет выпущен как R 3.2.0 в апреле, имеет ряд улучшений в механизме байтового кода для таких скалярных вычислений; там я получаю

microbenchmark(f1(price), f2(price), f1c(price), f2c(price), times = 1000)
## Unit: microseconds
##        expr     min       lq      mean   median       uq      max neval  cld
##   f1(price) 490.300 520.3845 559.19539 533.2050 548.6850 1330.219  1000    d
##   f2(price) 298.375 319.7475 348.71384 330.4535 342.6405 1813.113  1000   c 
##  f1c(price)  61.947  66.3255  68.01555  67.7270  69.5470  138.308  1000  b  
##  f2c(price)  36.334  38.9500  40.45085  40.1830  41.8610   55.909  1000 a   

Это приводит вас к тому же общему этапу, что и векторизованные решения в этом примере. Есть еще место для дальнейшего улучшения механизма байтового кода, который должен сделать его в будущих выпусках.

Все решения отличаются тем, как они обрабатывают значения NA/NaN, которые могут или не имеют значения для вас.

Ответ 3

Возможно, сначала попробуйте векторизацию. Хотя это, вероятно, будет не так быстро, как Rcpp, это более прямолинейно.

f2 <- function(price) {
    change <- sign(diff(price))
    good <- change != 0
    goodval <- change[good]
    c(NA, goodval[cumsum(good)])
}

И он по-прежнему значительно ускоряет работу над циклом R.

f1 <- function(price) {
    TR <- numeric(length(price)-1)
    TR <- c(NA,TR)
    for (i in 1: (length(price)-1)){
        if (price[i] == price[i+1]) {TR[i+1] = TR[i]}
        if (price[i] < price[i+1]) {TR[i+1] = 1}
        if (price[i] > price[i+1]) {TR[i+1] = -1}
    }
    TR
}

microbenchmark(f1(price), f2(price), times=100)
## Unit: microseconds
##       expr     min       lq      mean   median       uq      max neval cld
##  f1(price) 550.037 592.9830 756.20095 618.7910 703.8335 3042.530   100   b
##  f2(price)  36.915  39.3285  56.45267  45.5225  60.1965  184.536   100  a 

Ответ 4

Это легко может быть проиндексировано в R.

Например, с diff и findInterval:

TR <- sign(diff(price))
TR[TR==0] <- TR[which(TR != 0)][findInterval(which(TR == 0), which(TR != 0))]