Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
DB Tsai committed May 8, 2015
1 parent a784321 commit f98e711
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,17 @@ class LogisticRegression
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

val (summarizer, labelSummarizer) = instances.treeAggregate(
(new MultivariateOnlineSummarizer, new MultiClassSummarizer))( {
case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer),
(label: Double, features: Vector)) =>
(summarizer.add(features), labelSummarizer.add(label))
}, {
case ((summarizer1: MultivariateOnlineSummarizer, labelSummarizer1: MultiClassSummarizer),
(summarizer2: MultivariateOnlineSummarizer, labelSummarizer2: MultiClassSummarizer)) =>
(summarizer1.merge(summarizer2), labelSummarizer1.merge(labelSummarizer2))
(new MultivariateOnlineSummarizer, new MultiClassSummarizer))(
seqOp = (c, v) => (c, v) match {
case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer),
(label: Double, features: Vector)) =>
(summarizer.add(features), labelSummarizer.add(label))
},
combOp = (c1, c2) => (c1, c2) match {
case ((summarizer1: MultivariateOnlineSummarizer,
classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer,
classSummarizer2: MultiClassSummarizer)) =>
(summarizer1.merge(summarizer2), classSummarizer1.merge(classSummarizer2))
})

val histogram = labelSummarizer.histogram
Expand All @@ -123,15 +126,17 @@ class LogisticRegression
val numFeatures = summarizer.mean.size

if (numInvalid != 0) {
logError("Classification labels should be in {0 to " + (numClasses - 1) + "}. " +
"Found " + numInvalid + " invalid labels.")
throw new SparkException("Input validation failed.")
val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
s"Found $numInvalid invalid labels."
logError(msg)
throw new SparkException(msg)
}

if (numClasses > 2) {
logError("Currently, LogisticRegression with ElasticNet in ML package only supports " +
"binary classification. Found " + numClasses + " in the input dataset.")
throw new SparkException("Input validation failed.")
val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " +
s"binary classification. Found $numClasses in the input dataset."
logError(msg)
throw new SparkException(msg)
}

val featuresMean = summarizer.mean.toArray
Expand Down Expand Up @@ -361,10 +366,13 @@ class MultiClassSummarizer private[ml] extends Serializable {
largeMap
}

/** @return The total invalid input counts. */
def countInvalid: Long = totalInvalidCnt

/** @return The number of distinct labels in the input dataset. */
def numClasses: Int = distinctMap.keySet.max + 1

/** @return The counts of each label in the input dataset. */
def histogram: Array[Long] = {
val result = Array.ofDim[Long](numClasses)
var i = 0
Expand All @@ -377,11 +385,20 @@ class MultiClassSummarizer private[ml] extends Serializable {
}

/**
* :: DeveloperApi ::
* LogisticAggregator computes the gradient and loss for binary logistic loss function, as used
* in binary classification for samples in sparse or dense vector in a online fashion.
*
* Note that multinomial logistic loss is not supported yet!
*
* Two LogisticAggregator can be merged together to have a summary of loss and gradient of
* the corresponding joint dataset.
*
* @param weights The weights/coefficients corresponding to the features.
* @param numClasses the number of possible outcomes for k classes classification problem in
* Multinomial Logistic Regression. By default, it is binary logistic regression
* so numClasses will be set to 2.
* Multinomial Logistic Regression.
* @param fitIntercept Whether to fit an intercept term.
* @param featuresStd The standard deviation values of the features.
* @param featuresMean The mean values of the features.
*/
private class LogisticAggregator(
weights: Vector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,16 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

val (summarizer, statCounter) = instances.treeAggregate(
(new MultivariateOnlineSummarizer, new StatCounter))( {
case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter),
(label: Double, features: Vector)) =>
(summarizer.add(features), statCounter.merge(label))
}, {
case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter),
(summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) =>
(summarizer1.merge(summarizer2), statCounter1.merge(statCounter2))
(new MultivariateOnlineSummarizer, new StatCounter))(
seqOp = (c, v) => (c, v) match {
case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter),
(label: Double, features: Vector)) =>
(summarizer.add(features), statCounter.merge(label))
},
combOp = (c1, c2) => (c1, c2) match {
case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter),
(summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) =>
(summarizer1.merge(summarizer2), statCounter1.merge(statCounter2))
})

val numFeatures = summarizer.mean.size
Expand Down

0 comments on commit f98e711

Please sign in to comment.