《R语言数据高…" /> 《R语言数据高…" />
绑定完请刷新页面
取消
刷新

分享好友

×
取消 复制
用C++给R语言加速:Rcpp简单用法
2020-06-22 13:36:35

作者:黄天元,复旦大学博士在读,热爱数据科学与开源工具(R),致力于利用数据科学迅速积累行业经验优势和科学知识发现,涉猎内容包括但不限于信息计量、机器学习、数据可视化、应用统计建模、知识图谱等,著有《R语言高效数据处理指南》(《R语言数据高效处理指南》(黄天元)【摘要 书评 试读】- 京东图书)。知乎专栏:R语言数据挖掘邮箱:huang.tian-yuan@qq.com.欢迎合作交流。

近要做一个小任务,它的描述非常简单,就是识别向量的变化。比如一个整数序列是“4,4,4,5,5,6,6,5,5,5,4,4”,那么我们要根据数字的连续关系来分组,输出应该是“1,1,1,2,2,3,3,4,4,4,5,5”。这个函数用R写起来非常简单,稍加思考草拟如下:

get_id = function(x){
  z = vector()
  y = NULL
  for (i in seq_along(x)) {
    if(i == 1) y = 1
    else if(x[i] != x[i-1]) y = y + 1
    z = c(z,y)
  }
  z
}

不妨做个小测试:

get_id = function(x){
  z = vector()
  y = NULL
  for (i in seq_along(x)) {
    if(i == 1) y = 1
    else if(x[i] != x[i-1]) y = y + 1
    z = c(z,y)
  }
  z
}
c(rep(33L,3),rep(44L,4),rep(33L,3))
#>  [1] 33 33 33 44 44 44 44 33 33 33
get_id(c(rep(33L,3),rep(44L,4),rep(33L,3)))
#>  [1] 1 1 1 2 2 2 2 3 3 3

得到的结果是准确的,而且按照这些代码,基本可以识别不同的数据类型,只要这些数据能够用“==”来判断是否相同(可能用setequal函数的健壮性更好)。

但是当数据量很大的时候,这样写是否足够快,就很重要了。这意味着看你要等一小时、一天还是一个月。想起自己小时候还学过C++,就希望尝试用Rcpp来加速,拟了代码如下:

library(Rcpp)

# 函数名称为get_id_c
cppFunction('
  IntegerVector get_id_c(IntegerVector x){
  int n = x.size();
  IntegerVector out(n);
  
  for (int i = 0; i < n; i++) {
    if(i == 1) out[i] = 1;
    else if(x[i] == x[i-1]) out[i] = out[i-1];
    else out[i] = out[i-1] + 1;
  }
  return out;
}')

需要声明的是,C++需要定义数据类型,因为任务是正整数,所以函数就接受一个整数向量,输出一个整数向量。多年不用C++,写这么一段代码居然调试过程就出了3次错,惭愧。但是对性能的提升效果非常显著,我们先做一些简单尝试。先尝试1万个整数:

library(pacman)
p_load(tidyfst)  

sys_time_print({
  res1 = get_id(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

0.14s和0.00s的区别,可能体会不够深。那么来10万个整数试试:

sys_time_print({
  res1 = get_id(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

13s vs 0s,有点不能忍了。那么,我们来100万个整数再来试试:

# 不要尝试这个
sys_time_print({
  res1 = get_id(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

# 可以尝试这个
sys_time_print({
  res2 = get_id_c(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

好的,关于这段代码:

sys_time_print({
  res1 = get_id(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

可以不要尝试了,因为直接卡死了。但是如果用Rcpp构造的函数,那么就放心试吧,我们还远远没有探知其算力上限。可以观察一下结果:

我们可以看到,1亿个整数,也就是0.69秒;10亿是7.15秒。虽然想尝试百亿,但是我的计算机内存已经不够了。

总结一下,永远不敢说R的速度不够快,只是自己代码写得烂而已(尽管完成了实现,其实get_id这个函数优化的空间是很多的),方法总比问题多。不多说了,去温习C++,学习Rcpp去了。后面如果有闲暇,来做一个Rcpp的学习系列。放一个核心资料链接:

Seamless R and C++ Integrationrcpp.org

根据评论区提示,重新写了R的代码再来比较。代码如下:

library(pacman)
p_load(Rcpp,tidyfst)

get_id = function(x){
  z = vector()
  for (i in seq_along(x)) {
    if(i == 1) z[i] = 1
    else if(x[i] != x[i-1]) z[i] = z[i-1] + 1
    else z[i] = z[i-1]
  }
  z
}

cppFunction('
  IntegerVector get_id_c(IntegerVector x){
  int n = x.size();
  IntegerVector out(n);
  
  for (int i = 0; i < n; i++) {
    if(i == 1) out[i] = 1;
    else if(x[i] == x[i-1]) out[i] = out[i-1];
    else out[i] = out[i-1] + 1;
  }
  return out;
}')

sys_time_print({
  res1 = get_id(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e7),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e7),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e8),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e8),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e9),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e9),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

1万:

10万:

100万:

1000万:

1亿:

结论:还是Rcpp香。


三更:R代码提前设置向量长度,再比较。

library(pacman)
p_load(Rcpp,tidyfst)

get_id = function(x){
  z = integer(length(x))
  for (i in seq_along(x)) {
    if(i == 1) z[i] = 1
    else if(x[i] != x[i-1]) z[i] = z[i-1] + 1
    else z[i] = z[i-1]
  }
  z
}

cppFunction('
  IntegerVector get_id_c(IntegerVector x){
  int n = x.size();
  IntegerVector out(n);
  
  for (int i = 0; i < n; i++) {
    if(i == 1) out[i] = 1;
    else if(x[i] == x[i-1]) out[i] = out[i-1];
    else out[i] = out[i-1] + 1;
  }
  return out;
}')

sys_time_print({
  res1 = get_id(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e7),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e7),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e8),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e8),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)


四更:对于这个任务来讲,data.table的rleid函数是快的。R语言中的终极魔咒,能找到现成的千万不要自己写,总有巨佬在前头。不过直到10亿个整数才有难以忍受的差距。

1亿
10亿

分享好友

分享这个小栈给你的朋友们,一起进步吧。

R语言
创建时间:2020-06-15 11:46:51
R是用于统计分析、绘图的语言和操作环境。R是属于GNU系统的一个自由、免费、源代码开放的软件,它是一个用于统计计算和统计制图的工具。
展开
订阅须知

• 所有用户可根据关注领域订阅专区或所有专区

• 付费订阅:虚拟交易,一经交易不退款;若特殊情况,可3日内客服咨询

• 专区发布评论属默认订阅所评论专区(除付费小栈外)

技术专家

查看更多
  • 小雨滴
    专家
戳我,来吐槽~