Skip to content

Commit

Permalink
fix and test binary search
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed May 7, 2015
1 parent 2466322 commit fb30d79
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
23 changes: 15 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])

/**
* Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
* A bucket defined by splits x,y holds values in the range (x,y].
* A bucket defined by splits x,y holds values in the range [x,y).
* @group param
*/
val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
"Split points for mapping continuous features into buckets. With n splits, there are n+1" +
" buckets. A bucket defined by splits x,y holds values in the range (x,y].",
" buckets. A bucket defined by splits x,y holds values in the range [x,y).",
Bucketizer.checkSplits)

/** @group getParam */
Expand Down Expand Up @@ -85,7 +85,8 @@ private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])
transformSchema(dataset.schema)
val wrappedSplits = Array(Double.MinValue) ++ $(splits) ++ Array(Double.MaxValue)
val bucketizer = udf { feature: Double =>
Bucketizer.binarySearchForBuckets(wrappedSplits, feature) }
Bucketizer
.binarySearchForBuckets(wrappedSplits, feature, $(lowerInclusive), $(upperInclusive)) }
val newCol = bucketizer(dataset($(inputCol)))
val newField = prepOutputField(dataset.schema)
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
Expand All @@ -95,7 +96,6 @@ private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])
val attr = new NominalAttribute(
name = Some($(outputCol)),
isOrdinal = Some(true),
numValues = Some($(splits).size),
values = Some($(splits).map(_.toString)))

attr.toStructField()
Expand Down Expand Up @@ -131,20 +131,27 @@ object Bucketizer {
/**
* Binary searching in several buckets to place each data point.
*/
private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
private[feature] def binarySearchForBuckets(
splits: Array[Double],
feature: Double,
lowerInclusive: Boolean,
upperInclusive: Boolean): Double = {
if ((feature < splits.head && !lowerInclusive) || (feature > splits.last && !upperInclusive))
throw new Exception(s"Feature $feature out of bound, check your features or loose the" +
s" lower/upper bound constraint.")
var left = 0
var right = splits.length - 2
while (left <= right) {
val mid = left + (right - left) / 2
val split = splits(mid)
if ((feature > split) && (feature <= splits(mid + 1))) {
if ((feature >= split) && (feature < splits(mid + 1))) {
return mid
} else if (feature <= split) {
} else if (feature < split) {
right = mid - 1
} else {
left = mid + 1
}
}
throw new Exception("Failed to find a bucket.")
throw new Exception(s"Failed to find a bucket for feature $feature.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@

package org.apache.spark.ml.feature

import scala.util.Random

import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

class BucketizerSuite extends FunSuite with MLlibTestSparkContext {

test("Bucket continuous features with setter") {
val sqlContext = new SQLContext(sc)
val data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
val data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4, -0.9)
val buckets = Array(-0.5, 0.0, 0.5)
val bucketizedData = Array(2.0, 0.0, 2.0, 1.0, 3.0, 3.0, 1.0, 1.0)
val bucketizedData = Array(2.0, 1.0, 2.0, 1.0, 3.0, 3.0, 1.0, 1.0, 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(
data.zip(bucketizedData)).toDF("feature", "expected")

Expand All @@ -44,6 +48,23 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
}

test("Binary search for finding buckets") {
val data = Array.fill[Double](100)(Random.nextDouble())
val splits = Array.fill[Double](10)(Random.nextDouble()).sorted
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
val bsResult = Vectors.dense(
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
assert(bsResult ~== lsResult absTol 1e-5)
}
}

object BucketizerSuite {
private def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
var i = 0
while (i < splits.size) {
if (feature < splits(i)) return i
i += 1
}
i
}
}

0 comments on commit fb30d79

Please sign in to comment.