决策树 - 基于 RDD 的 API

决策树及其集成是机器学习中分类和回归任务的常用方法。决策树被广泛使用,因为它们易于解释、能够处理分类特征、可扩展到多分类设置、不需要特征缩放,并且能够捕获非线性和特征交互。树集成算法,如随机森林和提升算法,是分类和回归任务中性能最佳的算法之一。

spark.mllib 支持用于二分类和多分类以及回归的决策树,使用连续和分类特征。该实现按行对数据进行分区,允许对数百万个实例进行分布式训练。

树的集成(随机森林和梯度提升树)在集成指南中进行了描述。

基本算法

决策树是一种贪婪算法,它对特征空间执行递归二元划分。树为每个最底层(叶)分区预测相同的标签。通过从一组可能的分割中选择*最佳分割*来贪婪地选择每个分区,以便最大化树节点处的信息增益。换句话说,在每个树节点处选择的分割是从集合 $\underset{s}{\operatorname{argmax}} IG(D,s)$ 中选择的,其中 $IG(D,s)$ 是将分割 $s$ 应用于数据集 $D$ 时获得的信息增益。

节点不纯度和信息增益

*节点不纯度*是衡量节点处标签均匀程度的指标。当前实现为分类提供了两种不纯度度量(基尼不纯度和熵)和一种用于回归的不纯度度量(方差)。

不纯度任务公式描述
基尼不纯度 分类 $\sum_{i=1}^{C} f_i(1-f_i)$$f_i$ 是节点处标签 $i$ 的频率,$C$ 是唯一标签的数量。
分类 $\sum_{i=1}^{C} -f_ilog(f_i)$$f_i$ 是节点处标签 $i$ 的频率,$C$ 是唯一标签的数量。
方差 回归 $\frac{1}{N} \sum_{i=1}^{N} (y_i - \mu)^2$$y_i$ 是实例的标签,$N$ 是实例的数量,$\mu$ 是由 $\frac{1}{N} \sum_{i=1}^N y_i$ 给出的均值。

*信息增益*是父节点不纯度与两个子节点不纯度的加权和之间的差值。假设分割 $s$ 将大小为 $N$ 的数据集 $D$ 分割成大小分别为 $N_{left}$$N_{right}$ 的两个数据集 $D_{left}$$D_{right}$,则信息增益为

$IG(D,s) = Impurity(D) - \frac{N_{left}}{N} Impurity(D_{left}) - \frac{N_{right}}{N} Impurity(D_{right})$

分割候选

连续特征

对于单机实现中的小型数据集,每个连续特征的分割候选通常是该特征的唯一值。一些实现对特征值进行排序,然后使用有序的唯一值作为分割候选,以加快树的计算速度。

对于大型分布式数据集,对特征值进行排序的成本很高。此实现通过对数据样本的一部分执行分位数计算来计算一组近似的分割候选。有序分割创建“箱”,可以使用 maxBins 参数指定此类箱的最大数量。

请注意,箱的数量不能大于实例数 $N$(这种情况很少见,因为默认的 maxBins 值为 32)。如果不满足条件,树算法会自动减少箱的数量。

分类特征

对于具有 $M$ 个可能值(类别)的分类特征,可以提出 $2^{M-1}-1$ 个分割候选。对于二元(0/1)分类和回归,我们可以通过按平均标签对分类特征值进行排序,将分割候选的数量减少到 $M-1$。(有关详细信息,请参阅统计学习基础第 9.2.4 节。)例如,对于一个具有一个分类特征的二元分类问题,该特征具有三个类别 A、B 和 C,其对应的标签 1 比例分别为 0.2、0.6 和 0.4,则分类特征的顺序为 A、C、B。两个分割候选是 A | C、B 和 A、C | B,其中 | 表示分割。

在多分类中,只要有可能,就会使用所有 $2^{M-1}-1$ 个可能的分割。当 $2^{M-1}-1$ 大于 maxBins 参数时,我们使用一种类似于二元分类和回归的方法(启发式方法)。$M$ 个分类特征值按不纯度排序,并考虑得到的 $M-1$ 个分割候选。

停止规则

当满足以下条件之一时,递归树构造在节点处停止

  1. 节点深度等于 maxDepth 训练参数。
  2. 没有分割候选导致信息增益大于 minInfoGain
  3. 没有分割候选产生至少具有 minInstancesPerNode 个训练实例的子节点。

使用技巧

我们通过讨论各种参数,提供了一些使用决策树的指南。参数大致按重要性降序排列。新用户应主要考虑“问题规范参数”部分和 maxDepth 参数。

问题规范参数

这些参数描述了您要解决的问题和您的数据集。应指定它们,并且不需要调整。

停止条件

这些参数决定了树何时停止构建(添加新节点)。调整这些参数时,请务必在留出的测试数据上进行验证,以避免过拟合。

可调参数

可以调整这些参数。调整时请务必在留出的测试数据上进行验证,以避免过拟合。

缓存和检查点

MLlib 1.2 添加了几个用于扩展到更大(更深)的树和树集合的功能。当 maxDepth 设置为较大时,打开节点 ID 缓存和检查点可能很有用。当 numTrees 设置为较大时,这些参数对 RandomForest 也很有用。

节点 ID 缓存会生成一系列 RDD(每次迭代 1 个)。这种长谱系会导致性能问题,但对中间 RDD 进行检查点可以缓解这些问题。请注意,检查点仅在 useNodeIdCache 设置为 true 时适用。

扩展

计算量大约与训练实例数、特征数和 maxBins 参数成线性比例。通信量大约与特征数和 maxBins 成线性比例。

实现的算法读取稀疏和密集数据。但是,它没有针对稀疏输入进行优化。

示例

分类

下面的示例演示了如何加载 LIBSVM 数据文件,将其解析为 LabeledPoint 的 RDD,然后使用基尼杂质作为杂质度量和最大树深度为 5 的决策树执行分类。计算测试误差以衡量算法的准确性。

有关 API 的更多详细信息,请参阅 DecisionTree Python 文档DecisionTreeModel Python 文档

from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a DecisionTree model.
#  Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
                                     impurity='gini', maxDepth=5, maxBins=32)

# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testErr = labelsAndPredictions.filter(
    lambda lp: lp[0] != lp[1]).count() / float(testData.count())
print('Test Error = ' + str(testErr))
print('Learned classification tree model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
在 Spark 存储库的“examples/src/main/python/mllib/decision_tree_classification_example.py”中查找完整的示例代码。

有关 API 的详细信息,请参阅 DecisionTree Scala 文档DecisionTreeModel Scala 文档

import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32

val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  impurity, maxDepth, maxBins)

// Evaluate model on test instances and compute test error
val labelAndPreds = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
println(s"Test Error = $testErr")
println(s"Learned classification tree model:\n ${model.toDebugString}")

// Save and load model
model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
在 Spark 存储库的“examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala”中查找完整的示例代码。

有关 API 的详细信息,请参阅 DecisionTree Java 文档DecisionTreeModel Java 文档

import java.util.HashMap;
import java.util.Map;

import scala.Tuple2;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;

SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);

// Load and parse the data file.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

// Set parameters.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
int numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "gini";
int maxDepth = 5;
int maxBins = 32;

// Train a DecisionTree model for classification.
DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
  categoricalFeaturesInfo, impurity, maxDepth, maxBins);

// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
  testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
double testErr =
  predictionAndLabel.filter(pl -> !pl._1().equals(pl._2())).count() / (double) testData.count();

System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:\n" + model.toDebugString());

// Save and load model
model.save(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
DecisionTreeModel sameModel = DecisionTreeModel
  .load(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
在 Spark 存储库的“examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java”中查找完整的示例代码。

回归

下面的示例演示了如何加载 LIBSVM 数据文件,将其解析为 LabeledPoint 的 RDD,然后使用方差作为杂质度量和最大树深度为 5 的决策树执行回归。最后计算均方误差 (MSE) 以评估 拟合优度

有关 API 的更多详细信息,请参阅 DecisionTree Python 文档DecisionTreeModel Python 文档

from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a DecisionTree model.
#  Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={},
                                    impurity='variance', maxDepth=5, maxBins=32)

# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testMSE = labelsAndPredictions.map(lambda lp: (lp[0] - lp[1]) * (lp[0] - lp[1])).sum() /\
    float(testData.count())
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression tree model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "target/tmp/myDecisionTreeRegressionModel")
sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel")
在 Spark 存储库的“examples/src/main/python/mllib/decision_tree_regression_example.py”中查找完整的示例代码。

有关 API 的详细信息,请参阅 DecisionTree Scala 文档DecisionTreeModel Scala 文档

import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "variance"
val maxDepth = 5
val maxBins = 32

val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity,
  maxDepth, maxBins)

// Evaluate model on test instances and compute test error
val labelsAndPredictions = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}
val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean()
println(s"Test Mean Squared Error = $testMSE")
println(s"Learned regression tree model:\n ${model.toDebugString}")

// Save and load model
model.save(sc, "target/tmp/myDecisionTreeRegressionModel")
val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel")
在 Spark 存储库的“examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala”中查找完整的示例代码。

有关 API 的详细信息,请参阅 DecisionTree Java 文档DecisionTreeModel Java 文档

import java.util.HashMap;
import java.util.Map;

import scala.Tuple2;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;

SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);

// Load and parse the data file.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "variance";
int maxDepth = 5;
int maxBins = 32;

// Train a DecisionTree model.
DecisionTreeModel model = DecisionTree.trainRegressor(trainingData,
  categoricalFeaturesInfo, impurity, maxDepth, maxBins);

// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
  testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
double testMSE = predictionAndLabel.mapToDouble(pl -> {
  double diff = pl._1() - pl._2();
  return diff * diff;
}).mean();
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression tree model:\n" + model.toDebugString());

// Save and load model
model.save(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");
DecisionTreeModel sameModel = DecisionTreeModel
  .load(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");
在 Spark 存储库的“examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java”中查找完整的示例代码。