val sqlContext = new SQLContext(sc) import sqlContext.implicits._ import org.apache.spark.ml.classification.DecisionTreeClassifier import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.Pipeline import org.apache.spark.mllib.util.MLUtils import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} // Decision Tree Classifier val classifier = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures") val data = MLUtils.loadLibSVMFile(sc, "./data/sample_libsvm_data.txt").toDF() val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data) val indexToString = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) val vectorIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(data) val stages = Array(labelIndexer, vectorIndexer, classifier, indexToString) val pipeline = new Pipeline().setStages(stages) val Array(train, test) = data.randomSplit(Array(0.7, 0.3)) val model = pipeline.fit(train) val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] // Print the tree model treeModel.toDebugString val prediction = model.transform(test) prediction.show() // Decision Tree Regressor import org.apache.spark.ml.regression.DecisionTreeRegressor import org.apache.spark.ml.regression.DecisionTreeRegressionModel val regressor = new DecisionTreeRegressor().setLabelCol("label").setFeaturesCol("indexedFeatures") val pipelineReg = new Pipeline().setStages(Array(vectorIndexer, regressor)) val modelReg = pipelineReg.fit(train) val treeModelReg = modelReg.stages(1).asInstanceOf[DecisionTreeRegressionModel] // Print the tree model treeModelReg.toDebugString val predictionReg = modelReg.transform(test) predictionReg.show()
0 Comments
Leave a Reply. |
Archives
October 2016
Categories
All
|