绑定完请刷新页面
取消
刷新

分享好友

×
取消 复制
R语言机器学习笔记(七):mlr模型调参
2020-07-15 15:43:46

前文提要:

HopeR:R语言机器学习笔记(一):mlr总纲

HopeR:R语言机器学习笔记(二):mlr任务定义

HopeR:R语言机器学习笔记(三):mlr学习器定义

HopeR:R语言机器学习笔记(四):mlr模型训练

HopeR:R语言机器学习笔记(五):mlr模型预测

HopeR:R语言机器学习笔记(六):mlr数据预处理


一个模型可能会有很多超参数,这些超参数的排列组合,我们必须不断尝试才能够知道哪个效果好。mlr能够让我们设置一个网格,然后对这些参数组合进行尝试,然后选择效果好的模型。要做调参,需要做三个步骤:

1、定义搜索空间;

2、定义优化算法;

3、确立评估方法。

定义搜索空间

比如SVM方法有两个参数(C和gamma)可以调(SVM的两个参数 C 和 gamma),那么我们就可以这样设置网格:

discrete_ps = makeParamSet(
  makeDiscreteParam("C", values = c(0.5, 1.0, 1.5, 2.0)),
  makeDiscreteParam("sigma", values = c(0.5, 1.0, 1.5, 2.0))
)
print(discrete_ps)
##           Type len Def      Constr Req Tunable Trafo
## C     discrete   -   - 0.5,1,1.5,2   -    TRUE     -
## sigma discrete   -   - 0.5,1,1.5,2   -    TRUE     -

makeDiscreteParam函数可以让我们设置离散型的网格,也可以定义连续数值空间作为搜索的范围:

num_ps = makeParamSet(
  makeNumericParam("C", lower = -10, upper = 10, trafo = function(x) 10^x),
  makeNumericParam("sigma", lower = -10, upper = 10, trafo = function(x) 10^x)
)

这样,就定义了C和sigma在10的-10次方到10次方之间来做搜索。

定义优化算法

优化算法,就是在我们的搜索空间中如何来取参数的组合。比如利用makeTuneControlGrid()函数,我们就相当于用了离散数值之间的所有组合。

ctrl = makeTuneControlGrid()

如果这个算法用在之前定义好的discrete_ps中,我们就相当于设立了一个4*4=16的搜索网格。如果要对连续数值空间定义算法,可以确立其精度,然后均与地进行参数设置:

ctrl = makeTuneControlGrid(resolution = 15L)

上面这个设置,就从10的-10次方到10次方均匀选取了15个数值(相当于10 ^ seq(-10, 10, length.out = 15)),来作为验证参数。我们可以知道,这种方法会让计算负担非常大,因此我们也有一个可以随机尝试的方法:

ctrl = makeTuneControlRandom(maxit = 200L)

上面这个方法就相当于在这个空间中随机采样200次。如果觉得覆盖面不够广,可以再增加这个数量。

执行调试

以上面的discrete_ps搜索空间为例,我们执行一次调试。这里,我们还需要设置重采样方法rdesc,这里设置为3折交叉验证。

rdesc = makeResampleDesc("CV", iters = 3L)

discrete_ps = makeParamSet(
  makeDiscreteParam("C", values = c(0.5, 1.0, 1.5, 2.0)),
  makeDiscreteParam("sigma", values = c(0.5, 1.0, 1.5, 2.0))
)
ctrl = makeTuneControlGrid()
rdesc = makeResampleDesc("CV", iters = 3L)
res = tuneParams("classif.ksvm", task = iris.task, resampling = rdesc,
  par.set = discrete_ps, control = ctrl)
## [Tune] Started tuning learner classif.ksvm for parameter set:
##           Type len Def      Constr Req Tunable Trafo
## C     discrete   -   - 0.5,1,1.5,2   -    TRUE     -
## sigma discrete   -   - 0.5,1,1.5,2   -    TRUE     -
## With control class: TuneControlGrid
## Imputation value: 1
## [Tune-x] 1: C=0.5; sigma=0.5
## [Tune-y] 1: mmce.test.mean=0.0400000; time: 0.0 min
## [Tune-x] 2: C=1; sigma=0.5
## [Tune-y] 2: mmce.test.mean=0.0400000; time: 0.0 min
## [Tune-x] 3: C=1.5; sigma=0.5
## [Tune-y] 3: mmce.test.mean=0.0400000; time: 0.0 min
## [Tune-x] 4: C=2; sigma=0.5
## [Tune-y] 4: mmce.test.mean=0.0400000; time: 0.0 min
## [Tune-x] 5: C=0.5; sigma=1
## [Tune-y] 5: mmce.test.mean=0.0533333; time: 0.0 min
## [Tune-x] 6: C=1; sigma=1
## [Tune-y] 6: mmce.test.mean=0.0400000; time: 0.0 min
## [Tune-x] 7: C=1.5; sigma=1
## [Tune-y] 7: mmce.test.mean=0.0400000; time: 0.0 min
## [Tune-x] 8: C=2; sigma=1
## [Tune-y] 8: mmce.test.mean=0.0400000; time: 0.0 min
## [Tune-x] 9: C=0.5; sigma=1.5
## [Tune-y] 9: mmce.test.mean=0.0533333; time: 0.0 min
## [Tune-x] 10: C=1; sigma=1.5
## [Tune-y] 10: mmce.test.mean=0.0533333; time: 0.0 min
## [Tune-x] 11: C=1.5; sigma=1.5
## [Tune-y] 11: mmce.test.mean=0.0466667; time: 0.0 min
## [Tune-x] 12: C=2; sigma=1.5
## [Tune-y] 12: mmce.test.mean=0.0466667; time: 0.0 min
## [Tune-x] 13: C=0.5; sigma=2
## [Tune-y] 13: mmce.test.mean=0.0600000; time: 0.0 min
## [Tune-x] 14: C=1; sigma=2
## [Tune-y] 14: mmce.test.mean=0.0533333; time: 0.0 min
## [Tune-x] 15: C=1.5; sigma=2
## [Tune-y] 15: mmce.test.mean=0.0466667; time: 0.0 min
## [Tune-x] 16: C=2; sigma=2
## [Tune-y] 16: mmce.test.mean=0.0533333; time: 0.0 min
## [Tune] Result: C=1.5; sigma=0.5 : mmce.test.mean=0.0400000

res
## Tune result:
## Op. pars: C=1.5; sigma=0.5
## mmce.test.mean=0.0400000

这里默认采用MMCE作为损失函数,即让错误率保持到小。可以这样观察是否采用了这种衡量方法:

# error rate
mmce$minimize
## [1] TRUE

# accuracy
acc$minimize
## [1] FALSE

获得调参结果

后获得的结果是一个对象,可以直接对其中的优结果进行提取:

res$x
## $C
## [1] 15.52092
## 
## $sigma
## [1] 0.03449146

res$y
## acc.test.mean   acc.test.sd 
##          0.96          0.00

然后,我们可以根据这个结果,来训练一个优的模型:

lrn = setHyperPars(makeLearner("classif.ksvm"), C = res$x$C, sigma = res$x$sigma)
lrn
## Learner classif.ksvm from package kernlab
## Type: classif
## Name: Support Vector Machines; Short name: ksvm
## Class: classif.ksvm
## Properties: twoclass,multiclass,numerics,factors,prob,class.weights
## Predict-Type: response
## Hyperparameters: fit=FALSE,C=15.5,sigma=0.0345

m = train(lrn, iris.task)
predict(m, task = iris.task)

## Prediction: 150 observations
## predict.type: response
## threshold: 
## time: 0.00
##   id  truth response
## 1  1 setosa   setosa
## 2  2 setosa   setosa
## 3  3 setosa   setosa
## 4  4 setosa   setosa
## 5  5 setosa   setosa
## 6  6 setosa   setosa
## ... (#rows: 150, #cols: 3)

调参的过程可以使用generateHyperParsEffectData函数获得,并采用plotHyperParsEffect进行可视化。详细的教程,见官方参考链接。


参考链接:

mlr.mlr-org.com/article

作者:黄天元,复旦大学博士在读,热爱数据科学与开源工具(R),致力于利用数据科学迅速积累行业经验优势和科学知识发现,涉猎内容包括但不限于信息计量、机器学习、数据可视化、应用统计建模、知识图谱等,著有《R语言高效数据处理指南》(《R语言数据高效处理指南》(黄天元)【摘要 书评 试读】- 京东图书)。


分享好友

分享这个小栈给你的朋友们,一起进步吧。

R语言
创建时间:2020-06-15 11:46:51
R是用于统计分析、绘图的语言和操作环境。R是属于GNU系统的一个自由、免费、源代码开放的软件,它是一个用于统计计算和统计制图的工具。
展开
订阅须知

• 所有用户可根据关注领域订阅专区或所有专区

• 付费订阅:虚拟交易,一经交易不退款;若特殊情况,可3日内客服咨询

• 专区发布评论属默认订阅所评论专区(除付费小栈外)

技术专家

查看更多
  • 小雨滴
    专家
戳我,来吐槽~