Spark随机森林算法交叉验证、管道模型(pipeline)、模型评估代码实例

package cn.itcast.tags.ml.classification

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel, VectorIndexer, VectorIndexerModel}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
import org.apache.spark.sql.{DataFrame, SparkSession}

object RfModel {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName(this.getClass.getSimpleName.stripSuffix("$"))
      .master("local[4]")
      .getOrCreate()
    import org.apache.spark.sql.functions._
    import spark.implicits._

    // 1. 加载数据
    val dataframe: DataFrame = spark.read
      .format("libsvm")
      .load("datas/ship/total001.txt")

    // 划分数据集:训练数据和测试数据
    val Array(trainingDF, testingDF) = dataframe.randomSplit(Array(0.8, 0.2))

    // 2. 特征工程:特征提取、特征转换及特征选择
    // 2.1. 将标签值label,转换为索引,从0开始,到 K-1
    val labelIndexer: StringIndexerModel = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("index_label")
      .fit(dataframe)
    val df1: DataFrame = labelIndexer.transform(dataframe)

    // 2.2. 对类别特征数据进行特殊处理, 当每列的值的个数小于等于设置K,那么此列数据被当做类别特征,自动进行索引转换
    val featureIndexer: VectorIndexerModel = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("index_features")
      // TODO: 表示哪些特征列为类别特征列,并且将特征列的特征值进行索引化转换操作
      .setMaxCategories(4) // 类别特征最大类别个数
      .fit(df1)
    val df2: DataFrame = featureIndexer.transform(df1)

    val rf = new RandomForestClassifier()
      .setLabelCol("index_label")
      .setFeaturesCol("index_features")
    //.setSubsamplingRate(1.0)

    // TODO: 4. 构建Pipeline管道,设置Stage,每个Stage要么是算法(模型学习器Estimator),要么是模型(转换器Transformer)
    val pipeline: Pipeline = new Pipeline()
      // 设置Stage,依赖顺序
      .setStages(
      Array(labelIndexer, featureIndexer, rf)
    )

    val paramGrid: Array[ParamMap] = new ParamGridBuilder()
      .addGrid(rf.maxDepth, Array(5, 10,15,20,25,30))
      .addGrid(rf.impurity, Array("gini", "entropy"))
      .addGrid(rf.maxBins, Array(32, 64))
      .addGrid(rf.numTrees, Array(5, 10,20,30,40,50))
      .addGrid(rf.featureSubsetStrategy,Array("auto","sqrt"))
      .build()

//    val paramGrid: Array[ParamMap] = new ParamGridBuilder()
//      .addGrid(rf.maxDepth, Array(5, 10,15,20,25,30))
//      .addGrid(rf.impurity, Array("gini", "entropy"))
//      .addGrid(rf.maxBins, Array(32, 64))
//      .addGrid(rf.numTrees, Array(5, 10))
//      .addGrid(rf.featureSubsetStrategy,Array("auto","sqrt"))
//      .build()

    // 多分类评估器
    val evaluator: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("index_label")
      .setPredictionCol("prediction")
       指标名称,支持:f1、weightedPrecision、weightedRecall、accuracy
      .setMetricName("accuracy")

    // 训练验证
    val validator: CrossValidator = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(4)

    // 训练模型
    println("正在训练模型...")
    val model: CrossValidatorModel = validator.fit(trainingDF)
    println(model.toString())

    // 5. 模型评估,计算准确度
    val predictionDF: DataFrame = model.transform(testingDF)
    predictionDF.printSchema()
    predictionDF
      .select( $"probability", $"prediction",$"index_label")
      .show(100, truncate = false)

    val accuracy: Double = evaluator.evaluate(predictionDF)
    println(s"Accuracy = $accuracy")

    spark.stop()
  }

}

来源:https://www.icode9.com/content-1-901201.html

(0)

相关推荐