Skip to content

Commit

Permalink
categorical feature support
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent 6df35b9 commit dbb7ac1
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 53 deletions.
127 changes: 96 additions & 31 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
//Cache input RDD for speedup during multiple passes
input.cache()

val (splits, bins) = DecisionTree.find_splits_bins(input, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
logDebug("numSplits = " + bins(0).length)
strategy.numBins = bins(0).length

Expand All @@ -54,8 +54,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {

logDebug("algo = " + strategy.algo)



breakable {
for (level <- 0 until maxDepth){

Expand Down Expand Up @@ -185,10 +183,21 @@ object DecisionTree extends Serializable with Logging {
val featureIndex = filter.split.feature
val threshold = filter.split.threshold
val comparison = filter.comparison
comparison match {
case(-1) => if (features(featureIndex) > threshold) return false
case(0) => if (features(featureIndex) != threshold) return false
case(1) => if (features(featureIndex) <= threshold) return false
val categories = filter.split.categories
val isFeatureContinuous = filter.split.featureType == Continuous
val feature = features(featureIndex)
if (isFeatureContinuous){
comparison match {
case(-1) => if (feature > threshold) return false
case(1) => if (feature <= threshold) return false
}
} else {
val containsFeature = categories.contains(feature)
comparison match {
case(-1) => if (!containsFeature) return false
case(1) => if (containsFeature) return false
}

}
}
true
Expand All @@ -197,18 +206,34 @@ object DecisionTree extends Serializable with Logging {
/*Finds the right bin for the given feature*/
def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = {
//logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex))
//TODO: Do binary search
for (binIndex <- 0 until strategy.numBins) {
val bin = bins(featureIndex)(binIndex)
//TODO: Remove this requirement post basic functional
val lowThreshold = bin.lowSplit.threshold
val highThreshold = bin.highSplit.threshold
val features = labeledPoint.features
if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
return binIndex

val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinous){
//TODO: Do binary search
for (binIndex <- 0 until strategy.numBins) {
val bin = bins(featureIndex)(binIndex)
//TODO: Remove this requirement post basic functional
val lowThreshold = bin.lowSplit.threshold
val highThreshold = bin.highSplit.threshold
val features = labeledPoint.features
if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
return binIndex
}
}
throw new UnknownError("no bin was found for continuous variable.")
} else {
for (binIndex <- 0 until strategy.numBins) {
val bin = bins(featureIndex)(binIndex)
//TODO: Remove this requirement post basic functional
val category = bin.category
val features = labeledPoint.features
if (category == features(featureIndex)) {
return binIndex
}
}
throw new UnknownError("no bin was found for categorical variable.")

}
throw new UnknownError("no bin was found.")

}

Expand Down Expand Up @@ -565,7 +590,7 @@ object DecisionTree extends Serializable with Logging {
@return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an
Array[Array[Bin]] of size (numFeatures,numSplits1)
*/
def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = {
def findSplitsBins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = {

val count = input.count()

Expand Down Expand Up @@ -603,31 +628,71 @@ object DecisionTree extends Serializable with Logging {
logDebug("stride = " + stride)
for (index <- 0 until numBins-1) {
val sampleIndex = (index+1)*stride.toInt
val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous)
val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List())
splits(featureIndex)(index) = split
}
} else {
val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
for (index <- 0 until maxFeatureValue){
//TODO: Sort by centriod
val split = new Split(featureIndex,index,Categorical)
splits(featureIndex)(index) = split

require(maxFeatureValue < numBins, "number of categories should be less than number of bins")

val centriodForCategories
= sampledInput.map(lp => (lp.features(featureIndex),lp.label))
.groupBy(_._1).mapValues(x => x.map(_._2).sum / x.map(_._1).length)

//Checking for missing categorical variables
val fullCentriodForCategories = scala.collection.mutable.Map[Double,Double]()
for (i <- 0 until maxFeatureValue){
if (centriodForCategories.contains(i)){
fullCentriodForCategories(i) = centriodForCategories(i)
} else {
fullCentriodForCategories(i) = Double.MaxValue
}
}

val categoriesSortedByCentriod
= fullCentriodForCategories.toList sortBy {_._2}

logDebug("centriod for categorical variable = " + categoriesSortedByCentriod)

var categoriesForSplit = List[Double]()
categoriesSortedByCentriod.iterator.zipWithIndex foreach {
case((key, value), index) => {
categoriesForSplit = key :: categoriesForSplit
splits(featureIndex)(index) = new Split(featureIndex,Double.MinValue,Categorical,categoriesForSplit)
bins(featureIndex)(index) = {
if(index == 0) {
new Bin(new DummyCategoricalSplit(featureIndex,Categorical),splits(featureIndex)(0),Categorical,key)
}
else {
new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Categorical,key)
}
}
}
}
}
}

//Find all bins
for (featureIndex <- 0 until numFeatures){
bins(featureIndex)(0)
= new Bin(new DummyLowSplit(Continuous),splits(featureIndex)(0),Continuous)
for (index <- 1 until numBins - 1){
val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Continuous)
bins(featureIndex)(index) = bin
val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinous) { //bins for categorical variables are already assigned
bins(featureIndex)(0)
= new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0),Continuous,Double.MinValue)
for (index <- 1 until numBins - 1){
val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Continuous,Double.MinValue)
bins(featureIndex)(index) = bin
}
bins(featureIndex)(numBins-1)
= new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, Continuous),Continuous,Double.MinValue)
} else {
val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
for (i <- maxFeatureValue until numBins){
bins(featureIndex)(i)
= new Bin(new DummyCategoricalSplit(featureIndex,Categorical),new DummyCategoricalSplit(featureIndex,Categorical),Categorical,Double.MaxValue)
}
}
bins(featureIndex)(numBins-1)
= new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(Continuous),Continuous)
}

(splits,bins)
}
case MinMax => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ package org.apache.spark.mllib.tree.model

import org.apache.spark.mllib.tree.configuration.FeatureType._

case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType) {
case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType, category : Double) {

}
15 changes: 12 additions & 3 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree.model

import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.FeatureType._

class Node ( val id : Int,
val predict : Double,
Expand Down Expand Up @@ -49,10 +50,18 @@ class Node ( val id : Int,
if (isLeaf) {
predict
} else{
if (feature(split.get.feature) <= split.get.threshold) {
leftNode.get.predictIfLeaf(feature)
if (split.get.featureType == Continuous) {
if (feature(split.get.feature) <= split.get.threshold) {
leftNode.get.predictIfLeaf(feature)
} else {
rightNode.get.predictIfLeaf(feature)
}
} else {
rightNode.get.predictIfLeaf(feature)
if (split.get.categories.contains(feature(split.get.feature))) {
leftNode.get.predictIfLeaf(feature)
} else {
rightNode.get.predictIfLeaf(feature)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ package org.apache.spark.mllib.tree.model

import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType

case class Split(feature: Int, threshold : Double, featureType : FeatureType){
override def toString = "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType
case class Split(feature: Int, threshold : Double, featureType : FeatureType, categories : List[Double]){
override def toString =
"Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + ", categories = " + categories
}

class DummyLowSplit(kind : FeatureType) extends Split(Int.MinValue, Double.MinValue, kind)
class DummyLowSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MinValue, kind, List())

class DummyHighSplit(kind : FeatureType) extends Split(Int.MaxValue, Double.MaxValue, kind)
class DummyHighSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MaxValue, kind, List())

class DummyCategoricalSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MaxValue, kind, List())

Loading

0 comments on commit dbb7ac1

Please sign in to comment.