Skip to content

Commit

Permalink
added numclasses to tree runner, predict logic for multiclass, add mu…
Browse files Browse the repository at this point in the history
…lticlass option to train
  • Loading branch information
manishamde committed May 12, 2014
1 parent 75f2bfc commit 6b912dc
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ object DecisionTreeRunner {
case class Params(
input: String = null,
algo: Algo = Classification,
numClasses: Int = 2,
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 100)
Expand All @@ -68,6 +69,9 @@ object DecisionTreeRunner {
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
.action((x, c) => c.copy(maxDepth = x))
opt[Int]("numClasses")
.text(s"number of classes for classification, default: ${defaultParams.numClasses}")
.action((x, c) => c.copy(numClasses = x))
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
Expand Down Expand Up @@ -139,12 +143,8 @@ object DecisionTreeRunner {
*/
private def accuracyScore(
model: DecisionTreeModel,
data: RDD[LabeledPoint],
threshold: Double = 0.5): Double = {
def predictedValue(features: Vector): Double = {
if (model.predict(features) < threshold) 0.0 else 1.0
}
val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
data: RDD[LabeledPoint]): Double = {
val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
val count = data.count()
correctCount.toDouble / count
}
Expand Down
65 changes: 54 additions & 11 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -231,19 +231,43 @@ object DecisionTree extends Serializable with Logging {
* @param maxDepth maxDepth maximum depth of the tree
* @return a DecisionTreeModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int): DecisionTreeModel = {
val strategy = new Strategy(algo,impurity,maxDepth)
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
}

/**
* Method to train a decision tree model where the instances are represented as an RDD of
* (label, features) pairs. The method supports binary classification and regression. For the
* binary classification, the label for each instance should either be 0 or 1 to denote the two
* classes.
*
* @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
* training data
* @param algo algorithm, classification or regression
* @param impurity impurity criterion used for information gain calculation
* @param maxDepth maxDepth maximum depth of the tree
* @param numClasses number of classes for classification
* @return a DecisionTreeModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int): DecisionTreeModel = {
val strategy = new Strategy(algo,impurity,maxDepth)
maxDepth: Int,
numClasses: Int): DecisionTreeModel = {
val strategy = new Strategy(algo,impurity,maxDepth,numClasses)
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
}

// TODO: Add multiclass classification support

// TODO: Add sample weight support

/**
Expand All @@ -258,6 +282,7 @@ object DecisionTree extends Serializable with Logging {
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
* @param maxDepth maximum depth of the tree
* @param numClasses number of classes for classification
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
Expand All @@ -272,11 +297,12 @@ object DecisionTree extends Serializable with Logging {
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClasses: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
categoricalFeaturesInfo)
val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
// Converting from standard instance format to weighted input format for tree training
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
Expand Down Expand Up @@ -737,10 +763,26 @@ object DecisionTree extends Serializable with Logging {
}
}

//TODO: Make multiclass modification here
val predict = (leftCounts(1) + rightCounts(1)) / (leftTotalCount + rightTotalCount)
val totalCount = leftTotalCount + rightTotalCount

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
// Sum of count for each label
val leftRightCounts: Array[Double]
= leftCounts.zip(rightCounts)
.map{case (leftCount, rightCount) => leftCount + rightCount}

def indexOfLargest(array: Seq[Double]): Int = {
val result = array.foldLeft(-1,Double.MinValue,0) {
case ((maxIndex, maxValue, currentIndex), currentValue) =>
if(currentValue > maxValue) (currentIndex,currentValue,currentIndex+1)
else (maxIndex,maxValue,currentIndex+1)
}
if (result._1 < 0) result._1 else 0
}

val predict = indexOfLargest(leftRightCounts)
val prob = leftRightCounts(predict) / totalCount

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
case Regression =>
val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
Expand Down Expand Up @@ -793,8 +835,9 @@ object DecisionTree extends Serializable with Logging {
/**
* Extracts left and right split aggregates.
* @param binData Array[Double] of size 2*numFeatures*numSplits
* @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double],
* Array[Double]) where each array is of size(numFeature,2*(numSplits-1))
* @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\],
* Array[Array[Array[Double\]\]\]) where each array is of size(numFeature,
* (numBins - 1), numClasses)
*/
def extractLeftRightNodeAggregates(
binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
* @param maxDepth maximum depth of the tree
* @param numClassesForClassification number of classes for classification. Default value is 2
* leads to binary classification
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
Expand All @@ -37,20 +39,18 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* zero-indexed.
* @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is
* 128 MB.
* @param numClassesForClassification number of classes for classification. Default value is 2
* leads to binary classification
*
*/
@Experimental
class Strategy (
val algo: Algo,
val impurity: Impurity,
val maxDepth: Int,
val numClassesForClassification: Int = 2,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
val maxMemoryInMB: Int = 128,
val numClassesForClassification: Int = 2) extends Serializable {
val maxMemoryInMB: Int = 128) extends Serializable {

require(numClassesForClassification >= 2)
val isMultiClassification = numClassesForClassification > 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,19 @@ import org.apache.spark.annotation.DeveloperApi
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
@DeveloperApi
class InformationGainStats(
val gain: Double,
val impurity: Double,
val leftImpurity: Double,
val rightImpurity: Double,
val predict: Double) extends Serializable {
val predict: Double,
val prob: Double = 0.0) extends Serializable {

override def toString = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
.format(gain, impurity, leftImpurity, rightImpurity, predict)
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
.format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
}
}

0 comments on commit 6b912dc

Please sign in to comment.