绑定完请刷新页面
取消
刷新

分享好友

×
取消 复制
R语言机器学习笔记(五):mlr模型预测
2020-07-13 10:16:44

作者:黄天元,复旦大学博士在读,热爱数据科学与开源工具(R),致力于利用数据科学迅速积累行业经验优势和科学知识发现,涉猎内容包括但不限于信息计量、机器学习、数据可视化、应用统计建模、知识图谱等,著有《R语言高效数据处理指南》(《R语言数据高效处理指南》(黄天元)【摘要 书评 试读】- 京东图书)。知乎专栏:R语言数据挖掘邮箱:huang.tian-yuan@qq.com.欢迎合作交流。

前文提要:

HopeR:R语言机器学习笔记(一):mlr总纲

HopeR:R语言机器学习笔记(二):mlr任务定义

HopeR:R语言机器学习笔记(三):mlr学习器定义

HopeR:R语言机器学习笔记(四):mlr模型训练

在mlr包中,预测非常简单,用predict函数即可。如果在任务定义中,已经指定了数据集,还使用了其中的一部分来做训练,希望用剩下的数据来做验证,那么可以使用subset参数来完成,举例如下:

n = getTaskSize(bh.task)
train.set = seq(1, n, by = 2)
test.set = seq(2, n, by = 2)
lrn = makeLearner("regr.gbm", n.trees = 100)
mod = train(lrn, bh.task, subset = train.set)

task.pred = predict(mod, task = bh.task, subset = test.set)
task.pred

## Prediction: 253 observations
## predict.type: response
## threshold: 
## time: 0.00
##    id truth response
## 2   2  21.6 22.27841
## 4   4  33.4 38.92484
## 6   6  28.7 25.36387
## 8   8  27.1 15.61735
## 10 10  18.9 16.98708
## 12 12  18.9 20.41749
## ... (#rows: 253, #cols: 3) 

train的时候,subset是train.set;predict的时候,subset是test.set,后得到task.pred值,包含真实值(truth)和预测值(response)。如果数据不再训练的数据框中,那么就需要使用newdata参数来指定。

n = nrow(iris)
iris.train = iris[seq(1, n, by = 2), -5]
iris.test = iris[seq(2, n, by = 2), -5]
task = makeClusterTask(data = iris.train)
mod = train("cluster.kmeans", task)

newdata.pred = predict(mod, newdata = iris.test)
newdata.pred

## Prediction: 75 observations
## predict.type: response
## threshold: 
## time: 0.00
##    response
## 2         2
## 4         2
## 6         2
## 8         2
## 10        2
## 12        2
## ... (#rows: 75, #cols: 1)

如果需要直接获得这些预测值,可以将其转化为数据框。

### Result of predict with data passed via task argument
head(as.data.frame(task.pred))
##    id truth response
## 2   2  21.6 22.52737
## 4   4  33.4 36.06190
## 6   6  28.7 24.75354
## 8   8  27.1 16.90299
## 10 10  18.9 17.25558
## 12 12  18.9 20.54365

### Result of predict with data passed via newdata argument
head(as.data.frame(newdata.pred))
##    response
## 2         1
## 4         1
## 6         1
## 8         1
## 10        1
## 12        1

当然,如果只想取向量,获得结果是一个对象,也有响应的方法可以取出来(getPredictionTruthgetPredictionResponse)。

对于回归问题,还可以得到预测值的标准误(Standard Error,SE),在predict.type中设置为“se”即可:

### Create learner and specify predict.type
lrn.lm = makeLearner("regr.lm", predict.type = 'se')
mod.lm = train(lrn.lm, bh.task, subset = train.set)
task.pred.lm = predict(mod.lm, task = bh.task, subset = test.set)
task.pred.lm

## Prediction: 253 observations
## predict.type: se
## threshold: 
## time: 0.00
##    id truth response        se
## 2   2  21.6 24.83734 0.7501615
## 4   4  33.4 28.38206 0.8742590
## 6   6  28.7 25.16725 0.8652139
## 8   8  27.1 19.38145 1.1963265
## 10 10  18.9 18.66449 1.1793944
## 12 12  18.9 21.25802 1.0727918
## ... (#rows: 253, #cols: 4)

而对于分类和聚类问题,则可以得到其判断为每一个类别的似然概率,只要将predict.type设置为“prob”即可:

lrn = makeLearner("cluster.cmeans", predict.type = "prob")
mod = train(lrn, mtcars.task)

pred = predict(mod, task = mtcars.task)
head(getPredictionProbabilities(pred))
##                             1          2
## Mazda RX4         0.020400964 0.97959904
## Mazda RX4 Wag     0.020360747 0.97963925
## Datsun 710        0.007341207 0.99265879
## Hornet 4 Drive    0.457052250 0.54294775
## Hornet Sportabout 0.981291168 0.01870883
## Valiant           0.242514386 0.75748561

lrn = makeLearner("classif.rpart", predict.type = "prob")
mod = train(lrn, iris.task)

pred = predict(mod, newdata = iris)
head(as.data.frame(pred))
##    truth prob.setosa prob.versicolor prob.virginica response
## 1 setosa           1               0              0   setosa
## 2 setosa           1               0              0   setosa
## 3 setosa           1               0              0   setosa
## 4 setosa           1               0              0   setosa
## 5 setosa           1               0              0   setosa
## 6 setosa           1               0              0   setosa

响应的面向对象提取方法为getPredictionProbabilities。对于分类问题而言,还可以使用calculateConfusionMatrix来获得混淆矩阵。同时,可以使用setThreshold函数来更改判别阈值,如似然概率达到90%才能够认为它属于某一个类别,这个功能在二分类和多分类中均可以使用。

后,plotLearnerPrediction函数可以对预测结果进行一定的可视化。不过由于数据是多维的,一般需要自定义选择1~2个特征来做。


参考链接:

Predicting Outcomes for New Data

分享好友

分享这个小栈给你的朋友们,一起进步吧。

R语言
创建时间:2020-06-15 11:46:51
R是用于统计分析、绘图的语言和操作环境。R是属于GNU系统的一个自由、免费、源代码开放的软件,它是一个用于统计计算和统计制图的工具。
展开
订阅须知

• 所有用户可根据关注领域订阅专区或所有专区

• 付费订阅:虚拟交易,一经交易不退款;若特殊情况,可3日内客服咨询

• 专区发布评论属默认订阅所评论专区(除付费小栈外)

技术专家

查看更多
  • 小雨滴
    专家
戳我,来吐槽~