决策树 - 基于RDD的API

决策树及其集成方法是机器学习分类和回归任务的常用方法。决策树被广泛使用,因为它们易于解释,能够处理类别特征,可扩展到多类分类设置,不需要特征缩放,并且能够捕获非线性关系和特征交互。随机森林和梯度提升等树集成算法在分类和回归任务中表现出色。

spark.mllib 支持二元和多类分类以及回归的决策树,可使用连续和类别特征。该实现按行划分数据,允许使用数百万实例进行分布式训练。

树的集成(随机森林和梯度提升树)在集成指南中描述。

基本算法

决策树是一种贪心算法,它对特征空间执行递归二元划分。树为每个最底层(叶)划分预测相同的标签。每个划分都是通过从一组可能的分割中选择最佳分割来贪心地选择的,以最大化树节点处的信息增益。换句话说,在每个树节点选择的分割是从集合 $\underset{s}{\operatorname{argmax}} IG(D,s)$ 中选出的,其中 $IG(D,s)$ 是对数据集 $D$ 应用分割 $s$ 时的信息增益。

节点不纯度与信息增益

节点不纯度是衡量节点处标签同质性的一种度量。当前的实现为分类提供了两种不纯度度量(基尼不纯度和熵),为回归提供了一种不纯度度量(方差)。

不纯度任务公式描述
基尼不纯度 分类 $\sum_{i=1}^{C} f_i(1-f_i)$$f_i$ 是节点处标签 $i$ 的频率,$C$ 是唯一标签的数量。
分类 $\sum_{i=1}^{C} -f_ilog(f_i)$$f_i$ 是节点处标签 $i$ 的频率,$C$ 是唯一标签的数量。
方差 回归 $\frac{1}{N} \sum_{i=1}^{N} (y_i - \mu)^2$$y_i$ 是实例的标签,$N$ 是实例的数量,$\mu$ 是由 $\frac{1}{N} \sum_{i=1}^N y_i$ 给出的平均值。

信息增益是父节点不纯度与两个子节点不纯度的加权和之间的差值。假设分割 $s$ 将大小为 $N$ 的数据集 $D$ 划分为大小分别为 $N_{left}$$N_{right}$ 的两个数据集 $D_{left}$$D_{right}$,则信息增益为

$IG(D,s) = Impurity(D) - \frac{N_{left}}{N} Impurity(D_{left}) - \frac{N_{right}}{N} Impurity(D_{right})$

分割候选

连续特征

对于单机实现中的小型数据集,每个连续特征的分割候选通常是该特征的唯一值。一些实现会对特征值进行排序,然后使用排序后的唯一值作为分割候选,以加快树的计算。

对于大型分布式数据集,对特征值进行排序开销很大。此实现通过对数据采样部分进行分位数计算来计算近似的分割候选集。有序分割创建“分箱”,此类分箱的最大数量可以使用 maxBins 参数指定。

请注意,分箱的数量不能大于实例的数量 $N$(这种情况很少见,因为默认的 maxBins 值是 32)。如果条件不满足,树算法会自动减少分箱的数量。

类别特征

对于具有 $M$ 个可能值(类别)的类别特征,可以产生 $2^{M-1}-1$ 个分割候选。对于二元(0/1)分类和回归,我们可以通过根据平均标签对类别特征值进行排序,将分割候选的数量减少到 $M-1$。(详见 统计机器学习要素的第 9.2.4 节。)例如,对于一个二元分类问题,其中一个类别特征有三个类别 A、B 和 C,它们对应标签 1 的比例分别为 0.2、0.6 和 0.4,则类别特征按 A、C、B 的顺序排列。两个分割候选是 A | C, B 和 A , C | B,其中 | 表示分割。

在多类分类中,尽可能使用所有 $2^{M-1}-1$ 种可能的分割。当 $2^{M-1}-1$ 大于 maxBins 参数时,我们使用一种(启发式)方法,类似于二元分类和回归中使用的那个方法。将 $M$ 个类别特征值按不纯度排序,然后考虑由此产生的 $M-1$ 个分割候选。

停止规则

当满足以下任一条件时,递归树构建在节点处停止:

  1. 节点深度等于 maxDepth 训练参数。
  2. 没有分割候选能够带来大于 minInfoGain 的信息增益。
  3. 没有分割候选能产生子节点,且每个子节点至少包含 minInstancesPerNode 个训练实例。

使用技巧

我们通过讨论各种参数来提供一些使用决策树的指导。参数大致按重要性降序排列。新用户应主要关注“问题规范参数”部分和 maxDepth 参数。

问题规范参数

这些参数描述了您要解决的问题和您的数据集。它们应该被指定,并且不需要调优。

  • algo: 决策树的类型,可以是 Classification(分类)或 Regression(回归)。

  • numClasses: 类别数量(仅适用于 Classification)。

  • categoricalFeaturesInfo: 指定哪些特征是类别型的,以及每个类别型特征可以有多少个类别值。这以从特征索引到特征基数(类别数量)的映射形式给出。不在该映射中的任何特征都将被视为连续特征。

    • 例如,Map(0 -> 2, 4 -> 10) 指定特征 0 是二元的(取值 01),并且特征 4 有 10 个类别(取值 {0, 1, ..., 9})。请注意,特征索引是从 0 开始的:特征 04 是实例特征向量的第 1 个和第 5 个元素。
    • 请注意,您不必指定 categoricalFeaturesInfo。算法仍然会运行并可能获得合理的结果。但是,如果正确指定了类别特征,性能应该会更好。

停止准则

这些参数决定了树何时停止构建(添加新节点)。在调整这些参数时,请务必在保留的测试数据上进行验证,以避免过拟合。

  • maxDepth: 树的最大深度。更深的树表达能力更强(可能实现更高的准确性),但它们的训练成本也更高,并且更容易过拟合。

  • minInstancesPerNode: 为了让节点进一步分割,其每个子节点必须至少接收到此数量的训练实例。这通常与 RandomForest 一起使用,因为那些模型通常比单个树训练得更深。

  • minInfoGain: 为了让节点进一步分割,分割改进量(以信息增益衡量)必须至少达到此值。

可调参数

这些参数可以进行调优。在调优时请务必在保留的测试数据上进行验证,以避免过拟合。

  • maxBins: 对连续特征进行离散化时使用的分箱数量。
    • 增加 maxBins 允许算法考虑更多的分割候选并进行更细粒度的分割决策。但是,这也会增加计算和通信开销。
    • 请注意,maxBins 参数必须至少是任何类别特征的最大类别数量 $M$
  • maxMemoryInMB: 用于收集足够统计信息的内存量。
    • 默认值保守地设置为 256 MiB,以允许决策算法在大多数场景中工作。增加 maxMemoryInMB 可以通过减少数据遍历次数来加快训练(如果内存可用)。然而,随着 maxMemoryInMB 的增长,收益可能会递减,因为每次迭代的通信量可能与 maxMemoryInMB 成比例。
    • 实现细节: 为了加快处理速度,决策树算法会收集一组待分割节点(而不是一次一个节点)的统计信息。一次可以处理的节点数量由内存需求决定(内存需求因特征而异)。maxMemoryInMB 参数以兆字节为单位指定了每个 Worker 可以用于这些统计信息的内存限制。
  • subsamplingRate: 用于学习决策树的训练数据比例。此参数对于训练树集成模型(使用 RandomForestGradientBoostedTrees)最为相关,其中对原始数据进行二次采样可能很有用。对于训练单个决策树,此参数的用处较小,因为训练实例的数量通常不是主要约束。

  • impurity: 用于在候选分割之间进行选择的不纯度度量(如上所述)。此度量必须与 algo 参数匹配。

缓存与检查点

MLlib 1.2 增加了多项功能,用于扩展到更大(更深)的树和树集成。当 maxDepth 设置得很大时,开启节点 ID 缓存和检查点可能很有用。这些参数对于 RandomForestnumTrees 设置得很大时也很有用。

  • useNodeIdCache: 如果此项设置为 true,算法将避免在每次迭代时将当前模型(一棵或多棵树)传递给执行器。
    • 这对于深层树(加快 Worker 上的计算)和大型随机森林(减少每次迭代的通信量)可能很有用。
    • 实现细节: 默认情况下,算法会将当前模型通信给执行器,以便执行器可以将训练实例与树节点进行匹配。当此设置开启时,算法将转而缓存此信息。

节点 ID 缓存会生成一系列 RDD(每次迭代一个)。这种长血缘关系可能会导致性能问题,但对中间 RDD 进行检查点可以缓解这些问题。请注意,检查点仅在 useNodeIdCache 设置为 true 时适用。

  • checkpointDir: 用于检查点节点 ID 缓存 RDD 的目录。

  • checkpointInterval: 检查点节点 ID 缓存 RDD 的频率。设置得太低会导致写入 HDFS 产生额外的开销;设置得太高可能会在执行器失败且 RDD 需要重新计算时导致问题。

扩展性

计算量近似与训练实例数量、特征数量和 maxBins 参数呈线性关系。通信量近似与特征数量和 maxBins 呈线性关系。

实现的算法读取稀疏和密集数据。但是,它并未针对稀疏输入进行优化。

示例

分类

下面的示例演示如何加载 LIBSVM 数据文件,将其解析为 LabeledPoint 的 RDD,然后使用决策树执行分类,其中基尼不纯度作为不纯度度量,最大树深度为 5。计算测试误差以衡量算法准确性。

有关 API 的更多详细信息,请参阅 DecisionTree Python 文档DecisionTreeModel Python 文档

from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a DecisionTree model.
#  Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
                                     impurity='gini', maxDepth=5, maxBins=32)

# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testErr = labelsAndPredictions.filter(
    lambda lp: lp[0] != lp[1]).count() / float(testData.count())
print('Test Error = ' + str(testErr))
print('Learned classification tree model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
完整示例代码请参见 Spark 仓库中 "examples/src/main/python/mllib/decision_tree_classification_example.py"。

有关 API 的详细信息,请参阅 DecisionTree Scala 文档DecisionTreeModel Scala 文档

import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32

val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  impurity, maxDepth, maxBins)

// Evaluate model on test instances and compute test error
val labelAndPreds = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
println(s"Test Error = $testErr")
println(s"Learned classification tree model:\n ${model.toDebugString}")

// Save and load model
model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
完整示例代码请参见 Spark 仓库中 "examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala"。

有关 API 的详细信息,请参阅 DecisionTree Java 文档DecisionTreeModel Java 文档

import java.util.HashMap;
import java.util.Map;

import scala.Tuple2;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;

SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);

// Load and parse the data file.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

// Set parameters.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
int numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "gini";
int maxDepth = 5;
int maxBins = 32;

// Train a DecisionTree model for classification.
DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
  categoricalFeaturesInfo, impurity, maxDepth, maxBins);

// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
  testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
double testErr =
  predictionAndLabel.filter(pl -> !pl._1().equals(pl._2())).count() / (double) testData.count();

System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:\n" + model.toDebugString());

// Save and load model
model.save(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
DecisionTreeModel sameModel = DecisionTreeModel
  .load(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
完整示例代码请参见 Spark 仓库中 "examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java"。

回归

下面的示例演示如何加载 LIBSVM 数据文件,将其解析为 LabeledPoint 的 RDD,然后使用决策树执行回归,其中方差作为不纯度度量,最大树深度为 5。最后计算均方误差 (MSE) 以评估拟合优度

有关 API 的更多详细信息,请参阅 DecisionTree Python 文档DecisionTreeModel Python 文档

from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils

# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a DecisionTree model.
#  Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={},
                                    impurity='variance', maxDepth=5, maxBins=32)

# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testMSE = labelsAndPredictions.map(lambda lp: (lp[0] - lp[1]) * (lp[0] - lp[1])).sum() /\
    float(testData.count())
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression tree model:')
print(model.toDebugString())

# Save and load model
model.save(sc, "target/tmp/myDecisionTreeRegressionModel")
sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel")
完整示例代码请参见 Spark 仓库中 "examples/src/main/python/mllib/decision_tree_regression_example.py"。

有关 API 的详细信息,请参阅 DecisionTree Scala 文档DecisionTreeModel Scala 文档

import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "variance"
val maxDepth = 5
val maxBins = 32

val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity,
  maxDepth, maxBins)

// Evaluate model on test instances and compute test error
val labelsAndPredictions = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}
val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean()
println(s"Test Mean Squared Error = $testMSE")
println(s"Learned regression tree model:\n ${model.toDebugString}")

// Save and load model
model.save(sc, "target/tmp/myDecisionTreeRegressionModel")
val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel")
完整示例代码请参见 Spark 仓库中 "examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala"。

有关 API 的详细信息,请参阅 DecisionTree Java 文档DecisionTreeModel Java 文档

import java.util.HashMap;
import java.util.Map;

import scala.Tuple2;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;

SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);

// Load and parse the data file.
String datapath = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "variance";
int maxDepth = 5;
int maxBins = 32;

// Train a DecisionTree model.
DecisionTreeModel model = DecisionTree.trainRegressor(trainingData,
  categoricalFeaturesInfo, impurity, maxDepth, maxBins);

// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
  testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
double testMSE = predictionAndLabel.mapToDouble(pl -> {
  double diff = pl._1() - pl._2();
  return diff * diff;
}).mean();
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression tree model:\n" + model.toDebugString());

// Save and load model
model.save(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");
DecisionTreeModel sameModel = DecisionTreeModel
  .load(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");
完整示例代码请参见 Spark 仓库中 "examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java"。