• gridsearchcv
    • 功能介绍
    • 参数说明
    • 脚本示例
      • 脚本代码
      • 脚本结果

    gridsearchcv

    功能介绍

    gridsearch是通过参数数组组成的网格,对其中的每一组输入参数的组很分别进行训练,预测,评估。取得评估参数最优的模型,作为最终的返回模型

    cv为交叉验证,将数据切分为k-folds,对每k-1份数据做训练,对剩余一份数据做预测和评估,得到一个评估结果。

    此函数用cv方法得到每一个grid对应参数的评估结果,得到最优模型

    参数说明

    名称 中文名称 描述 类型 是否必须? 默认值
    NumFolds 折数 交叉验证的参数,数据的折数(大于等于2) Integer 10
    ParamGrid 参数网格 指定参数的网格 ParamGrid —-
    Estimator Estimator 用于调优的Estimator Estimator —-
    TuningEvaluator 评估指标 用于选择最优模型的评估指标 TuningEvaluator —-

    脚本示例

    脚本代码

    1. def adult(url):
    2. data = (
    3. CsvSourceBatchOp()
    4. .setFilePath('http://alink-dataset.cn-hangzhou.oss.aliyun-inc.com/csv/adult_train.csv')
    5. .setSchemaStr(
    6. 'age bigint, workclass string, fnlwgt bigint,'
    7. 'education string, education_num bigint,'
    8. 'marital_status string, occupation string,'
    9. 'relationship string, race string, sex string,'
    10. 'capital_gain bigint, capital_loss bigint,'
    11. 'hours_per_week bigint, native_country string,'
    12. 'label string'
    13. )
    14. )
    15. return data
    16. def adult_train():
    17. return adult('http://alink-dataset.cn-hangzhou.oss.aliyun-inc.com/csv/adult_train.csv')
    18. def adult_test():
    19. return adult('http://alink-dataset.cn-hangzhou.oss.aliyun-inc.com/csv/adult_test.csv')
    20. def adult_numerical_feature_strs():
    21. return [
    22. "age", "fnlwgt", "education_num",
    23. "capital_gain", "capital_loss", "hours_per_week"
    24. ]
    25. def adult_categorical_feature_strs():
    26. return [
    27. "workclass", "education", "marital_status",
    28. "occupation", "relationship", "race", "sex",
    29. "native_country"
    30. ]
    31. def adult_features_strs():
    32. feature = adult_numerical_feature_strs()
    33. feature.extend(adult_categorical_feature_strs())
    34. return feature
    35. def rf_grid_search_cv(featureCols, categoryFeatureCols, label, metric):
    36. rf = (
    37. RandomForestClassifier()
    38. .setFeatureCols(featureCols)
    39. .setCategoricalCols(categoryFeatureCols)
    40. .setLabelCol(label)
    41. .setPredictionCol('prediction')
    42. .setPredictionDetailCol('prediction_detail')
    43. )
    44. paramGrid = (
    45. ParamGrid()
    46. .addGrid(rf, 'SUBSAMPLING_RATIO', [1.0, 0.99, 0.98])
    47. .addGrid(rf, 'NUM_TREES', [3, 6, 9])
    48. )
    49. tuningEvaluator = (
    50. BinaryClassificationTuningEvaluator()
    51. .setLabelCol(label)
    52. .setPredictionDetailCol("prediction_detail")
    53. .setMetricName(metric)
    54. )
    55. cv = (
    56. GridSearchCV()
    57. .setEstimator(rf)
    58. .setParamGrid(paramGrid)
    59. .setTuningEvaluator(tuningEvaluator)
    60. .setNumFolds(2)
    61. )
    62. return cv
    63. def rf_grid_search_tv(featureCols, categoryFeatureCols, label, metric):
    64. rf = (
    65. RandomForestClassifier()
    66. .setFeatureCols(featureCols)
    67. .setCategoricalCols(categoryFeatureCols)
    68. .setLabelCol(label)
    69. .setPredictionCol('prediction')
    70. .setPredictionDetailCol('prediction_detail')
    71. )
    72. paramGrid = (
    73. ParamGrid()
    74. .addGrid(rf, 'SUBSAMPLING_RATIO', [1.0, 0.99, 0.98])
    75. .addGrid(rf, 'NUM_TREES', [3, 6, 9])
    76. )
    77. tuningEvaluator = (
    78. BinaryClassificationTuningEvaluator()
    79. .setLabelCol(label)
    80. .setPredictionDetailCol("prediction_detail")
    81. .setMetricName(metric)
    82. )
    83. cv = (
    84. GridSearchTVSplit()
    85. .setEstimator(rf)
    86. .setParamGrid(paramGrid)
    87. .setTuningEvaluator(tuningEvaluator)
    88. )
    89. return cv
    90. def tuningcv(cv_estimator, input):
    91. return cv_estimator.fit(input)
    92. def tuningtv(tv_estimator, input):
    93. return tv_estimator.fit(input)
    94. def main():
    95. print('rf cv tuning')
    96. model = tuningcv(
    97. rf_grid_search_cv(adult_features_strs(),
    98. adult_categorical_feature_strs(), 'label', 'AUC'),
    99. adult_train()
    100. )
    101. print(model.getReport())
    102. print('rf tv tuning')
    103. model = tuningtv(
    104. rf_grid_search_tv(adult_features_strs(),
    105. adult_categorical_feature_strs(), 'label', 'AUC'),
    106. adult_train()
    107. )
    108. print(model.getReport())
    109. main()

    脚本结果

    1. rf cv tuning
    2. com.alibaba.alink.pipeline.tuning.GridSearchCV
    3. [ {
    4. "param" : [ {
    5. "stage" : "RandomForestClassifier",
    6. "paramName" : "numTrees",
    7. "paramValue" : 3
    8. }, {
    9. "stage" : "RandomForestClassifier",
    10. "paramName" : "subsamplingRatio",
    11. "paramValue" : 1.0
    12. } ],
    13. "metric" : 0.8922549257899725
    14. }, {
    15. "param" : [ {
    16. "stage" : "RandomForestClassifier",
    17. "paramName" : "numTrees",
    18. "paramValue" : 3
    19. }, {
    20. "stage" : "RandomForestClassifier",
    21. "paramName" : "subsamplingRatio",
    22. "paramValue" : 0.99
    23. } ],
    24. "metric" : 0.8920255970548456
    25. }, {
    26. "param" : [ {
    27. "stage" : "RandomForestClassifier",
    28. "paramName" : "numTrees",
    29. "paramValue" : 3
    30. }, {
    31. "stage" : "RandomForestClassifier",
    32. "paramName" : "subsamplingRatio",
    33. "paramValue" : 0.98
    34. } ],
    35. "metric" : 0.8944982480437225
    36. }, {
    37. "param" : [ {
    38. "stage" : "RandomForestClassifier",
    39. "paramName" : "numTrees",
    40. "paramValue" : 6
    41. }, {
    42. "stage" : "RandomForestClassifier",
    43. "paramName" : "subsamplingRatio",
    44. "paramValue" : 1.0
    45. } ],
    46. "metric" : 0.8923867598288401
    47. }, {
    48. "param" : [ {
    49. "stage" : "RandomForestClassifier",
    50. "paramName" : "numTrees",
    51. "paramValue" : 6
    52. }, {
    53. "stage" : "RandomForestClassifier",
    54. "paramName" : "subsamplingRatio",
    55. "paramValue" : 0.99
    56. } ],
    57. "metric" : 0.9012141767959505
    58. }, {
    59. "param" : [ {
    60. "stage" : "RandomForestClassifier",
    61. "paramName" : "numTrees",
    62. "paramValue" : 6
    63. }, {
    64. "stage" : "RandomForestClassifier",
    65. "paramName" : "subsamplingRatio",
    66. "paramValue" : 0.98
    67. } ],
    68. "metric" : 0.8993774036693788
    69. }, {
    70. "param" : [ {
    71. "stage" : "RandomForestClassifier",
    72. "paramName" : "numTrees",
    73. "paramValue" : 9
    74. }, {
    75. "stage" : "RandomForestClassifier",
    76. "paramName" : "subsamplingRatio",
    77. "paramValue" : 1.0
    78. } ],
    79. "metric" : 0.8981738808130779
    80. }, {
    81. "param" : [ {
    82. "stage" : "RandomForestClassifier",
    83. "paramName" : "numTrees",
    84. "paramValue" : 9
    85. }, {
    86. "stage" : "RandomForestClassifier",
    87. "paramName" : "subsamplingRatio",
    88. "paramValue" : 0.99
    89. } ],
    90. "metric" : 0.9029671873892725
    91. }, {
    92. "param" : [ {
    93. "stage" : "RandomForestClassifier",
    94. "paramName" : "numTrees",
    95. "paramValue" : 9
    96. }, {
    97. "stage" : "RandomForestClassifier",
    98. "paramName" : "subsamplingRatio",
    99. "paramValue" : 0.98
    100. } ],
    101. "metric" : 0.905228896323363
    102. } ]
    103. rf tv tuning
    104. com.alibaba.alink.pipeline.tuning.GridSearchTVSplit
    105. [ {
    106. "param" : [ {
    107. "stage" : "RandomForestClassifier",
    108. "paramName" : "numTrees",
    109. "paramValue" : 3
    110. }, {
    111. "stage" : "RandomForestClassifier",
    112. "paramName" : "subsamplingRatio",
    113. "paramValue" : 1.0
    114. } ],
    115. "metric" : 0.9022694229691741
    116. }, {
    117. "param" : [ {
    118. "stage" : "RandomForestClassifier",
    119. "paramName" : "numTrees",
    120. "paramValue" : 3
    121. }, {
    122. "stage" : "RandomForestClassifier",
    123. "paramName" : "subsamplingRatio",
    124. "paramValue" : 0.99
    125. } ],
    126. "metric" : 0.8963559966080328
    127. }, {
    128. "param" : [ {
    129. "stage" : "RandomForestClassifier",
    130. "paramName" : "numTrees",
    131. "paramValue" : 3
    132. }, {
    133. "stage" : "RandomForestClassifier",
    134. "paramName" : "subsamplingRatio",
    135. "paramValue" : 0.98
    136. } ],
    137. "metric" : 0.9041948454957178
    138. }, {
    139. "param" : [ {
    140. "stage" : "RandomForestClassifier",
    141. "paramName" : "numTrees",
    142. "paramValue" : 6
    143. }, {
    144. "stage" : "RandomForestClassifier",
    145. "paramName" : "subsamplingRatio",
    146. "paramValue" : 1.0
    147. } ],
    148. "metric" : 0.8982021117392784
    149. }, {
    150. "param" : [ {
    151. "stage" : "RandomForestClassifier",
    152. "paramName" : "numTrees",
    153. "paramValue" : 6
    154. }, {
    155. "stage" : "RandomForestClassifier",
    156. "paramName" : "subsamplingRatio",
    157. "paramValue" : 0.99
    158. } ],
    159. "metric" : 0.9031851535310546
    160. }, {
    161. "param" : [ {
    162. "stage" : "RandomForestClassifier",
    163. "paramName" : "numTrees",
    164. "paramValue" : 6
    165. }, {
    166. "stage" : "RandomForestClassifier",
    167. "paramName" : "subsamplingRatio",
    168. "paramValue" : 0.98
    169. } ],
    170. "metric" : 0.9034443322241488
    171. }, {
    172. "param" : [ {
    173. "stage" : "RandomForestClassifier",
    174. "paramName" : "numTrees",
    175. "paramValue" : 9
    176. }, {
    177. "stage" : "RandomForestClassifier",
    178. "paramName" : "subsamplingRatio",
    179. "paramValue" : 1.0
    180. } ],
    181. "metric" : 0.8993474753000145
    182. }, {
    183. "param" : [ {
    184. "stage" : "RandomForestClassifier",
    185. "paramName" : "numTrees",
    186. "paramValue" : 9
    187. }, {
    188. "stage" : "RandomForestClassifier",
    189. "paramName" : "subsamplingRatio",
    190. "paramValue" : 0.99
    191. } ],
    192. "metric" : 0.9090250137144916
    193. }, {
    194. "param" : [ {
    195. "stage" : "RandomForestClassifier",
    196. "paramName" : "numTrees",
    197. "paramValue" : 9
    198. }, {
    199. "stage" : "RandomForestClassifier",
    200. "paramName" : "subsamplingRatio",
    201. "paramValue" : 0.98
    202. } ],
    203. "metric" : 0.9129786771786127
    204. } ]