val sqlContext = new SQLContext(sc) import sqlContext.implicits._ import org.apache.spark.ml.Pipeline import org.apache.spark.mllib.util.MLUtils import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} import org.apache.spark.ml.classification.RandomForestClassifier import org.apache.spark.ml.classification.RandomForestClassificationModel import org.apache.spark.ml.regression.RandomForestRegressor import org.apache.spark.ml.classification.GBTClassifier import org.apache.spark.ml.classification.GBTClassificationModel val data = MLUtils.loadLibSVMFile(sc, "./data/sample_libsvm_data.txt").toDF() val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data) val vectorIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(data) val indexToString = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) // RandomForestClassifier Example val classifier = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(3) val Array(train, test) = data.randomSplit(Array(0.7, 0.3)) val stages = Array(labelIndexer, vectorIndexer, classifier, indexToString) val pipeline = new Pipeline().setStages(stages) val model = pipeline.fit(train) val modelRFC = model.stages(2).asInstanceOf[RandomForestClassificationModel] // to see feature Importances println(modelRFC.featureImportances) // to inspect rules of each tree in RF println(modelRFC.toDebugString) val predictions = model.transform(test) predictions.show() // RandomForestRegressor Example val regressor = new RandomForestRegressor().setLabelCol("label").setFeaturesCol("indexedFeatures") val stagesReg = Array(vectorIndexer, regressor) val pipelineReg = new Pipeline().setStages(stagesReg) val modelReg = pipelineReg.fit(train) val predictionsReg = modelReg.transform(test) predictionsReg.show() // GBT Classification Example val classifierGBT = new GBTClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(10) val stagesGBT = Array(labelIndexer, vectorIndexer, classifierGBT, indexToString) val pipelineGBTC = new Pipeline().setStages(stagesGBT) val modelGBTC = pipelineGBTC.fit(train) val modelGBTTree = modelGBTC.stages(2).asInstanceOf[GBTClassificationModel] println(modelGBTTree.toDebugString) val predictionsGBTC = modelGBTC.transform(test) println(predictionsGBTC.show()) // GBT Regression Example val regressorGBT = new GBTRegressor().setLabelCol("label").setFeaturesCol("indexedFeatures").setMaxIter(10) val pipelineGBTR = new Pipeline().setStages(Array(vectorIndexer, regressorGBT)) val modelGBTR = pipelineGBTR.fit(train) val predictionsGBTR = modelGBTR.transform(test) println(predictionsGBTR.show())
0 Comments
Leave a Reply. |
Archives
October 2016
Categories
All
|