决策树 - 基于 RDD 的 API
决策树及其集成是用于分类和回归机器学习任务的常用方法。决策树被广泛使用,因为它们易于解释、处理分类特征、扩展到多类分类设置、不需要特征缩放,并且能够捕获非线性和特征交互。诸如随机森林和提升之类的树集成算法是分类和回归任务的最佳执行者之一。
spark.mllib
支持用于二元和多类分类以及回归的决策树,同时使用连续和分类特征。该实现按行对数据进行分区,从而允许使用数百万个实例进行分布式训练。
树的集成(随机森林和梯度提升树)在集成指南中描述。
基本算法
决策树是一种贪心算法,它对特征空间执行递归二元分区。该树为每个最底部(叶)分区预测相同的标签。通过从一组可能的分割中选择最佳分割来贪婪地选择每个分区,以便最大化树节点处的信息增益。换句话说,在每个树节点处选择的分割是从集合 $\underset{s}{\operatorname{argmax}} IG(D,s)$
中选择的,其中 $IG(D,s)$
是当分割 $s$
应用于数据集 $D$
时的信息增益。
节点不纯度和信息增益
节点不纯度 是节点处标签同质性的度量。当前实现为分类(Gini 不纯度和熵)提供两个不纯度度量,为回归(方差)提供一个不纯度度量。
不纯度 | 任务 | 公式 | 描述 |
---|---|---|---|
Gini 不纯度 | 分类 | $\sum_{i=1}^{C} f_i(1-f_i)$ | $f_i$ 是节点处标签 $i$ 的频率,$C$ 是唯一标签的数量。 |
熵 | 分类 | $\sum_{i=1}^{C} -f_ilog(f_i)$ | $f_i$ 是节点处标签 $i$ 的频率,$C$ 是唯一标签的数量。 |
方差 | 回归 | $\frac{1}{N} \sum_{i=1}^{N} (y_i - \mu)^2$ | $y_i$ 是实例的标签,$N$ 是实例的数量,$\mu$ 是由 $\frac{1}{N} \sum_{i=1}^N y_i$ 给出的均值。 |
信息增益是父节点不纯度与两个子节点不纯度的加权和之间的差。假设分割 $s$ 将大小为 $N$
的数据集 $D$
分区为大小分别为 $N_{left}$
和 $N_{right}$
的两个数据集 $D_{left}$
和 $D_{right}$
,则信息增益为
$IG(D,s) = Impurity(D) - \frac{N_{left}}{N} Impurity(D_{left}) - \frac{N_{right}}{N} Impurity(D_{right})$
分割候选
连续特征
对于单机实现中的小型数据集,每个连续特征的分割候选通常是该特征的唯一值。一些实现对特征值进行排序,然后使用排序的唯一值作为分割候选,以加快树计算。
对于大型分布式数据集,排序特征值代价高昂。此实现通过对数据的采样分数执行分位数计算来计算一组近似的分割候选。有序分割创建“桶”,并且可以使用 maxBins
参数指定此类桶的最大数量。
请注意,桶的数量不能大于实例的数量 $N$
(罕见情况,因为默认的 maxBins
值为 32)。如果条件不满足,树算法会自动减少桶的数量。
分类特征
对于具有 $M$
个可能值(类别)的分类特征,可以提出 $2^{M-1}-1$
个分割候选。对于二元 (0/1) 分类和回归,我们可以通过按平均标签对分类特征值进行排序,将分割候选的数量减少到 $M-1$
。(有关详细信息,请参见统计机器学习的要素中的第 9.2.4 节。)例如,对于一个二元分类问题,其中一个分类特征具有三个类别 A、B 和 C,其标签 1 的相应比例为 0.2、0.6 和 0.4,则分类特征的顺序为 A、C、B。两个分割候选是 A | C, B 和 A , C | B,其中 | 表示分割。
在多类分类中,只要有可能,都会使用所有 $2^{M-1}-1$
个可能的分割。当 $2^{M-1}-1$
大于 maxBins
参数时,我们使用类似于用于二元分类和回归的方法(启发式)方法。按不纯度对 $M$
个分类特征值进行排序,并考虑生成的 $M-1$
个分割候选。
停止规则
当满足以下条件之一时,递归树构建会在节点处停止
- 节点深度等于
maxDepth
训练参数。 - 没有分割候选导致的信息增益大于
minInfoGain
。 - 没有分割候选生成具有至少
minInstancesPerNode
个训练实例的子节点。
使用技巧
我们通过讨论各种参数来包括一些使用决策树的指南。这些参数大致按重要性降序排列。新用户应主要考虑“问题规范参数”部分和 maxDepth
参数。
问题规范参数
这些参数描述了您要解决的问题和数据集。 应该指定它们,不需要调整。
-
algo
: 决策树的类型,可以是Classification
或Regression
。 -
numClasses
: 类的数量(仅用于Classification
)。 -
categoricalFeaturesInfo
: 指定哪些特征是分类的,以及每个分类特征可以采用多少个分类值。 这是作为从特征索引到特征元数(类别数)的映射给出的。 此映射中不存在的任何特征都将视为连续特征。- 例如,
Map(0 -> 2, 4 -> 10)
指定特征0
是二元的(取值0
或1
),并且特征4
具有 10 个类别(值{0, 1, ..., 9}
)。 请注意,特征索引是基于 0 的:特征0
和4
是实例的特征向量的第 1 个和第 5 个元素。 - 请注意,您不必指定
categoricalFeaturesInfo
。 该算法仍将运行,并且可能会获得合理的结果。 但是,如果正确指定了分类特征,则性能应该更好。
- 例如,
停止标准
这些参数确定树何时停止构建(添加新节点)。 在调整这些参数时,请务必在保留的测试数据上进行验证,以避免过度拟合。
-
maxDepth
: 树的最大深度。更深的树具有更强的表达能力(可能允许更高的准确率),但它们的训练成本也更高,并且更容易过拟合。 -
minInstancesPerNode
: 为了进一步拆分一个节点,它的每个子节点必须至少接收到这个数量的训练实例。这通常与 RandomForest 一起使用,因为它们的训练通常比单个树更深。 -
minInfoGain
: 为了进一步拆分一个节点,拆分必须至少改善这么多(在信息增益方面)。
可调参数
这些参数可以进行调整。在调整时要小心在保留的测试数据上进行验证,以避免过拟合。
maxBins
: 离散化连续特征时使用的 bin 的数量。- 增加
maxBins
允许算法考虑更多的拆分候选对象并做出细粒度的拆分决策。但是,它也会增加计算和通信成本。 - 请注意,对于任何类别特征,
maxBins
参数必须至少是最大类别数$M$
。
- 增加
maxMemoryInMB
: 用于收集充分统计信息的内存量。- 默认值被保守地选择为 256 MiB,以允许决策算法在大多数情况下工作。增加
maxMemoryInMB
可以通过允许更少的数据传递来加快训练速度(如果内存可用)。但是,随着maxMemoryInMB
的增长,可能会出现收益递减的情况,因为每次迭代的通信量可能与maxMemoryInMB
成正比。 - 实现细节:为了更快的处理,决策树算法收集关于节点组的统计信息来进行拆分(而不是一次一个节点)。一个组中可以处理的节点数由内存需求决定(每个特征的需求各不相同)。
maxMemoryInMB
参数指定每个 worker 可以用于这些统计信息的内存限制(以兆字节为单位)。
- 默认值被保守地选择为 256 MiB,以允许决策算法在大多数情况下工作。增加
-
subsamplingRate
: 用于学习决策树的训练数据 fraction。此参数与训练树的集合(使用RandomForest
和GradientBoostedTrees
)最相关,在这些情况下,对原始数据进行子采样可能很有用。 对于训练单个决策树,此参数不太有用,因为训练实例的数量通常不是主要约束。 impurity
: 用于在候选拆分之间进行选择的杂质度量(如上所述)。此度量必须与algo
参数匹配。
缓存和检查点
MLlib 1.2 添加了几个特性,用于扩展到更大(更深)的树和树的集合。当 maxDepth
设置为较大值时,打开节点 ID 缓存和检查点可能会很有用。当 numTrees
设置为较大值时,这些参数对于 RandomForest 也很有用。
useNodeIdCache
: 如果此项设置为 true,则算法将避免在每次迭代时将当前模型(树或多棵树)传递给 executors。- 这对于深树(加快 workers 上的计算速度)和大型 Random Forests(减少每次迭代时的通信量)非常有用。
- 实现细节:默认情况下,该算法将当前模型传递给 executors,以便 executors 可以将训练实例与树节点匹配。当打开此设置时,算法将改为缓存此信息。
节点 ID 缓存生成一系列 RDD(每次迭代 1 个)。这种长 lineage 可能会导致性能问题,但检查点中间 RDD 可以缓解这些问题。请注意,只有在 useNodeIdCache
设置为 true 时,才能应用检查点。
-
checkpointDir
: 用于检查点节点 ID 缓存 RDD 的目录。 -
checkpointInterval
: 检查点节点 ID 缓存 RDD 的频率。将其设置得太低会导致写入 HDFS 的额外开销;将其设置得太高会导致 executors 失败并且需要重新计算 RDD 时出现问题。
扩展
计算量大约与训练实例的数量、特征的数量和 maxBins
参数成线性关系。通信量大约与特征的数量和 maxBins
成线性关系。
实现的算法读取稀疏和密集数据。但是,它没有针对稀疏输入进行优化。
例子
分类
下面的示例演示了如何加载 LIBSVM 数据文件,将其解析为 LabeledPoint
的 RDD,然后使用以 Gini 杂质作为杂质度量和最大树深度为 5 的决策树执行分类。计算测试误差以衡量算法的准确性。
有关 API 的更多详细信息,请参阅 DecisionTree
Python 文档 和 DecisionTreeModel
Python 文档。
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])
# Train a DecisionTree model.
# Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
impurity='gini', maxDepth=5, maxBins=32)
# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testErr = labelsAndPredictions.filter(
lambda lp: lp[0] != lp[1]).count() / float(testData.count())
print('Test Error = ' + str(testErr))
print('Learned classification tree model:')
print(model.toDebugString())
# Save and load model
model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
有关 API 的详细信息,请参阅 DecisionTree
Scala 文档 和 DecisionTreeModel
Scala 文档。
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// Train a DecisionTree model.
// Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32
val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins)
// Evaluate model on test instances and compute test error
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
println(s"Test Error = $testErr")
println(s"Learned classification tree model:\n ${model.toDebugString}")
// Save and load model
model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
有关 API 的详细信息,请参阅 DecisionTree
Java 文档 和 DecisionTreeModel
Java 文档。
import java.util.HashMap;
import java.util.Map;
import scala.Tuple2;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
// Load and parse the data file.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];
// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
int numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "gini";
int maxDepth = 5;
int maxBins = 32;
// Train a DecisionTree model for classification.
DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
double testErr =
predictionAndLabel.filter(pl -> !pl._1().equals(pl._2())).count() / (double) testData.count();
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:\n" + model.toDebugString());
// Save and load model
model.save(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
DecisionTreeModel sameModel = DecisionTreeModel
.load(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
回归
下面的示例演示了如何加载 LIBSVM 数据文件,将其解析为 LabeledPoint
的 RDD,然后使用以方差作为杂质度量和最大树深度为 5 的决策树执行回归。最后计算均方误差 (MSE) 以评估 拟合优度。
有关 API 的更多详细信息,请参阅 DecisionTree
Python 文档 和 DecisionTreeModel
Python 文档。
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])
# Train a DecisionTree model.
# Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={},
impurity='variance', maxDepth=5, maxBins=32)
# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testMSE = labelsAndPredictions.map(lambda lp: (lp[0] - lp[1]) * (lp[0] - lp[1])).sum() /\
float(testData.count())
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression tree model:')
print(model.toDebugString())
# Save and load model
model.save(sc, "target/tmp/myDecisionTreeRegressionModel")
sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel")
有关 API 的详细信息,请参阅 DecisionTree
Scala 文档 和 DecisionTreeModel
Scala 文档。
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// Train a DecisionTree model.
// Empty categoricalFeaturesInfo indicates all features are continuous.
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "variance"
val maxDepth = 5
val maxBins = 32
val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity,
maxDepth, maxBins)
// Evaluate model on test instances and compute test error
val labelsAndPredictions = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean()
println(s"Test Mean Squared Error = $testMSE")
println(s"Learned regression tree model:\n ${model.toDebugString}")
// Save and load model
model.save(sc, "target/tmp/myDecisionTreeRegressionModel")
val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel")
有关 API 的详细信息,请参阅 DecisionTree
Java 文档 和 DecisionTreeModel
Java 文档。
import java.util.HashMap;
import java.util.Map;
import scala.Tuple2;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
// Load and parse the data file.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];
// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "variance";
int maxDepth = 5;
int maxBins = 32;
// Train a DecisionTree model.
DecisionTreeModel model = DecisionTree.trainRegressor(trainingData,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
double testMSE = predictionAndLabel.mapToDouble(pl -> {
double diff = pl._1() - pl._2();
return diff * diff;
}).mean();
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression tree model:\n" + model.toDebugString());
// Save and load model
model.save(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");
DecisionTreeModel sameModel = DecisionTreeModel
.load(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");