-
Notifications
You must be signed in to change notification settings - Fork 28.3k
Commit
…uator to get correct cross validation JIRA: https://issues.apache.org/jira/browse/SPARK-8468 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #6905 from viirya/cv_min and squashes the following commits: 930d3db [Liang-Chi Hsieh] Fix python unit test and add document. d632135 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into cv_min 16e3b2c [Liang-Chi Hsieh] Take the negative instead of reciprocal. c3dd8d9 [Liang-Chi Hsieh] For comments. b5f52c1 [Liang-Chi Hsieh] Add param to CrossValidator for choosing whether to maximize evaulation value. (cherry picked from commit 0b89951) Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,11 +20,12 @@ package org.apache.spark.ml.tuning | |
import org.apache.spark.SparkFunSuite | ||
import org.apache.spark.ml.{Estimator, Model} | ||
import org.apache.spark.ml.classification.LogisticRegression | ||
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator} | ||
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} | ||
import org.apache.spark.ml.param.ParamMap | ||
import org.apache.spark.ml.param.shared.HasInputCol | ||
import org.apache.spark.ml.regression.LinearRegression | ||
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput | ||
import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} | ||
import org.apache.spark.sql.{DataFrame, SQLContext} | ||
import org.apache.spark.sql.types.StructType | ||
|
||
|
@@ -57,6 +58,36 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { | |
assert(parent.getMaxIter === 10) | ||
} | ||
|
||
test("cross validation with linear regression") { | ||
val dataset = sqlContext.createDataFrame( | ||
sc.parallelize(LinearDataGenerator.generateLinearInput( | ||
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) | ||
|
||
val trainer = new LinearRegression | ||
val lrParamMaps = new ParamGridBuilder() | ||
.addGrid(trainer.regParam, Array(1000.0, 0.001)) | ||
.addGrid(trainer.maxIter, Array(0, 10)) | ||
.build() | ||
val eval = new RegressionEvaluator() | ||
val cv = new CrossValidator() | ||
.setEstimator(trainer) | ||
.setEstimatorParamMaps(lrParamMaps) | ||
.setEvaluator(eval) | ||
.setNumFolds(3) | ||
val cvModel = cv.fit(dataset) | ||
val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] | ||
assert(parent.getRegParam === 0.001) | ||
assert(parent.getMaxIter === 10) | ||
assert(cvModel.avgMetrics.length === lrParamMaps.length) | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
viirya
Author
Member
|
||
|
||
eval.setMetricName("r2") | ||
val cvModel2 = cv.fit(dataset) | ||
val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression] | ||
assert(parent2.getRegParam === 0.001) | ||
assert(parent2.getMaxIter === 10) | ||
assert(cvModel2.avgMetrics.length === lrParamMaps.length) | ||
} | ||
|
||
test("validateParams should check estimatorParamMaps") { | ||
import CrossValidatorSuite._ | ||
|
||
|
@viirya @jkbradley Seems this one breaks the 1.4 build? I see the following in jenkins build.