From a6cd31108f0d73ce6823daafe8447677e03cfd13 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 6 Aug 2014 12:28:35 -0700 Subject: [PATCH 01/83] [SPARK-2678][Core][SQL] A workaround for SPARK-2678 JIRA issues: - Main: [SPARK-2678](https://issues.apache.org/jira/browse/SPARK-2678) - Related: [SPARK-2874](https://issues.apache.org/jira/browse/SPARK-2874) Related PR: - #1715 This PR is both a fix for SPARK-2874 and a workaround for SPARK-2678. Fixing SPARK-2678 completely requires some API level changes that need further discussion, and we decided not to include it in Spark 1.1 release. As currently SPARK-2678 only affects Spark SQL scripts, this workaround is enough for Spark 1.1. Command line option handling logic in bash scripts looks somewhat dirty and duplicated, but it helps to provide a cleaner user interface as well as retain full downward compatibility for now. Author: Cheng Lian Closes #1801 from liancheng/spark-2874 and squashes the following commits: 8045d7a [Cheng Lian] Make sure test suites pass 8493a9e [Cheng Lian] Using eval to retain quoted arguments aed523f [Cheng Lian] Fixed typo in bin/spark-sql f12a0b1 [Cheng Lian] Worked arount SPARK-2678 daee105 [Cheng Lian] Fixed usage messages of all Spark SQL related scripts --- bin/beeline | 29 ++------ bin/spark-sql | 66 +++++++++++++++++-- .../spark/deploy/SparkSubmitArguments.scala | 39 ++++------- .../spark/deploy/SparkSubmitSuite.scala | 12 ++++ sbin/start-thriftserver.sh | 50 ++++++++++++-- .../hive/thriftserver/HiveThriftServer2.scala | 1 - .../sql/hive/thriftserver/CliSuite.scala | 19 +++--- .../thriftserver/HiveThriftServer2Suite.scala | 23 ++++--- 8 files changed, 164 insertions(+), 75 deletions(-) diff --git a/bin/beeline b/bin/beeline index 09fe366c609fa..1bda4dba50605 100755 --- a/bin/beeline +++ b/bin/beeline @@ -17,29 +17,14 @@ # limitations under the License. # -# Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +# +# Shell script for starting BeeLine -# Find the java binary -if [ -n "${JAVA_HOME}" ]; then - RUNNER="${JAVA_HOME}/bin/java" -else - if [ `command -v java` ]; then - RUNNER="java" - else - echo "JAVA_HOME is not set" >&2 - exit 1 - fi -fi +# Enter posix mode for bash +set -o posix -# Compute classpath using external script -classpath_output=$($FWDIR/bin/compute-classpath.sh) -if [[ "$?" != "0" ]]; then - echo "$classpath_output" - exit 1 -else - CLASSPATH=$classpath_output -fi +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" CLASS="org.apache.hive.beeline.BeeLine" -exec "$RUNNER" -cp "$CLASSPATH" $CLASS "$@" +exec "$FWDIR/bin/spark-class" $CLASS "$@" diff --git a/bin/spark-sql b/bin/spark-sql index bba7f897b19bc..61ebd8ab6dec8 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -23,14 +23,72 @@ # Enter posix mode for bash set -o posix +CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" + # Figure out where Spark is installed FWDIR="$(cd `dirname $0`/..; pwd)" -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - echo "Usage: ./sbin/spark-sql [options]" +function usage { + echo "Usage: ./sbin/spark-sql [options] [cli option]" + pattern="usage" + pattern+="\|Spark assembly has been built with Hive" + pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" + pattern+="\|Spark Command: " + pattern+="\|--help" + pattern+="\|=======" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + echo + echo "CLI options:" + $FWDIR/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 +} + +function ensure_arg_number { + arg_number=$1 + at_least=$2 + + if [[ $arg_number -lt $at_least ]]; then + usage + exit 1 + fi +} + +if [[ "$@" = --help ]] || [[ "$@" = -h ]]; then + usage exit 0 fi -CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" -exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ +CLI_ARGS=() +SUBMISSION_ARGS=() + +while (($#)); do + case $1 in + -d | --define | --database | -f | -h | --hiveconf | --hivevar | -i | -p) + ensure_arg_number $# 2 + CLI_ARGS+=($1); shift + CLI_ARGS+=($1); shift + ;; + + -e) + ensure_arg_number $# 2 + CLI_ARGS+=($1); shift + CLI_ARGS+=(\"$1\"); shift + ;; + + -s | --silent) + CLI_ARGS+=($1); shift + ;; + + -v | --verbose) + # Both SparkSubmit and SparkSQLCLIDriver recognizes -v | --verbose + CLI_ARGS+=($1) + SUBMISSION_ARGS+=($1); shift + ;; + + *) + SUBMISSION_ARGS+=($1); shift + ;; + esac +done + +eval exec "$FWDIR"/bin/spark-submit --class $CLASS ${SUBMISSION_ARGS[*]} spark-internal ${CLI_ARGS[*]} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 9391f24e71ed7..087dd4d633db0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -220,6 +220,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { /** Fill in values by parsing user options. */ private def parseOpts(opts: Seq[String]): Unit = { var inSparkOpts = true + val EQ_SEPARATED_OPT="""(--[^=]+)=(.+)""".r // Delineates parsing of Spark options from parsing of user options. parse(opts) @@ -322,33 +323,21 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { verbose = true parse(tail) + case EQ_SEPARATED_OPT(opt, value) :: tail => + parse(opt :: value :: tail) + + case value :: tail if value.startsWith("-") => + SparkSubmit.printErrorAndExit(s"Unrecognized option '$value'.") + case value :: tail => - if (inSparkOpts) { - value match { - // convert --foo=bar to --foo bar - case v if v.startsWith("--") && v.contains("=") && v.split("=").size == 2 => - val parts = v.split("=") - parse(Seq(parts(0), parts(1)) ++ tail) - case v if v.startsWith("-") => - val errMessage = s"Unrecognized option '$value'." - SparkSubmit.printErrorAndExit(errMessage) - case v => - primaryResource = - if (!SparkSubmit.isShell(v) && !SparkSubmit.isInternal(v)) { - Utils.resolveURI(v).toString - } else { - v - } - inSparkOpts = false - isPython = SparkSubmit.isPython(v) - parse(tail) + primaryResource = + if (!SparkSubmit.isShell(value) && !SparkSubmit.isInternal(value)) { + Utils.resolveURI(value).toString + } else { + value } - } else { - if (!value.isEmpty) { - childArgs += value - } - parse(tail) - } + isPython = SparkSubmit.isPython(value) + childArgs ++= tail case Nil => } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index a5cdcfb5de03b..7e1ef80c84561 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -106,6 +106,18 @@ class SparkSubmitSuite extends FunSuite with Matchers { appArgs.childArgs should be (Seq("some", "--weird", "args")) } + test("handles arguments to user program with name collision") { + val clArgs = Seq( + "--name", "myApp", + "--class", "Foo", + "userjar.jar", + "--master", "local", + "some", + "--weird", "args") + val appArgs = new SparkSubmitArguments(clArgs) + appArgs.childArgs should be (Seq("--master", "local", "some", "--weird", "args")) + } + test("handles YARN cluster mode") { val clArgs = Seq( "--deploy-mode", "cluster", diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index 8398e6f19b511..603f50ae13240 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -26,11 +26,53 @@ set -o posix # Figure out where Spark is installed FWDIR="$(cd `dirname $0`/..; pwd)" -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - echo "Usage: ./sbin/start-thriftserver [options]" +CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" + +function usage { + echo "Usage: ./sbin/start-thriftserver [options] [thrift server options]" + pattern="usage" + pattern+="\|Spark assembly has been built with Hive" + pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" + pattern+="\|Spark Command: " + pattern+="\|=======" + pattern+="\|--help" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + echo + echo "Thrift server options:" + $FWDIR/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 +} + +function ensure_arg_number { + arg_number=$1 + at_least=$2 + + if [[ $arg_number -lt $at_least ]]; then + usage + exit 1 + fi +} + +if [[ "$@" = --help ]] || [[ "$@" = -h ]]; then + usage exit 0 fi -CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" -exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ +THRIFT_SERVER_ARGS=() +SUBMISSION_ARGS=() + +while (($#)); do + case $1 in + --hiveconf) + ensure_arg_number $# 2 + THRIFT_SERVER_ARGS+=($1); shift + THRIFT_SERVER_ARGS+=($1); shift + ;; + + *) + SUBMISSION_ARGS+=($1); shift + ;; + esac +done + +eval exec "$FWDIR"/bin/spark-submit --class $CLASS ${SUBMISSION_ARGS[*]} spark-internal ${THRIFT_SERVER_ARGS[*]} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 08d3f983d9e71..6f7942aba314a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -40,7 +40,6 @@ private[hive] object HiveThriftServer2 extends Logging { val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { - logWarning("Error starting HiveThriftServer2 with given arguments") System.exit(-1) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 69f19f826a802..2bf8cfdcacd22 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io.{BufferedReader, InputStreamReader, PrintWriter} +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { @@ -27,15 +28,15 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { val METASTORE_PATH = TestUtils.getMetastorePath("cli") override def beforeAll() { - val pb = new ProcessBuilder( - "../../bin/spark-sql", - "--master", - "local", - "--hiveconf", - s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", - "--hiveconf", - "hive.metastore.warehouse.dir=" + WAREHOUSE_PATH) - + val jdbcUrl = s"jdbc:derby:;databaseName=$METASTORE_PATH;create=true" + val commands = + s"""../../bin/spark-sql + | --master local + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}="$jdbcUrl" + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$WAREHOUSE_PATH + """.stripMargin.split("\\s+") + + val pb = new ProcessBuilder(commands: _*) process = pb.start() outputWriter = new PrintWriter(process.getOutputStream, true) inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index b7b7c9957ac34..78bffa2607349 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -25,6 +25,7 @@ import java.io.{BufferedReader, InputStreamReader} import java.net.ServerSocket import java.sql.{Connection, DriverManager, Statement} +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.Logging @@ -63,16 +64,18 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt // Forking a new process to start the Hive Thrift server. The reason to do this is it is // hard to clean up Hive resources entirely, so we just start a new process and kill // that process for cleanup. - val defaultArgs = Seq( - "../../sbin/start-thriftserver.sh", - "--master local", - "--hiveconf", - "hive.root.logger=INFO,console", - "--hiveconf", - s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", - "--hiveconf", - s"hive.metastore.warehouse.dir=$WAREHOUSE_PATH") - val pb = new ProcessBuilder(defaultArgs ++ args) + val jdbcUrl = s"jdbc:derby:;databaseName=$METASTORE_PATH;create=true" + val command = + s"""../../sbin/start-thriftserver.sh + | --master local + | --hiveconf hive.root.logger=INFO,console + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}="$jdbcUrl" + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$METASTORE_PATH + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$HOST + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$PORT + """.stripMargin.split("\\s+") + + val pb = new ProcessBuilder(command ++ args: _*) val environment = pb.environment() environment.put("HIVE_SERVER2_THRIFT_PORT", PORT.toString) environment.put("HIVE_SERVER2_THRIFT_BIND_HOST", HOST) From d614967b0bad1e6c5277d612602ec0a653a00258 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Wed, 6 Aug 2014 12:58:24 -0700 Subject: [PATCH 02/83] [SPARK-2627] [PySpark] have the build enforce PEP 8 automatically As described in [SPARK-2627](https://issues.apache.org/jira/browse/SPARK-2627), we'd like Python code to automatically be checked for PEP 8 compliance by Jenkins. This pull request aims to do that. Notes: * We may need to install [`pep8`](https://pypi.python.org/pypi/pep8) on the build server. * I'm expecting tests to fail now that PEP 8 compliance is being checked as part of the build. I'm fine with cleaning up any remaining PEP 8 violations as part of this pull request. * I did not understand why the RAT and scalastyle reports are saved to text files. I did the same for the PEP 8 check, but only so that the console output style can match those for the RAT and scalastyle checks. The PEP 8 report is removed right after the check is complete. * Updates to the ["Contributing to Spark"](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) guide will be submitted elsewhere, as I don't believe that text is part of the Spark repo. Author: Nicholas Chammas Author: nchammas Closes #1744 from nchammas/master and squashes the following commits: 274b238 [Nicholas Chammas] [SPARK-2627] [PySpark] minor indentation changes 983d963 [nchammas] Merge pull request #5 from apache/master 1db5314 [nchammas] Merge pull request #4 from apache/master 0e0245f [Nicholas Chammas] [SPARK-2627] undo erroneous whitespace fixes bf30942 [Nicholas Chammas] [SPARK-2627] PEP8: comment spacing 6db9a44 [nchammas] Merge pull request #3 from apache/master 7b4750e [Nicholas Chammas] merge upstream changes 91b7584 [Nicholas Chammas] [SPARK-2627] undo unnecessary line breaks 44e3e56 [Nicholas Chammas] [SPARK-2627] use tox.ini to exclude files b09fae2 [Nicholas Chammas] don't wrap comments unnecessarily bfb9f9f [Nicholas Chammas] [SPARK-2627] keep up with the PEP 8 fixes 9da347f [nchammas] Merge pull request #2 from apache/master aa5b4b5 [Nicholas Chammas] [SPARK-2627] follow Spark bash style for if blocks d0a83b9 [Nicholas Chammas] [SPARK-2627] check that pep8 downloaded fine dffb5dd [Nicholas Chammas] [SPARK-2627] download pep8 at runtime a1ce7ae [Nicholas Chammas] [SPARK-2627] space out test report sections 21da538 [Nicholas Chammas] [SPARK-2627] it's PEP 8, not PEP8 6f4900b [Nicholas Chammas] [SPARK-2627] more misc PEP 8 fixes fe57ed0 [Nicholas Chammas] removing merge conflict backups 9c01d4c [nchammas] Merge pull request #1 from apache/master 9a66cb0 [Nicholas Chammas] resolving merge conflicts a31ccc4 [Nicholas Chammas] [SPARK-2627] miscellaneous PEP 8 fixes beaa9ac [Nicholas Chammas] [SPARK-2627] fail check on non-zero status 723ed39 [Nicholas Chammas] always delete the report file 0541ebb [Nicholas Chammas] [SPARK-2627] call Python linter from run-tests 12440fa [Nicholas Chammas] [SPARK-2627] add Scala linter 61c07b9 [Nicholas Chammas] [SPARK-2627] add Python linter 75ad552 [Nicholas Chammas] make check output style consistent --- dev/lint-python | 60 +++++++++++ dev/lint-scala | 23 ++++ dev/run-tests | 13 ++- dev/scalastyle | 2 +- python/pyspark/accumulators.py | 7 ++ python/pyspark/broadcast.py | 1 + python/pyspark/conf.py | 1 + python/pyspark/context.py | 25 +++-- python/pyspark/daemon.py | 5 +- python/pyspark/files.py | 1 + python/pyspark/java_gateway.py | 1 + python/pyspark/mllib/_common.py | 5 +- python/pyspark/mllib/classification.py | 8 ++ python/pyspark/mllib/clustering.py | 3 + python/pyspark/mllib/linalg.py | 2 + python/pyspark/mllib/random.py | 14 +-- python/pyspark/mllib/recommendation.py | 2 + python/pyspark/mllib/regression.py | 12 +++ python/pyspark/mllib/stat.py | 1 + python/pyspark/mllib/tests.py | 11 +- python/pyspark/mllib/tree.py | 4 +- python/pyspark/mllib/util.py | 1 + python/pyspark/rdd.py | 22 ++-- python/pyspark/rddsampler.py | 4 + python/pyspark/resultiterable.py | 2 + python/pyspark/serializers.py | 21 +++- python/pyspark/shuffle.py | 20 ++-- python/pyspark/sql.py | 66 ++++++++---- python/pyspark/storagelevel.py | 1 + python/pyspark/tests.py | 143 ++++++++++++++----------- python/test_support/userlibrary.py | 2 + tox.ini | 1 + 32 files changed, 348 insertions(+), 136 deletions(-) create mode 100755 dev/lint-python create mode 100755 dev/lint-scala diff --git a/dev/lint-python b/dev/lint-python new file mode 100755 index 0000000000000..4efddad839387 --- /dev/null +++ b/dev/lint-python @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" +PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" + +cd $SPARK_ROOT_DIR + +# Get pep8 at runtime so that we don't rely on it being installed on the build server. +#+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 +#+ TODOs: +#+ - Dynamically determine latest release version of pep8 and use that. +#+ - Download this from a more reliable source. (GitHub raw can be flaky, apparently. (?)) +PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8.py" +PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.5.7/pep8.py" + +curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" +curl_status=$? + +if [ $curl_status -ne 0 ]; then + echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." + exit $curl_status +fi + + +# There is no need to write this output to a file +#+ first, but we do so so that the check status can +#+ be output before the report, like with the +#+ scalastyle and RAT checks. +python $PEP8_SCRIPT_PATH ./python > "$PEP8_REPORT_PATH" +pep8_status=${PIPESTATUS[0]} #$? + +if [ $pep8_status -ne 0 ]; then + echo "PEP 8 checks failed." + cat "$PEP8_REPORT_PATH" +else + echo "PEP 8 checks passed." +fi + +rm -f "$PEP8_REPORT_PATH" +rm "$PEP8_SCRIPT_PATH" + +exit $pep8_status diff --git a/dev/lint-scala b/dev/lint-scala new file mode 100755 index 0000000000000..c676dfdf4f44e --- /dev/null +++ b/dev/lint-scala @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" + +"$SCRIPT_DIR/scalastyle" diff --git a/dev/run-tests b/dev/run-tests index d401c90f41d7b..0e24515d1376c 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -66,16 +66,25 @@ fi set -e set -o pipefail +echo "" echo "=========================================================================" echo "Running Apache RAT checks" echo "=========================================================================" dev/check-license +echo "" echo "=========================================================================" echo "Running Scala style checks" echo "=========================================================================" -dev/scalastyle +dev/lint-scala +echo "" +echo "=========================================================================" +echo "Running Python style checks" +echo "=========================================================================" +dev/lint-python + +echo "" echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" @@ -89,11 +98,13 @@ fi echo -e "q\n" | sbt/sbt $SBT_MAVEN_PROFILES_ARGS clean package assembly/assembly test | \ grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" +echo "" echo "=========================================================================" echo "Running PySpark tests" echo "=========================================================================" ./python/run-tests +echo "" echo "=========================================================================" echo "Detecting binary incompatibilites with MiMa" echo "=========================================================================" diff --git a/dev/scalastyle b/dev/scalastyle index d9f2b91a3a091..b53053a04ff42 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -30,5 +30,5 @@ if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" exit 1 else - echo -e "Scalastyle checks passed.\n" + echo -e "Scalastyle checks passed." fi diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 45d36e5d0e764..f133cf6f7befc 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -110,6 +110,7 @@ def _deserialize_accumulator(aid, zero_value, accum_param): class Accumulator(object): + """ A shared variable that can be accumulated, i.e., has a commutative and associative "add" operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=} @@ -166,6 +167,7 @@ def __repr__(self): class AccumulatorParam(object): + """ Helper object that defines how to accumulate values of a given type. """ @@ -186,6 +188,7 @@ def addInPlace(self, value1, value2): class AddingAccumulatorParam(AccumulatorParam): + """ An AccumulatorParam that uses the + operators to add values. Designed for simple types such as integers, floats, and lists. Requires the zero value for the underlying type @@ -210,6 +213,7 @@ def addInPlace(self, value1, value2): class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + """ This handler will keep polling updates from the same socket until the server is shutdown. @@ -228,7 +232,9 @@ def handle(self): # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1)) + class AccumulatorServer(SocketServer.TCPServer): + """ A simple TCP server that intercepts shutdown() in order to interrupt our continuous polling on the handler. @@ -239,6 +245,7 @@ def shutdown(self): self.server_shutdown = True SocketServer.TCPServer.shutdown(self) + def _start_update_server(): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 43f40f8783bfd..f3e64989ed564 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -45,6 +45,7 @@ def _from_id(bid): class Broadcast(object): + """ A broadcast variable created with L{SparkContext.broadcast()}. diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index b4c82f519bd53..fb716f6753a45 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -56,6 +56,7 @@ class SparkConf(object): + """ Configuration for a Spark application. Used to set various Spark parameters as key-value pairs. diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2e80eb50f2207..4001ecab5ea00 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -47,6 +47,7 @@ class SparkContext(object): + """ Main entry point for Spark functionality. A SparkContext represents the connection to a Spark cluster, and can be used to create L{RDD}s and @@ -213,7 +214,7 @@ def _ensure_initialized(cls, instance=None, gateway=None): if instance: if (SparkContext._active_spark_context and - SparkContext._active_spark_context != instance): + SparkContext._active_spark_context != instance): currentMaster = SparkContext._active_spark_context.master currentAppName = SparkContext._active_spark_context.appName callsite = SparkContext._active_spark_context._callsite @@ -406,7 +407,7 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass, - keyConverter, valueConverter, minSplits, batchSize) + keyConverter, valueConverter, minSplits, batchSize) return RDD(jrdd, self, ser) def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, @@ -437,7 +438,8 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf, batchSize) + valueClass, keyConverter, valueConverter, + jconf, batchSize) return RDD(jrdd, self, ser) def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, @@ -465,7 +467,8 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf, batchSize) + valueClass, keyConverter, valueConverter, + jconf, batchSize) return RDD(jrdd, self, ser) def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, @@ -496,7 +499,8 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf, batchSize) + valueClass, keyConverter, valueConverter, + jconf, batchSize) return RDD(jrdd, self, ser) def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, @@ -523,8 +527,9 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, jconf = self._dictToJavaMap(conf) batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() - jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, - keyConverter, valueConverter, jconf, batchSize) + jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, + valueClass, keyConverter, valueConverter, + jconf, batchSize) return RDD(jrdd, self, ser) def _checkpointFile(self, name, input_deserializer): @@ -555,8 +560,7 @@ def union(self, rdds): first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] rest = ListConverter().convert(rest, self._gateway._gateway_client) - return RDD(self._jsc.union(first, rest), self, - rdds[0]._jrdd_deserializer) + return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): """ @@ -568,8 +572,7 @@ def broadcast(self, value): pickleSer = PickleSerializer() pickled = pickleSer.dumps(value) jbroadcast = self._jsc.broadcast(bytearray(pickled)) - return Broadcast(jbroadcast.id(), value, jbroadcast, - self._pickled_broadcast_vars) + return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) def accumulator(self, value, accum_param=None): """ diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index b00da833d06f1..e73538baf0b93 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -43,7 +43,7 @@ def worker(sock): """ # Redirect stdout to stderr os.dup2(2, 1) - sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 + sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 signal.signal(SIGHUP, SIG_DFL) signal.signal(SIGCHLD, SIG_DFL) @@ -134,8 +134,7 @@ def handle_sigchld(*args): try: os.kill(worker_pid, signal.SIGKILL) except OSError: - pass # process already died - + pass # process already died if listen_sock in ready_fds: sock, addr = listen_sock.accept() diff --git a/python/pyspark/files.py b/python/pyspark/files.py index 57ee14eeb7776..331de9a9b2212 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -19,6 +19,7 @@ class SparkFiles(object): + """ Resolves paths to files added through L{SparkContext.addFile()}. diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 2c129679f47f3..37386ab0d7d49 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -65,6 +65,7 @@ def preexec_func(): # Create a thread to echo output from the GatewayServer, which is required # for Java log output to show up: class EchoOutputThread(Thread): + def __init__(self, stream): Thread.__init__(self) self.daemon = True diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index 9c1565affbdac..db341da85f865 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -72,9 +72,9 @@ # Python interpreter must agree on what endian the machine is. -DENSE_VECTOR_MAGIC = 1 +DENSE_VECTOR_MAGIC = 1 SPARSE_VECTOR_MAGIC = 2 -DENSE_MATRIX_MAGIC = 3 +DENSE_MATRIX_MAGIC = 3 LABELED_POINT_MAGIC = 4 @@ -443,6 +443,7 @@ def _serialize_rating(r): class RatingDeserializer(Serializer): + def loads(self, stream): length = struct.unpack("!i", stream.read(4))[0] ba = stream.read(length) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 5ec1a8084d269..ffdda7ee19302 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -31,6 +31,7 @@ class LogisticRegressionModel(LinearModel): + """A linear binary classification model derived from logistic regression. >>> data = [ @@ -60,6 +61,7 @@ class LogisticRegressionModel(LinearModel): >>> lrm.predict(SparseVector(2, {1: 0.0})) <= 0 True """ + def predict(self, x): _linear_predictor_typecheck(x, self._coeff) margin = _dot(x, self._coeff) + self._intercept @@ -72,6 +74,7 @@ def predict(self, x): class LogisticRegressionWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=1.0, regType=None, intercept=False): @@ -108,6 +111,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, class SVMModel(LinearModel): + """A support vector machine. >>> data = [ @@ -131,6 +135,7 @@ class SVMModel(LinearModel): >>> svm.predict(SparseVector(2, {0: -1.0})) <= 0 True """ + def predict(self, x): _linear_predictor_typecheck(x, self._coeff) margin = _dot(x, self._coeff) + self._intercept @@ -138,6 +143,7 @@ def predict(self, x): class SVMWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False): @@ -173,6 +179,7 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, class NaiveBayesModel(object): + """ Model for Naive Bayes classifiers. @@ -213,6 +220,7 @@ def predict(self, x): class NaiveBayes(object): + @classmethod def train(cls, data, lambda_=1.0): """ diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index b380e8f6c8725..a0630d1d5c58b 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -27,6 +27,7 @@ class KMeansModel(object): + """A clustering model derived from the k-means method. >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) @@ -55,6 +56,7 @@ class KMeansModel(object): >>> type(model.clusterCenters) """ + def __init__(self, centers): self.centers = centers @@ -76,6 +78,7 @@ def predict(self, x): class KMeans(object): + @classmethod def train(cls, data, k, maxIterations=100, runs=1, initializationMode="k-means||"): """Train a k-means clustering model.""" diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 54720c2324ca6..9a239abfbbeb1 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -27,6 +27,7 @@ class SparseVector(object): + """ A simple sparse vector class for passing data to MLlib. Users may alternatively pass SciPy's {scipy.sparse} data types. @@ -192,6 +193,7 @@ def __ne__(self, other): class Vectors(object): + """ Factory methods for working with vectors. Note that dense vectors are simply represented as NumPy array objects, so there is no need diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 36e710dbae7a8..eb496688b6eef 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -24,7 +24,9 @@ from pyspark.mllib._common import _deserialize_double, _deserialize_double_vector from pyspark.serializers import NoOpSerializer + class RandomRDDGenerators: + """ Generator methods for creating RDDs comprised of i.i.d samples from some distribution. @@ -53,7 +55,7 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): True """ jrdd = sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed) - uniform = RDD(jrdd, sc, NoOpSerializer()) + uniform = RDD(jrdd, sc, NoOpSerializer()) return uniform.map(lambda bytes: _deserialize_double(bytearray(bytes))) @staticmethod @@ -77,7 +79,7 @@ def normalRDD(sc, size, numPartitions=None, seed=None): True """ jrdd = sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed) - normal = RDD(jrdd, sc, NoOpSerializer()) + normal = RDD(jrdd, sc, NoOpSerializer()) return normal.map(lambda bytes: _deserialize_double(bytearray(bytes))) @staticmethod @@ -98,7 +100,7 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): True """ jrdd = sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed) - poisson = RDD(jrdd, sc, NoOpSerializer()) + poisson = RDD(jrdd, sc, NoOpSerializer()) return poisson.map(lambda bytes: _deserialize_double(bytearray(bytes))) @staticmethod @@ -118,7 +120,7 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ jrdd = sc._jvm.PythonMLLibAPI() \ .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) - uniform = RDD(jrdd, sc, NoOpSerializer()) + uniform = RDD(jrdd, sc, NoOpSerializer()) return uniform.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) @staticmethod @@ -138,7 +140,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ jrdd = sc._jvm.PythonMLLibAPI() \ .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) - normal = RDD(jrdd, sc, NoOpSerializer()) + normal = RDD(jrdd, sc, NoOpSerializer()) return normal.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) @staticmethod @@ -161,7 +163,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ jrdd = sc._jvm.PythonMLLibAPI() \ .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed) - poisson = RDD(jrdd, sc, NoOpSerializer()) + poisson = RDD(jrdd, sc, NoOpSerializer()) return poisson.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 6c385042ffa5f..e863fc249ec36 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -26,6 +26,7 @@ class MatrixFactorizationModel(object): + """A matrix factorisation model trained by regularized alternating least-squares. @@ -58,6 +59,7 @@ def predictAll(self, usersProducts): class ALS(object): + @classmethod def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): sc = ratings.context diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 041b119269427..d8792cf44872f 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -27,6 +27,7 @@ class LabeledPoint(object): + """ The features and labels of a data point. @@ -34,6 +35,7 @@ class LabeledPoint(object): @param features: Vector of features for this point (NumPy array, list, pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix) """ + def __init__(self, label, features): self.label = label if (type(features) == ndarray or type(features) == SparseVector @@ -49,7 +51,9 @@ def __str__(self): class LinearModel(object): + """A linear model that has a vector of coefficients and an intercept.""" + def __init__(self, weights, intercept): self._coeff = weights self._intercept = intercept @@ -64,6 +68,7 @@ def intercept(self): class LinearRegressionModelBase(LinearModel): + """A linear regression model. >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1) @@ -72,6 +77,7 @@ class LinearRegressionModelBase(LinearModel): >>> abs(lrmb.predict(SparseVector(2, {0: -1.03, 1: 7.777})) - 14.624) < 1e-6 True """ + def predict(self, x): """Predict the value of the dependent variable given a vector x""" """containing values for the independent variables.""" @@ -80,6 +86,7 @@ def predict(self, x): class LinearRegressionModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit. >>> from pyspark.mllib.regression import LabeledPoint @@ -111,6 +118,7 @@ class LinearRegressionModel(LinearRegressionModelBase): class LinearRegressionWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=1.0, regType=None, intercept=False): @@ -146,6 +154,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, class LassoModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit with an l_1 penalty term. @@ -178,6 +187,7 @@ class LassoModel(LinearRegressionModelBase): class LassoWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None): @@ -189,6 +199,7 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, class RidgeRegressionModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit with an l_2 penalty term. @@ -221,6 +232,7 @@ class RidgeRegressionModel(LinearRegressionModelBase): class RidgeRegressionWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None): diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 0a08a562d1f1f..982906b9d09f0 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -24,6 +24,7 @@ _serialize_double, _serialize_double_vector, \ _deserialize_double, _deserialize_double_matrix + class Statistics(object): @staticmethod diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 9d1e5be637a9a..6f3ec8ac94bac 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -39,6 +39,7 @@ class VectorTests(unittest.TestCase): + def test_serialize(self): sv = SparseVector(4, {1: 1, 3: 2}) dv = array([1., 2., 3., 4.]) @@ -81,6 +82,7 @@ def test_squared_distance(self): class ListTests(PySparkTestCase): + """ Test MLlib algorithms on plain lists, to make sure they're passed through as NumPy arrays. @@ -128,7 +130,7 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories dt_model = \ DecisionTree.trainClassifier(rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) @@ -168,7 +170,7 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories dt_model = \ DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) @@ -179,6 +181,7 @@ def test_regression(self): @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): + """ Test both vector operations and MLlib algorithms with SciPy sparse matrices, if SciPy is available. @@ -276,7 +279,7 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) @@ -315,7 +318,7 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 1e0006df75ac6..2518001ea0b93 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -25,7 +25,9 @@ from pyspark.mllib.regression import LabeledPoint from pyspark.serializers import NoOpSerializer + class DecisionTreeModel(object): + """ A decision tree model for classification or regression. @@ -77,6 +79,7 @@ def __str__(self): class DecisionTree(object): + """ Learning algorithm for a decision tree model for classification or regression. @@ -174,7 +177,6 @@ def trainRegressor(data, categoricalFeaturesInfo={}, categoricalFeaturesInfo, impurity, maxDepth, maxBins) - @staticmethod def train(data, algo, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins=100): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 639cda6350229..4962d05491c03 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -26,6 +26,7 @@ class MLUtils: + """ Helper methods to load, save and pre-process data used in MLlib. """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 309f5a9b6038d..30b834d2085cd 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -233,7 +233,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): def _toPickleSerialization(self): if (self._jrdd_deserializer == PickleSerializer() or - self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): + self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): return self else: return self._reserialize(BatchedSerializer(PickleSerializer(), 10)) @@ -1079,7 +1079,9 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl pickledRDD = self._toPickleSerialization() batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path, - outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf) + outputFormatClass, + keyClass, valueClass, + keyConverter, valueConverter, jconf) def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ @@ -1125,8 +1127,10 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No pickledRDD = self._toPickleSerialization() batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path, - outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, - jconf, compressionCodecClass) + outputFormatClass, + keyClass, valueClass, + keyConverter, valueConverter, + jconf, compressionCodecClass) def saveAsSequenceFile(self, path, compressionCodecClass=None): """ @@ -1348,7 +1352,7 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): outputSerializer = self.ctx._unbatched_serializer limit = (_parse_memory(self.ctx._conf.get( - "spark.python.worker.memory", "512m")) / 2) + "spark.python.worker.memory", "512m")) / 2) def add_shuffle_key(split, iterator): @@ -1430,12 +1434,12 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true') memory = _parse_memory(self.ctx._conf.get( - "spark.python.worker.memory", "512m")) + "spark.python.worker.memory", "512m")) agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + if spill else InMemoryMerger(agg) merger.mergeValues(iterator) return merger.iteritems() @@ -1444,7 +1448,7 @@ def combineLocally(iterator): def _mergeCombiners(iterator): merger = ExternalMerger(agg, memory, serializer) \ - if spill else InMemoryMerger(agg) + if spill else InMemoryMerger(agg) merger.mergeCombiners(iterator) return merger.iteritems() @@ -1588,7 +1592,7 @@ def sampleByKey(self, withReplacement, fractions, seed=None): """ for fraction in fractions.values(): assert fraction >= 0.0, "Negative fraction value: %s" % fraction - return self.mapPartitionsWithIndex( \ + return self.mapPartitionsWithIndex( RDDStratifiedSampler(withReplacement, fractions, seed).func, True) def subtractByKey(self, other, numPartitions=None): diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 2df000fdb08ca..55e247da0e4dc 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -20,6 +20,7 @@ class RDDSamplerBase(object): + def __init__(self, withReplacement, seed=None): try: import numpy @@ -95,6 +96,7 @@ def shuffle(self, vals): class RDDSampler(RDDSamplerBase): + def __init__(self, withReplacement, fraction, seed=None): RDDSamplerBase.__init__(self, withReplacement, seed) self._fraction = fraction @@ -113,7 +115,9 @@ def func(self, split, iterator): if self.getUniformSample(split) <= self._fraction: yield obj + class RDDStratifiedSampler(RDDSamplerBase): + def __init__(self, withReplacement, fractions, seed=None): RDDSamplerBase.__init__(self, withReplacement, seed) self._fractions = fractions diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index df34740fc8176..ef04c82866e6c 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -21,9 +21,11 @@ class ResultIterable(collections.Iterable): + """ A special result iterable. This is used because the standard iterator can not be pickled """ + def __init__(self, data): self.data = data self.index = 0 diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a10f85b55ad30..b35558db3e007 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -111,6 +111,7 @@ def __ne__(self, other): class FramedSerializer(Serializer): + """ Serializer that writes objects as a stream of (length, data) pairs, where C{length} is a 32-bit integer and data is C{length} bytes. @@ -162,6 +163,7 @@ def loads(self, obj): class BatchedSerializer(Serializer): + """ Serializes a stream of objects in batches by calling its wrapped Serializer with streams of objects. @@ -207,6 +209,7 @@ def __str__(self): class CartesianDeserializer(FramedSerializer): + """ Deserializes the JavaRDD cartesian() of two PythonRDDs. """ @@ -240,6 +243,7 @@ def __str__(self): class PairDeserializer(CartesianDeserializer): + """ Deserializes the JavaRDD zip() of two PythonRDDs. """ @@ -289,6 +293,7 @@ def _hack_namedtuple(cls): """ Make class generated by namedtuple picklable """ name = cls.__name__ fields = cls._fields + def __reduce__(self): return (_restore, (name, fields, tuple(self))) cls.__reduce__ = __reduce__ @@ -301,10 +306,11 @@ def _hijack_namedtuple(): if hasattr(collections.namedtuple, "__hijack"): return - global _old_namedtuple # or it will put in closure + global _old_namedtuple # or it will put in closure + def _copy_func(f): return types.FunctionType(f.func_code, f.func_globals, f.func_name, - f.func_defaults, f.func_closure) + f.func_defaults, f.func_closure) _old_namedtuple = _copy_func(collections.namedtuple) @@ -323,15 +329,16 @@ def namedtuple(name, fields, verbose=False, rename=False): # so only hack those in __main__ module for n, o in sys.modules["__main__"].__dict__.iteritems(): if (type(o) is type and o.__base__ is tuple - and hasattr(o, "_fields") - and "__reduce__" not in o.__dict__): - _hack_namedtuple(o) # hack inplace + and hasattr(o, "_fields") + and "__reduce__" not in o.__dict__): + _hack_namedtuple(o) # hack inplace _hijack_namedtuple() class PickleSerializer(FramedSerializer): + """ Serializes objects using Python's cPickle serializer: @@ -354,6 +361,7 @@ def dumps(self, obj): class MarshalSerializer(FramedSerializer): + """ Serializes objects using Python's Marshal serializer: @@ -367,9 +375,11 @@ class MarshalSerializer(FramedSerializer): class AutoSerializer(FramedSerializer): + """ Choose marshal or cPickle as serialization protocol autumatically """ + def __init__(self): FramedSerializer.__init__(self) self._type = None @@ -394,6 +404,7 @@ def loads(self, obj): class UTF8Deserializer(Serializer): + """ Deserializes streams written by String.getBytes. """ diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index e3923d1c36c57..2c68cd4921deb 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -45,7 +45,7 @@ def get_used_memory(): return int(line.split()[1]) >> 10 else: warnings.warn("Please install psutil to have better " - "support with spilling") + "support with spilling") if platform.system() == "Darwin": import resource rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss @@ -141,7 +141,7 @@ class ExternalMerger(Merger): This class works as follows: - - It repeatedly combine the items and save them in one dict in + - It repeatedly combine the items and save them in one dict in memory. - When the used memory goes above memory limit, it will split @@ -190,12 +190,12 @@ class ExternalMerger(Merger): MAX_TOTAL_PARTITIONS = 4096 def __init__(self, aggregator, memory_limit=512, serializer=None, - localdirs=None, scale=1, partitions=59, batch=1000): + localdirs=None, scale=1, partitions=59, batch=1000): Merger.__init__(self, aggregator) self.memory_limit = memory_limit # default serializer is only used for tests self.serializer = serializer or \ - BatchedSerializer(PickleSerializer(), 1024) + BatchedSerializer(PickleSerializer(), 1024) self.localdirs = localdirs or self._get_dirs() # number of partitions when spill data into disks self.partitions = partitions @@ -341,7 +341,7 @@ def _spill(self): self.pdata[i].clear() self.spills += 1 - gc.collect() # release the memory as much as possible + gc.collect() # release the memory as much as possible def iteritems(self): """ Return all merged items as iterator """ @@ -370,8 +370,8 @@ def _external_items(self): if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS and j < self.spills - 1 and get_used_memory() > hard_limit): - self.data.clear() # will read from disk again - gc.collect() # release the memory as much as possible + self.data.clear() # will read from disk again + gc.collect() # release the memory as much as possible for v in self._recursive_merged_items(i): yield v return @@ -409,9 +409,9 @@ def _recursive_merged_items(self, start): for i in range(start, self.partitions): subdirs = [os.path.join(d, "parts", str(i)) - for d in self.localdirs] + for d in self.localdirs] m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions) + subdirs, self.scale * self.partitions) m.pdata = [{} for _ in range(self.partitions)] limit = self._next_limit() @@ -419,7 +419,7 @@ def _recursive_merged_items(self, start): path = self._get_spill_dir(j) p = os.path.join(path, str(i)) m._partitioned_mergeCombiners( - self.serializer.load_stream(open(p))) + self.serializer.load_stream(open(p))) if get_used_memory() > limit: m._spill() diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index adc56e7ec0e2b..950e275adbf01 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -45,6 +45,7 @@ class DataType(object): + """Spark SQL DataType""" def __repr__(self): @@ -62,6 +63,7 @@ def __ne__(self, other): class PrimitiveTypeSingleton(type): + """Metaclass for PrimitiveType""" _instances = {} @@ -73,6 +75,7 @@ def __call__(cls): class PrimitiveType(DataType): + """Spark SQL PrimitiveType""" __metaclass__ = PrimitiveTypeSingleton @@ -83,6 +86,7 @@ def __eq__(self, other): class StringType(PrimitiveType): + """Spark SQL StringType The data type representing string values. @@ -90,6 +94,7 @@ class StringType(PrimitiveType): class BinaryType(PrimitiveType): + """Spark SQL BinaryType The data type representing bytearray values. @@ -97,6 +102,7 @@ class BinaryType(PrimitiveType): class BooleanType(PrimitiveType): + """Spark SQL BooleanType The data type representing bool values. @@ -104,6 +110,7 @@ class BooleanType(PrimitiveType): class TimestampType(PrimitiveType): + """Spark SQL TimestampType The data type representing datetime.datetime values. @@ -111,6 +118,7 @@ class TimestampType(PrimitiveType): class DecimalType(PrimitiveType): + """Spark SQL DecimalType The data type representing decimal.Decimal values. @@ -118,6 +126,7 @@ class DecimalType(PrimitiveType): class DoubleType(PrimitiveType): + """Spark SQL DoubleType The data type representing float values. @@ -125,6 +134,7 @@ class DoubleType(PrimitiveType): class FloatType(PrimitiveType): + """Spark SQL FloatType The data type representing single precision floating-point values. @@ -132,6 +142,7 @@ class FloatType(PrimitiveType): class ByteType(PrimitiveType): + """Spark SQL ByteType The data type representing int values with 1 singed byte. @@ -139,6 +150,7 @@ class ByteType(PrimitiveType): class IntegerType(PrimitiveType): + """Spark SQL IntegerType The data type representing int values. @@ -146,6 +158,7 @@ class IntegerType(PrimitiveType): class LongType(PrimitiveType): + """Spark SQL LongType The data type representing long values. If the any value is @@ -155,6 +168,7 @@ class LongType(PrimitiveType): class ShortType(PrimitiveType): + """Spark SQL ShortType The data type representing int values with 2 signed bytes. @@ -162,6 +176,7 @@ class ShortType(PrimitiveType): class ArrayType(DataType): + """Spark SQL ArrayType The data type representing list values. An ArrayType object @@ -187,10 +202,11 @@ def __init__(self, elementType, containsNull=False): def __str__(self): return "ArrayType(%s,%s)" % (self.elementType, - str(self.containsNull).lower()) + str(self.containsNull).lower()) class MapType(DataType): + """Spark SQL MapType The data type representing dict values. A MapType object comprises @@ -226,10 +242,11 @@ def __init__(self, keyType, valueType, valueContainsNull=True): def __repr__(self): return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, - str(self.valueContainsNull).lower()) + str(self.valueContainsNull).lower()) class StructField(DataType): + """Spark SQL StructField Represents a field in a StructType. @@ -263,10 +280,11 @@ def __init__(self, name, dataType, nullable): def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, - str(self.nullable).lower()) + str(self.nullable).lower()) class StructType(DataType): + """Spark SQL StructType The data type representing rows. @@ -291,7 +309,7 @@ def __init__(self, fields): def __repr__(self): return ("StructType(List(%s))" % - ",".join(str(field) for field in self.fields)) + ",".join(str(field) for field in self.fields)) def _parse_datatype_list(datatype_list_string): @@ -319,7 +337,7 @@ def _parse_datatype_list(datatype_list_string): _all_primitive_types = dict((k, v) for k, v in globals().iteritems() - if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) + if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) def _parse_datatype_string(datatype_string): @@ -459,16 +477,16 @@ def _infer_schema(row): items = sorted(row.items()) elif isinstance(row, tuple): - if hasattr(row, "_fields"): # namedtuple + if hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) - elif hasattr(row, "__FIELDS__"): # Row + elif hasattr(row, "__FIELDS__"): # Row items = zip(row.__FIELDS__, tuple(row)) elif all(isinstance(x, tuple) and len(x) == 2 for x in row): items = row else: raise ValueError("Can't infer schema from tuple") - elif hasattr(row, "__dict__"): # object + elif hasattr(row, "__dict__"): # object items = sorted(row.__dict__.items()) else: @@ -499,7 +517,7 @@ def _create_converter(obj, dataType): conv = lambda o: tuple(o.get(n) for n in names) elif isinstance(obj, tuple): - if hasattr(obj, "_fields"): # namedtuple + if hasattr(obj, "_fields"): # namedtuple conv = tuple elif hasattr(obj, "__FIELDS__"): conv = tuple @@ -508,7 +526,7 @@ def _create_converter(obj, dataType): else: raise ValueError("unexpected tuple") - elif hasattr(obj, "__dict__"): # object + elif hasattr(obj, "__dict__"): # object conv = lambda o: [o.__dict__.get(n, None) for n in names] nested = any(_has_struct(f.dataType) for f in dataType.fields) @@ -660,7 +678,7 @@ def _infer_schema_type(obj, dataType): assert len(fs) == len(obj), \ "Obj(%s) have different length with fields(%s)" % (obj, fs) fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) - for o, f in zip(obj, fs)] + for o, f in zip(obj, fs)] return StructType(fields) else: @@ -683,6 +701,7 @@ def _infer_schema_type(obj, dataType): StructType: (tuple, list), } + def _verify_type(obj, dataType): """ Verify the type of obj against dataType, raise an exception if @@ -728,7 +747,7 @@ def _verify_type(obj, dataType): elif isinstance(dataType, StructType): if len(obj) != len(dataType.fields): raise ValueError("Length of object (%d) does not match with" - "length of fields (%d)" % (len(obj), len(dataType.fields))) + "length of fields (%d)" % (len(obj), len(dataType.fields))) for v, f in zip(obj, dataType.fields): _verify_type(v, f.dataType) @@ -861,6 +880,7 @@ def __reduce__(self): raise Exception("unexpected data type: %s" % dataType) class Row(tuple): + """ Row in SchemaRDD """ __DATATYPE__ = dataType __FIELDS__ = tuple(f.name for f in dataType.fields) @@ -872,7 +892,7 @@ class Row(tuple): def __repr__(self): # call collect __repr__ for nested objects return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__FIELDS__)) + for n in self.__FIELDS__)) def __reduce__(self): return (_restore_object, (self.__DATATYPE__, tuple(self))) @@ -881,6 +901,7 @@ def __reduce__(self): class SQLContext: + """Main entry point for SparkSQL functionality. A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as @@ -960,7 +981,7 @@ def registerFunction(self, name, f, returnType=StringType()): env = MapConverter().convert(self._sc.environment, self._sc._gateway._gateway_client) includes = ListConverter().convert(self._sc._python_includes, - self._sc._gateway._gateway_client) + self._sc._gateway._gateway_client) self._ssql_ctx.registerPython(name, bytearray(CloudPickleSerializer().dumps(command)), env, @@ -1012,7 +1033,7 @@ def inferSchema(self, rdd): first = rdd.first() if not first: raise ValueError("The first row in RDD is empty, " - "can not infer schema") + "can not infer schema") if type(first) is dict: warnings.warn("Using RDD of dict to inferSchema is deprecated") @@ -1287,6 +1308,7 @@ def uncacheTable(self, tableName): class HiveContext(SQLContext): + """A variant of Spark SQL that integrates with data stored in Hive. Configuration for Hive is read from hive-site.xml on the classpath. @@ -1327,6 +1349,7 @@ def hql(self, hqlQuery): class LocalHiveContext(HiveContext): + """Starts up an instance of hive where metadata is stored locally. An in-process metadata data is created with data stored in ./metadata. @@ -1357,7 +1380,7 @@ class LocalHiveContext(HiveContext): def __init__(self, sparkContext, sqlContext=None): HiveContext.__init__(self, sparkContext, sqlContext) warnings.warn("LocalHiveContext is deprecated. " - "Use HiveContext instead.", DeprecationWarning) + "Use HiveContext instead.", DeprecationWarning) def _get_hive_ctx(self): return self._jvm.LocalHiveContext(self._jsc.sc()) @@ -1376,6 +1399,7 @@ def _create_row(fields, values): class Row(tuple): + """ A row in L{SchemaRDD}. The fields in it can be accessed like attributes. @@ -1417,7 +1441,6 @@ def __new__(self, *args, **kwargs): else: raise ValueError("No args or kwargs") - # let obect acs like class def __call__(self, *args): """create new Row object""" @@ -1443,12 +1466,13 @@ def __reduce__(self): def __repr__(self): if hasattr(self, "__FIELDS__"): return "Row(%s)" % ", ".join("%s=%r" % (k, v) - for k, v in zip(self.__FIELDS__, self)) + for k, v in zip(self.__FIELDS__, self)) else: return "" % ", ".join(self) class SchemaRDD(RDD): + """An RDD of L{Row} objects that has an associated schema. The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can @@ -1659,7 +1683,7 @@ def subtract(self, other, numPartitions=None): rdd = self._jschema_rdd.subtract(other._jschema_rdd) else: rdd = self._jschema_rdd.subtract(other._jschema_rdd, - numPartitions) + numPartitions) return SchemaRDD(rdd, self.sql_ctx) else: raise ValueError("Can only subtract another SchemaRDD") @@ -1686,9 +1710,9 @@ def _test(): jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', + '"field6":[{"field7": "row2"}]}', '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' + '"field3":{"field4":33, "field5": []}}' ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 5d77a131f2856..2aa0fb9d2c1ed 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -19,6 +19,7 @@ class StorageLevel: + """ Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 4ac94ba729d35..88a61176e51ab 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -62,53 +62,53 @@ def setUp(self): self.N = 1 << 16 self.l = [i for i in xrange(self.N)] self.data = zip(self.l, self.l) - self.agg = Aggregator(lambda x: [x], - lambda x, y: x.append(y) or x, - lambda x, y: x.extend(y) or x) + self.agg = Aggregator(lambda x: [x], + lambda x, y: x.append(y) or x, + lambda x, y: x.extend(y) or x) def test_in_memory(self): m = InMemoryMerger(self.agg) m.mergeValues(self.data) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) m = InMemoryMerger(self.agg) m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) def test_small_dataset(self): m = ExternalMerger(self.agg, 1000) m.mergeValues(self.data) self.assertEqual(m.spills, 0) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) m = ExternalMerger(self.agg, 1000) m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) self.assertEqual(m.spills, 0) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) def test_medium_dataset(self): m = ExternalMerger(self.agg, 10) m.mergeValues(self.data) self.assertTrue(m.spills >= 1) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) m = ExternalMerger(self.agg, 10) m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3)) self.assertTrue(m.spills >= 1) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N)) * 3) + sum(xrange(self.N)) * 3) def test_huge_dataset(self): m = ExternalMerger(self.agg, 10) m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) self.assertTrue(m.spills >= 1) self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), - self.N * 10) + self.N * 10) m._cleanup() @@ -188,6 +188,7 @@ def test_add_py_file(self): log4j = self.sc._jvm.org.apache.log4j old_level = log4j.LogManager.getRootLogger().getLevel() log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) + def func(x): from userlibrary import UserClass return UserClass().hello() @@ -355,8 +356,8 @@ def test_sequencefiles(self): self.assertEqual(doubles, ed) bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BytesWritable").collect()) + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BytesWritable").collect()) ebs = [(1, bytearray('aa', 'utf-8')), (1, bytearray('aa', 'utf-8')), (2, bytearray('aa', 'utf-8')), @@ -428,9 +429,9 @@ def test_sequencefiles(self): self.assertEqual(clazz[0], ec) unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", - "org.apache.hadoop.io.Text", - "org.apache.spark.api.python.TestWritable", - batchSize=1).collect()) + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable", + batchSize=1).collect()) self.assertEqual(unbatched_clazz[0], ec) def test_oldhadoop(self): @@ -443,7 +444,7 @@ def test_oldhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") - oldconf = {"mapred.input.dir" : hellopath} + oldconf = {"mapred.input.dir": hellopath} hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", "org.apache.hadoop.io.LongWritable", "org.apache.hadoop.io.Text", @@ -462,7 +463,7 @@ def test_newhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") - newconf = {"mapred.input.dir" : hellopath} + newconf = {"mapred.input.dir": hellopath} hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", "org.apache.hadoop.io.LongWritable", "org.apache.hadoop.io.Text", @@ -517,6 +518,7 @@ def test_converters(self): (u'\x03', [2.0])] self.assertEqual(maps, em) + class TestOutputFormat(PySparkTestCase): def setUp(self): @@ -574,8 +576,8 @@ def test_sequencefiles(self): def test_oldhadoop(self): basepath = self.tempdir.name dict_data = [(1, {}), - (1, {"row1" : 1.0}), - (2, {"row2" : 2.0})] + (1, {"row1": 1.0}), + (2, {"row2": 2.0})] self.sc.parallelize(dict_data).saveAsHadoopFile( basepath + "/oldhadoop/", "org.apache.hadoop.mapred.SequenceFileOutputFormat", @@ -589,12 +591,13 @@ def test_oldhadoop(self): self.assertEqual(result, dict_data) conf = { - "mapred.output.format.class" : "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class" : "org.apache.hadoop.io.MapWritable", - "mapred.output.dir" : basepath + "/olddataset/"} + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.hadoop.io.MapWritable", + "mapred.output.dir": basepath + "/olddataset/" + } self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) - input_conf = {"mapred.input.dir" : basepath + "/olddataset/"} + input_conf = {"mapred.input.dir": basepath + "/olddataset/"} old_dataset = sorted(self.sc.hadoopRDD( "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -622,14 +625,17 @@ def test_newhadoop(self): valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) self.assertEqual(result, array_data) - conf = {"mapreduce.outputformat.class" : - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class" : "org.apache.spark.api.python.DoubleArrayWritable", - "mapred.output.dir" : basepath + "/newdataset/"} - self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset(conf, + conf = { + "mapreduce.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", + "mapred.output.dir": basepath + "/newdataset/" + } + self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( + conf, valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - input_conf = {"mapred.input.dir" : basepath + "/newdataset/"} + input_conf = {"mapred.input.dir": basepath + "/newdataset/"} new_dataset = sorted(self.sc.newAPIHadoopRDD( "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -640,7 +646,7 @@ def test_newhadoop(self): def test_newolderror(self): basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( basepath + "/newolderror/saveAsHadoopFile/", "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) @@ -650,7 +656,7 @@ def test_newolderror(self): def test_bad_inputs(self): basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( basepath + "/badinputs/saveAsHadoopFile/", "org.apache.hadoop.mapred.NotValidOutputFormat")) @@ -685,30 +691,32 @@ def test_reserialization(self): result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) self.assertEqual(result1, data) - rdd.saveAsHadoopFile(basepath + "/reserialize/hadoop", - "org.apache.hadoop.mapred.SequenceFileOutputFormat") + rdd.saveAsHadoopFile( + basepath + "/reserialize/hadoop", + "org.apache.hadoop.mapred.SequenceFileOutputFormat") result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) self.assertEqual(result2, data) - rdd.saveAsNewAPIHadoopFile(basepath + "/reserialize/newhadoop", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") + rdd.saveAsNewAPIHadoopFile( + basepath + "/reserialize/newhadoop", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) self.assertEqual(result3, data) conf4 = { - "mapred.output.format.class" : "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.dir" : basepath + "/reserialize/dataset"} + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.dir": basepath + "/reserialize/dataset"} rdd.saveAsHadoopDataset(conf4) result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) self.assertEqual(result4, data) - conf5 = {"mapreduce.outputformat.class" : - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.dir" : basepath + "/reserialize/newdataset"} + conf5 = {"mapreduce.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.dir": basepath + "/reserialize/newdataset"} rdd.saveAsNewAPIHadoopDataset(conf5) result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) self.assertEqual(result5, data) @@ -719,25 +727,28 @@ def test_unbatched_save_and_read(self): self.sc.parallelize(ei, numSlices=len(ei)).saveAsSequenceFile( basepath + "/unbatched/") - unbatched_sequence = sorted(self.sc.sequenceFile(basepath + "/unbatched/", + unbatched_sequence = sorted(self.sc.sequenceFile( + basepath + "/unbatched/", batchSize=1).collect()) self.assertEqual(unbatched_sequence, ei) - unbatched_hadoopFile = sorted(self.sc.hadoopFile(basepath + "/unbatched/", + unbatched_hadoopFile = sorted(self.sc.hadoopFile( + basepath + "/unbatched/", "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.Text", batchSize=1).collect()) self.assertEqual(unbatched_hadoopFile, ei) - unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile(basepath + "/unbatched/", + unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile( + basepath + "/unbatched/", "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.Text", batchSize=1).collect()) self.assertEqual(unbatched_newAPIHadoopFile, ei) - oldconf = {"mapred.input.dir" : basepath + "/unbatched/"} + oldconf = {"mapred.input.dir": basepath + "/unbatched/"} unbatched_hadoopRDD = sorted(self.sc.hadoopRDD( "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -746,7 +757,7 @@ def test_unbatched_save_and_read(self): batchSize=1).collect()) self.assertEqual(unbatched_hadoopRDD, ei) - newconf = {"mapred.input.dir" : basepath + "/unbatched/"} + newconf = {"mapred.input.dir": basepath + "/unbatched/"} unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD( "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -763,7 +774,9 @@ def test_malformed_RDD(self): self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( basepath + "/malformed/sequence")) + class TestDaemon(unittest.TestCase): + def connect(self, port): from socket import socket, AF_INET, SOCK_STREAM sock = socket(AF_INET, SOCK_STREAM) @@ -810,12 +823,15 @@ def test_termination_sigterm(self): class TestWorker(PySparkTestCase): + def test_cancel_task(self): temp = tempfile.NamedTemporaryFile(delete=True) temp.close() path = temp.name + def sleep(x): - import os, time + import os + import time with open(path, 'w') as f: f.write("%d %d" % (os.getppid(), os.getpid())) time.sleep(100) @@ -845,7 +861,7 @@ def run(): os.kill(worker_pid, 0) time.sleep(0.1) except OSError: - break # worker was killed + break # worker was killed else: self.fail("worker has not been killed after 5 seconds") @@ -855,12 +871,13 @@ def run(): self.fail("daemon had been killed") def test_fd_leak(self): - N = 1100 # fd limit is 1024 by default + N = 1100 # fd limit is 1024 by default rdd = self.sc.parallelize(range(N), N) self.assertEquals(N, rdd.count()) class TestSparkSubmit(unittest.TestCase): + def setUp(self): self.programDir = tempfile.mkdtemp() self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit") @@ -953,9 +970,9 @@ def test_module_dependency_on_cluster(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen( - [self.sparkSubmit, "--py-files", zip, "--master", "local-cluster[1,1,512]", script], - stdout=subprocess.PIPE) + proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master", + "local-cluster[1,1,512]", script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out) @@ -981,6 +998,7 @@ def test_single_script_on_cluster(self): @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): + """General PySpark tests that depend on scipy """ def test_serialize(self): @@ -993,15 +1011,16 @@ def test_serialize(self): @unittest.skipIf(not _have_numpy, "NumPy not installed") class NumPyTests(PySparkTestCase): + """General PySpark tests that depend on numpy """ def test_statcounter_array(self): - x = self.sc.parallelize([np.array([1.0,1.0]), np.array([2.0,2.0]), np.array([3.0,3.0])]) + x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])]) s = x.stats() - self.assertSequenceEqual([2.0,2.0], s.mean().tolist()) - self.assertSequenceEqual([1.0,1.0], s.min().tolist()) - self.assertSequenceEqual([3.0,3.0], s.max().tolist()) - self.assertSequenceEqual([1.0,1.0], s.sampleStdev().tolist()) + self.assertSequenceEqual([2.0, 2.0], s.mean().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.min().tolist()) + self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) if __name__ == "__main__": diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py index 8e4a6292bc17c..73fd26e71f10d 100755 --- a/python/test_support/userlibrary.py +++ b/python/test_support/userlibrary.py @@ -19,6 +19,8 @@ Used to test shipping of code depenencies with SparkContext.addPyFile(). """ + class UserClass(object): + def hello(self): return "Hello World!" diff --git a/tox.ini b/tox.ini index 44766e529bf7f..a1fefdd0e176f 100644 --- a/tox.ini +++ b/tox.ini @@ -15,3 +15,4 @@ [pep8] max-line-length=100 +exclude=cloudpickle.py From 4e982364426c7d65032e8006c63ca4f9a0d40470 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Wed, 6 Aug 2014 13:10:33 -0700 Subject: [PATCH 03/83] SPARK-2566. Update ShuffleWriteMetrics incrementally I haven't tested this out on a cluster yet, but wanted to make sure the approach (passing ShuffleWriteMetrics down to DiskBlockObjectWriter) was ok Author: Sandy Ryza Closes #1481 from sryza/sandy-spark-2566 and squashes the following commits: 8090d88 [Sandy Ryza] Fix ExternalSorter b2a62ed [Sandy Ryza] Fix more test failures 8be6218 [Sandy Ryza] Fix test failures and mark a couple variables private c5e68e5 [Sandy Ryza] SPARK-2566. Update ShuffleWriteMetrics incrementally --- .../apache/spark/executor/TaskMetrics.scala | 4 +- .../shuffle/hash/HashShuffleWriter.scala | 16 ++-- .../shuffle/sort/SortShuffleWriter.scala | 16 ++-- .../apache/spark/storage/BlockManager.scala | 12 +-- .../spark/storage/BlockObjectWriter.scala | 77 ++++++++++--------- .../spark/storage/ShuffleBlockManager.scala | 9 ++- .../collection/ExternalAppendOnlyMap.scala | 18 +++-- .../util/collection/ExternalSorter.scala | 17 ++-- .../storage/BlockObjectWriterSuite.scala | 65 ++++++++++++++++ .../spark/storage/DiskBlockManagerSuite.scala | 9 ++- .../spark/tools/StoragePerfTester.scala | 3 +- 11 files changed, 164 insertions(+), 82 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 56cd8723a3a22..11a6e10243211 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -190,10 +190,10 @@ class ShuffleWriteMetrics extends Serializable { /** * Number of bytes written for the shuffle by this task */ - var shuffleBytesWritten: Long = _ + @volatile var shuffleBytesWritten: Long = _ /** * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ - var shuffleWriteTime: Long = _ + @volatile var shuffleWriteTime: Long = _ } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 45d3b8b9b8725..51e454d9313c9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -39,10 +39,14 @@ private[spark] class HashShuffleWriter[K, V]( // we don't try deleting files, etc twice. private var stopping = false + private val writeMetrics = new ShuffleWriteMetrics() + metrics.shuffleWriteMetrics = Some(writeMetrics) + private val blockManager = SparkEnv.get.blockManager private val shuffleBlockManager = blockManager.shuffleBlockManager private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) - private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser) + private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser, + writeMetrics) /** Write a bunch of records to this task's output */ override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { @@ -99,22 +103,12 @@ private[spark] class HashShuffleWriter[K, V]( private def commitWritesAndBuildStatus(): MapStatus = { // Commit the writes. Get the size of each bucket block (total block size). - var totalBytes = 0L - var totalTime = 0L val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter => writer.commitAndClose() val size = writer.fileSegment().length - totalBytes += size - totalTime += writer.timeWriting() MapOutputTracker.compressSize(size) } - // Update shuffle metrics. - val shuffleMetrics = new ShuffleWriteMetrics - shuffleMetrics.shuffleBytesWritten = totalBytes - shuffleMetrics.shuffleWriteTime = totalTime - metrics.shuffleWriteMetrics = Some(shuffleMetrics) - new MapStatus(blockManager.blockManagerId, compressedSizes) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 24db2f287a47b..e54e6383d2ccc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -52,6 +52,9 @@ private[spark] class SortShuffleWriter[K, V, C]( private var mapStatus: MapStatus = null + private val writeMetrics = new ShuffleWriteMetrics() + context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics) + /** Write a bunch of records to this task's output */ override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { // Get an iterator with the elements for each partition ID @@ -84,13 +87,10 @@ private[spark] class SortShuffleWriter[K, V, C]( val offsets = new Array[Long](numPartitions + 1) val lengths = new Array[Long](numPartitions) - // Statistics - var totalBytes = 0L - var totalTime = 0L - for ((id, elements) <- partitions) { if (elements.hasNext) { - val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize) + val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize, + writeMetrics) for (elem <- elements) { writer.write(elem) } @@ -98,18 +98,12 @@ private[spark] class SortShuffleWriter[K, V, C]( val segment = writer.fileSegment() offsets(id + 1) = segment.offset + segment.length lengths(id) = segment.length - totalTime += writer.timeWriting() - totalBytes += segment.length } else { // The partition is empty; don't create a new writer to avoid writing headers, etc offsets(id + 1) = offsets(id) } } - val shuffleMetrics = new ShuffleWriteMetrics - shuffleMetrics.shuffleBytesWritten = totalBytes - shuffleMetrics.shuffleWriteTime = totalTime - context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics) context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3876cf43e2a7d..8d21b02b747ff 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ -import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleWriteMetrics} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -562,17 +562,19 @@ private[spark] class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. - * The Block will be appended to the File specified by filename. This is currently used for - * writing shuffle files out. Callers should handle error cases. + * The Block will be appended to the File specified by filename. Callers should handle error + * cases. */ def getDiskWriter( blockId: BlockId, file: File, serializer: Serializer, - bufferSize: Int): BlockObjectWriter = { + bufferSize: Int, + writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites) + new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites, + writeMetrics) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 01d46e1ffc960..adda971fd7b47 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -22,6 +22,7 @@ import java.nio.channels.FileChannel import org.apache.spark.Logging import org.apache.spark.serializer.{SerializationStream, Serializer} +import org.apache.spark.executor.ShuffleWriteMetrics /** * An interface for writing JVM objects to some underlying storage. This interface allows @@ -60,41 +61,26 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { * This is only valid after commitAndClose() has been called. */ def fileSegment(): FileSegment - - /** - * Cumulative time spent performing blocking writes, in ns. - */ - def timeWriting(): Long - - /** - * Number of bytes written so far - */ - def bytesWritten: Long } -/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */ +/** + * BlockObjectWriter which writes directly to a file on disk. Appends to the given file. + * The given write metrics will be updated incrementally, but will not necessarily be current until + * commitAndClose is called. + */ private[spark] class DiskBlockObjectWriter( blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int, compressStream: OutputStream => OutputStream, - syncWrites: Boolean) + syncWrites: Boolean, + writeMetrics: ShuffleWriteMetrics) extends BlockObjectWriter(blockId) with Logging { - /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { - def timeWriting = _timeWriting - private var _timeWriting = 0L - - private def callWithTiming(f: => Unit) = { - val start = System.nanoTime() - f - _timeWriting += (System.nanoTime() - start) - } - def write(i: Int): Unit = callWithTiming(out.write(i)) override def write(b: Array[Byte]) = callWithTiming(out.write(b)) override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) @@ -111,7 +97,11 @@ private[spark] class DiskBlockObjectWriter( private val initialPosition = file.length() private var finalPosition: Long = -1 private var initialized = false - private var _timeWriting = 0L + + /** Calling channel.position() to update the write metrics can be a little bit expensive, so we + * only call it every N writes */ + private var writesSinceMetricsUpdate = 0 + private var lastPosition = initialPosition override def open(): BlockObjectWriter = { fos = new FileOutputStream(file, true) @@ -128,14 +118,11 @@ private[spark] class DiskBlockObjectWriter( if (syncWrites) { // Force outstanding writes to disk and track how long it takes objOut.flush() - val start = System.nanoTime() - fos.getFD.sync() - _timeWriting += System.nanoTime() - start + def sync = fos.getFD.sync() + callWithTiming(sync) } objOut.close() - _timeWriting += ts.timeWriting - channel = null bs = null fos = null @@ -153,6 +140,7 @@ private[spark] class DiskBlockObjectWriter( // serializer stream and the lower level stream. objOut.flush() bs.flush() + updateBytesWritten() close() } finalPosition = file.length() @@ -162,6 +150,8 @@ private[spark] class DiskBlockObjectWriter( // truncating the file to its initial position. override def revertPartialWritesAndClose() { try { + writeMetrics.shuffleBytesWritten -= (lastPosition - initialPosition) + if (initialized) { objOut.flush() bs.flush() @@ -184,19 +174,36 @@ private[spark] class DiskBlockObjectWriter( if (!initialized) { open() } + objOut.writeObject(value) + + if (writesSinceMetricsUpdate == 32) { + writesSinceMetricsUpdate = 0 + updateBytesWritten() + } else { + writesSinceMetricsUpdate += 1 + } } override def fileSegment(): FileSegment = { - new FileSegment(file, initialPosition, bytesWritten) + new FileSegment(file, initialPosition, finalPosition - initialPosition) } - // Only valid if called after close() - override def timeWriting() = _timeWriting + private def updateBytesWritten() { + val pos = channel.position() + writeMetrics.shuffleBytesWritten += (pos - lastPosition) + lastPosition = pos + } + + private def callWithTiming(f: => Unit) = { + val start = System.nanoTime() + f + writeMetrics.shuffleWriteTime += (System.nanoTime() - start) + } - // Only valid if called after commit() - override def bytesWritten: Long = { - assert(finalPosition != -1, "bytesWritten is only valid after successful commit()") - finalPosition - initialPosition + // For testing + private[spark] def flush() { + objOut.flush() + bs.flush() } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index f9fdffae8bd8f..3565719b54545 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -29,6 +29,7 @@ import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.executor.ShuffleWriteMetrics /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -111,7 +112,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { * Get a ShuffleWriterGroup for the given map task, which will register it as complete * when the writers are closed successfully */ - def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = { + def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, + writeMetrics: ShuffleWriteMetrics) = { new ShuffleWriterGroup { shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) private val shuffleState = shuffleStates(shuffleId) @@ -121,7 +123,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { fileGroup = getUnusedFileGroup() Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize) + blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize, + writeMetrics) } } else { Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => @@ -136,7 +139,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { logWarning(s"Failed to remove existing shuffle file $blockFile") } } - blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize) + blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 260a5c3888aa7..9f85b94a70800 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -31,6 +31,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator +import org.apache.spark.executor.ShuffleWriteMetrics /** * :: DeveloperApi :: @@ -102,6 +103,10 @@ class ExternalAppendOnlyMap[K, V, C]( private var _diskBytesSpilled = 0L private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Write metrics for current spill + private var curWriteMetrics: ShuffleWriteMetrics = _ + private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() @@ -172,7 +177,9 @@ class ExternalAppendOnlyMap[K, V, C]( logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) val (blockId, file) = diskBlockManager.createTempBlock() - var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize) + curWriteMetrics = new ShuffleWriteMetrics() + var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, + curWriteMetrics) var objectsWritten = 0 // List of batch sizes (bytes) in the order they are written to disk @@ -183,9 +190,8 @@ class ExternalAppendOnlyMap[K, V, C]( val w = writer writer = null w.commitAndClose() - val bytesWritten = w.bytesWritten - batchSizes.append(bytesWritten) - _diskBytesSpilled += bytesWritten + _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten + batchSizes.append(curWriteMetrics.shuffleBytesWritten) objectsWritten = 0 } @@ -199,7 +205,9 @@ class ExternalAppendOnlyMap[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize) + curWriteMetrics = new ShuffleWriteMetrics() + writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, + curWriteMetrics) } } if (objectsWritten > 0) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 3f93afd57b3ad..eb4849ebc6e52 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -28,6 +28,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner} import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.BlockId +import org.apache.spark.executor.ShuffleWriteMetrics /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -112,11 +113,14 @@ private[spark] class ExternalSorter[K, V, C]( // What threshold of elementsRead we start estimating map size at. private val trackMemoryThreshold = 1000 - // Spilling statistics + // Total spilling statistics private var spillCount = 0 private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L + // Write metrics for current spill + private var curWriteMetrics: ShuffleWriteMetrics = _ + // How much of the shared memory pool this collection has claimed private var myMemoryThreshold = 0L @@ -239,7 +243,8 @@ private[spark] class ExternalSorter[K, V, C]( logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)" .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) val (blockId, file) = diskBlockManager.createTempBlock() - var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize) + curWriteMetrics = new ShuffleWriteMetrics() + var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) var objectsWritten = 0 // Objects written since the last flush // List of batch sizes (bytes) in the order they are written to disk @@ -254,9 +259,8 @@ private[spark] class ExternalSorter[K, V, C]( val w = writer writer = null w.commitAndClose() - val bytesWritten = w.bytesWritten - batchSizes.append(bytesWritten) - _diskBytesSpilled += bytesWritten + _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten + batchSizes.append(curWriteMetrics.shuffleBytesWritten) objectsWritten = 0 } @@ -275,7 +279,8 @@ private[spark] class ExternalSorter[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize) + curWriteMetrics = new ShuffleWriteMetrics() + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) } } if (objectsWritten > 0) { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala new file mode 100644 index 0000000000000..bbc7e1357b90d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.scalatest.FunSuite +import java.io.File +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.SparkConf + +class BlockObjectWriterSuite extends FunSuite { + test("verify write metrics") { + val file = new File("somefile") + file.deleteOnExit() + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + + writer.write(Long.box(20)) + // Metrics don't update on every write + assert(writeMetrics.shuffleBytesWritten == 0) + // After 32 writes, metrics should update + for (i <- 0 until 32) { + writer.flush() + writer.write(Long.box(i)) + } + assert(writeMetrics.shuffleBytesWritten > 0) + writer.commitAndClose() + assert(file.length() == writeMetrics.shuffleBytesWritten) + } + + test("verify write metrics on revert") { + val file = new File("somefile") + file.deleteOnExit() + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + + writer.write(Long.box(20)) + // Metrics don't update on every write + assert(writeMetrics.shuffleBytesWritten == 0) + // After 32 writes, metrics should update + for (i <- 0 until 32) { + writer.flush() + writer.write(Long.box(i)) + } + assert(writeMetrics.shuffleBytesWritten > 0) + writer.revertPartialWritesAndClose() + assert(writeMetrics.shuffleBytesWritten == 0) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 985ac9394738c..b8299e2ea187f 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.SparkConf import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.executor.ShuffleWriteMetrics class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) @@ -153,7 +154,7 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before val shuffleManager = store.shuffleBlockManager - val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer) + val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer, new ShuffleWriteMetrics) for (writer <- shuffle1.writers) { writer.write("test1") writer.write("test2") @@ -165,7 +166,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before val shuffle1Segment = shuffle1.writers(0).fileSegment() shuffle1.releaseWriters(success = true) - val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf)) + val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf), + new ShuffleWriteMetrics) for (writer <- shuffle2.writers) { writer.write("test3") @@ -183,7 +185,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before // of block based on remaining data in file : which could mess things up when there is concurrent read // and writes happening to the same shuffle group. - val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf)) + val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf), + new ShuffleWriteMetrics) for (writer <- shuffle3.writers) { writer.write("test3") writer.write("test4") diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 8a05fcb449aa6..17bf7c2541d13 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils +import org.apache.spark.executor.ShuffleWriteMetrics /** * Internal utility for micro-benchmarking shuffle write performance. @@ -56,7 +57,7 @@ object StoragePerfTester { def writeOutputBytes(mapId: Int, total: AtomicLong) = { val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, - new KryoSerializer(sc.conf)) + new KryoSerializer(sc.conf), new ShuffleWriteMetrics()) val writers = shuffle.writers for (i <- 1 to recordsPerMap) { writers(i % numOutputSplits).write(writeData) From 25cff1019da9d6cfc486a31d035b372ea5fbdfd2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 6 Aug 2014 14:07:51 -0700 Subject: [PATCH 04/83] [SPARK-2852][MLLIB] API consistency for `mllib.feature` This is part of SPARK-2828: 1. added a Java-friendly fit method to Word2Vec with tests 2. change DeveloperApi to Experimental for Normalizer & StandardScaler 3. change default feature dimension to 2^20 in HashingTF Author: Xiangrui Meng Closes #1807 from mengxr/feature-api-check and squashes the following commits: 773c1a9 [Xiangrui Meng] change default numFeatures to 2^20 in HashingTF change annotation from DeveloperApi to Experimental in Normalizer and StandardScaler 883e122 [Xiangrui Meng] add @Experimental to Word2VecModel add a Java-friendly method to Word2Vec.fit with tests --- .../spark/mllib/feature/HashingTF.scala | 4 +- .../spark/mllib/feature/Normalizer.scala | 6 +- .../spark/mllib/feature/StandardScaler.scala | 6 +- .../apache/spark/mllib/feature/Word2Vec.scala | 19 +++++- .../mllib/feature/JavaWord2VecSuite.java | 66 +++++++++++++++++++ 5 files changed, 91 insertions(+), 10 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 0f6d5809e098f..c53475818395f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -32,12 +32,12 @@ import org.apache.spark.util.Utils * :: Experimental :: * Maps a sequence of terms to their term frequencies using the hashing trick. * - * @param numFeatures number of features (default: 1000000) + * @param numFeatures number of features (default: 2^20^) */ @Experimental class HashingTF(val numFeatures: Int) extends Serializable { - def this() = this(1000000) + def this() = this(1 << 20) /** * Returns the index of the input term. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index ea9fd0a80d8e0..3afb47767281c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -19,11 +19,11 @@ package org.apache.spark.mllib.feature import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} /** - * :: DeveloperApi :: + * :: Experimental :: * Normalizes samples individually to unit L^p^ norm * * For any 1 <= p < Double.PositiveInfinity, normalizes samples using @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} * * @param p Normalization in L^p^ space, p = 2 by default. */ -@DeveloperApi +@Experimental class Normalizer(p: Double) extends VectorTransformer { def this() = this(2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index cc2d7579c2901..e6c9f8f67df63 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -19,14 +19,14 @@ package org.apache.spark.mllib.feature import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD /** - * :: DeveloperApi :: + * :: Experimental :: * Standardizes features by removing the mean and scaling to unit variance using column summary * statistics on the samples in the training set. * @@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD * dense output, so this does not work on sparse input and will raise an exception. * @param withStd True by default. Scales the data to unit standard deviation. */ -@DeveloperApi +@Experimental class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer { def this() = this(false, true) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 3bf44ad7c44e3..395037e1ec47c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -17,6 +17,9 @@ package org.apache.spark.mllib.feature +import java.lang.{Iterable => JavaIterable} + +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -25,6 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ @@ -239,7 +243,7 @@ class Word2Vec extends Serializable with Logging { a += 1 } } - + /** * Computes the vector representation of each word in vocabulary. * @param dataset an RDD of words @@ -369,11 +373,22 @@ class Word2Vec extends Serializable with Logging { new Word2VecModel(word2VecMap.toMap) } + + /** + * Computes the vector representation of each word in vocabulary (Java version). + * @param dataset a JavaRDD of words + * @return a Word2VecModel + */ + def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = { + fit(dataset.rdd.map(_.asScala)) + } } /** -* Word2Vec model + * :: Experimental :: + * Word2Vec model */ +@Experimental class Word2VecModel private[mllib] ( private val model: Map[String, Array[Float]]) extends Serializable { diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java new file mode 100644 index 0000000000000..fb7afe8c6434b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import com.google.common.base.Strings; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaWord2VecSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaWord2VecSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void word2Vec() { + // The tests are to check Java compatibility. + String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10); + List words = Lists.newArrayList(sentence.split(" ")); + List> localDoc = Lists.newArrayList(words, words); + JavaRDD> doc = sc.parallelize(localDoc); + Word2Vec word2vec = new Word2Vec() + .setVectorSize(10) + .setSeed(42L); + Word2VecModel model = word2vec.fit(doc); + Tuple2[] syms = model.findSynonyms("a", 2); + Assert.assertEquals(2, syms.length); + Assert.assertEquals("b", syms[0]._1()); + Assert.assertEquals("c", syms[1]._1()); + } +} From e537b33c63d3fb373fe41deaa607d72e76e3906b Mon Sep 17 00:00:00 2001 From: RJ Nowling Date: Wed, 6 Aug 2014 14:12:21 -0700 Subject: [PATCH 05/83] [PySpark] Add blanklines to Python docstrings so example code renders correctly Author: RJ Nowling Closes #1808 from rnowling/pyspark_docs and squashes the following commits: c06d774 [RJ Nowling] Add blanklines to Python docstrings so example code renders correctly --- python/pyspark/rdd.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 30b834d2085cd..756e8f35fb03d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -134,6 +134,7 @@ class MaxHeapQ(object): """ An implementation of MaxHeap. + >>> import pyspark.rdd >>> heap = pyspark.rdd.MaxHeapQ(5) >>> [heap.insert(i) for i in range(10)] @@ -381,6 +382,7 @@ def mapPartitionsWithSplit(self, f, preservesPartitioning=False): def getNumPartitions(self): """ Returns the number of partitions in RDD + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) >>> rdd.getNumPartitions() 2 @@ -570,6 +572,7 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): """ Sorts this RDD, which is assumed to consist of (key, value) pairs. # noqa + >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] >>> sc.parallelize(tmp).sortByKey(True, 2).collect() [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] @@ -1209,6 +1212,7 @@ def collectAsMap(self): def keys(self): """ Return an RDD with the keys of each tuple. + >>> m = sc.parallelize([(1, 2), (3, 4)]).keys() >>> m.collect() [1, 3] @@ -1218,6 +1222,7 @@ def keys(self): def values(self): """ Return an RDD with the values of each tuple. + >>> m = sc.parallelize([(1, 2), (3, 4)]).values() >>> m.collect() [2, 4] @@ -1642,6 +1647,7 @@ def repartition(self, numPartitions): Internally, this uses a shuffle to redistribute data. If you are decreasing the number of partitions in this RDD, consider using `coalesce`, which can avoid performing a shuffle. + >>> rdd = sc.parallelize([1,2,3,4,5,6,7], 4) >>> sorted(rdd.glom().collect()) [[1], [2, 3], [4, 5], [6, 7]] @@ -1656,6 +1662,7 @@ def repartition(self, numPartitions): def coalesce(self, numPartitions, shuffle=False): """ Return a new RDD that is reduced into `numPartitions` partitions. + >>> sc.parallelize([1, 2, 3, 4, 5], 3).glom().collect() [[1], [2, 3], [4, 5]] >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect() @@ -1694,6 +1701,7 @@ def name(self): def setName(self, name): """ Assign a name to this RDD. + >>> rdd1 = sc.parallelize([1,2]) >>> rdd1.setName('RDD1') >>> rdd1.name() @@ -1753,6 +1761,7 @@ class PipelinedRDD(RDD): """ Pipelined maps: + >>> rdd = sc.parallelize([1, 2, 3, 4]) >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() [4, 8, 12, 16] From c6889d2cb9cd99f7e3e0ee14a4fdf301f1f9810e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 6 Aug 2014 16:34:53 -0700 Subject: [PATCH 06/83] [HOTFIX][Streaming] Handle port collisions in flume polling test This is failing my tests in #1777. @tdas Author: Andrew Or Closes #1803 from andrewor14/fix-flaky-streaming-test and squashes the following commits: ea11a03 [Andrew Or] Catch all exceptions caused by BindExceptions 54a0ca0 [Andrew Or] Merge branch 'master' of github.com:apache/spark into fix-flaky-streaming-test 664095c [Andrew Or] Tone down bind exception message af3ddc9 [Andrew Or] Handle port collisions in flume polling test --- .../flume/FlumePollingStreamSuite.scala | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 27bf2ac962721..a69baa16981a1 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.util.ManualClock import org.apache.spark.streaming.{TestSuiteBase, TestOutputStream, StreamingContext} import org.apache.spark.streaming.flume.sink._ +import org.apache.spark.util.Utils class FlumePollingStreamSuite extends TestSuiteBase { @@ -45,8 +46,37 @@ class FlumePollingStreamSuite extends TestSuiteBase { val eventsPerBatch = 100 val totalEventsPerChannel = batchCount * eventsPerBatch val channelCapacity = 5000 + val maxAttempts = 5 test("flume polling test") { + testMultipleTimes(testFlumePolling) + } + + test("flume polling test multiple hosts") { + testMultipleTimes(testFlumePollingMultipleHost) + } + + /** + * Run the given test until no more java.net.BindException's are thrown. + * Do this only up to a certain attempt limit. + */ + private def testMultipleTimes(test: () => Unit): Unit = { + var testPassed = false + var attempt = 0 + while (!testPassed && attempt < maxAttempts) { + try { + test() + testPassed = true + } catch { + case e: Exception if Utils.isBindCollision(e) => + logWarning("Exception when running flume polling test: " + e) + attempt += 1 + } + } + assert(testPassed, s"Test failed after $attempt attempts!") + } + + private def testFlumePolling(): Unit = { val testPort = getTestPort // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) @@ -80,7 +110,7 @@ class FlumePollingStreamSuite extends TestSuiteBase { channel.stop() } - test("flume polling test multiple hosts") { + private def testFlumePollingMultipleHost(): Unit = { val testPort = getTestPort // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) From 4e008334ee0fb60f9fe8820afa06f7b7f0fa7a6c Mon Sep 17 00:00:00 2001 From: Gregory Owen Date: Wed, 6 Aug 2014 16:52:00 -0700 Subject: [PATCH 07/83] SPARK-2882: Spark build now checks local maven cache for dependencies Fixes [SPARK-2882](https://issues.apache.org/jira/browse/SPARK-2882) Author: Gregory Owen Closes #1818 from GregOwen/spark-2882 and squashes the following commits: 294446d [Gregory Owen] SPARK-2882: Spark build now checks local maven cache for dependencies --- project/SparkBuild.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 40b588512ff08..ed587783d5606 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -115,7 +115,8 @@ object SparkBuild extends PomBuild { retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, - + + resolvers += Resolver.mavenLocal, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), publishLocalConfiguration in MavenCompile <<= (packagedArtifacts, deliverLocal, ivyLoggingLevel) map { (arts, _, level) => new PublishConfiguration(None, "dotM2", arts, Seq(), level) From 17caae48b3608552dd6e3ae652043831f932ce95 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 6 Aug 2014 17:27:55 -0700 Subject: [PATCH 08/83] [SPARK-2583] ConnectionManager error reporting This patch modifies the ConnectionManager so that error messages are sent in reply when uncaught exceptions occur during message processing. This prevents message senders from hanging while waiting for an acknowledgment if the remote message processing failed. This is an updated version of sarutak's PR, #1490. The main change is to use Futures / Promises to signal errors. Author: Kousuke Saruta Author: Josh Rosen Closes #1758 from JoshRosen/connection-manager-fixes and squashes the following commits: 68620cb [Josh Rosen] Fix test in BlockFetcherIteratorSuite: 83673de [Josh Rosen] Error ACKs should trigger IOExceptions, so catch only those exceptions in the test. b8bb4d4 [Josh Rosen] Fix manager.id vs managerServer.id typo that broke security tests. 659521f [Josh Rosen] Include previous exception when throwing new one a2f745c [Josh Rosen] Remove sendMessageReliablySync; callers can wait themselves. c01c450 [Josh Rosen] Return Try[Message] from sendMessageReliablySync. f1cd1bb [Josh Rosen] Clean up @sarutak's PR #1490 for [SPARK-2583]: ConnectionManager error reporting 7399c6b [Josh Rosen] Merge remote-tracking branch 'origin/pr/1490' into connection-manager-fixes ee91bb7 [Kousuke Saruta] Modified BufferMessage.scala to keep the spark code style 9dfd0d8 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 e7d9aa6 [Kousuke Saruta] rebase to master 326a17f [Kousuke Saruta] Add test cases to ConnectionManagerSuite.scala for SPARK-2583 2a18d6b [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 22d7ebd [Kousuke Saruta] Add test cases to BlockManagerSuite for SPARK-2583 e579302 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 281589c [Kousuke Saruta] Add a test case to BlockFetcherIteratorSuite.scala for fetching block from remote from successfully 0654128 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 ffaa83d [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 12d3de8 [Kousuke Saruta] Added BlockFetcherIteratorSuite.scala 4117b8f [Kousuke Saruta] Modified ConnectionManager to be alble to handle error during processing message 717c9c3 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 6635467 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 e2b8c4a [Kousuke Saruta] Modify to propagete error using ConnectionManager --- .../apache/spark/network/BufferMessage.scala | 7 +- .../spark/network/ConnectionManager.scala | 143 ++++++++++-------- .../org/apache/spark/network/Message.scala | 2 + .../spark/network/MessageChunkHeader.scala | 7 +- .../org/apache/spark/network/SenderTest.scala | 7 +- .../spark/storage/BlockFetcherIterator.scala | 9 +- .../spark/storage/BlockManagerWorker.scala | 30 ++-- .../network/ConnectionManagerSuite.scala | 38 ++++- .../storage/BlockFetcherIteratorSuite.scala | 98 +++++++++++- .../spark/storage/BlockManagerSuite.scala | 110 +++++++++++++- 10 files changed, 362 insertions(+), 89 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala index 04df2f3b0d696..af35f1fc3e459 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val security = if (isSecurityNeg) 1 else 0 if (size == 0 && !gotChunkForSendingOnce) { val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null) + new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null) gotChunkForSendingOnce = true return Some(newChunk) } @@ -66,7 +66,8 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: } buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, + hasError, security, senderAddress), newBuffer) gotChunkForSendingOnce = true return Some(newChunk) } @@ -88,7 +89,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) return Some(newChunk) } None diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 4c00225280cce..95f96b8463a01 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -17,6 +17,7 @@ package org.apache.spark.network +import java.io.IOException import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ @@ -45,16 +46,26 @@ private[spark] class ConnectionManager( name: String = "Connection manager") extends Logging { + /** + * Used by sendMessageReliably to track messages being sent. + * @param message the message that was sent + * @param connectionManagerId the connection manager that sent this message + * @param completionHandler callback that's invoked when the send has completed or failed + */ class MessageStatus( val message: Message, val connectionManagerId: ConnectionManagerId, completionHandler: MessageStatus => Unit) { + /** This is non-None if message has been ack'd */ var ackMessage: Option[Message] = None - var attempted = false - var acked = false - def markDone() { completionHandler(this) } + def markDone(ackMessage: Option[Message]) { + this.synchronized { + this.ackMessage = ackMessage + completionHandler(this) + } + } } private val selector = SelectorProvider.provider.openSelector() @@ -442,11 +453,7 @@ private[spark] class ConnectionManager( messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) .foreach(status => { logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.markDone() - } + status.markDone(None) }) messageStatuses.retain((i, status) => { @@ -475,11 +482,7 @@ private[spark] class ConnectionManager( for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { logInfo("Notifying " + s) - s.synchronized { - s.attempted = true - s.acked = false - s.markDone() - } + s.markDone(None) } messageStatuses.retain((i, status) => { @@ -547,13 +550,13 @@ private[spark] class ConnectionManager( val securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId.toString) val message = securityMsgResp.toBufferMessage - if (message == null) throw new Exception("Error creating security message") + if (message == null) throw new IOException("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) } catch { case e: Exception => { logError("Error handling sasl client authentication", e) waitingConn.close() - throw new Exception("Error evaluating sasl response: " + e) + throw new IOException("Error evaluating sasl response: ", e) } } } @@ -661,34 +664,39 @@ private[spark] class ConnectionManager( } } } - sentMessageStatus.synchronized { - sentMessageStatus.ackMessage = Some(message) - sentMessageStatus.attempted = true - sentMessageStatus.acked = true - sentMessageStatus.markDone() - } + sentMessageStatus.markDone(Some(message)) } else { - val ackMessage = if (onReceiveCallback != null) { - logDebug("Calling back") - onReceiveCallback(bufferMessage, connectionManagerId) - } else { - logDebug("Not calling back as callback is null") - None - } + var ackMessage : Option[Message] = None + try { + ackMessage = if (onReceiveCallback != null) { + logDebug("Calling back") + onReceiveCallback(bufferMessage, connectionManagerId) + } else { + logDebug("Not calling back as callback is null") + None + } - if (ackMessage.isDefined) { - if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " - + ackMessage.get.getClass) - } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logDebug("Response to " + bufferMessage + " does not have ack id set") - ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id + if (ackMessage.isDefined) { + if (!ackMessage.get.isInstanceOf[BufferMessage]) { + logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + + ackMessage.get.getClass) + } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { + logDebug("Response to " + bufferMessage + " does not have ack id set") + ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id + } + } + } catch { + case e: Exception => { + logError(s"Exception was thrown while processing message", e) + val m = Message.createBufferMessage(bufferMessage.id) + m.hasError = true + ackMessage = Some(m) } + } finally { + sendMessage(connectionManagerId, ackMessage.getOrElse { + Message.createBufferMessage(bufferMessage.id) + }) } - - sendMessage(connectionManagerId, ackMessage.getOrElse { - Message.createBufferMessage(bufferMessage.id) - }) } } case _ => throw new Exception("Unknown type message received") @@ -800,11 +808,7 @@ private[spark] class ConnectionManager( case Some(msgStatus) => { messageStatuses -= message.id logInfo("Notifying " + msgStatus.connectionManagerId) - msgStatus.synchronized { - msgStatus.attempted = true - msgStatus.acked = false - msgStatus.markDone() - } + msgStatus.markDone(None) } case None => { logError("no messageStatus for failed message id: " + message.id) @@ -823,11 +827,28 @@ private[spark] class ConnectionManager( selector.wakeup() } + /** + * Send a message and block until an acknowldgment is received or an error occurs. + * @param connectionManagerId the message's destination + * @param message the message being sent + * @return a Future that either returns the acknowledgment message or captures an exception. + */ def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) - : Future[Option[Message]] = { - val promise = Promise[Option[Message]] - val status = new MessageStatus( - message, connectionManagerId, s => promise.success(s.ackMessage)) + : Future[Message] = { + val promise = Promise[Message]() + val status = new MessageStatus(message, connectionManagerId, s => { + s.ackMessage match { + case None => // Indicates a failure where we either never sent or never got ACK'd + promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) + case Some(ackMessage) => + if (ackMessage.hasError) { + promise.failure( + new IOException("sendMessageReliably failed with ACK that signalled a remote error")) + } else { + promise.success(ackMessage) + } + } + }) messageStatuses.synchronized { messageStatuses += ((message.id, status)) } @@ -835,11 +856,6 @@ private[spark] class ConnectionManager( promise.future } - def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, - message: Message): Option[Message] = { - Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) - } - def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { onReceiveCallback = callback } @@ -862,6 +878,7 @@ private[spark] class ConnectionManager( private[spark] object ConnectionManager { + import ExecutionContext.Implicits.global def main(args: Array[String]) { val conf = new SparkConf @@ -896,7 +913,7 @@ private[spark] object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(manager.id, bufferMessage) + Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf) }) println("--------------------------") println() @@ -917,8 +934,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis @@ -952,8 +971,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis @@ -982,8 +1003,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis Thread.sleep(1000) diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala index 7caccfdbb44f9..04ea50f62918c 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/Message.scala @@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { var startTime = -1L var finishTime = -1L var isSecurityNeg = false + var hasError = false def size: Int @@ -87,6 +88,7 @@ private[spark] object Message { case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) } + newMessage.hasError = header.hasError newMessage.senderAddress = header.address newMessage } diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala index ead663ede7a1c..f3ecca5f992e0 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala @@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader( val totalSize: Int, val chunkSize: Int, val other: Int, + val hasError: Boolean, val securityNeg: Int, val address: InetSocketAddress) { lazy val buffer = { @@ -41,6 +42,7 @@ private[spark] class MessageChunkHeader( putInt(totalSize). putInt(chunkSize). putInt(other). + put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]). putInt(securityNeg). putInt(ip.size). put(ip). @@ -56,7 +58,7 @@ private[spark] class MessageChunkHeader( private[spark] object MessageChunkHeader { - val HEADER_SIZE = 44 + val HEADER_SIZE = 45 def create(buffer: ByteBuffer): MessageChunkHeader = { if (buffer.remaining != HEADER_SIZE) { @@ -67,13 +69,14 @@ private[spark] object MessageChunkHeader { val totalSize = buffer.getInt() val chunkSize = buffer.getInt() val other = buffer.getInt() + val hasError = buffer.get() != 0 val securityNeg = buffer.getInt() val ipSize = buffer.getInt() val ipBytes = new Array[Byte](ipSize) buffer.get(ipBytes) val ip = InetAddress.getByAddress(ipBytes) val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg, + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg, new InetSocketAddress(ip, port)) } } diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index b8ea7c2cff9a2..ea2ad104ecae1 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -20,6 +20,10 @@ package org.apache.spark.network import java.nio.ByteBuffer import org.apache.spark.{SecurityManager, SparkConf} +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.util.Try + private[spark] object SenderTest { def main(args: Array[String]) { @@ -51,7 +55,8 @@ private[spark] object SenderTest { val dataMessage = Message.createBufferMessage(buffer.duplicate) val startTime = System.currentTimeMillis /* println("Started timer at " + startTime) */ - val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) + val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage) + val responseStr: String = Try(Await.result(promise, Duration.Inf)) .map { response => val buffer = response.asInstanceOf[BufferMessage].buffers(0) new String(buffer.array, "utf-8") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index ccf830e118ee7..938af6f5b923a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -22,6 +22,7 @@ import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue +import scala.util.{Failure, Success} import io.netty.buffer.ByteBuf @@ -118,8 +119,8 @@ object BlockFetcherIterator { bytesInFlight += req.size val sizeMap = req.blocks.toMap // so we can look up the size of each blockID val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onSuccess { - case Some(message) => { + future.onComplete { + case Success(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) for (blockMessage <- blockMessageArray) { @@ -135,8 +136,8 @@ object BlockFetcherIterator { logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } } - case None => { - logError("Could not get block(s) from " + cmId) + case Failure(exception) => { + logError("Could not get block(s) from " + cmId, exception) for ((blockId, size) <- req.blocks) { results.put(new FetchResult(blockId, -1, null)) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index c7766a3a65671..bf002a42d5dc5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -23,6 +23,10 @@ import org.apache.spark.Logging import org.apache.spark.network._ import org.apache.spark.util.Utils +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.util.{Try, Failure, Success} + /** * A network interface for BlockManager. Each slave should have one * BlockManagerWorker. @@ -44,13 +48,19 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { - case e: Exception => logError("Exception handling buffer message", e) - None + case e: Exception => { + logError("Exception handling buffer message", e) + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) + } } } case otherMessage: Any => { logError("Unknown type message received: " + otherMessage) - None + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) } } } @@ -109,9 +119,9 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromPutBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val resultMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) - resultMessage.isDefined + val resultMessage = Try(Await.result(connectionManager.sendMessageReliably( + toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) + resultMessage.isSuccess } def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { @@ -119,10 +129,10 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromGetBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val responseMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) + val responseMessage = Try(Await.result(connectionManager.sendMessageReliably( + toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) responseMessage match { - case Some(message) => { + case Success(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] logDebug("Response message received " + bufferMessage) BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => { @@ -130,7 +140,7 @@ private[spark] object BlockManagerWorker extends Logging { return blockMessage.getData }) } - case None => logDebug("No response message received") + case Failure(exception) => logDebug("No response message received") } null } diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index 415ad8c432c12..846537df003df 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.network +import java.io.IOException import java.nio._ import org.apache.spark.{SecurityManager, SparkConf} @@ -25,6 +26,7 @@ import org.scalatest.FunSuite import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.Try /** * Test the ConnectionManager with various security settings. @@ -46,7 +48,7 @@ class ConnectionManagerSuite extends FunSuite { buffer.flip val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(manager.id, bufferMessage) + Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) assert(receivedMessage == true) @@ -79,7 +81,7 @@ class ConnectionManagerSuite extends FunSuite { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(managerServer.id, bufferMessage) + Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) }) assert(numReceivedServerMessages == 10) @@ -118,7 +120,10 @@ class ConnectionManagerSuite extends FunSuite { val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(managerServer.id, bufferMessage) + // Expect managerServer to close connection, which we'll report as an error: + intercept[IOException] { + Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) + } assert(numReceivedServerMessages == 0) assert(numReceivedMessages == 0) @@ -163,6 +168,8 @@ class ConnectionManagerSuite extends FunSuite { val g = Await.result(f, 1 second) assert(false) } catch { + case i: IOException => + assert(true) case e: TimeoutException => { // we should timeout here since the client can't do the negotiation assert(true) @@ -209,7 +216,6 @@ class ConnectionManagerSuite extends FunSuite { }).foreach(f => { try { val g = Await.result(f, 1 second) - if (!g.isDefined) assert(false) else assert(true) } catch { case e: Exception => { assert(false) @@ -223,7 +229,31 @@ class ConnectionManagerSuite extends FunSuite { managerServer.stop() } + test("Ack error message") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + val managerServer = new ConnectionManager(0, conf, securityManager) + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + throw new Exception + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer) + + val future = manager.sendMessageReliably(managerServer.id, bufferMessage) + + intercept[IOException] { + Await.result(future, 1 second) + } + manager.stop() + managerServer.stop() + + } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index 8dca2ebb312f5..1538995a6b404 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -17,18 +17,22 @@ package org.apache.spark.storage +import java.io.IOException +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.future +import scala.concurrent.ExecutionContext.Implicits.global + import org.scalatest.{FunSuite, Matchers} -import org.scalatest.PrivateMethodTester._ import org.mockito.Mockito._ import org.mockito.Matchers.{any, eq => meq} import org.mockito.stubbing.Answer import org.mockito.invocation.InvocationOnMock -import org.apache.spark._ import org.apache.spark.storage.BlockFetcherIterator._ -import org.apache.spark.network.{ConnectionManager, ConnectionManagerId, - Message} +import org.apache.spark.network.{ConnectionManager, Message} class BlockFetcherIteratorSuite extends FunSuite with Matchers { @@ -137,4 +141,90 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined") } + test("block fetch from remote fails using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + when(blockManager.connectionManager).thenReturn(connManager) + + val f = future { + throw new IOException("Send failed or we received an error ACK") + } + when(connManager.sendMessageReliably(any(), + any())).thenReturn(f) + when(blockManager.futureExecContext).thenReturn(global) + + when(blockManager.blockManagerId).thenReturn( + BlockManagerId("test-client", "test-client", 1, 0)) + when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) + + val blId1 = ShuffleBlockId(0,0,0) + val blId2 = ShuffleBlockId(0,1,0) + val bmId = BlockManagerId("test-server", "test-server",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, Seq((blId1, 1L), (blId2, 1L))) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + + iterator.initialize() + iterator.foreach{ + case (_, r) => { + (!r.isDefined) should be(true) + } + } + } + + test("block fetch from remote succeed using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + when(blockManager.connectionManager).thenReturn(connManager) + + val blId1 = ShuffleBlockId(0,0,0) + val blId2 = ShuffleBlockId(0,1,0) + val buf1 = ByteBuffer.allocate(4) + val buf2 = ByteBuffer.allocate(4) + buf1.putInt(1) + buf1.flip() + buf2.putInt(1) + buf2.flip() + val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1)) + val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2)) + val blockMessageArray = new BlockMessageArray( + Seq(blockMessage1, blockMessage2)) + + val bufferMessage = blockMessageArray.toBufferMessage + val buffer = ByteBuffer.allocate(bufferMessage.size) + val arrayBuffer = new ArrayBuffer[ByteBuffer] + bufferMessage.buffers.foreach{ b => + buffer.put(b) + } + buffer.flip() + arrayBuffer += buffer + + val f = future { + Message.createBufferMessage(arrayBuffer) + } + when(connManager.sendMessageReliably(any(), + any())).thenReturn(f) + when(blockManager.futureExecContext).thenReturn(global) + + when(blockManager.blockManagerId).thenReturn( + BlockManagerId("test-client", "test-client", 1, 0)) + when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) + + val bmId = BlockManagerId("test-server", "test-server",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, Seq((blId1, 1L), (blId2, 1L))) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + iterator.initialize() + iterator.foreach{ + case (_, r) => { + (r.isDefined) should be(true) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 0ac0269d7cfc1..94bb2c445d2e9 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -25,7 +25,11 @@ import akka.actor._ import akka.pattern.ask import akka.util.Timeout -import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.Matchers.any +import org.mockito.Mockito.{doAnswer, mock, spy, when} +import org.mockito.stubbing.Answer + import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ @@ -33,6 +37,7 @@ import org.scalatest.Matchers import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.executor.DataReadMethod +import org.apache.spark.network.{Message, ConnectionManagerId} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -1000,6 +1005,109 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store") } + test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + + val worker = spy(new BlockManagerWorker(store)) + val connManagerId = mock(classOf[ConnectionManagerId]) + + // setup request block messages + val reqBlId1 = ShuffleBlockId(0,0,0) + val reqBlId2 = ShuffleBlockId(0,1,0) + val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) + val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) + val reqBlockMessages = new BlockMessageArray( + Seq(reqBlockMessage1, reqBlockMessage2)) + val reqBufferMessage = reqBlockMessages.toBufferMessage + + val answer = new Answer[Option[BlockMessage]] { + override def answer(invocation: InvocationOnMock) + :Option[BlockMessage]= { + throw new Exception + } + } + + doAnswer(answer).when(worker).processBlockMessage(any()) + + // Test when exception was thrown during processing block messages + var ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) + + assert(ackMessage.isDefined, "When Exception was thrown in " + + "BlockManagerWorker#processBlockMessage, " + + "ackMessage should be defined") + assert(ackMessage.get.hasError, "When Exception was thown in " + + "BlockManagerWorker#processBlockMessage, " + + "ackMessage should have error") + + val notBufferMessage = mock(classOf[Message]) + + // Test when not BufferMessage was received + ackMessage = worker.onBlockMessageReceive(notBufferMessage, connManagerId) + assert(ackMessage.isDefined, "When not BufferMessage was passed to " + + "BlockManagerWorker#onBlockMessageReceive, " + + "ackMessage should be defined") + assert(ackMessage.get.hasError, "When not BufferMessage was passed to " + + "BlockManagerWorker#onBlockMessageReceive, " + + "ackMessage should have error") + } + + test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + + val worker = spy(new BlockManagerWorker(store)) + val connManagerId = mock(classOf[ConnectionManagerId]) + + // setup request block messages + val reqBlId1 = ShuffleBlockId(0,0,0) + val reqBlId2 = ShuffleBlockId(0,1,0) + val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) + val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) + val reqBlockMessages = new BlockMessageArray( + Seq(reqBlockMessage1, reqBlockMessage2)) + + val tmpBufferMessage = reqBlockMessages.toBufferMessage + val buffer = ByteBuffer.allocate(tmpBufferMessage.size) + val arrayBuffer = new ArrayBuffer[ByteBuffer] + tmpBufferMessage.buffers.foreach{ b => + buffer.put(b) + } + buffer.flip() + arrayBuffer += buffer + val reqBufferMessage = Message.createBufferMessage(arrayBuffer) + + // setup ack block messages + val buf1 = ByteBuffer.allocate(4) + val buf2 = ByteBuffer.allocate(4) + buf1.putInt(1) + buf1.flip() + buf2.putInt(1) + buf2.flip() + val ackBlockMessage1 = BlockMessage.fromGotBlock(GotBlock(reqBlId1, buf1)) + val ackBlockMessage2 = BlockMessage.fromGotBlock(GotBlock(reqBlId2, buf2)) + + val answer = new Answer[Option[BlockMessage]] { + override def answer(invocation: InvocationOnMock) + :Option[BlockMessage]= { + if (invocation.getArguments()(0).asInstanceOf[BlockMessage].eq( + reqBlockMessage1)) { + return Some(ackBlockMessage1) + } else { + return Some(ackBlockMessage2) + } + } + } + + doAnswer(answer).when(worker).processBlockMessage(any()) + + val ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) + assert(ackMessage.isDefined, "When BlockManagerWorker#onBlockMessageReceive " + + "was executed successfully, ackMessage should be defined") + assert(!ackMessage.get.hasError, "When BlockManagerWorker#onBlockMessageReceive " + + "was executed successfully, ackMessage should not have error") + } + test("reserve/release unroll memory") { store = makeBlockManager(12000) val memoryStore = store.memoryStore From 4201d2711cd20a2892c40eb11102f73c2f826b2e Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 6 Aug 2014 18:13:35 -0700 Subject: [PATCH 09/83] SPARK-2879 [BUILD] Use HTTPS to access Maven Central and other repos Maven Central has just now enabled HTTPS access for everyone to Maven Central (http://central.sonatype.org/articles/2014/Aug/03/https-support-launching-now/) This is timely, as a reminder of how easily an attacker can slip malicious code into a build that's downloading artifacts over HTTP (http://blog.ontoillogical.com/blog/2014/07/28/how-to-take-over-any-java-developer/). In the meantime, it looks like the Spring repo also now supports HTTPS, so can be used this way too. I propose to use HTTPS to access these repos. Author: Sean Owen Closes #1805 from srowen/SPARK-2879 and squashes the following commits: 7043a8e [Sean Owen] Use HTTPS for Maven Central libs and plugins; use id 'central' to override parent properly; use HTTPS for Spring repo --- pom.xml | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index 4ab027bad55c0..76bf6d8f902a8 100644 --- a/pom.xml +++ b/pom.xml @@ -143,11 +143,11 @@ - maven-repo + central Maven Repository - http://repo.maven.apache.org/maven2 + https://repo.maven.apache.org/maven2 true @@ -213,7 +213,7 @@ spring-releases Spring Release Repository - http://repo.spring.io/libs-release + https://repo.spring.io/libs-release true @@ -222,6 +222,15 @@ + + + central + https://repo1.maven.org/maven2 + + true + + + From a263a7e9f060b3017142cdae5f1270db9458d8d3 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 6 Aug 2014 18:45:03 -0700 Subject: [PATCH 10/83] HOTFIX: Support custom Java 7 location --- dev/create-release/create-release.sh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 42473629d4f15..1867cf4ec46ca 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -35,6 +35,12 @@ RELEASE_VERSION=${RELEASE_VERSION:-1.0.0} RC_NAME=${RC_NAME:-rc2} USER_NAME=${USER_NAME:-pwendell} +if [ -z "$JAVA_HOME" ]; then + echo "Error: JAVA_HOME is not set, cannot proceed." + exit -1 +fi +JAVA_7_HOME=${JAVA_7_HOME:-$JAVA_HOME} + set -e GIT_TAG=v$RELEASE_VERSION-$RC_NAME @@ -130,7 +136,8 @@ scp spark-* \ cd spark sbt/sbt clean cd docs -PRODUCTION=1 jekyll build +# Compile docs with Java 7 to use nicer format +JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build echo "Copying release documentation" rc_docs_folder=${rc_folder}-docs ssh $USER_NAME@people.apache.org \ From ffd1f59a62a9dd9a4d5a7b09490b9d01ff1cd42d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 6 Aug 2014 21:22:13 -0700 Subject: [PATCH 11/83] [SPARK-2887] fix bug of countApproxDistinct() when have more than one partition fix bug of countApproxDistinct() when have more than one partition Author: Davies Liu Closes #1812 from davies/approx and squashes the following commits: bf757ce [Davies Liu] fix bug of countApproxDistinct() when have more than one partition --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 2 +- .../src/test/scala/org/apache/spark/rdd/RDDSuite.scala | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e1c49e35abecd..0159003c88e06 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1004,7 +1004,7 @@ abstract class RDD[T: ClassTag]( }, (h1: HyperLogLogPlus, h2: HyperLogLogPlus) => { h1.addAll(h2) - h2 + h1 }).cardinality() } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index b31e3a09e5b9c..4a7dc8dca25e2 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -81,11 +81,11 @@ class RDDSuite extends FunSuite with SharedSparkContext { def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble - val size = 100 - val uniformDistro = for (i <- 1 to 100000) yield i % size - val simpleRdd = sc.makeRDD(uniformDistro) - assert(error(simpleRdd.countApproxDistinct(4, 0), size) < 0.4) - assert(error(simpleRdd.countApproxDistinct(8, 0), size) < 0.1) + val size = 1000 + val uniformDistro = for (i <- 1 to 5000) yield i % size + val simpleRdd = sc.makeRDD(uniformDistro, 10) + assert(error(simpleRdd.countApproxDistinct(8, 0), size) < 0.2) + assert(error(simpleRdd.countApproxDistinct(12, 0), size) < 0.1) } test("SparkContext.union") { From 47ccd5e71be49b723476f3ff8d5768f0f45c2ea6 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 6 Aug 2014 22:58:59 -0700 Subject: [PATCH 12/83] [SPARK-2851] [mllib] DecisionTree Python consistency update Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs). Added factory classes for Algo and Impurity, but made private[mllib]. CC: mengxr dorx Please let me know if there are other changes which would help with API consistency---thanks! Author: Joseph K. Bradley Closes #1798 from jkbradley/dt-python-consistency and squashes the following commits: 6f7edf8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency a0d7dbe [Joseph K. Bradley] DecisionTree: In Java-friendly train* methods, changed to use JavaRDD instead of RDD. ee1d236 [Joseph K. Bradley] DecisionTree API updates: * Removed train() function in Python API (tree.py) ** Removed corresponding function in Scala/Java API (the ones taking basic types) 00f820e [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency fe6dbfa [Joseph K. Bradley] removed unnecessary imports e358661 [Joseph K. Bradley] DecisionTree API change: * Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs). c699850 [Joseph K. Bradley] a few doc comments eaf84c0 [Joseph K. Bradley] Added DecisionTree static train() methods API to match Python, but without default parameters --- .../mllib/api/python/PythonMLLibAPI.scala | 19 +-- .../spark/mllib/tree/DecisionTree.scala | 151 ++++++++++++++---- .../spark/mllib/tree/configuration/Algo.scala | 6 + .../mllib/tree/impurity/Impurities.scala | 32 ++++ python/pyspark/mllib/tree.py | 50 ++---- 5 files changed, 181 insertions(+), 77 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index fd0b9556c7d54..ba7ccd8ce4b8b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -25,16 +25,14 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ -import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.stat.correlation.CorrelationNames @@ -523,17 +521,8 @@ class PythonMLLibAPI extends Serializable { val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) - val algo: Algo = algoStr match { - case "classification" => Classification - case "regression" => Regression - case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr") - } - val impurity: Impurity = impurityStr match { - case "gini" => Gini - case "entropy" => Entropy - case "variance" => Variance - case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr") - } + val algo = Algo.fromString(algoStr) + val impurity = Impurities.fromString(impurityStr) val strategy = new Strategy( algo = algo, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1d03e6e3b36cf..c8a865659682f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,14 +17,18 @@ package org.apache.spark.mllib.tree +import org.apache.spark.api.java.JavaRDD + +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom @@ -200,6 +204,10 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -213,10 +221,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -237,10 +247,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -263,11 +275,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The decision tree method supports binary classification and - * regression. For the binary classification, the label for each instance should either be 0 or - * 1 to denote the two classes. The method also supports categorical features inputs where the - * number of categories can specified using the categoricalFeaturesInfo option. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -279,11 +292,9 @@ object DecisionTree extends Serializable with Logging { * @param numClassesForClassification number of classes for classification. Default value of 2. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. * @return DecisionTreeModel that can be used for prediction */ def train( @@ -300,6 +311,93 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(input) } + /** + * Method to train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClassesForClassification number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param impurity Criterion used for information gain calculation. + * Supported values: "gini" (recommended) or "entropy". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @return DecisionTreeModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: Map[Int, Int], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + val impurityType = Impurities.fromString(impurity) + train(input, Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort, + categoricalFeaturesInfo) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + */ + def trainClassifier( + input: JavaRDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + trainClassifier(input.rdd, numClassesForClassification, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + impurity, maxDepth, maxBins) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param impurity Criterion used for information gain calculation. + * Supported values: "variance". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @return DecisionTreeModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + categoricalFeaturesInfo: Map[Int, Int], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + val impurityType = Impurities.fromString(impurity) + train(input, Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + */ + def trainRegressor( + input: JavaRDD[LabeledPoint], + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + trainRegressor(input.rdd, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + impurity, maxDepth, maxBins) + } + + private val InvalidBinIndex = -1 /** @@ -1331,16 +1429,15 @@ object DecisionTree extends Serializable with Logging { * Categorical features: * For each feature, there is 1 bin per split. * Splits and bins are handled in 2 ways: - * (a) For multiclass classification with a low-arity feature + * (a) "unordered features" + * For multiclass classification with a low-arity feature * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), * the feature is split based on subsets of categories. - * There are 2^(maxFeatureValue - 1) - 1 splits. - * (b) For regression and binary classification, + * There are math.pow(2, maxFeatureValue - 1) - 1 splits. + * (b) "ordered features" + * For regression and binary classification, * and for multiclass classification with a high-arity feature, - * there is one split per category. - - * Categorical case (a) features are called unordered features. - * Other cases are called ordered features. + * there is one bin per category. * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 79a01f58319e8..0ef9c6181a0a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -27,4 +27,10 @@ import org.apache.spark.annotation.Experimental object Algo extends Enumeration { type Algo = Value val Classification, Regression = Value + + private[mllib] def fromString(name: String): Algo = name match { + case "classification" => Classification + case "regression" => Regression + case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name") + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala new file mode 100644 index 0000000000000..9a6452aa13a61 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Factory for Impurity instances. + */ +private[mllib] object Impurities { + + def fromString(name: String): Impurity = name match { + case "gini" => Gini + case "entropy" => Entropy + case "variance" => Variance + case _ => throw new IllegalArgumentException(s"Did not recognize Impurity name: $name") + } + +} diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 2518001ea0b93..e1a4671709b7d 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -131,7 +131,7 @@ class DecisionTree(object): """ @staticmethod - def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, + def trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=4, maxBins=100): """ Train a DecisionTreeModel for classification. @@ -150,12 +150,20 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "classification", numClasses, - categoricalFeaturesInfo, - impurity, maxDepth, maxBins) + sc = data.context + dataBytes = _get_unmangled_labeled_point_rdd(data) + categoricalFeaturesInfoJMap = \ + MapConverter().convert(categoricalFeaturesInfo, + sc._gateway._gateway_client) + model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( + dataBytes._jrdd, "classification", + numClasses, categoricalFeaturesInfoJMap, + impurity, maxDepth, maxBins) + dataBytes.unpersist() + return DecisionTreeModel(sc, model) @staticmethod - def trainRegressor(data, categoricalFeaturesInfo={}, + def trainRegressor(data, categoricalFeaturesInfo, impurity="variance", maxDepth=4, maxBins=100): """ Train a DecisionTreeModel for regression. @@ -173,42 +181,14 @@ def trainRegressor(data, categoricalFeaturesInfo={}, :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "regression", 0, - categoricalFeaturesInfo, - impurity, maxDepth, maxBins) - - @staticmethod - def train(data, algo, numClasses, categoricalFeaturesInfo, - impurity, maxDepth, maxBins=100): - """ - Train a DecisionTreeModel for classification or regression. - - :param data: Training data: RDD of LabeledPoint. - For classification, labels are integers - {0,1,...,numClasses}. - For regression, labels are real numbers. - :param algo: "classification" or "regression" - :param numClasses: Number of classes for classification. - :param categoricalFeaturesInfo: Map from categorical feature index - to number of categories. - Any feature not in this map - is treated as continuous. - :param impurity: For classification: "entropy" or "gini". - For regression: "variance". - :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each node. - :return: DecisionTreeModel - """ sc = data.context dataBytes = _get_unmangled_labeled_point_rdd(data) categoricalFeaturesInfoJMap = \ MapConverter().convert(categoricalFeaturesInfo, sc._gateway._gateway_client) model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( - dataBytes._jrdd, algo, - numClasses, categoricalFeaturesInfoJMap, + dataBytes._jrdd, "regression", + 0, categoricalFeaturesInfoJMap, impurity, maxDepth, maxBins) dataBytes.unpersist() return DecisionTreeModel(sc, model) From 75993a65173172da32bbe98751e8c0f55c17a52e Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 7 Aug 2014 00:04:18 -0700 Subject: [PATCH 13/83] SPARK-2879 part 2 [BUILD] Use HTTPS to access Maven Central and other repos .. and use canonical repo1.maven.org Maven Central repo. (And make sure snapshots are disabled for plugins from Maven Central.) Author: Sean Owen Closes #1828 from srowen/SPARK-2879.2 and squashes the following commits: 639f495 [Sean Owen] .. and use canonical repo1.maven.org Maven Central repo. (And make sure snapshots are disabled for plugins from Maven Central.) --- pom.xml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index 76bf6d8f902a8..920912353fe9c 100644 --- a/pom.xml +++ b/pom.xml @@ -146,8 +146,7 @@ central Maven Repository - - https://repo.maven.apache.org/maven2 + https://repo1.maven.org/maven2 true @@ -229,6 +228,9 @@ true + + false + From 8d1dec4fa4798bb48b8947446d306ec9ba6bddb5 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 7 Aug 2014 00:20:38 -0700 Subject: [PATCH 14/83] [mllib] DecisionTree Strategy parameter checks Added some checks to Strategy to print out meaningful error messages when given invalid DecisionTree parameters. CC mengxr Author: Joseph K. Bradley Closes #1821 from jkbradley/dt-robustness and squashes the following commits: 4dc449a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-robustness 7a61f7b [Joseph K. Bradley] Added some checks to Strategy to print out meaningful error messages when given invalid DecisionTree parameters --- .../spark/mllib/tree/DecisionTree.scala | 10 ++++-- .../mllib/tree/configuration/Strategy.scala | 31 ++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c8a865659682f..bb50f07be5d7b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -44,6 +44,8 @@ import org.apache.spark.util.random.XORShiftRandom @Experimental class DecisionTree (private val strategy: Strategy) extends Serializable with Logging { + strategy.assertValid() + /** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] @@ -1465,10 +1467,14 @@ object DecisionTree extends Serializable with Logging { /* - * Ensure #bins is always greater than the categories. For multiclass classification, - * #bins should be greater than 2^(maxCategories - 1) - 1. + * Ensure numBins is always greater than the categories. For multiclass classification, + * numBins should be greater than 2^(maxCategories - 1) - 1. * It's a limitation of the current implementation but a reasonable trade-off since features * with large number of categories get favored over continuous features. + * + * This needs to be checked here instead of in Strategy since numBins can be determined + * by the number of training examples. + * TODO: Allow this case, where we simply will know nothing about some categories. */ if (strategy.categoricalFeaturesInfo.size > 0) { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 4ee4bcd0bcbc7..f31a503608b22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.configuration import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -90,4 +90,33 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + private[tree] def assertValid(): Unit = { + algo match { + case Classification => + require(numClassesForClassification >= 2, + s"DecisionTree Strategy for Classification must have numClassesForClassification >= 2," + + s" but numClassesForClassification = $numClassesForClassification.") + require(Set(Gini, Entropy).contains(impurity), + s"DecisionTree Strategy given invalid impurity for Classification: $impurity." + + s" Valid settings: Gini, Entropy") + case Regression => + require(impurity == Variance, + s"DecisionTree Strategy given invalid impurity for Regression: $impurity." + + s" Valid settings: Variance") + case _ => + throw new IllegalArgumentException( + s"DecisionTree Strategy given invalid algo parameter: $algo." + + s" Valid settings are: Classification, Regression.") + } + require(maxDepth >= 0, s"DecisionTree Strategy given invalid maxDepth parameter: $maxDepth." + + s" Valid values are integers >= 0.") + require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." + + s" Valid values are integers >= 2.") + categoricalFeaturesInfo.foreach { case (feature, arity) => + require(arity >= 2, + s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + + s" feature $feature has $arity categories. The number of categories should be >= 2.") + } + } + } From b9e9e53773a618e4322b845c40deae22f2ba52ac Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 7 Aug 2014 11:28:12 -0700 Subject: [PATCH 15/83] [SPARK-2852][MLLIB] Separate model from IDF/StandardScaler algorithms This is part of SPARK-2828: 1. separate IDF model from IDF algorithm (which generates a model) 2. separate StandardScaler model from StandardScaler CC: dbtsai Author: Xiangrui Meng Closes #1814 from mengxr/feature-api-update and squashes the following commits: 40d863b [Xiangrui Meng] move mean and variance to model 48a0fff [Xiangrui Meng] separate Model from StandardScaler algorithm 89f3486 [Xiangrui Meng] update IDF to separate Model from Algorithm --- .../org/apache/spark/mllib/feature/IDF.scala | 130 ++++++++---------- .../spark/mllib/feature/StandardScaler.scala | 58 ++++---- .../apache/spark/mllib/feature/IDFSuite.scala | 12 +- .../mllib/feature/StandardScalerSuite.scala | 50 +++---- 4 files changed, 121 insertions(+), 129 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 7ed611a857acc..d40d5553c1d21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -36,87 +36,25 @@ class IDF { // TODO: Allow different IDF formulations. - private var brzIdf: BDV[Double] = _ - /** * Computes the inverse document frequency. * @param dataset an RDD of term frequency vectors */ - def fit(dataset: RDD[Vector]): this.type = { - brzIdf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( + def fit(dataset: RDD[Vector]): IDFModel = { + val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( seqOp = (df, v) => df.add(v), combOp = (df1, df2) => df1.merge(df2) ).idf() - this + new IDFModel(idf) } /** * Computes the inverse document frequency. * @param dataset a JavaRDD of term frequency vectors */ - def fit(dataset: JavaRDD[Vector]): this.type = { + def fit(dataset: JavaRDD[Vector]): IDFModel = { fit(dataset.rdd) } - - /** - * Transforms term frequency (TF) vectors to TF-IDF vectors. - * @param dataset an RDD of term frequency vectors - * @return an RDD of TF-IDF vectors - */ - def transform(dataset: RDD[Vector]): RDD[Vector] = { - if (!initialized) { - throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") - } - val theIdf = brzIdf - val bcIdf = dataset.context.broadcast(theIdf) - dataset.mapPartitions { iter => - val thisIdf = bcIdf.value - iter.map { v => - val n = v.size - v match { - case sv: SparseVector => - val nnz = sv.indices.size - val newValues = new Array[Double](nnz) - var k = 0 - while (k < nnz) { - newValues(k) = sv.values(k) * thisIdf(sv.indices(k)) - k += 1 - } - Vectors.sparse(n, sv.indices, newValues) - case dv: DenseVector => - val newValues = new Array[Double](n) - var j = 0 - while (j < n) { - newValues(j) = dv.values(j) * thisIdf(j) - j += 1 - } - Vectors.dense(newValues) - case other => - throw new UnsupportedOperationException( - s"Only sparse and dense vectors are supported but got ${other.getClass}.") - } - } - } - } - - /** - * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version). - * @param dataset a JavaRDD of term frequency vectors - * @return a JavaRDD of TF-IDF vectors - */ - def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { - transform(dataset.rdd).toJavaRDD() - } - - /** Returns the IDF vector. */ - def idf(): Vector = { - if (!initialized) { - throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") - } - Vectors.fromBreeze(brzIdf) - } - - private def initialized: Boolean = brzIdf != null } private object IDF { @@ -177,18 +115,72 @@ private object IDF { private def isEmpty: Boolean = m == 0L /** Returns the current IDF vector. */ - def idf(): BDV[Double] = { + def idf(): Vector = { if (isEmpty) { throw new IllegalStateException("Haven't seen any document yet.") } val n = df.length - val inv = BDV.zeros[Double](n) + val inv = new Array[Double](n) var j = 0 while (j < n) { inv(j) = math.log((m + 1.0)/ (df(j) + 1.0)) j += 1 } - inv + Vectors.dense(inv) } } } + +/** + * :: Experimental :: + * Represents an IDF model that can transform term frequency vectors. + */ +@Experimental +class IDFModel private[mllib] (val idf: Vector) extends Serializable { + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors. + * @param dataset an RDD of term frequency vectors + * @return an RDD of TF-IDF vectors + */ + def transform(dataset: RDD[Vector]): RDD[Vector] = { + val bcIdf = dataset.context.broadcast(idf) + dataset.mapPartitions { iter => + val thisIdf = bcIdf.value + iter.map { v => + val n = v.size + v match { + case sv: SparseVector => + val nnz = sv.indices.size + val newValues = new Array[Double](nnz) + var k = 0 + while (k < nnz) { + newValues(k) = sv.values(k) * thisIdf(sv.indices(k)) + k += 1 + } + Vectors.sparse(n, sv.indices, newValues) + case dv: DenseVector => + val newValues = new Array[Double](n) + var j = 0 + while (j < n) { + newValues(j) = dv.values(j) * thisIdf(j) + j += 1 + } + Vectors.dense(newValues) + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + } + } + } + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version). + * @param dataset a JavaRDD of term frequency vectors + * @return a JavaRDD of TF-IDF vectors + */ + def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { + transform(dataset.rdd).toJavaRDD() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index e6c9f8f67df63..4dfd1f0ab8134 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -17,8 +17,9 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ @@ -35,37 +36,55 @@ import org.apache.spark.rdd.RDD * @param withStd True by default. Scales the data to unit standard deviation. */ @Experimental -class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer { +class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { def this() = this(false, true) - require(withMean || withStd, s"withMean and withStd both equal to false. Doing nothing.") - - private var mean: BV[Double] = _ - private var factor: BV[Double] = _ + if (!(withMean || withStd)) { + logWarning("Both withMean and withStd are false. The model does nothing.") + } /** * Computes the mean and variance and stores as a model to be used for later scaling. * * @param data The data used to compute the mean and variance to build the transformation model. - * @return This StandardScalar object. + * @return a StandardScalarModel */ - def fit(data: RDD[Vector]): this.type = { + def fit(data: RDD[Vector]): StandardScalerModel = { + // TODO: skip computation if both withMean and withStd are false val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) + new StandardScalerModel(withMean, withStd, summary.mean, summary.variance) + } +} - mean = summary.mean.toBreeze - factor = summary.variance.toBreeze - require(mean.length == factor.length) +/** + * :: Experimental :: + * Represents a StandardScaler model that can transform vectors. + * + * @param withMean whether to center the data before scaling + * @param withStd whether to scale the data to have unit standard deviation + * @param mean column mean values + * @param variance column variance values + */ +@Experimental +class StandardScalerModel private[mllib] ( + val withMean: Boolean, + val withStd: Boolean, + val mean: Vector, + val variance: Vector) extends VectorTransformer { + + require(mean.size == variance.size) + private lazy val factor: BDV[Double] = { + val f = BDV.zeros[Double](variance.size) var i = 0 - while (i < factor.length) { - factor(i) = if (factor(i) != 0.0) 1.0 / math.sqrt(factor(i)) else 0.0 + while (i < f.size) { + f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0 i += 1 } - - this + f } /** @@ -76,13 +95,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor * for the column with zero variance. */ override def transform(vector: Vector): Vector = { - if (mean == null || factor == null) { - throw new IllegalStateException( - "Haven't learned column summary statistics yet. Call fit first.") - } - - require(vector.size == mean.length) - + require(mean.size == vector.size) if (withMean) { vector.toBreeze match { case dv: BDV[Double] => @@ -115,5 +128,4 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor vector } } - } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 78a2804ff204b..53d9c0c640b98 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -36,18 +36,12 @@ class IDFSuite extends FunSuite with LocalSparkContext { val m = localTermFrequencies.size val termFrequencies = sc.parallelize(localTermFrequencies, 2) val idf = new IDF - intercept[IllegalStateException] { - idf.idf() - } - intercept[IllegalStateException] { - idf.transform(termFrequencies) - } - idf.fit(termFrequencies) + val model = idf.fit(termFrequencies) val expected = Vectors.dense(Array(0, 3, 1, 2).map { x => math.log((m.toDouble + 1.0) / (x + 1.0)) }) - assert(idf.idf() ~== expected absTol 1e-12) - val tfidf = idf.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() + assert(model.idf ~== expected absTol 1e-12) + val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() assert(tfidf.size === 3) val tfidf0 = tfidf(0L).asInstanceOf[SparseVector] assert(tfidf0.indices === Array(1, 3)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 5a9be923a8625..e217b93cebbdb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -50,23 +50,17 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext { val standardizer2 = new StandardScaler() val standardizer3 = new StandardScaler(withMean = true, withStd = false) - withClue("Using a standardizer before fitting the model should throw exception.") { - intercept[IllegalStateException] { - data.map(standardizer1.transform) - } - } - - standardizer1.fit(dataRDD) - standardizer2.fit(dataRDD) - standardizer3.fit(dataRDD) + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) - val data1 = data.map(standardizer1.transform) - val data2 = data.map(standardizer2.transform) - val data3 = data.map(standardizer3.transform) + val data1 = data.map(model1.transform) + val data2 = data.map(model2.transform) + val data3 = data.map(model3.transform) - val data1RDD = standardizer1.transform(dataRDD) - val data2RDD = standardizer2.transform(dataRDD) - val data3RDD = standardizer3.transform(dataRDD) + val data1RDD = model1.transform(dataRDD) + val data2RDD = model2.transform(dataRDD) + val data3RDD = model3.transform(dataRDD) val summary = computeSummary(dataRDD) val summary1 = computeSummary(data1RDD) @@ -129,25 +123,25 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext { val standardizer2 = new StandardScaler() val standardizer3 = new StandardScaler(withMean = true, withStd = false) - standardizer1.fit(dataRDD) - standardizer2.fit(dataRDD) - standardizer3.fit(dataRDD) + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) - val data2 = data.map(standardizer2.transform) + val data2 = data.map(model2.transform) withClue("Standardization with mean can not be applied on sparse input.") { intercept[IllegalArgumentException] { - data.map(standardizer1.transform) + data.map(model1.transform) } } withClue("Standardization with mean can not be applied on sparse input.") { intercept[IllegalArgumentException] { - data.map(standardizer3.transform) + data.map(model3.transform) } } - val data2RDD = standardizer2.transform(dataRDD) + val data2RDD = model2.transform(dataRDD) val summary2 = computeSummary(data2RDD) @@ -181,13 +175,13 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext { val standardizer2 = new StandardScaler(withMean = true, withStd = false) val standardizer3 = new StandardScaler(withMean = false, withStd = true) - standardizer1.fit(dataRDD) - standardizer2.fit(dataRDD) - standardizer3.fit(dataRDD) + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) - val data1 = data.map(standardizer1.transform) - val data2 = data.map(standardizer2.transform) - val data3 = data.map(standardizer3.transform) + val data1 = data.map(model1.transform) + val data2 = data.map(model2.transform) + val data3 = data.map(model3.transform) assert(data1.forall(_.toArray.forall(_ == 0.0)), "The variance is zero, so the transformed result should be 0.0") From 80ec5bad1311651fe56e1d5178090dc63753233b Mon Sep 17 00:00:00 2001 From: Oleg Danilov Date: Thu, 7 Aug 2014 15:48:44 -0700 Subject: [PATCH 16/83] SPARK-2905 Fixed path sbin => bin Author: Oleg Danilov Closes #1835 from dosoft/SPARK-2905 and squashes the following commits: 4df423c [Oleg Danilov] SPARK-2905 Fixed path sbin => bin --- bin/spark-sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/spark-sql b/bin/spark-sql index 61ebd8ab6dec8..7813ccc361415 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -29,7 +29,7 @@ CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" FWDIR="$(cd `dirname $0`/..; pwd)" function usage { - echo "Usage: ./sbin/spark-sql [options] [cli option]" + echo "Usage: ./bin/spark-sql [options] [cli option]" pattern="usage" pattern+="\|Spark assembly has been built with Hive" pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" From 32096c2aed9978cfb9a904b4f56bb61800d17e9e Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Thu, 7 Aug 2014 16:24:22 -0700 Subject: [PATCH 17/83] SPARK-2899 Doc generation is back to working in new SBT Build. The reason for this bug was introduciton of OldDeps project. It had to be excluded to prevent unidocs from trying to put it on "docs compile" classpath. Author: Prashant Sharma Closes #1830 from ScrapCodes/doc-fix and squashes the following commits: e5d52e6 [Prashant Sharma] SPARK-2899 Doc generation is back to working in new SBT Build. --- project/SparkBuild.scala | 60 ++++++++++++++++++++++------------------ project/plugins.sbt | 2 +- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ed587783d5606..63a285b81a60c 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -30,11 +30,11 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, spark, + val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, sql, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", - "spark", "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", + "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = @@ -44,8 +44,9 @@ object BuildCommons { val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") .map(ProjectRef(buildLocation, _)) - val tools = "tools" - + val tools = ProjectRef(buildLocation, "tools") + // Root project. + val spark = ProjectRef(buildLocation, "spark") val sparkHome = buildLocation } @@ -126,26 +127,6 @@ object SparkBuild extends PomBuild { publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn ) - /** Following project only exists to pull previous artifacts of Spark for generating - Mima ignores. For more information see: SPARK 2071 */ - lazy val oldDeps = Project("oldDeps", file("dev"), settings = oldDepsSettings) - - def versionArtifact(id: String): Option[sbt.ModuleID] = { - val fullId = id + "_2.10" - Some("org.apache.spark" % fullId % "1.0.0") - } - - def oldDepsSettings() = Defaults.defaultSettings ++ Seq( - name := "old-deps", - scalaVersion := "2.10.4", - retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", - libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", - "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", - "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", - "spark-core").map(versionArtifact(_).get intransitive()) - ) - def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = { val existingSettings = projectsMap.getOrElse(projectRef.project, Seq[Setting[_]]()) projectsMap += (projectRef.project -> (existingSettings ++ settings)) @@ -184,7 +165,7 @@ object SparkBuild extends PomBuild { super.projectDefinitions(baseDirectory).map { x => if (projectsMap.exists(_._1 == x.id)) x.settings(projectsMap(x.id): _*) else x.settings(Seq[Setting[_]](): _*) - } ++ Seq[Project](oldDeps) + } ++ Seq[Project](OldDeps.project) } } @@ -193,6 +174,31 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } +/** + * Following project only exists to pull previous artifacts of Spark for generating + * Mima ignores. For more information see: SPARK 2071 + */ +object OldDeps { + + lazy val project = Project("oldDeps", file("dev"), settings = oldDepsSettings) + + def versionArtifact(id: String): Option[sbt.ModuleID] = { + val fullId = id + "_2.10" + Some("org.apache.spark" % fullId % "1.0.0") + } + + def oldDepsSettings() = Defaults.defaultSettings ++ Seq( + name := "old-deps", + scalaVersion := "2.10.4", + retrieveManaged := true, + retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", + libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", + "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", + "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", + "spark-core").map(versionArtifact(_).get intransitive()) + ) +} + object Catalyst { lazy val settings = Seq( addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), @@ -285,9 +291,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(repl, examples, tools, catalyst, yarn, yarnAlpha), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, yarn, yarnAlpha), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(repl, bagel, graphx, examples, tools, catalyst, yarn, yarnAlpha), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, graphx, examples, tools, catalyst, yarn, yarnAlpha), // Skip class names containing $ and some internal packages in Javadocs unidocAllSources in (JavaUnidoc, unidoc) := { diff --git a/project/plugins.sbt b/project/plugins.sbt index 06d18e193076e..2a61f56c2ea60 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -23,6 +23,6 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") -addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.0") +addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1") addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") From 6906b69cf568015f20c7d7c77cbcba650e5431a9 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 7 Aug 2014 18:04:49 -0700 Subject: [PATCH 18/83] SPARK-2787: Make sort-based shuffle write files directly when there's no sorting/aggregation and # partitions is small As described in https://issues.apache.org/jira/browse/SPARK-2787, right now sort-based shuffle is more expensive than hash-based for map operations that do no partial aggregation or sorting, such as groupByKey. This is because it has to serialize each data item twice (once when spilling to intermediate files, and then again when merging these files object-by-object). This patch adds a code path to just write separate files directly if the # of output partitions is small, and concatenate them at the end to produce a sorted file. On the unit test side, I added some tests that force or don't force this bypass path to be used, and checked that our tests for other features (e.g. all the operations) cover both cases. Author: Matei Zaharia Closes #1799 from mateiz/SPARK-2787 and squashes the following commits: 88cf26a [Matei Zaharia] Fix rebase 10233af [Matei Zaharia] Review comments 398cb95 [Matei Zaharia] Fix looking up shuffle manager in conf ca3efd9 [Matei Zaharia] Add docs for shuffle manager properties, and allow short names for them d0ae3c5 [Matei Zaharia] Fix some comments 90d084f [Matei Zaharia] Add code path to bypass merge-sort in ExternalSorter, and tests 31e5d7c [Matei Zaharia] Move existing logic for writing partitioned files into ExternalSorter --- .../scala/org/apache/spark/SparkEnv.scala | 27 +- .../shuffle/hash/HashShuffleReader.scala | 2 +- .../shuffle/sort/SortShuffleWriter.scala | 80 ++---- .../util/collection/ExternalSorter.scala | 233 +++++++++++++++--- .../util/collection/ExternalSorterSuite.scala | 165 +++++++++++-- docs/configuration.md | 18 ++ 6 files changed, 407 insertions(+), 118 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 9d4edeb6d96cf..22d8d1cb1ddcf 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -156,11 +156,9 @@ object SparkEnv extends Logging { conf.set("spark.driver.port", boundPort.toString) } - // Create an instance of the class named by the given Java system property, or by - // defaultClassName if the property is not set, and return it as a T - def instantiateClass[T](propertyName: String, defaultClassName: String): T = { - val name = conf.get(propertyName, defaultClassName) - val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader) + // Create an instance of the class with the given name, possibly initializing it with our conf + def instantiateClass[T](className: String): T = { + val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader) // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just // SparkConf, then one taking no arguments try { @@ -178,11 +176,17 @@ object SparkEnv extends Logging { } } - val serializer = instantiateClass[Serializer]( + // Create an instance of the class named by the given SparkConf property, or defaultClassName + // if the property is not set, possibly initializing it with our conf + def instantiateClassFromConf[T](propertyName: String, defaultClassName: String): T = { + instantiateClass[T](conf.get(propertyName, defaultClassName)) + } + + val serializer = instantiateClassFromConf[Serializer]( "spark.serializer", "org.apache.spark.serializer.JavaSerializer") logDebug(s"Using serializer: ${serializer.getClass}") - val closureSerializer = instantiateClass[Serializer]( + val closureSerializer = instantiateClassFromConf[Serializer]( "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") def registerOrLookup(name: String, newActor: => Actor): ActorRef = { @@ -246,8 +250,13 @@ object SparkEnv extends Logging { "." } - val shuffleManager = instantiateClass[ShuffleManager]( - "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") + // Let the user specify short names for shuffle managers + val shortShuffleMgrNames = Map( + "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", + "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") + val shuffleMgrName = conf.get("spark.shuffle.manager", "hash") + val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) + val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) val shuffleMemoryManager = new ShuffleMemoryManager(conf) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 7c9dc8e5f88ef..88a5f1e5ddf58 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -58,7 +58,7 @@ private[spark] class HashShuffleReader[K, C]( // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, // the ExternalSorter won't spill to disk. val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) - sorter.write(aggregatedIter) + sorter.insertAll(aggregatedIter) context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled sorter.iterator diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index e54e6383d2ccc..22f656fa371ea 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -44,6 +44,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private var sorter: ExternalSorter[K, V, _] = null private var outputFile: File = null + private var indexFile: File = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -57,78 +58,36 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { - // Get an iterator with the elements for each partition ID - val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = { - if (dep.mapSideCombine) { - if (!dep.aggregator.isDefined) { - throw new IllegalStateException("Aggregator is empty for map-side combine") - } - sorter = new ExternalSorter[K, V, C]( - dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - sorter.write(records) - sorter.partitionedIterator - } else { - // In this case we pass neither an aggregator nor an ordering to the sorter, because we - // don't care whether the keys get sorted in each partition; that will be done on the - // reduce side if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V]( - None, Some(dep.partitioner), None, dep.serializer) - sorter.write(records) - sorter.partitionedIterator + if (dep.mapSideCombine) { + if (!dep.aggregator.isDefined) { + throw new IllegalStateException("Aggregator is empty for map-side combine") } + sorter = new ExternalSorter[K, V, C]( + dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) + sorter.insertAll(records) + } else { + // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't + // care whether the keys get sorted in each partition; that will be done on the reduce side + // if the operation being run is sortByKey. + sorter = new ExternalSorter[K, V, V]( + None, Some(dep.partitioner), None, dep.serializer) + sorter.insertAll(records) } // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later // serve different ranges of this file using an index file that we create at the end. val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0) - outputFile = blockManager.diskBlockManager.getFile(blockId) - - // Track location of each range in the output file - val offsets = new Array[Long](numPartitions + 1) - val lengths = new Array[Long](numPartitions) - - for ((id, elements) <- partitions) { - if (elements.hasNext) { - val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize, - writeMetrics) - for (elem <- elements) { - writer.write(elem) - } - writer.commitAndClose() - val segment = writer.fileSegment() - offsets(id + 1) = segment.offset + segment.length - lengths(id) = segment.length - } else { - // The partition is empty; don't create a new writer to avoid writing headers, etc - offsets(id + 1) = offsets(id) - } - } - - context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled - context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled - // Write an index file with the offsets of each block, plus a final offset at the end for the - // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure - // out where each block begins and ends. + outputFile = blockManager.diskBlockManager.getFile(blockId) + indexFile = blockManager.diskBlockManager.getFile(blockId.name + ".index") - val diskBlockManager = blockManager.diskBlockManager - val indexFile = diskBlockManager.getFile(blockId.name + ".index") - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) - try { - var i = 0 - while (i < numPartitions + 1) { - out.writeLong(offsets(i)) - i += 1 - } - } finally { - out.close() - } + val partitionLengths = sorter.writePartitionedFile(blockId, context) // Register our map output with the ShuffleBlockManager, which handles cleaning it over time blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions) mapStatus = new MapStatus(blockManager.blockManagerId, - lengths.map(MapOutputTracker.compressSize)) + partitionLengths.map(MapOutputTracker.compressSize)) } /** Close this writer, passing along whether the map completed */ @@ -145,6 +104,9 @@ private[spark] class SortShuffleWriter[K, V, C]( if (outputFile != null) { outputFile.delete() } + if (indexFile != null) { + indexFile.delete() + } return None } } finally { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index eb4849ebc6e52..b73d5e0cf1714 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -25,10 +25,10 @@ import scala.collection.mutable import com.google.common.io.ByteStreams -import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner} +import org.apache.spark._ import org.apache.spark.serializer.{DeserializationStream, Serializer} -import org.apache.spark.storage.BlockId import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.storage.{BlockObjectWriter, BlockId} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -67,6 +67,13 @@ import org.apache.spark.executor.ShuffleWriteMetrics * for equality to merge values. * * - Users are expected to call stop() at the end to delete all the intermediate files. + * + * As a special case, if no Ordering and no Aggregator is given, and the number of partitions is + * less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort and just write to + * separate files for each partition each time we spill, similar to the HashShuffleWriter. We can + * then concatenate these files to produce a single sorted file, without having to serialize and + * de-serialize each item twice (as is needed during the merge). This speeds up the map side of + * groupBy, sort, etc operations since they do no partial aggregation. */ private[spark] class ExternalSorter[K, V, C]( aggregator: Option[Aggregator[K, V, C]] = None, @@ -124,6 +131,18 @@ private[spark] class ExternalSorter[K, V, C]( // How much of the shared memory pool this collection has claimed private var myMemoryThreshold = 0L + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need + // local aggregation and sorting, write numPartitions files directly and just concatenate them + // at the end. This avoids doing serialization and deserialization twice to merge together the + // spilled files, which would happen with the normal code path. The downside is having multiple + // files open at a time and thus more memory allocated to buffers. + private val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + private val bypassMergeSort = + (numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty) + + // Array of file writers for each partition, used if bypassMergeSort is true and we've spilled + private var partitionWriters: Array[BlockObjectWriter] = null + // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some @@ -137,7 +156,14 @@ private[spark] class ExternalSorter[K, V, C]( } }) - // A comparator for (Int, K) elements that orders them by partition and then possibly by key + // A comparator for (Int, K) pairs that orders them by only their partition ID + private val partitionComparator: Comparator[(Int, K)] = new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + a._1 - b._1 + } + } + + // A comparator that orders (Int, K) pairs by partition ID and then possibly by key private val partitionKeyComparator: Comparator[(Int, K)] = { if (ordering.isDefined || aggregator.isDefined) { // Sort by partition ID then key comparator @@ -153,11 +179,7 @@ private[spark] class ExternalSorter[K, V, C]( } } else { // Just sort it by partition ID - new Comparator[(Int, K)] { - override def compare(a: (Int, K), b: (Int, K)): Int = { - a._1 - b._1 - } - } + partitionComparator } } @@ -171,7 +193,7 @@ private[spark] class ExternalSorter[K, V, C]( elementsPerPartition: Array[Long]) private val spills = new ArrayBuffer[SpilledFile] - def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined @@ -242,6 +264,38 @@ private[spark] class ExternalSorter[K, V, C]( val threadId = Thread.currentThread().getId logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)" .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) + + if (bypassMergeSort) { + spillToPartitionFiles(collection) + } else { + spillToMergeableFile(collection) + } + + if (usingMap) { + map = new SizeTrackingAppendOnlyMap[(Int, K), C] + } else { + buffer = new SizeTrackingPairBuffer[(Int, K), C] + } + + // Release our memory back to the shuffle pool so that other threads can grab it + shuffleMemoryManager.release(myMemoryThreshold) + myMemoryThreshold = 0 + + _memoryBytesSpilled += memorySize + } + + /** + * Spill our in-memory collection to a sorted file that we can merge later (normal code path). + * We add this file into spilledFiles to find it later. + * + * Alternatively, if bypassMergeSort is true, we spill to separate files for each partition. + * See spillToPartitionedFiles() for that code path. + * + * @param collection whichever collection we're using (map or buffer) + */ + private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { + assert(!bypassMergeSort) + val (blockId, file) = diskBlockManager.createTempBlock() curWriteMetrics = new ShuffleWriteMetrics() var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) @@ -304,18 +358,36 @@ private[spark] class ExternalSorter[K, V, C]( } } - if (usingMap) { - map = new SizeTrackingAppendOnlyMap[(Int, K), C] - } else { - buffer = new SizeTrackingPairBuffer[(Int, K), C] - } + spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) + } - // Release our memory back to the shuffle pool so that other threads can grab it - shuffleMemoryManager.release(myMemoryThreshold) - myMemoryThreshold = 0 + /** + * Spill our in-memory collection to separate files, one for each partition. This is used when + * there's no aggregator and ordering and the number of partitions is small, because it allows + * writePartitionedFile to just concatenate files without deserializing data. + * + * @param collection whichever collection we're using (map or buffer) + */ + private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { + assert(bypassMergeSort) + + // Create our file writers if we haven't done so yet + if (partitionWriters == null) { + curWriteMetrics = new ShuffleWriteMetrics() + partitionWriters = Array.fill(numPartitions) { + val (blockId, file) = diskBlockManager.createTempBlock() + blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open() + } + } - spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) - _memoryBytesSpilled += memorySize + val it = collection.iterator // No need to sort stuff, just write each element out + while (it.hasNext) { + val elem = it.next() + val partitionId = elem._1._1 + val key = elem._1._2 + val value = elem._2 + partitionWriters(partitionId).write((key, value)) + } } /** @@ -479,7 +551,6 @@ private[spark] class ExternalSorter[K, V, C]( skipToNextPartition() - // Intermediate file and deserializer streams that read from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams var fileStream: FileInputStream = null @@ -619,23 +690,25 @@ private[spark] class ExternalSorter[K, V, C]( def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer - if (spills.isEmpty) { + if (spills.isEmpty && partitionWriters == null) { // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps // we don't even need to sort by anything other than partition ID if (!ordering.isDefined) { - // The user isn't requested sorted keys, so only sort by partition ID, not key - val partitionComparator = new Comparator[(Int, K)] { - override def compare(a: (Int, K), b: (Int, K)): Int = { - a._1 - b._1 - } - } + // The user hasn't requested sorted keys, so only sort by partition ID, not key groupByPartition(collection.destructiveSortedIterator(partitionComparator)) } else { // We do need to sort by both partition ID and key groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator)) } + } else if (bypassMergeSort) { + // Read data from each partition file and merge it together with the data in memory; + // note that there's no ordering or aggregator in this case -- we just partition objects + val collIter = groupByPartition(collection.destructiveSortedIterator(partitionComparator)) + collIter.map { case (partitionId, values) => + (partitionId, values ++ readPartitionFile(partitionWriters(partitionId))) + } } else { - // General case: merge spilled and in-memory data + // Merge spilled and in-memory data merge(spills, collection.destructiveSortedIterator(partitionKeyComparator)) } } @@ -645,9 +718,113 @@ private[spark] class ExternalSorter[K, V, C]( */ def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2) + /** + * Write all the data added into this ExternalSorter into a file in the disk store, creating + * an .index file for it as well with the offsets of each partition. This is called by the + * SortShuffleWriter and can go through an efficient path of just concatenating binary files + * if we decided to avoid merge-sorting. + * + * @param blockId block ID to write to. The index file will be blockId.name + ".index". + * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + def writePartitionedFile(blockId: BlockId, context: TaskContext): Array[Long] = { + val outputFile = blockManager.diskBlockManager.getFile(blockId) + + // Track location of each range in the output file + val offsets = new Array[Long](numPartitions + 1) + val lengths = new Array[Long](numPartitions) + + if (bypassMergeSort && partitionWriters != null) { + // We decided to write separate files for each partition, so just concatenate them. To keep + // this simple we spill out the current in-memory collection so that everything is in files. + spillToPartitionFiles(if (aggregator.isDefined) map else buffer) + partitionWriters.foreach(_.commitAndClose()) + var out: FileOutputStream = null + var in: FileInputStream = null + try { + out = new FileOutputStream(outputFile) + for (i <- 0 until numPartitions) { + val file = partitionWriters(i).fileSegment().file + in = new FileInputStream(file) + org.apache.spark.util.Utils.copyStream(in, out) + in.close() + in = null + lengths(i) = file.length() + offsets(i + 1) = offsets(i) + lengths(i) + } + } finally { + if (out != null) { + out.close() + } + if (in != null) { + in.close() + } + } + } else { + // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by + // partition and just write everything directly. + for ((id, elements) <- this.partitionedIterator) { + if (elements.hasNext) { + val writer = blockManager.getDiskWriter( + blockId, outputFile, ser, fileBufferSize, context.taskMetrics.shuffleWriteMetrics.get) + for (elem <- elements) { + writer.write(elem) + } + writer.commitAndClose() + val segment = writer.fileSegment() + offsets(id + 1) = segment.offset + segment.length + lengths(id) = segment.length + } else { + // The partition is empty; don't create a new writer to avoid writing headers, etc + offsets(id + 1) = offsets(id) + } + } + } + + context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += diskBytesSpilled + + // Write an index file with the offsets of each block, plus a final offset at the end for the + // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure + // out where each block begins and ends. + + val diskBlockManager = blockManager.diskBlockManager + val indexFile = diskBlockManager.getFile(blockId.name + ".index") + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) + try { + var i = 0 + while (i < numPartitions + 1) { + out.writeLong(offsets(i)) + i += 1 + } + } finally { + out.close() + } + + lengths + } + + /** + * Read a partition file back as an iterator (used in our iterator method) + */ + def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { + if (writer.isOpen) { + writer.commitAndClose() + } + blockManager.getLocalFromDisk(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]] + } + def stop(): Unit = { spills.foreach(s => s.file.delete()) spills.clear() + if (partitionWriters != null) { + partitionWriters.foreach { w => + w.revertPartialWritesAndClose() + diskBlockManager.getFile(w.blockId).delete() + } + partitionWriters = null + } } def memoryBytesSpilled: Long = _memoryBytesSpilled diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 57dcb4ffabac1..706faed980f31 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite +import org.scalatest.{PrivateMethodTester, FunSuite} import org.apache.spark._ import org.apache.spark.SparkContext._ -class ExternalSorterSuite extends FunSuite with LocalSparkContext { +class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester { private def createSparkConf(loadDefaults: Boolean): SparkConf = { val conf = new SparkConf(loadDefaults) // Make the Java serializer write a reset instruction (TC_RESET) after each object to test @@ -36,6 +36,16 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { conf } + private def assertBypassedMergeSort(sorter: ExternalSorter[_, _, _]): Unit = { + val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort) + assert(sorter.invokePrivate(bypassMergeSort()), "sorter did not bypass merge-sort") + } + + private def assertDidNotBypassMergeSort(sorter: ExternalSorter[_, _, _]): Unit = { + val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort) + assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort") + } + test("empty data stream") { val conf = new SparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") @@ -86,28 +96,28 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { // Both aggregator and ordering val sorter = new ExternalSorter[Int, Int, Int]( Some(agg), Some(new HashPartitioner(7)), Some(ord), None) - sorter.write(elements.iterator) + sorter.insertAll(elements.iterator) assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter.stop() // Only aggregator val sorter2 = new ExternalSorter[Int, Int, Int]( Some(agg), Some(new HashPartitioner(7)), None, None) - sorter2.write(elements.iterator) + sorter2.insertAll(elements.iterator) assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter2.stop() // Only ordering val sorter3 = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(7)), Some(ord), None) - sorter3.write(elements.iterator) + sorter3.insertAll(elements.iterator) assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter3.stop() // Neither aggregator nor ordering val sorter4 = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(7)), None, None) - sorter4.write(elements.iterator) + sorter4.insertAll(elements.iterator) assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter4.stop() } @@ -118,13 +128,37 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val ord = implicitly[Ordering[Int]] val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), Some(ord), None) + assertDidNotBypassMergeSort(sorter) + sorter.insertAll(elements) + assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled + val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) + assert(iter.next() === (0, Nil)) + assert(iter.next() === (1, List((1, 1)))) + assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList)) + assert(iter.next() === (3, Nil)) + assert(iter.next() === (4, Nil)) + assert(iter.next() === (5, List((5, 5)))) + assert(iter.next() === (6, Nil)) + sorter.stop() + } + + test("empty partitions with spilling, bypass merge-sort") { + val conf = createSparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(7)), None, None) - sorter.write(elements) + assertBypassedMergeSort(sorter) + sorter.insertAll(elements) assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) assert(iter.next() === (0, Nil)) @@ -286,14 +320,43 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + val ord = implicitly[Ordering[Int]] + + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + assertDidNotBypassMergeSort(sorter) + sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) + assert(diskBlockManager.getAllFiles().length > 0) + sorter.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + + val sorter2 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + assertDidNotBypassMergeSort(sorter2) + sorter2.insertAll((0 until 100000).iterator.map(i => (i, i))) + assert(diskBlockManager.getAllFiles().length > 0) + assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet) + sorter2.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + } + + test("cleanup of intermediate files in sorter, bypass merge-sort") { + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - sorter.write((0 until 100000).iterator.map(i => (i, i))) + assertBypassedMergeSort(sorter) + sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) sorter.stop() assert(diskBlockManager.getAllBlocks().length === 0) val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - sorter2.write((0 until 100000).iterator.map(i => (i, i))) + assertBypassedMergeSort(sorter2) + sorter2.insertAll((0 until 100000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet) sorter2.stop() @@ -307,9 +370,35 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + val ord = implicitly[Ordering[Int]] + + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + assertDidNotBypassMergeSort(sorter) + intercept[SparkException] { + sorter.insertAll((0 until 100000).iterator.map(i => { + if (i == 99990) { + throw new SparkException("Intentional failure") + } + (i, i) + })) + } + assert(diskBlockManager.getAllFiles().length > 0) + sorter.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + } + + test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") { + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) + assertBypassedMergeSort(sorter) intercept[SparkException] { - sorter.write((0 until 100000).iterator.map(i => { + sorter.insertAll((0 until 100000).iterator.map(i => { if (i == 99990) { throw new SparkException("Intentional failure") } @@ -365,7 +454,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - sorter.write((0 until 100000).iterator.map(i => (i / 4, i))) + sorter.insertAll((0 until 100000).iterator.map(i => (i / 4, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet) @@ -381,7 +470,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) - sorter.write((0 until 100).iterator.map(i => (i / 2, i))) + sorter.insertAll((0 until 100).iterator.map(i => (i / 2, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) @@ -397,7 +486,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) - sorter.write((0 until 100000).iterator.map(i => (i / 2, i))) + sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) @@ -414,7 +503,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val ord = implicitly[Ordering[Int]] val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None) - sorter.write((0 until 100000).iterator.map(i => (i / 2, i))) + sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) @@ -431,7 +520,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val ord = implicitly[Ordering[Int]] val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - sorter.write((0 until 100).iterator.map(i => (i, i))) + sorter.insertAll((0 until 100).iterator.map(i => (i, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq val expected = (0 until 3).map(p => { (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) @@ -448,7 +537,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val ord = implicitly[Ordering[Int]] val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - sorter.write((0 until 100000).iterator.map(i => (i, i))) + sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq val expected = (0 until 3).map(p => { (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) @@ -495,7 +584,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++ collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap) - sorter.write(toInsert) + sorter.insertAll(toInsert) // A map of collision pairs in both directions val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap @@ -524,7 +613,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes // problems if the map fails to group together the objects with the same code (SPARK-2043). val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1) - sorter.write(toInsert.iterator) + sorter.insertAll(toInsert.iterator) val it = sorter.iterator var count = 0 @@ -548,7 +637,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None) - sorter.write((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) + sorter.insertAll((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) val it = sorter.iterator while (it.hasNext) { @@ -572,7 +661,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( Some(agg), None, None, None) - sorter.write((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator( + sorter.insertAll((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator( (null.asInstanceOf[String], "1"), ("1", null.asInstanceOf[String]), (null.asInstanceOf[String], null.asInstanceOf[String]) @@ -584,4 +673,38 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { it.next() } } + + test("conditions for bypassing merge-sort") { + val conf = createSparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + + // Numbers of partitions that are above and below the default bypassMergeThreshold + val FEW_PARTITIONS = 50 + val MANY_PARTITIONS = 10000 + + // Sorters with no ordering or aggregator: should bypass unless # of partitions is high + + val sorter1 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(FEW_PARTITIONS)), None, None) + assertBypassedMergeSort(sorter1) + + val sorter2 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(MANY_PARTITIONS)), None, None) + assertDidNotBypassMergeSort(sorter2) + + // Sorters with an ordering or aggregator: should not bypass even if they have few partitions + + val sorter3 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(FEW_PARTITIONS)), Some(ord), None) + assertDidNotBypassMergeSort(sorter3) + + val sorter4 = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None) + assertDidNotBypassMergeSort(sorter4) + } } diff --git a/docs/configuration.md b/docs/configuration.md index 5e3eb0f0871af..4d27c5a918fe0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -281,6 +281,24 @@ Apart from these, the following properties are also available, and may be useful overhead per reduce task, so keep it small unless you have a large amount of memory. + + spark.shuffle.manager + HASH + + Implementation to use for shuffling data. A hash-based shuffle manager is the default, but + starting in Spark 1.1 there is an experimental sort-based shuffle manager that is more + memory-efficient in environments with small executors, such as YARN. To use that, change + this value to SORT. + + + + spark.shuffle.sort.bypassMergeThreshold + 200 + + (Advanced) In the sort-based shuffle manager, avoid merge-sorting data if there is no + map-side aggregation and there are at most this many reduce partitions. + + #### Spark UI From 4c51098f320f164eb66f92ff0f26b0b595a58f38 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 7 Aug 2014 18:09:03 -0700 Subject: [PATCH 19/83] SPARK-2565. Update ShuffleReadMetrics as blocks are fetched Author: Sandy Ryza Closes #1507 from sryza/sandy-spark-2565 and squashes the following commits: 74dad41 [Sandy Ryza] SPARK-2565. Update ShuffleReadMetrics as blocks are fetched --- .../org/apache/spark/executor/Executor.scala | 1 + .../apache/spark/executor/TaskMetrics.scala | 55 ++++++++++++++----- .../hash/BlockStoreShuffleFetcher.scala | 13 ++--- .../shuffle/hash/HashShuffleReader.scala | 4 +- .../spark/storage/BlockFetcherIterator.scala | 40 +++++--------- .../apache/spark/storage/BlockManager.scala | 11 ++-- .../org/apache/spark/util/JsonProtocol.scala | 5 +- .../storage/BlockFetcherIteratorSuite.scala | 13 +++-- .../ui/jobs/JobProgressListenerSuite.scala | 4 +- .../apache/spark/util/JsonProtocolSuite.scala | 2 +- 10 files changed, 84 insertions(+), 64 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c2b9c660ddaec..eac1f2326a29d 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -374,6 +374,7 @@ private[spark] class Executor( for (taskRunner <- runningTasks.values()) { if (!taskRunner.attemptedTask.isEmpty) { Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => + metrics.updateShuffleReadMetrics tasksMetrics += ((taskRunner.taskId, metrics)) } } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 11a6e10243211..99a88c13456df 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,6 +17,8 @@ package org.apache.spark.executor +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.{BlockId, BlockStatus} @@ -81,12 +83,27 @@ class TaskMetrics extends Serializable { var inputMetrics: Option[InputMetrics] = None /** - * If this task reads from shuffle output, metrics on getting shuffle data will be collected here + * If this task reads from shuffle output, metrics on getting shuffle data will be collected here. + * This includes read metrics aggregated over all the task's shuffle dependencies. */ private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None def shuffleReadMetrics = _shuffleReadMetrics + /** + * This should only be used when recreating TaskMetrics, not when updating read metrics in + * executors. + */ + private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) { + _shuffleReadMetrics = shuffleReadMetrics + } + + /** + * ShuffleReadMetrics per dependency for collecting independently while task is in progress. + */ + @transient private lazy val depsShuffleReadMetrics: ArrayBuffer[ShuffleReadMetrics] = + new ArrayBuffer[ShuffleReadMetrics]() + /** * If this task writes to shuffle output, metrics on the written shuffle data will be collected * here @@ -98,19 +115,31 @@ class TaskMetrics extends Serializable { */ var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None - /** Adds the given ShuffleReadMetrics to any existing shuffle metrics for this task. */ - def updateShuffleReadMetrics(newMetrics: ShuffleReadMetrics) = synchronized { - _shuffleReadMetrics match { - case Some(existingMetrics) => - existingMetrics.shuffleFinishTime = math.max( - existingMetrics.shuffleFinishTime, newMetrics.shuffleFinishTime) - existingMetrics.fetchWaitTime += newMetrics.fetchWaitTime - existingMetrics.localBlocksFetched += newMetrics.localBlocksFetched - existingMetrics.remoteBlocksFetched += newMetrics.remoteBlocksFetched - existingMetrics.remoteBytesRead += newMetrics.remoteBytesRead - case None => - _shuffleReadMetrics = Some(newMetrics) + /** + * A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization + * issues from readers in different threads, in-progress tasks use a ShuffleReadMetrics for each + * dependency, and merge these metrics before reporting them to the driver. This method returns + * a ShuffleReadMetrics for a dependency and registers it for merging later. + */ + private [spark] def createShuffleReadMetricsForDependency(): ShuffleReadMetrics = synchronized { + val readMetrics = new ShuffleReadMetrics() + depsShuffleReadMetrics += readMetrics + readMetrics + } + + /** + * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics. + */ + private[spark] def updateShuffleReadMetrics() = synchronized { + val merged = new ShuffleReadMetrics() + for (depMetrics <- depsShuffleReadMetrics) { + merged.fetchWaitTime += depMetrics.fetchWaitTime + merged.localBlocksFetched += depMetrics.localBlocksFetched + merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched + merged.remoteBytesRead += depMetrics.remoteBytesRead + merged.shuffleFinishTime = math.max(merged.shuffleFinishTime, depMetrics.shuffleFinishTime) } + _shuffleReadMetrics = Some(merged) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 99788828981c7..12b475658e29d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -32,7 +32,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer) + serializer: Serializer, + shuffleMetrics: ShuffleReadMetrics) : Iterator[T] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) @@ -73,17 +74,11 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } - val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) + val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { - val shuffleMetrics = new ShuffleReadMetrics - shuffleMetrics.shuffleFinishTime = System.currentTimeMillis - shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime - shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead - shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks - shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks - context.taskMetrics.updateShuffleReadMetrics(shuffleMetrics) + context.taskMetrics.updateShuffleReadMetrics() }) new InterruptibleIterator[T](context, completionIter) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 88a5f1e5ddf58..7bed97a63f0f6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -36,8 +36,10 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser, + readMetrics) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 938af6f5b923a..5f44f5f3197fd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -27,6 +27,7 @@ import scala.util.{Failure, Success} import io.netty.buffer.ByteBuf import org.apache.spark.{Logging, SparkException} +import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.network.BufferMessage import org.apache.spark.network.ConnectionManagerId import org.apache.spark.network.netty.ShuffleCopier @@ -47,10 +48,6 @@ import org.apache.spark.util.Utils private[storage] trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { def initialize() - def numLocalBlocks: Int - def numRemoteBlocks: Int - def fetchWaitTime: Long - def remoteBytesRead: Long } @@ -72,14 +69,12 @@ object BlockFetcherIterator { class BasicBlockFetcherIterator( private val blockManager: BlockManager, val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer) + serializer: Serializer, + readMetrics: ShuffleReadMetrics) extends BlockFetcherIterator { import blockManager._ - private var _remoteBytesRead = 0L - private var _fetchWaitTime = 0L - if (blocksByAddress == null) { throw new IllegalArgumentException("BlocksByAddress is null") } @@ -89,13 +84,9 @@ object BlockFetcherIterator { protected var startTime = System.currentTimeMillis - // This represents the number of local blocks, also counting zero-sized blocks - private var numLocal = 0 // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks protected val localBlocksToFetch = new ArrayBuffer[BlockId]() - // This represents the number of remote blocks, also counting zero-sized blocks - private var numRemote = 0 // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks protected val remoteBlocksToFetch = new HashSet[BlockId]() @@ -132,7 +123,10 @@ object BlockFetcherIterator { val networkSize = blockMessage.getData.limit() results.put(new FetchResult(blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData, serializer))) - _remoteBytesRead += networkSize + // TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can + // be incrementing bytes read at the same time (SPARK-2625). + readMetrics.remoteBytesRead += networkSize + readMetrics.remoteBlocksFetched += 1 logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } } @@ -155,14 +149,14 @@ object BlockFetcherIterator { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] + var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { + totalBlocks += blockInfos.size if (address == blockManagerId) { - numLocal = blockInfos.size // Filter out zero-sized blocks localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) _numBlocksToFetch += localBlocksToFetch.size } else { - numRemote += blockInfos.size val iterator = blockInfos.iterator var curRequestSize = 0L var curBlocks = new ArrayBuffer[(BlockId, Long)] @@ -192,7 +186,7 @@ object BlockFetcherIterator { } } logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " + - (numLocal + numRemote) + " blocks") + totalBlocks + " blocks") remoteRequests } @@ -205,6 +199,7 @@ object BlockFetcherIterator { // getLocalFromDisk never return None but throws BlockException val iter = getLocalFromDisk(id, serializer).get // Pass 0 as size since it's not in flight + readMetrics.localBlocksFetched += 1 results.put(new FetchResult(id, 0, () => iter)) logDebug("Got local block " + id) } catch { @@ -238,12 +233,6 @@ object BlockFetcherIterator { logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") } - override def numLocalBlocks: Int = numLocal - override def numRemoteBlocks: Int = numRemote - override def fetchWaitTime: Long = _fetchWaitTime - override def remoteBytesRead: Long = _remoteBytesRead - - // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue // as they arrive. @volatile protected var resultsGotten = 0 @@ -255,7 +244,7 @@ object BlockFetcherIterator { val startFetchWait = System.currentTimeMillis() val result = results.take() val stopFetchWait = System.currentTimeMillis() - _fetchWaitTime += (stopFetchWait - startFetchWait) + readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) if (! result.failed) bytesInFlight -= result.size while (!fetchRequests.isEmpty && (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { @@ -269,8 +258,9 @@ object BlockFetcherIterator { class NettyBlockFetcherIterator( blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer) - extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { + serializer: Serializer, + readMetrics: ShuffleReadMetrics) + extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) { import blockManager._ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 8d21b02b747ff..e8bbd298c631a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ -import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleWriteMetrics} +import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -539,12 +539,15 @@ private[spark] class BlockManager( */ def getMultiple( blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer): BlockFetcherIterator = { + serializer: Serializer, + readMetrics: ShuffleReadMetrics): BlockFetcherIterator = { val iter = if (conf.getBoolean("spark.shuffle.use.netty", false)) { - new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) + new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer, + readMetrics) } else { - new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) + new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer, + readMetrics) } iter.initialize() iter diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index b112b359368cd..6f8eb1ee12634 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -560,9 +560,8 @@ private[spark] object JsonProtocol { metrics.resultSerializationTime = (json \ "Result Serialization Time").extract[Long] metrics.memoryBytesSpilled = (json \ "Memory Bytes Spilled").extract[Long] metrics.diskBytesSpilled = (json \ "Disk Bytes Spilled").extract[Long] - Utils.jsonOption(json \ "Shuffle Read Metrics").map { shuffleReadMetrics => - metrics.updateShuffleReadMetrics(shuffleReadMetricsFromJson(shuffleReadMetrics)) - } + metrics.setShuffleReadMetrics( + Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson)) metrics.shuffleWriteMetrics = Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) metrics.inputMetrics = diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index 1538995a6b404..bcbfe8baf36ad 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -33,6 +33,7 @@ import org.mockito.invocation.InvocationOnMock import org.apache.spark.storage.BlockFetcherIterator._ import org.apache.spark.network.{ConnectionManager, Message} +import org.apache.spark.executor.ShuffleReadMetrics class BlockFetcherIteratorSuite extends FunSuite with Matchers { @@ -70,8 +71,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) ) - val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null, + new ShuffleReadMetrics()) iterator.initialize() @@ -121,8 +122,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) ) - val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null, + new ShuffleReadMetrics()) iterator.initialize() @@ -165,7 +166,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { ) val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + blocksByAddress, null, new ShuffleReadMetrics()) iterator.initialize() iterator.foreach{ @@ -219,7 +220,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { ) val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + blocksByAddress, null, new ShuffleReadMetrics()) iterator.initialize() iterator.foreach{ case (_, r) => { diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index cb8252515238e..f5ba31c309277 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -65,7 +65,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // finish this task, should get updated shuffleRead shuffleReadMetrics.remoteBytesRead = 1000 - taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) + taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 var task = new ShuffleMapTask(0) @@ -142,7 +142,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val taskMetrics = new TaskMetrics() val shuffleReadMetrics = new ShuffleReadMetrics() val shuffleWriteMetrics = new ShuffleWriteMetrics() - taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) + taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) shuffleReadMetrics.remoteBytesRead = base + 1 shuffleReadMetrics.remoteBlocksFetched = base + 2 diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 2002a817d9168..97ffb07662482 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -539,7 +539,7 @@ class JsonProtocolSuite extends FunSuite { sr.localBlocksFetched = e sr.fetchWaitTime = a + d sr.remoteBlocksFetched = f - t.updateShuffleReadMetrics(sr) + t.setShuffleReadMetrics(Some(sr)) } sw.shuffleBytesWritten = a + b + c sw.shuffleWriteTime = b + c + d From 9de6a42bb34ea8963225ce90f1a45adcfee38b58 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 7 Aug 2014 18:53:15 -0700 Subject: [PATCH 20/83] [SPARK-2904] Remove non-used local variable in SparkSubmitArguments Author: Kousuke Saruta Closes #1834 from sarutak/SPARK-2904 and squashes the following commits: 38e7d45 [Kousuke Saruta] Removed non-used variable in SparkSubmitArguments --- .../scala/org/apache/spark/deploy/SparkSubmitArguments.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 087dd4d633db0..c21f1529a1837 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -219,7 +219,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { /** Fill in values by parsing user options. */ private def parseOpts(opts: Seq[String]): Unit = { - var inSparkOpts = true val EQ_SEPARATED_OPT="""(--[^=]+)=(.+)""".r // Delineates parsing of Spark options from parsing of user options. From 9a54de16ed9de536e0436d532c587384e1ea0af6 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Thu, 7 Aug 2014 23:45:16 -0700 Subject: [PATCH 21/83] [SPARK-2911]: provide rdd.parent[T](j) to obtain jth parent RDD Author: Erik Erlandson Closes #1841 from erikerlandson/spark-2911-pr and squashes the following commits: 4699e2f [Erik Erlandson] [SPARK-2911]: provide rdd.parent[T](j) to obtain jth parent RDD --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 5 +++++ .../src/test/scala/org/apache/spark/rdd/RDDSuite.scala | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0159003c88e06..19e10bd04681b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1233,6 +1233,11 @@ abstract class RDD[T: ClassTag]( dependencies.head.rdd.asInstanceOf[RDD[U]] } + /** Returns the jth parent RDD: e.g. rdd.parent[T](0) is equivalent to rdd.firstParent[T] */ + protected[spark] def parent[U: ClassTag](j: Int) = { + dependencies(j).rdd.asInstanceOf[RDD[U]] + } + /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ def context = sc diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 4a7dc8dca25e2..926d4fecb5b91 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -726,6 +726,16 @@ class RDDSuite extends FunSuite with SharedSparkContext { jrdd.rdd.retag.collect() } + test("parent method") { + val rdd1 = sc.parallelize(1 to 10, 2) + val rdd2 = rdd1.filter(_ % 2 == 0) + val rdd3 = rdd2.map(_ + 1) + val rdd4 = new UnionRDD(sc, List(rdd1, rdd2, rdd3)) + assert(rdd4.parent(0).isInstanceOf[ParallelCollectionRDD[_]]) + assert(rdd4.parent(1).isInstanceOf[FilteredRDD[_]]) + assert(rdd4.parent(2).isInstanceOf[MappedRDD[_, _]]) + } + test("getNarrowAncestors") { val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.filter(_ % 2 == 0).map(_ + 1) From 9016af3f2729101027e33593e094332f05f48d92 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 8 Aug 2014 11:01:51 -0700 Subject: [PATCH 22/83] [SPARK-2888] [SQL] Fix addColumnMetadataToConf in HiveTableScan JIRA: https://issues.apache.org/jira/browse/SPARK-2888 Author: Yin Huai Closes #1817 from yhuai/fixAddColumnMetadataToConf and squashes the following commits: fba728c [Yin Huai] Fix addColumnMetadataToConf. --- .../sql/hive/execution/HiveTableScan.scala | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 8920e2a76a27f..577ca928b43b6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -72,17 +72,12 @@ case class HiveTableScan( } private def addColumnMetadataToConf(hiveConf: HiveConf) { - // Specifies IDs and internal names of columns to be scanned. - val neededColumnIDs = attributes.map(a => relation.output.indexWhere(_.name == a.name): Integer) - val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",") - - if (attributes.size == relation.output.size) { - // SQLContext#pruneFilterProject guarantees no duplicated value in `attributes` - ColumnProjectionUtils.setFullyReadColumns(hiveConf) - } else { - ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) - } + // Specifies needed column IDs for those non-partitioning columns. + val neededColumnIDs = + attributes.map(a => + relation.attributes.indexWhere(_.name == a.name): Integer).filter(index => index >= 0) + ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) ColumnProjectionUtils.appendReadColumnNames(hiveConf, attributes.map(_.name)) // Specifies types and object inspectors of columns to be scanned. @@ -99,7 +94,7 @@ case class HiveTableScan( .mkString(",") hiveConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypeNames) - hiveConf.set(serdeConstants.LIST_COLUMNS, columnInternalNames) + hiveConf.set(serdeConstants.LIST_COLUMNS, relation.attributes.map(_.name).mkString(",")) } addColumnMetadataToConf(context.hiveconf) From 0489cee6b24ca34f1adab03a75d157e04a9e06b7 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 8 Aug 2014 11:10:11 -0700 Subject: [PATCH 23/83] [SPARK-2908] [SQL] JsonRDD.nullTypeToStringType does not convert all NullType to StringType JIRA: https://issues.apache.org/jira/browse/SPARK-2908 Author: Yin Huai Closes #1840 from yhuai/SPARK-2908 and squashes the following commits: 86e833e [Yin Huai] Update test. cb11759 [Yin Huai] nullTypeToStringType should check columns with the type of array of structs. --- .../scala/org/apache/spark/sql/json/JsonRDD.scala | 4 +++- .../scala/org/apache/spark/sql/json/JsonSuite.scala | 11 ++++++++--- .../org/apache/spark/sql/json/TestJsonData.scala | 2 +- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index a3d2a1c7a51f8..1c0b03c684f10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -109,7 +109,9 @@ private[sql] object JsonRDD extends Logging { val newType = dataType match { case NullType => StringType case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) - case struct: StructType => nullTypeToStringType(struct) + case ArrayType(struct: StructType, containsNull) => + ArrayType(nullTypeToStringType(struct), containsNull) + case struct: StructType =>nullTypeToStringType(struct) case other: DataType => other } StructField(fieldName, newType, nullable) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 75c0589eb208e..58b1e23891a3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -213,7 +213,8 @@ class JsonSuite extends QueryTest { StructField("arrayOfStruct", ArrayType( StructType( StructField("field1", BooleanType, true) :: - StructField("field2", StringType, true) :: Nil)), true) :: + StructField("field2", StringType, true) :: + StructField("field3", StringType, true) :: Nil)), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: StructField("field2", DecimalType, true) :: Nil), true) :: @@ -263,8 +264,12 @@ class JsonSuite extends QueryTest { // Access elements of an array of structs. checkAnswer( - sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2] from jsonTable"), - (true :: "str1" :: Nil, false :: null :: Nil, null) :: Nil + sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + + "from jsonTable"), + (true :: "str1" :: null :: Nil, + false :: null :: null :: Nil, + null :: null :: null :: Nil, + null) :: Nil ) // Access a struct and fields inside of it. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index d0180f3754f22..a88310b5f1b46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -43,7 +43,7 @@ object TestJsonData { "arrayOfDouble":[1.2, 1.7976931348623157E308, 4.9E-324, 2.2250738585072014E-308], "arrayOfBoolean":[true, false, true], "arrayOfNull":[null, null, null, null], - "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}], + "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "arrayOfArray1":[[1, 2, 3], ["str1", "str2"]], "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] }""" :: Nil) From c874723fa844b49f057bb2434a12228b2f717e99 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 8 Aug 2014 11:15:16 -0700 Subject: [PATCH 24/83] [SPARK-2877] [SQL] MetastoreRelation should use SparkClassLoader when creating the tableDesc JIRA: https://issues.apache.org/jira/browse/SPARK-2877 Author: Yin Huai Closes #1806 from yhuai/SPARK-2877 and squashes the following commits: 4142bcb [Yin Huai] Use Spark's classloader. --- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 301cf51c00e2b..82e9c1a248626 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive import scala.util.parsing.combinator.RegexParsers -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo} import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} @@ -39,6 +37,7 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.hive.execution.HiveTableScan +import org.apache.spark.util.Utils /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -288,7 +287,10 @@ private[hive] case class MetastoreRelation ) val tableDesc = new TableDesc( - Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]], + Class.forName( + hiveQlTable.getSerializationLib, + true, + Utils.getContextOrSparkClassLoader).asInstanceOf[Class[Deserializer]], hiveQlTable.getInputFormatClass, // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to From 45d8f4deab50ae069ecde2201bd486d464a4501e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 8 Aug 2014 11:23:58 -0700 Subject: [PATCH 25/83] [SPARK-2919] [SQL] Basic support for analyze command in HiveQl The command we will support is ``` ANALYZE TABLE tablename COMPUTE STATISTICS noscan ``` Other cases shown in https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables will still be treated as Hive native commands. JIRA: https://issues.apache.org/jira/browse/SPARK-2919 Author: Yin Huai Closes #1848 from yhuai/sqlAnalyze and squashes the following commits: 0b79d36 [Yin Huai] Typo and format. c59d94b [Yin Huai] Support "ANALYZE TABLE tableName COMPUTE STATISTICS noscan". --- .../org/apache/spark/sql/hive/HiveQl.scala | 21 +++++++-- .../spark/sql/hive/HiveStrategies.scala | 2 + .../{DropTable.scala => commands.scala} | 26 +++++++++++ .../spark/sql/hive/StatisticsSuite.scala | 45 ++++++++++++++++++- 4 files changed, 89 insertions(+), 5 deletions(-) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/{DropTable.scala => commands.scala} (72%) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index bc2fefafd58c8..05b2f5f6cd3f7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -46,6 +46,8 @@ private[hive] case class AddFile(filePath: String) extends Command private[hive] case class DropTable(tableName: String, ifExists: Boolean) extends Command +private[hive] case class AnalyzeTable(tableName: String) extends Command + /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl { protected val nativeCommands = Seq( @@ -74,7 +76,6 @@ private[hive] object HiveQl { "TOK_CREATEFUNCTION", "TOK_DROPFUNCTION", - "TOK_ANALYZE", "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", @@ -92,7 +93,6 @@ private[hive] object HiveQl { "TOK_ALTERTABLE_SKEWED", "TOK_ALTERTABLE_TOUCH", "TOK_ALTERTABLE_UNARCHIVE", - "TOK_ANALYZE", "TOK_CREATEDATABASE", "TOK_CREATEFUNCTION", "TOK_CREATEINDEX", @@ -239,7 +239,6 @@ private[hive] object HiveQl { ShellCommand(sql.drop(1)) } else { val tree = getAst(sql) - if (nativeCommands contains tree.getText) { NativeCommand(sql) } else { @@ -387,6 +386,22 @@ private[hive] object HiveQl { ifExists) => val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") DropTable(tableName, ifExists.nonEmpty) + // Support "ANALYZE TABLE tableNmae COMPUTE STATISTICS noscan" + case Token("TOK_ANALYZE", + Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: + isNoscan) => + // Reference: + // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables + if (partitionSpec.nonEmpty) { + // Analyze partitions will be treated as a Hive native command. + NativePlaceholder + } else if (isNoscan.isEmpty) { + // If users do not specify "noscan", it will be treated as a Hive native command. + NativePlaceholder + } else { + val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") + AnalyzeTable(tableName) + } // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if noExplainCommands.contains(explainArgs.head.getText) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 2175c5f3835a6..85d2496a34cfb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -83,6 +83,8 @@ private[hive] trait HiveStrategies { case DropTable(tableName, ifExists) => execution.DropTable(tableName, ifExists) :: Nil + case AnalyzeTable(tableName) => execution.AnalyzeTable(tableName) :: Nil + case describe: logical.DescribeCommand => val resolvedTable = context.executePlan(describe.table).analyzed resolvedTable match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala similarity index 72% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 9cd0c86c6c796..2985169da033c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -23,6 +23,32 @@ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.execution.{Command, LeafNode} import org.apache.spark.sql.hive.HiveContext +/** + * :: DeveloperApi :: + * Analyzes the given table in the current database to generate statistics, which will be + * used in query optimizations. + * + * Right now, it only supports Hive tables and it only updates the size of a Hive table + * in the Hive metastore. + */ +@DeveloperApi +case class AnalyzeTable(tableName: String) extends LeafNode with Command { + + def hiveContext = sqlContext.asInstanceOf[HiveContext] + + def output = Seq.empty + + override protected[sql] lazy val sideEffectResult = { + hiveContext.analyze(tableName) + Seq.empty[Any] + } + + override def execute(): RDD[Row] = { + sideEffectResult + sparkContext.emptyRDD[Row] + } +} + /** * :: DeveloperApi :: * Drops a table from the metastore and removes it if it is cached. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index bf5931bbf97ee..7c82964b5ecdc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -19,13 +19,54 @@ package org.apache.spark.sql.hive import scala.reflect.ClassTag + import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.catalyst.plans.logical.NativeCommand import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ class StatisticsSuite extends QueryTest { + test("parse analyze commands") { + def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { + val parsed = HiveQl.parseSql(analyzeCommand) + val operators = parsed.collect { + case a: AnalyzeTable => a + case o => o + } + + assert(operators.size === 1) + if (operators(0).getClass() != c) { + fail( + s"""$analyzeCommand expected command: $c, but got ${operators(0)} + |parsed command: + |$parsed + """.stripMargin) + } + } + + assertAnalyzeCommand( + "ANALYZE TABLE Table1 COMPUTE STATISTICS", + classOf[NativeCommand]) + assertAnalyzeCommand( + "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS", + classOf[NativeCommand]) + assertAnalyzeCommand( + "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan", + classOf[NativeCommand]) + assertAnalyzeCommand( + "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS", + classOf[NativeCommand]) + assertAnalyzeCommand( + "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS noscan", + classOf[NativeCommand]) + + assertAnalyzeCommand( + "ANALYZE TABLE Table1 COMPUTE STATISTICS nOscAn", + classOf[AnalyzeTable]) + } + test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = catalog.lookupRelation(None, tableName).statistics.sizeInBytes @@ -37,7 +78,7 @@ class StatisticsSuite extends QueryTest { assert(queryTotalSize("analyzeTable") === defaultSizeInBytes) - analyze("analyzeTable") + sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable") === BigInt(11624)) @@ -66,7 +107,7 @@ class StatisticsSuite extends QueryTest { assert(queryTotalSize("analyzeTable_part") === defaultSizeInBytes) - analyze("analyzeTable_part") + sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) From b7c89a7f0ca73153dce36e0f01b81a3947ee1189 Mon Sep 17 00:00:00 2001 From: chutium Date: Fri, 8 Aug 2014 13:31:08 -0700 Subject: [PATCH 26/83] [SPARK-2700] [SQL] Hidden files (such as .impala_insert_staging) should be filtered out by sqlContext.parquetFile Author: chutium Closes #1691 from chutium/SPARK-2700 and squashes the following commits: b76ae8c [chutium] [SPARK-2700] [SQL] fixed styling issue d75a8bd [chutium] [SPARK-2700] [SQL] Hidden files (such as .impala_insert_staging) should be filtered out by sqlContext.parquetFile --- .../scala/org/apache/spark/sql/parquet/ParquetTypes.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index aaef1a1d474fe..2867dc0a8b1f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -373,8 +373,9 @@ private[parquet] object ParquetTypesConverter extends Logging { } ParquetRelation.enableLogForwarding() - val children = fs.listStatus(path).filterNot { - _.getPath.getName == FileOutputCommitter.SUCCEEDED_FILE_NAME + val children = fs.listStatus(path).filterNot { status => + val name = status.getPath.getName + name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME } // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row From 74d6f62264babfc6045c21545552f0a2e6958155 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 8 Aug 2014 15:07:31 -0700 Subject: [PATCH 27/83] [SPARK-1997][MLLIB] update breeze to 0.9 0.9 dependences (this version doesn't depend on scalalogging and I excluded commons-math3 from its transitive dependencies): ~~~ +-org.scalanlp:breeze_2.10:0.9 [S] +-com.github.fommil.netlib:core:1.1.2 +-com.github.rwl:jtransforms:2.4.0 +-net.sf.opencsv:opencsv:2.3 +-net.sourceforge.f2j:arpack_combined_all:0.1 +-org.scalanlp:breeze-macros_2.10:0.3.1 [S] | +-org.scalamacros:quasiquotes_2.10:2.0.0 [S] | +-org.slf4j:slf4j-api:1.7.5 +-org.spire-math:spire_2.10:0.7.4 [S] +-org.scalamacros:quasiquotes_2.10:2.0.0 [S] | +-org.spire-math:spire-macros_2.10:0.7.4 [S] +-org.scalamacros:quasiquotes_2.10:2.0.0 [S] ~~~ Closes #1749 CC: witgo avati Author: Xiangrui Meng Closes #1857 from mengxr/breeze-0.9 and squashes the following commits: 7fc16b6 [Xiangrui Meng] don't know why but exclude a private method for mima dcc502e [Xiangrui Meng] update breeze to 0.9 --- mllib/pom.xml | 2 +- .../org/apache/spark/mllib/linalg/distributed/RowMatrix.scala | 4 ++-- .../spark/mllib/linalg/distributed/RowMatrixSuite.scala | 2 +- project/MimaExcludes.scala | 4 ++++ 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/pom.xml b/mllib/pom.xml index 9a33bd1cf6ad1..fc1ecfbea708f 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -57,7 +57,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.7 + 0.9 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 45486b2c7d82d..e76bc9fefff01 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -222,7 +222,7 @@ class RowMatrix( EigenValueDecomposition.symmetricEigs(v => G * v, n, k, tol, maxIter) case SVDMode.LocalLAPACK => val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] - val (uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) + val brzSvd.SVD(uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) (sigmaSquaresFull, uFull) case SVDMode.DistARPACK => require(k < n, s"k must be smaller than n in dist-eigs mode but got k=$k and n=$n.") @@ -338,7 +338,7 @@ class RowMatrix( val Cov = computeCovariance().toBreeze.asInstanceOf[BDM[Double]] - val (u: BDM[Double], _, _) = brzSvd(Cov) + val brzSvd.SVD(u: BDM[Double], _, _) = brzSvd(Cov) if (k == n) { Matrices.dense(n, k, u.data) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 325b817980f68..1d3a3221365cc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -99,7 +99,7 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { for (mat <- Seq(denseMat, sparseMat)) { for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) { val localMat = mat.toBreeze() - val (localU, localSigma, localVt) = brzSvd(localMat) + val brzSvd.SVD(localU, localSigma, localVt) = brzSvd(localMat) val localV: BDM[Double] = localVt.t.toDenseMatrix for (k <- 1 to n) { val skip = (mode == "local-eigs" || mode == "dist-eigs") && k == n diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 537ca0dcf267d..b4653c72c10b5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -110,6 +110,10 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") + ) ++ + Seq ( // package-private classes removed in MLlib + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") ) case v if v.startsWith("1.0") => Seq( From ec79063fad44751a6689f5e58d47886babeaecff Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Fri, 8 Aug 2014 16:57:26 -0700 Subject: [PATCH 28/83] [SPARK-2897][SPARK-2920]TorrentBroadcast does use the serializer class specified in the spark option "spark.serializer" Author: GuoQiang Li Closes #1836 from witgo/SPARK-2897 and squashes the following commits: 23cdc5b [GuoQiang Li] review commit ada4fba [GuoQiang Li] TorrentBroadcast does not support broadcast compression fb91792 [GuoQiang Li] org.apache.spark.broadcast.TorrentBroadcast does use the serializer class specified in the spark option "spark.serializer" --- .../spark/broadcast/TorrentBroadcast.scala | 31 +++++++++++++++---- .../spark/broadcast/BroadcastSuite.scala | 10 ++++-- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 86731b684f441..fe73456ef8fad 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,14 +17,15 @@ package org.apache.spark.broadcast -import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} +import java.io.{ByteArrayOutputStream, ByteArrayInputStream, InputStream, + ObjectInputStream, ObjectOutputStream, OutputStream} import scala.reflect.ClassTag import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} +import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.Utils /** * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like @@ -214,11 +215,15 @@ private[broadcast] object TorrentBroadcast extends Logging { private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 private var initialized = false private var conf: SparkConf = null + private var compress: Boolean = false + private var compressionCodec: CompressionCodec = null def initialize(_isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests synchronized { if (!initialized) { + compress = conf.getBoolean("spark.broadcast.compress", true) + compressionCodec = CompressionCodec.createCodec(conf) initialized = true } } @@ -228,8 +233,13 @@ private[broadcast] object TorrentBroadcast extends Logging { initialized = false } - def blockifyObject[T](obj: T): TorrentInfo = { - val byteArray = Utils.serialize[T](obj) + def blockifyObject[T: ClassTag](obj: T): TorrentInfo = { + val bos = new ByteArrayOutputStream() + val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos + val ser = SparkEnv.get.serializer.newInstance() + val serOut = ser.serializeStream(out) + serOut.writeObject[T](obj).close() + val byteArray = bos.toByteArray val bais = new ByteArrayInputStream(byteArray) var blockNum = byteArray.length / BLOCK_SIZE @@ -255,7 +265,7 @@ private[broadcast] object TorrentBroadcast extends Logging { info } - def unBlockifyObject[T]( + def unBlockifyObject[T: ClassTag]( arrayOfBlocks: Array[TorrentBlock], totalBytes: Int, totalBlocks: Int): T = { @@ -264,7 +274,16 @@ private[broadcast] object TorrentBroadcast extends Logging { System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length) } - Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) + + val in: InputStream = { + val arrIn = new ByteArrayInputStream(retByteArray) + if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn + } + val ser = SparkEnv.get.serializer.newInstance() + val serIn = ser.deserializeStream(in) + val obj = serIn.readObject[T]() + serIn.close() + obj } /** diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 7c3d0208b195a..17c64455b2429 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -44,7 +44,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { test("Accessing HttpBroadcast variables in a local cluster") { val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf) + val conf = httpConf.clone + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.broadcast.compress", "true") + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -69,7 +72,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { test("Accessing TorrentBroadcast variables in a local cluster") { val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf) + val conf = torrentConf.clone + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.broadcast.compress", "true") + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) From 1c84dba9881118204687c81003bded6d49e27255 Mon Sep 17 00:00:00 2001 From: WangTao Date: Fri, 8 Aug 2014 20:53:21 -0700 Subject: [PATCH 29/83] [Web UI]Make decision order of Worker's WebUI port consistent with Master's The decision order of Worker's WebUI port is "--webui-port", SPARK_WORKER_WEBUI_POR, 8081(default), spark.worker.ui.port. But in Master, the order is "--webui-port", spark.master.ui.port, SPARK_MASTER_WEBUI_PORT and 8080(default). So we change the order in Worker's to keep it consistent with Master. Author: WangTao Closes #1838 from WangTaoTheTonic/reOrder and squashes the following commits: 460f4d4 [WangTao] Make decision order of Worker's WebUI consistent with Master's --- .../scala/org/apache/spark/deploy/worker/Worker.scala | 5 +++-- .../org/apache/spark/deploy/worker/WorkerArguments.scala | 6 +++++- .../org/apache/spark/deploy/worker/ui/WorkerWebUI.scala | 9 ++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 458d9947bd873..bacb514ed6335 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -136,7 +136,7 @@ private[spark] class Worker( logInfo("Spark home: " + sparkHome) createWorkDir() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - webUi = new WorkerWebUI(this, workDir, Some(webUiPort)) + webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() registerWithMaster() @@ -373,7 +373,8 @@ private[spark] class Worker( private[spark] object Worker extends Logging { def main(argStrings: Array[String]) { SignalLogger.register(log) - val args = new WorkerArguments(argStrings) + val conf = new SparkConf + val args = new WorkerArguments(argStrings, conf) val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir) actorSystem.awaitTermination() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index dc5158102054e..1e295aaa48c30 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -20,11 +20,12 @@ package org.apache.spark.deploy.worker import java.lang.management.ManagementFactory import org.apache.spark.util.{IntParam, MemoryParam, Utils} +import org.apache.spark.SparkConf /** * Command-line parser for the worker. */ -private[spark] class WorkerArguments(args: Array[String]) { +private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { var host = Utils.localHostName() var port = 0 var webUiPort = 8081 @@ -46,6 +47,9 @@ private[spark] class WorkerArguments(args: Array[String]) { if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt } + if (conf.contains("spark.worker.ui.port")) { + webUiPort = conf.get("spark.worker.ui.port").toInt + } if (System.getenv("SPARK_WORKER_DIR") != null) { workDir = System.getenv("SPARK_WORKER_DIR") } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 47fbda600bea7..b07942a9ca729 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -34,8 +34,8 @@ private[spark] class WorkerWebUI( val worker: Worker, val workDir: File, - port: Option[Int] = None) - extends WebUI(worker.securityMgr, getUIPort(port, worker.conf), worker.conf, name = "WorkerUI") + requestedPort: Int) + extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { val timeout = AkkaUtils.askTimeout(worker.conf) @@ -55,10 +55,5 @@ class WorkerWebUI( } private[spark] object WorkerWebUI { - val DEFAULT_PORT = 8081 val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR - - def getUIPort(requestedPort: Option[Int], conf: SparkConf): Int = { - requestedPort.getOrElse(conf.getInt("spark.worker.ui.port", WorkerWebUI.DEFAULT_PORT)) - } } From 43af2817007eaa2cce2567bd83f5cde1ee28d1f7 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Fri, 8 Aug 2014 20:58:44 -0700 Subject: [PATCH 30/83] [SPARK-2911] apply parent[T](j) to clarify UnionRDD code References to dependencies(j) for actually obtaining RDD parents are less common than I originally estimated. It does clarify UnionRDD (also will clarify some of my other PRs) Use of firstParent[T] is ubiquitous, but not as sure that benefits from being replaced with parent(0)[T]. Author: Erik Erlandson Closes #1858 from erikerlandson/spark-2911-pr2 and squashes the following commits: 7ffea74 [Erik Erlandson] [SPARK-2911] apply parent[T](j) to clarify UnionRDD code --- core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 197167ecad0bd..0c97eb0aaa51f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -83,8 +83,7 @@ class UnionRDD[T: ClassTag]( override def compute(s: Partition, context: TaskContext): Iterator[T] = { val part = s.asInstanceOf[UnionPartition[T]] - val parentRdd = dependencies(part.parentRddIndex).rdd.asInstanceOf[RDD[T]] - parentRdd.iterator(part.parentPartition, context) + parent[T](part.parentRddIndex).iterator(part.parentPartition, context) } override def getPreferredLocations(s: Partition): Seq[String] = From 28dbae85aaf6842e22cd7465cb11cb34d58fc56d Mon Sep 17 00:00:00 2001 From: li-zhihui Date: Fri, 8 Aug 2014 22:52:56 -0700 Subject: [PATCH 31/83] [SPARK-2635] Fix race condition at SchedulerBackend.isReady in standalone mode In SPARK-1946(PR #900), configuration spark.scheduler.minRegisteredExecutorsRatio was introduced. However, in standalone mode, there is a race condition where isReady() can return true because totalExpectedExecutors has not been correctly set. Because expected executors is uncertain in standalone mode, the PR try to use CPU cores(--total-executor-cores) as expected resources to judge whether SchedulerBackend is ready. Author: li-zhihui Author: Li Zhihui Closes #1525 from li-zhihui/fixre4s and squashes the following commits: e9a630b [Li Zhihui] Rename variable totalExecutors and clean codes abf4860 [Li Zhihui] Push down variable totalExpectedResources to children classes ca54bd9 [li-zhihui] Format log with String interpolation 88c7dc6 [li-zhihui] Few codes and docs refactor 41cf47e [li-zhihui] Fix race condition at SchedulerBackend.isReady in standalone mode --- .../CoarseGrainedSchedulerBackend.scala | 30 +++++++++---------- .../cluster/SparkDeploySchedulerBackend.scala | 6 +++- docs/configuration.md | 13 ++++---- .../cluster/YarnClientSchedulerBackend.scala | 9 ++++-- .../cluster/YarnClusterSchedulerBackend.scala | 17 +++++++---- 5 files changed, 43 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 9f085eef46720..33500d967ebb1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -47,19 +47,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) - var totalExpectedExecutors = new AtomicInteger(0) + var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - // Submit tasks only after (registered executors / total expected executors) + // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. - var minRegisteredRatio = conf.getDouble("spark.scheduler.minRegisteredExecutorsRatio", 0) - if (minRegisteredRatio > 1) minRegisteredRatio = 1 - // Whatever minRegisteredExecutorsRatio is arrived, submit tasks after the time(milliseconds). + var minRegisteredRatio = + math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0)) + // Submit tasks after maxRegisteredWaitingTime milliseconds + // if minRegisteredRatio has not yet been reached val maxRegisteredWaitingTime = - conf.getInt("spark.scheduler.maxRegisteredExecutorsWaitingTime", 30000) + conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) val createTime = System.currentTimeMillis() - var ready = if (minRegisteredRatio <= 0) true else false class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { private val executorActor = new HashMap[String, ActorRef] @@ -94,12 +94,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A executorAddress(executorId) = sender.path.address addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) - if (executorActor.size >= totalExpectedExecutors.get() * minRegisteredRatio && !ready) { - ready = true - logInfo("SchedulerBackend is ready for scheduling beginning, registered executors: " + - executorActor.size + ", total expected executors: " + totalExpectedExecutors.get() + - ", minRegisteredExecutorsRatio: " + minRegisteredRatio) - } + totalRegisteredExecutors.addAndGet(1) makeOffers() } @@ -268,14 +263,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } } + def sufficientResourcesRegistered(): Boolean = true + override def isReady(): Boolean = { - if (ready) { + if (sufficientResourcesRegistered) { + logInfo("SchedulerBackend is ready for scheduling beginning after " + + s"reached minRegisteredResourcesRatio: $minRegisteredRatio") return true } if ((System.currentTimeMillis() - createTime) >= maxRegisteredWaitingTime) { - ready = true logInfo("SchedulerBackend is ready for scheduling beginning after waiting " + - "maxRegisteredExecutorsWaitingTime: " + maxRegisteredWaitingTime) + s"maxRegisteredResourcesWaitingTime: $maxRegisteredWaitingTime(ms)") return true } false diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index a28446f6c8a6b..589dba2e40d20 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -36,6 +36,7 @@ private[spark] class SparkDeploySchedulerBackend( var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ val maxCores = conf.getOption("spark.cores.max").map(_.toInt) + val totalExpectedCores = maxCores.getOrElse(0) override def start() { super.start() @@ -97,7 +98,6 @@ private[spark] class SparkDeploySchedulerBackend( override def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { - totalExpectedExecutors.addAndGet(1) logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( fullId, hostPort, cores, Utils.megabytesToString(memory))) } @@ -110,4 +110,8 @@ private[spark] class SparkDeploySchedulerBackend( logInfo("Executor %s removed: %s".format(fullId, message)) removeExecutor(fullId.split("/")(1), reason.toString) } + + override def sufficientResourcesRegistered(): Boolean = { + totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio + } } diff --git a/docs/configuration.md b/docs/configuration.md index 4d27c5a918fe0..617a72a021f6e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -825,21 +825,22 @@ Apart from these, the following properties are also available, and may be useful - spark.scheduler.minRegisteredExecutorsRatio + spark.scheduler.minRegisteredResourcesRatio 0 - The minimum ratio of registered executors (registered executors / total expected executors) + The minimum ratio of registered resources (registered resources / total expected resources) + (resources are executors in yarn mode, CPU cores in standalone mode) to wait for before scheduling begins. Specified as a double between 0 and 1. - Regardless of whether the minimum ratio of executors has been reached, + Regardless of whether the minimum ratio of resources has been reached, the maximum amount of time it will wait before scheduling begins is controlled by config - spark.scheduler.maxRegisteredExecutorsWaitingTime + spark.scheduler.maxRegisteredResourcesWaitingTime - spark.scheduler.maxRegisteredExecutorsWaitingTime + spark.scheduler.maxRegisteredResourcesWaitingTime 30000 - Maximum amount of time to wait for executors to register before scheduling begins + Maximum amount of time to wait for resources to register before scheduling begins (in milliseconds). diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index f8fb96b312f23..833e249f9f612 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -30,15 +30,15 @@ private[spark] class YarnClientSchedulerBackend( extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with Logging { - if (conf.getOption("spark.scheduler.minRegisteredExecutorsRatio").isEmpty) { + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 - ready = false } var client: Client = null var appId: ApplicationId = null var checkerThread: Thread = null var stopping: Boolean = false + var totalExpectedExecutors = 0 private[spark] def addArg(optionName: String, envVar: String, sysProp: String, arrayBuf: ArrayBuffer[String]) { @@ -84,7 +84,7 @@ private[spark] class YarnClientSchedulerBackend( logDebug("ClientArguments called with: " + argsArrayBuf) val args = new ClientArguments(argsArrayBuf.toArray, conf) - totalExpectedExecutors.set(args.numExecutors) + totalExpectedExecutors = args.numExecutors client = new Client(args, conf) appId = client.runApp() waitForApp() @@ -150,4 +150,7 @@ private[spark] class YarnClientSchedulerBackend( logInfo("Stopped") } + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio + } } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 0ad1794d19538..55665220a6f96 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -27,19 +27,24 @@ private[spark] class YarnClusterSchedulerBackend( sc: SparkContext) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { - if (conf.getOption("spark.scheduler.minRegisteredExecutorsRatio").isEmpty) { + var totalExpectedExecutors = 0 + + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 - ready = false } override def start() { super.start() - var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS + totalExpectedExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS if (System.getenv("SPARK_EXECUTOR_INSTANCES") != null) { - numExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")).getOrElse(numExecutors) + totalExpectedExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")) + .getOrElse(totalExpectedExecutors) } // System property can override environment variable. - numExecutors = sc.getConf.getInt("spark.executor.instances", numExecutors) - totalExpectedExecutors.set(numExecutors) + totalExpectedExecutors = sc.getConf.getInt("spark.executor.instances", totalExpectedExecutors) + } + + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } } From b431e6747f410aaf9624585920adc1f303159861 Mon Sep 17 00:00:00 2001 From: Chandan Kumar Date: Sat, 9 Aug 2014 00:45:54 -0700 Subject: [PATCH 32/83] [SPARK-2861] Fix Doc comment of histogram method Tested and ready to merge. Author: Chandan Kumar Closes #1786 from nrchandan/spark-2861 and squashes the following commits: cb0bc1e [Chandan Kumar] [SPARK-2861] Fix a typo in the histogram doc comment 6a2a71b [Chandan Kumar] SPARK-2861. Fix Doc comment of histogram method --- .../scala/org/apache/spark/rdd/DoubleRDDFunctions.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 9ca971c8a4c27..f233544d128f5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -119,11 +119,11 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** * Compute a histogram using the provided buckets. The buckets are all open - * to the left except for the last which is closed + * to the right except for the last which is closed * e.g. for the array * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] - * e.g 1<=x<10 , 10<=x<20, 20<=x<50 - * And on the input of 1 and 50 we would have a histogram of 1, 0, 0 + * e.g 1<=x<10 , 10<=x<20, 20<=x<=50 + * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 * * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets From e45daf226d780f4a7aaabc2de9f04367bee16f26 Mon Sep 17 00:00:00 2001 From: Chris Cope Date: Sat, 9 Aug 2014 20:58:56 -0700 Subject: [PATCH 33/83] [SPARK-1766] sorted functions to meet pedantic requirements Pedantry is underrated Author: Chris Cope Closes #1859 from copester/master and squashes the following commits: 0fb4499 [Chris Cope] [SPARK-1766] sorted functions to meet pedantic requirements --- .../apache/spark/rdd/PairRDDFunctions.scala | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 93af50c0a9cd1..5dd6472b0776c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -237,6 +237,25 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) combineByKey[V]((v: V) => v, func, func, partitioner) } + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. + */ + def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { + reduceByKey(new HashPartitioner(numPartitions), func) + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ + * parallelism level. + */ + def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { + reduceByKey(defaultPartitioner(self), func) + } + /** * Merge the values for each key using an associative reduce function, but return the results * immediately to the master as a Map. This will also perform the merging locally on each mapper @@ -374,15 +393,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) countApproxDistinctByKey(relativeSD, defaultPartitioner(self)) } - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. - */ - def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { - reduceByKey(new HashPartitioner(numPartitions), func) - } - /** * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. @@ -482,16 +492,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ - * parallelism level. - */ - def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { - reduceByKey(defaultPartitioner(self), func) - } - /** * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. From 4f4a9884d9268ba9808744b3d612ac23c75f105a Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 9 Aug 2014 21:10:43 -0700 Subject: [PATCH 34/83] [SPARK-2894] spark-shell doesn't accept flags As sryza reported, spark-shell doesn't accept any flags. The root cause is wrong usage of spark-submit in spark-shell and it come to the surface by #1801 Author: Kousuke Saruta Author: Cheng Lian Closes #1715, Closes #1864, and Closes #1861 Closes #1825 from sarutak/SPARK-2894 and squashes the following commits: 47f3510 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2894 2c899ed [Kousuke Saruta] Removed useless code from java_gateway.py 98287ed [Kousuke Saruta] Removed useless code from java_gateway.py 513ad2e [Kousuke Saruta] Modified util.sh to enable to use option including white spaces 28a374e [Kousuke Saruta] Modified java_gateway.py to recognize arguments 5afc584 [Cheng Lian] Filter out spark-submit options when starting Python gateway e630d19 [Cheng Lian] Fixing pyspark and spark-shell CLI options --- bin/pyspark | 18 ++++-- bin/spark-shell | 20 +++++-- bin/utils.sh | 59 +++++++++++++++++++ .../spark/deploy/SparkSubmitArguments.scala | 4 ++ dev/merge_spark_pr.py | 2 + python/pyspark/java_gateway.py | 2 +- 6 files changed, 94 insertions(+), 11 deletions(-) create mode 100644 bin/utils.sh diff --git a/bin/pyspark b/bin/pyspark index 39a20e2a24a3c..01d42025c978e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -23,12 +23,18 @@ FWDIR="$(cd `dirname $0`/..; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" +source $FWDIR/bin/utils.sh + SCALA_VERSION=2.10 -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then +function usage() { echo "Usage: ./bin/pyspark [options]" 1>&2 $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 exit 0 +} + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage fi # Exit if the user hasn't compiled Spark @@ -66,10 +72,11 @@ fi # Build up arguments list manually to preserve quotes and backslashes. # We export Spark submit arguments as an environment variable because shell.py must run as a # PYTHONSTARTUP script, which does not take in arguments. This is required for IPython notebooks. - +SUBMIT_USAGE_FUNCTION=usage +gatherSparkSubmitOpts "$@" PYSPARK_SUBMIT_ARGS="" whitespace="[[:space:]]" -for i in "$@"; do +for i in "${SUBMISSION_OPTS[@]}"; do if [[ $i =~ \" ]]; then i=$(echo $i | sed 's/\"/\\\"/g'); fi if [[ $i =~ $whitespace ]]; then i=\"$i\"; fi PYSPARK_SUBMIT_ARGS="$PYSPARK_SUBMIT_ARGS $i" @@ -90,7 +97,10 @@ fi if [[ "$1" =~ \.py$ ]]; then echo -e "\nWARNING: Running python applications through ./bin/pyspark is deprecated as of Spark 1.0." 1>&2 echo -e "Use ./bin/spark-submit \n" 1>&2 - exec $FWDIR/bin/spark-submit "$@" + primary=$1 + shift + gatherSparkSubmitOpts "$@" + exec $FWDIR/bin/spark-submit "${SUBMISSION_OPTS[@]}" $primary "${APPLICATION_OPTS[@]}" else # Only use ipython if no command line arguments were provided [SPARK-1134] if [[ "$IPYTHON" = "1" ]]; then diff --git a/bin/spark-shell b/bin/spark-shell index 756c8179d12b6..8b7ccd7439551 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -31,13 +31,21 @@ set -o posix ## Global script variables FWDIR="$(cd `dirname $0`/..; pwd)" +function usage() { + echo "Usage: ./bin/spark-shell [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +} + if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - echo "Usage: ./bin/spark-shell [options]" - $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit 0 + usage fi -function main(){ +source $FWDIR/bin/utils.sh +SUBMIT_USAGE_FUNCTION=usage +gatherSparkSubmitOpts "$@" + +function main() { if $cygwin; then # Workaround for issue involving JLine and Cygwin # (see http://sourceforge.net/p/jline/bugs/40/). @@ -46,11 +54,11 @@ function main(){ # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" fi } diff --git a/bin/utils.sh b/bin/utils.sh new file mode 100644 index 0000000000000..0804b1ed9f231 --- /dev/null +++ b/bin/utils.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Gather all all spark-submit options into SUBMISSION_OPTS +function gatherSparkSubmitOpts() { + + if [ -z "$SUBMIT_USAGE_FUNCTION" ]; then + echo "Function for printing usage of $0 is not set." 1>&2 + echo "Please set usage function to shell variable 'SUBMIT_USAGE_FUNCTION' in $0" 1>&2 + exit 1 + fi + + # NOTE: If you add or remove spark-sumbmit options, + # modify NOT ONLY this script but also SparkSubmitArgument.scala + SUBMISSION_OPTS=() + APPLICATION_OPTS=() + while (($#)); do + case "$1" in + --master | --deploy-mode | --class | --name | --jars | --py-files | --files | \ + --conf | --properties-file | --driver-memory | --driver-java-options | \ + --driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \ + --total-executor-cores | --executor-cores | --queue | --num-executors | --archives) + if [[ $# -lt 2 ]]; then + "$SUBMIT_USAGE_FUNCTION" + exit 1; + fi + SUBMISSION_OPTS+=("$1"); shift + SUBMISSION_OPTS+=("$1"); shift + ;; + + --verbose | -v | --supervise) + SUBMISSION_OPTS+=("$1"); shift + ;; + + *) + APPLICATION_OPTS+=("$1"); shift + ;; + esac + done + + export SUBMISSION_OPTS + export APPLICATION_OPTS +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c21f1529a1837..d545f58c5da7e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -224,6 +224,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { // Delineates parsing of Spark options from parsing of user options. parse(opts) + /** + * NOTE: If you add or remove spark-submit options, + * modify NOT ONLY this file but also utils.sh + */ def parse(opts: Seq[String]): Unit = opts match { case ("--name") :: value :: tail => name = value diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 53df9b5a3f1d5..d48c8bde12905 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -74,8 +74,10 @@ def fail(msg): def run_cmd(cmd): if isinstance(cmd, list): + print " ".join(cmd) return subprocess.check_output(cmd) else: + print cmd return subprocess.check_output(cmd.split(" ")) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 37386ab0d7d49..c7f7c1fe591b0 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -39,7 +39,7 @@ def launch_gateway(): submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS") submit_args = submit_args if submit_args is not None else "" submit_args = shlex.split(submit_args) - command = [os.path.join(SPARK_HOME, script), "pyspark-shell"] + submit_args + command = [os.path.join(SPARK_HOME, script)] + submit_args + ["pyspark-shell"] if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): From 5b6585de6b939837d5bdc4b1a44634301949add6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 9 Aug 2014 22:05:36 -0700 Subject: [PATCH 35/83] Updated Spark SQL README to include the hive-thriftserver module Author: Reynold Xin Closes #1867 from rxin/sql-readme and squashes the following commits: 42a5307 [Reynold Xin] Updated Spark SQL README to include the hive-thriftserver module --- sql/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/README.md b/sql/README.md index 14d5555f0c713..31f9152344086 100644 --- a/sql/README.md +++ b/sql/README.md @@ -3,10 +3,11 @@ Spark SQL This module provides support for executing relational queries expressed in either SQL or a LINQ-like Scala DSL. -Spark SQL is broken up into three subprojects: +Spark SQL is broken up into four subprojects: - Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions. - Execution (sql/core) - A query planner / execution engine for translating Catalyst’s logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs. + - HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server. Other dependencies for developers From 482c5afbf6f3f12ac23851300a33249b26ddff3c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 9 Aug 2014 23:06:54 -0700 Subject: [PATCH 36/83] Turn UpdateBlockInfo into case class. This helps us log UpdateBlockInfo properly once #1870 is merged. Author: Reynold Xin Closes #1872 from rxin/UpdateBlockInfo and squashes the following commits: 0cee1c2 [Reynold Xin] Turn UpdateBlockInfo into case class. --- .../spark/storage/BlockManagerMessages.scala | 20 +------------------ 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 10b65286fb7db..2ba16b8476600 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -53,7 +53,7 @@ private[spark] object BlockManagerMessages { sender: ActorRef) extends ToBlockManagerMaster - class UpdateBlockInfo( + case class UpdateBlockInfo( var blockManagerId: BlockManagerId, var blockId: BlockId, var storageLevel: StorageLevel, @@ -84,24 +84,6 @@ private[spark] object BlockManagerMessages { } } - object UpdateBlockInfo { - def apply( - blockManagerId: BlockManagerId, - blockId: BlockId, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - tachyonSize: Long): UpdateBlockInfo = { - new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize) - } - - // For pattern-matching - def unapply(h: UpdateBlockInfo) - : Option[(BlockManagerId, BlockId, StorageLevel, Long, Long, Long)] = { - Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize, h.tachyonSize)) - } - } - case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster From 3570119c34ab8d61507e7703a171b742fb0957d4 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sun, 10 Aug 2014 12:12:22 -0700 Subject: [PATCH 37/83] Remove extra semicolon in Task.scala Author: GuoQiang Li Closes #1876 from witgo/remove_semicolon_in_Task_scala and squashes the following commits: c6ea732 [GuoQiang Li] Remove extra semicolon in Task.scala --- core/src/main/scala/org/apache/spark/scheduler/Task.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5c5e421404a21..cbe0bc0bcb0a5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -46,7 +46,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(attemptId: Long): T = { context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) - context.taskMetrics.hostname = Utils.localHostName(); + context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) From 1d03a26a4895c24ebfab1a3cf6656af75cb53003 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sun, 10 Aug 2014 12:44:17 -0700 Subject: [PATCH 38/83] [SPARK-2950] Add gc time and shuffle write time to JobLogger The JobLogger is very useful for performing offline performance profiling of Spark jobs. GC Time and Shuffle Write time are available in TaskMetrics but are currently missed from the JobLogger output. This patch adds these two fields. ~~Since this is a small change, I didn't create a JIRA. Let me know if I should do that.~~ cc kayousterhout Author: Shivaram Venkataraman Closes #1869 from shivaram/job-logger and squashes the following commits: 1b709fc [Shivaram Venkataraman] Add a space before GC_TIME c418105 [Shivaram Venkataraman] Add gc time and shuffle write time to JobLogger --- .../scala/org/apache/spark/scheduler/JobLogger.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 47dd112f68325..4d6b5c81883b6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -162,6 +162,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime + " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime + val gcTime = " GC_TIME=" + taskMetrics.jvmGCTime val inputMetrics = taskMetrics.inputMetrics match { case Some(metrics) => " READ_METHOD=" + metrics.readMethod.toString + @@ -179,11 +180,13 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener case None => "" } val writeMetrics = taskMetrics.shuffleWriteMetrics match { - case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten + case Some(metrics) => + " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten + + " SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime case None => "" } - stageLogInfo(stageId, status + info + executorRunTime + inputMetrics + shuffleReadMetrics + - writeMetrics) + stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + + shuffleReadMetrics + writeMetrics) } /** From 28dcbb531ae57dc50f15ad9df6c31022731669c9 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 10 Aug 2014 13:00:38 -0700 Subject: [PATCH 39/83] [SPARK-2898] [PySpark] fix bugs in deamon.py 1. do not use signal handler for SIGCHILD, it's easy to cause deadlock 2. handle EINTR during accept() 3. pass errno into JVM 4. handle EAGAIN during fork() Now, it can pass 50k tasks tests in 180 seconds. Author: Davies Liu Closes #1842 from davies/qa and squashes the following commits: f0ea451 [Davies Liu] fix lint 03a2e8c [Davies Liu] cleanup dead children every seconds 32cb829 [Davies Liu] fix lint 0cd0817 [Davies Liu] fix bugs in deamon.py --- .../api/python/PythonWorkerFactory.scala | 2 +- python/pyspark/daemon.py | 78 +++++++++++-------- 2 files changed, 48 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 7af260d0b7f26..bf716a8ab025b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -68,7 +68,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val socket = new Socket(daemonHost, daemonPort) val pid = new DataInputStream(socket.getInputStream).readInt() if (pid < 0) { - throw new IllegalStateException("Python daemon failed to launch worker") + throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } daemonWorkers.put(socket, pid) socket diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index e73538baf0b93..22ab8d30c0ae3 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -22,7 +22,8 @@ import socket import sys import traceback -from errno import EINTR, ECHILD +import time +from errno import EINTR, ECHILD, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from pyspark.worker import main as worker_main @@ -80,6 +81,17 @@ def waitSocketClose(sock): os._exit(compute_real_exit_code(exit_code)) +# Cleanup zombie children +def cleanup_dead_children(): + try: + while True: + pid, _ = os.waitpid(0, os.WNOHANG) + if not pid: + break + except: + pass + + def manager(): # Create a new process group to corral our children os.setpgid(0, 0) @@ -102,29 +114,21 @@ def handle_sigterm(*args): signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP - # Cleanup zombie children - def handle_sigchld(*args): - try: - pid, status = os.waitpid(0, os.WNOHANG) - if status != 0: - msg = "worker %s crashed abruptly with exit status %s" % (pid, status) - print >> sys.stderr, msg - except EnvironmentError as err: - if err.errno not in (ECHILD, EINTR): - raise - signal.signal(SIGCHLD, handle_sigchld) - # Initialization complete sys.stdout.close() try: while True: try: - ready_fds = select.select([0, listen_sock], [], [])[0] + ready_fds = select.select([0, listen_sock], [], [], 1)[0] except select.error as ex: if ex[0] == EINTR: continue else: raise + + # cleanup in signal handler will cause deadlock + cleanup_dead_children() + if 0 in ready_fds: try: worker_pid = read_int(sys.stdin) @@ -137,29 +141,41 @@ def handle_sigchld(*args): pass # process already died if listen_sock in ready_fds: - sock, addr = listen_sock.accept() + try: + sock, _ = listen_sock.accept() + except OSError as e: + if e.errno == EINTR: + continue + raise + # Launch a worker process try: pid = os.fork() - if pid == 0: - listen_sock.close() - try: - worker(sock) - except: - traceback.print_exc() - os._exit(1) - else: - os._exit(0) + except OSError as e: + if e.errno in (EAGAIN, EINTR): + time.sleep(1) + pid = os.fork() # error here will shutdown daemon else: + outfile = sock.makefile('w') + write_int(e.errno, outfile) # Signal that the fork failed + outfile.flush() + outfile.close() sock.close() - - except OSError as e: - print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e - outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - write_int(-1, outfile) # Signal that the fork failed - outfile.flush() - outfile.close() + continue + + if pid == 0: + # in child process + listen_sock.close() + try: + worker(sock) + except: + traceback.print_exc() + os._exit(1) + else: + os._exit(0) + else: sock.close() + finally: shutdown(1) From b715aa0c8090cd57158ead2a1b35632cb98a6277 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Sun, 10 Aug 2014 16:31:07 -0700 Subject: [PATCH 40/83] [SPARK-2937] Separate out samplyByKeyExact as its own API in PairRDDFunction To enable Python consistency and `Experimental` label of the `sampleByKeyExact` API. Author: Doris Xin Author: Xiangrui Meng Closes #1866 from dorx/stratified and squashes the following commits: 0ad97b2 [Doris Xin] reviewer comments. 2948aae [Doris Xin] remove unrelated changes e990325 [Doris Xin] Merge branch 'master' into stratified 555a3f9 [Doris Xin] separate out sampleByKeyExact as its own API 616e55c [Doris Xin] merge master 245439e [Doris Xin] moved minSamplingRate to getUpperBound eaf5771 [Doris Xin] bug fixes. 17a381b [Doris Xin] fixed a merge issue and a failed unit ea7d27f [Doris Xin] merge master b223529 [Xiangrui Meng] use approx bounds for poisson fix poisson mean for waitlisting add unit tests for Java b3013a4 [Xiangrui Meng] move math3 back to test scope eecee5f [Doris Xin] Merge branch 'master' into stratified f4c21f3 [Doris Xin] Reviewer comments a10e68d [Doris Xin] style fix a2bf756 [Doris Xin] Merge branch 'master' into stratified 680b677 [Doris Xin] use mapPartitionWithIndex instead 9884a9f [Doris Xin] style fix bbfb8c9 [Doris Xin] Merge branch 'master' into stratified ee9d260 [Doris Xin] addressed reviewer comments 6b5b10b [Doris Xin] Merge branch 'master' into stratified 254e03c [Doris Xin] minor fixes and Java API. 4ad516b [Doris Xin] remove unused imports from PairRDDFunctions bd9dc6e [Doris Xin] unit bug and style violation fixed 1fe1cff [Doris Xin] Changed fractionByKey to a map to enable arg check 944a10c [Doris Xin] [SPARK-2145] Add lower bound on sampling rate 0214a76 [Doris Xin] cleanUp 90d94c0 [Doris Xin] merge master 9e74ab5 [Doris Xin] Separated out most of the logic in sampleByKey 7327611 [Doris Xin] merge master 50581fc [Doris Xin] added a TODO for logging in python 46f6c8c [Doris Xin] fixed the NPE caused by closures being cleaned before being passed into the aggregate function 7e1a481 [Doris Xin] changed the permission on SamplingUtil 1d413ce [Doris Xin] fixed checkstyle issues 9ee94ee [Doris Xin] [SPARK-2082] stratified sampling in PairRDDFunctions that guarantees exact sample size e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample 7cab53a [Doris Xin] fixed import bug in rdd.py ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD 1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS --- .../apache/spark/api/java/JavaPairRDD.scala | 68 +++--- .../apache/spark/rdd/PairRDDFunctions.scala | 51 +++-- .../java/org/apache/spark/JavaAPISuite.java | 20 +- .../spark/rdd/PairRDDFunctionsSuite.scala | 205 +++++++++++------- 4 files changed, 216 insertions(+), 128 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 76d4193e96aea..feeb6c02caa78 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -133,68 +133,62 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return a subset of this RDD sampled by key (via stratified sampling). * * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * If `exact` is set to false, create the sample via simple random sampling, with one pass - * over the RDD, to produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over - * the RDD to create a sample size that's exactly equal to the sum of + * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the + * RDD, to produce a sample of size that's approximately equal to the sum of * math.ceil(numItems * samplingRate) over all key values. */ def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double], - exact: Boolean, seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed)) + new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed)) /** * Return a subset of this RDD sampled by key (via stratified sampling). * * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * If `exact` is set to false, create the sample via simple random sampling, with one pass - * over the RDD, to produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over - * the RDD to create a sample size that's exactly equal to the sum of + * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the + * RDD, to produce a sample of size that's approximately equal to the sum of * math.ceil(numItems * samplingRate) over all key values. * - * Use Utils.random.nextLong as the default seed for the random number generator + * Use Utils.random.nextLong as the default seed for the random number generator. */ def sampleByKey(withReplacement: Boolean, - fractions: JMap[K, Double], - exact: Boolean): JavaPairRDD[K, V] = - sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong) + fractions: JMap[K, Double]): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, Utils.random.nextLong) /** - * Return a subset of this RDD sampled by key (via stratified sampling). - * - * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. + * ::Experimental:: + * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly + * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * - * Produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via - * simple random sampling. + * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) + * over all key values with a 99.99% confidence. When sampling without replacement, we need one + * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need + * two additional passes. */ - def sampleByKey(withReplacement: Boolean, + @Experimental + def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - sampleByKey(withReplacement, fractions, false, seed) + new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, seed)) /** - * Return a subset of this RDD sampled by key (via stratified sampling). + * ::Experimental:: + * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly + * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * - * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * Produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via - * simple random sampling. + * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) + * over all key values with a 99.99% confidence. When sampling without replacement, we need one + * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need + * two additional passes. * - * Use Utils.random.nextLong as the default seed for the random number generator + * Use Utils.random.nextLong as the default seed for the random number generator. */ - def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = - sampleByKey(withReplacement, fractions, false, Utils.random.nextLong) + @Experimental + def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = + sampleByKeyExact(withReplacement, fractions, Utils.random.nextLong) /** * Return the union of this RDD and another one. Any identical elements will appear multiple diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 5dd6472b0776c..f6d9d12fe9006 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -197,33 +197,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Return a subset of this RDD sampled by key (via stratified sampling). * * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * If `exact` is set to false, create the sample via simple random sampling, with one pass - * over the RDD, to produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values; otherwise, use - * additional passes over the RDD to create a sample size that's exactly equal to the sum of - * math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling - * without replacement, we need one additional pass over the RDD to guarantee sample size; - * when sampling with replacement, we need two additional passes. + * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the + * RDD, to produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values. * * @param withReplacement whether to sample with or without replacement * @param fractions map of specific keys to sampling rates * @param seed seed for the random number generator - * @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key * @return RDD containing the sampled subset */ def sampleByKey(withReplacement: Boolean, fractions: Map[K, Double], - exact: Boolean = false, - seed: Long = Utils.random.nextLong): RDD[(K, V)]= { + seed: Long = Utils.random.nextLong): RDD[(K, V)] = { + + require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") + + val samplingFunc = if (withReplacement) { + StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, false, seed) + } else { + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, false, seed) + } + self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) + } + + /** + * ::Experimental:: + * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly + * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). + * + * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) + * over all key values with a 99.99% confidence. When sampling without replacement, we need one + * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need + * two additional passes. + * + * @param withReplacement whether to sample with or without replacement + * @param fractions map of specific keys to sampling rates + * @param seed seed for the random number generator + * @return RDD containing the sampled subset + */ + @Experimental + def sampleByKeyExact(withReplacement: Boolean, + fractions: Map[K, Double], + seed: Long = Utils.random.nextLong): RDD[(K, V)] = { require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") val samplingFunc = if (withReplacement) { - StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed) + StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, true, seed) } else { - StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed) + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, true, seed) } self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 56150caa5d6ba..e1c13de04a0be 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1239,12 +1239,28 @@ public Tuple2 call(Integer i) { Assert.assertTrue(worCounts.size() == 2); Assert.assertTrue(worCounts.get(0) > 0); Assert.assertTrue(worCounts.get(1) > 0); - JavaPairRDD wrExact = rdd2.sampleByKey(true, fractions, true, 1L); + } + + @Test + @SuppressWarnings("unchecked") + public void sampleByKeyExact() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); + JavaPairRDD rdd2 = rdd1.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(Integer i) { + return new Tuple2(i % 2, 1); + } + }); + Map fractions = Maps.newHashMap(); + fractions.put(0, 0.5); + fractions.put(1, 1.0); + JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); Map wrExactCounts = (Map) (Object) wrExact.countByKey(); Assert.assertTrue(wrExactCounts.size() == 2); Assert.assertTrue(wrExactCounts.get(0) == 2); Assert.assertTrue(wrExactCounts.get(1) == 4); - JavaPairRDD worExact = rdd2.sampleByKey(false, fractions, true, 1L); + JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); Map worExactCounts = (Map) (Object) worExact.countByKey(); Assert.assertTrue(worExactCounts.size() == 2); Assert.assertTrue(worExactCounts.get(0) == 2); diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 4f49d4a1d4d34..63d3ddb4af98a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -84,118 +84,81 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("sampleByKey") { - def stratifier (fractionPositive: Double) = { - (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" - } - def checkSize(exact: Boolean, - withReplacement: Boolean, - expected: Long, - actual: Long, - p: Double): Boolean = { - if (exact) { - return expected == actual - } - val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) - // Very forgiving margin since we're dealing with very small sample sizes most of the time - math.abs(actual - expected) <= 6 * stdev + val defaultSeed = 1L + + // vary RDD size + for (n <- List(100, 1000, 1000000)) { + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val samplingRate = 0.1 + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n) } - // Without replacement validation - def takeSampleAndValidateBernoulli(stratifiedData: RDD[(String, Int)], - exact: Boolean, - samplingRate: Double, - seed: Long, - n: Long) = { - val expectedSampleSize = stratifiedData.countByKey() - .mapValues(count => math.ceil(count * samplingRate).toInt) - val fractions = Map("1" -> samplingRate, "0" -> samplingRate) - val sample = stratifiedData.sampleByKey(false, fractions, exact, seed) - val sampleCounts = sample.countByKey() - val takeSample = sample.collect() - sampleCounts.foreach { case(k, v) => - assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } - assert(takeSample.size === takeSample.toSet.size) - takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + // vary fractionPositive + for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val samplingRate = 0.1 + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n) } - // With replacement validation - def takeSampleAndValidatePoisson(stratifiedData: RDD[(String, Int)], - exact: Boolean, - samplingRate: Double, - seed: Long, - n: Long) = { - val expectedSampleSize = stratifiedData.countByKey().mapValues(count => - math.ceil(count * samplingRate).toInt) - val fractions = Map("1" -> samplingRate, "0" -> samplingRate) - val sample = stratifiedData.sampleByKey(true, fractions, exact, seed) - val sampleCounts = sample.countByKey() - val takeSample = sample.collect() - sampleCounts.foreach { case(k, v) => - assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) } - val groupedByKey = takeSample.groupBy(_._1) - for ((key, v) <- groupedByKey) { - if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) { - // sample large enough for there to be repeats with high likelihood - assert(v.toSet.size < expectedSampleSize(key)) - } else { - if (exact) { - assert(v.toSet.size <= expectedSampleSize(key)) - } else { - assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) - } - } - } - takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + // Use the same data for the rest of the tests + val fractionPositive = 0.3 + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + + // vary seed + for (seed <- defaultSeed to defaultSeed + 5L) { + val samplingRate = 0.1 + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, seed, n) } - def checkAllCombos(stratifiedData: RDD[(String, Int)], - samplingRate: Double, - seed: Long, - n: Long) = { - takeSampleAndValidateBernoulli(stratifiedData, true, samplingRate, seed, n) - takeSampleAndValidateBernoulli(stratifiedData, false, samplingRate, seed, n) - takeSampleAndValidatePoisson(stratifiedData, true, samplingRate, seed, n) - takeSampleAndValidatePoisson(stratifiedData, false, samplingRate, seed, n) + // vary sampling rate + for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) { + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n) } + } + test("sampleByKeyExact") { val defaultSeed = 1L // vary RDD size for (n <- List(100, 1000, 1000000)) { val data = sc.parallelize(1 to n, 2) val fractionPositive = 0.3 - val stratifiedData = data.keyBy(stratifier(fractionPositive)) - + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) val samplingRate = 0.1 - checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n) } // vary fractionPositive for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { val n = 100 val data = sc.parallelize(1 to n, 2) - val stratifiedData = data.keyBy(stratifier(fractionPositive)) - + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) val samplingRate = 0.1 - checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n) } // Use the same data for the rest of the tests val fractionPositive = 0.3 val n = 100 val data = sc.parallelize(1 to n, 2) - val stratifiedData = data.keyBy(stratifier(fractionPositive)) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) // vary seed for (seed <- defaultSeed to defaultSeed + 5L) { val samplingRate = 0.1 - checkAllCombos(stratifiedData, samplingRate, seed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, seed, n) } // vary sampling rate for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) { - checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n) } } @@ -556,6 +519,98 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { intercept[IllegalArgumentException] {shuffled.lookup(-1)} } + private object StratifiedAuxiliary { + def stratifier (fractionPositive: Double) = { + (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" + } + + def checkSize(exact: Boolean, + withReplacement: Boolean, + expected: Long, + actual: Long, + p: Double): Boolean = { + if (exact) { + return expected == actual + } + val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) + // Very forgiving margin since we're dealing with very small sample sizes most of the time + math.abs(actual - expected) <= 6 * stdev + } + + def testSampleExact(stratifiedData: RDD[(String, Int)], + samplingRate: Double, + seed: Long, + n: Long) = { + testBernoulli(stratifiedData, true, samplingRate, seed, n) + testPoisson(stratifiedData, true, samplingRate, seed, n) + } + + def testSample(stratifiedData: RDD[(String, Int)], + samplingRate: Double, + seed: Long, + n: Long) = { + testBernoulli(stratifiedData, false, samplingRate, seed, n) + testPoisson(stratifiedData, false, samplingRate, seed, n) + } + + // Without replacement validation + def testBernoulli(stratifiedData: RDD[(String, Int)], + exact: Boolean, + samplingRate: Double, + seed: Long, + n: Long) = { + val expectedSampleSize = stratifiedData.countByKey() + .mapValues(count => math.ceil(count * samplingRate).toInt) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = if (exact) { + stratifiedData.sampleByKeyExact(false, fractions, seed) + } else { + stratifiedData.sampleByKey(false, fractions, seed) + } + val sampleCounts = sample.countByKey() + val takeSample = sample.collect() + sampleCounts.foreach { case(k, v) => + assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } + assert(takeSample.size === takeSample.toSet.size) + takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + } + + // With replacement validation + def testPoisson(stratifiedData: RDD[(String, Int)], + exact: Boolean, + samplingRate: Double, + seed: Long, + n: Long) = { + val expectedSampleSize = stratifiedData.countByKey().mapValues(count => + math.ceil(count * samplingRate).toInt) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = if (exact) { + stratifiedData.sampleByKeyExact(true, fractions, seed) + } else { + stratifiedData.sampleByKey(true, fractions, seed) + } + val sampleCounts = sample.countByKey() + val takeSample = sample.collect() + sampleCounts.foreach { case (k, v) => + assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) + } + val groupedByKey = takeSample.groupBy(_._1) + for ((key, v) <- groupedByKey) { + if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) { + // sample large enough for there to be repeats with high likelihood + assert(v.toSet.size < expectedSampleSize(key)) + } else { + if (exact) { + assert(v.toSet.size <= expectedSampleSize(key)) + } else { + assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) + } + } + } + takeSample.foreach(x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]")) + } + } + } /* From ba28a8fcbc3ba432e7ea4d6f0b535450a6ec96c6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 10 Aug 2014 20:36:54 -0700 Subject: [PATCH 41/83] [SPARK-2936] Migrate Netty network module from Java to Scala The Netty network module was originally written when Scala 2.9.x had a bug that prevents a pure Scala implementation, and a subset of the files were done in Java. We have since upgraded to Scala 2.10, and can migrate all Java files now to Scala. https://github.com/netty/netty/issues/781 https://github.com/mesos/spark/pull/522 Author: Reynold Xin Closes #1865 from rxin/netty and squashes the following commits: 332422f [Reynold Xin] Code review feedback ca9eeee [Reynold Xin] Minor update. 7f1434b [Reynold Xin] [SPARK-2936] Migrate Netty network module from Java to Scala --- .../spark/network/netty/FileClient.java | 100 ---------------- .../spark/network/netty/FileServer.java | 111 ------------------ .../network/netty/FileServerHandler.java | 83 ------------- .../spark/network/netty/FileClient.scala | 85 ++++++++++++++ .../netty/FileClientChannelInitializer.scala} | 24 ++-- .../network/netty/FileClientHandler.scala} | 47 ++++---- .../spark/network/netty/FileHeader.scala | 5 +- .../spark/network/netty/FileServer.scala | 91 ++++++++++++++ .../netty/FileServerChannelInitializer.scala} | 31 ++--- .../network/netty/FileServerHandler.scala | 68 +++++++++++ .../spark/network/netty/PathResolver.scala} | 9 +- .../spark/network/netty/ShuffleSender.scala | 2 +- 12 files changed, 292 insertions(+), 364 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/network/netty/FileClient.java delete mode 100644 core/src/main/java/org/apache/spark/network/netty/FileServer.java delete mode 100644 core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java create mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileClient.scala rename core/src/main/{java/org/apache/spark/network/netty/FileClientChannelInitializer.java => scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala} (57%) rename core/src/main/{java/org/apache/spark/network/netty/FileClientHandler.java => scala/org/apache/spark/network/netty/FileClientHandler.scala} (51%) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileServer.scala rename core/src/main/{java/org/apache/spark/network/netty/FileServerChannelInitializer.java => scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala} (54%) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala rename core/src/main/{java/org/apache/spark/network/netty/PathResolver.java => scala/org/apache/spark/network/netty/PathResolver.scala} (80%) mode change 100755 => 100644 diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClient.java b/core/src/main/java/org/apache/spark/network/netty/FileClient.java deleted file mode 100644 index 0d31894d6ec7a..0000000000000 --- a/core/src/main/java/org/apache/spark/network/netty/FileClient.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty; - -import java.util.concurrent.TimeUnit; - -import io.netty.bootstrap.Bootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelOption; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.oio.OioEventLoopGroup; -import io.netty.channel.socket.oio.OioSocketChannel; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -class FileClient { - - private static final Logger LOG = LoggerFactory.getLogger(FileClient.class.getName()); - - private final FileClientHandler handler; - private Channel channel = null; - private Bootstrap bootstrap = null; - private EventLoopGroup group = null; - private final int connectTimeout; - private final int sendTimeout = 60; // 1 min - - FileClient(FileClientHandler handler, int connectTimeout) { - this.handler = handler; - this.connectTimeout = connectTimeout; - } - - public void init() { - group = new OioEventLoopGroup(); - bootstrap = new Bootstrap(); - bootstrap.group(group) - .channel(OioSocketChannel.class) - .option(ChannelOption.SO_KEEPALIVE, true) - .option(ChannelOption.TCP_NODELAY, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) - .handler(new FileClientChannelInitializer(handler)); - } - - public void connect(String host, int port) { - try { - // Start the connection attempt. - channel = bootstrap.connect(host, port).sync().channel(); - // ChannelFuture cf = channel.closeFuture(); - //cf.addListener(new ChannelCloseListener(this)); - } catch (InterruptedException e) { - LOG.warn("FileClient interrupted while trying to connect", e); - close(); - } - } - - public void waitForClose() { - try { - channel.closeFuture().sync(); - } catch (InterruptedException e) { - LOG.warn("FileClient interrupted", e); - } - } - - public void sendRequest(String file) { - //assert(file == null); - //assert(channel == null); - try { - // Should be able to send the message to network link channel. - boolean bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS); - if (!bSent) { - throw new RuntimeException("Failed to send"); - } - } catch (InterruptedException e) { - LOG.error("Error", e); - } - } - - public void close() { - if (group != null) { - group.shutdownGracefully(); - group = null; - bootstrap = null; - } - } -} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServer.java b/core/src/main/java/org/apache/spark/network/netty/FileServer.java deleted file mode 100644 index c93425e2787dc..0000000000000 --- a/core/src/main/java/org/apache/spark/network/netty/FileServer.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty; - -import java.net.InetSocketAddress; - -import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelOption; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.oio.OioEventLoopGroup; -import io.netty.channel.socket.oio.OioServerSocketChannel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Server that accept the path of a file an echo back its content. - */ -class FileServer { - - private static final Logger LOG = LoggerFactory.getLogger(FileServer.class.getName()); - - private EventLoopGroup bossGroup = null; - private EventLoopGroup workerGroup = null; - private ChannelFuture channelFuture = null; - private int port = 0; - - FileServer(PathResolver pResolver, int port) { - InetSocketAddress addr = new InetSocketAddress(port); - - // Configure the server. - bossGroup = new OioEventLoopGroup(); - workerGroup = new OioEventLoopGroup(); - - ServerBootstrap bootstrap = new ServerBootstrap(); - bootstrap.group(bossGroup, workerGroup) - .channel(OioServerSocketChannel.class) - .option(ChannelOption.SO_BACKLOG, 100) - .option(ChannelOption.SO_RCVBUF, 1500) - .childHandler(new FileServerChannelInitializer(pResolver)); - // Start the server. - channelFuture = bootstrap.bind(addr); - try { - // Get the address we bound to. - InetSocketAddress boundAddress = - ((InetSocketAddress) channelFuture.sync().channel().localAddress()); - this.port = boundAddress.getPort(); - } catch (InterruptedException ie) { - this.port = 0; - } - } - - /** - * Start the file server asynchronously in a new thread. - */ - public void start() { - Thread blockingThread = new Thread() { - @Override - public void run() { - try { - channelFuture.channel().closeFuture().sync(); - LOG.info("FileServer exiting"); - } catch (InterruptedException e) { - LOG.error("File server start got interrupted", e); - } - // NOTE: bootstrap is shutdown in stop() - } - }; - blockingThread.setDaemon(true); - blockingThread.start(); - } - - public int getPort() { - return port; - } - - public void stop() { - // Close the bound channel. - if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly(); - channelFuture = null; - } - - // Shutdown event groups - if (bossGroup != null) { - bossGroup.shutdownGracefully(); - bossGroup = null; - } - - if (workerGroup != null) { - workerGroup.shutdownGracefully(); - workerGroup = null; - } - // TODO: Shutdown all accepted channels as well ? - } -} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java deleted file mode 100644 index c0133e19c7f79..0000000000000 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty; - -import java.io.File; -import java.io.FileInputStream; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.channel.DefaultFileRegion; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.FileSegment; - -class FileServerHandler extends SimpleChannelInboundHandler { - - private static final Logger LOG = LoggerFactory.getLogger(FileServerHandler.class.getName()); - - private final PathResolver pResolver; - - FileServerHandler(PathResolver pResolver){ - this.pResolver = pResolver; - } - - @Override - public void channelRead0(ChannelHandlerContext ctx, String blockIdString) { - BlockId blockId = BlockId.apply(blockIdString); - FileSegment fileSegment = pResolver.getBlockLocation(blockId); - // if getBlockLocation returns null, close the channel - if (fileSegment == null) { - //ctx.close(); - return; - } - File file = fileSegment.file(); - if (file.exists()) { - if (!file.isFile()) { - ctx.write(new FileHeader(0, blockId).buffer()); - ctx.flush(); - return; - } - long length = fileSegment.length(); - if (length > Integer.MAX_VALUE || length <= 0) { - ctx.write(new FileHeader(0, blockId).buffer()); - ctx.flush(); - return; - } - int len = (int) length; - ctx.write((new FileHeader(len, blockId)).buffer()); - try { - ctx.write(new DefaultFileRegion(new FileInputStream(file) - .getChannel(), fileSegment.offset(), fileSegment.length())); - } catch (Exception e) { - LOG.error("Exception: ", e); - } - } else { - ctx.write(new FileHeader(0, blockId).buffer()); - } - ctx.flush(); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - LOG.error("Exception: ", cause); - ctx.close(); - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala b/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala new file mode 100644 index 0000000000000..c6d35f73db545 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.util.concurrent.TimeUnit + +import io.netty.bootstrap.Bootstrap +import io.netty.channel.{Channel, ChannelOption, EventLoopGroup} +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.oio.OioSocketChannel + +import org.apache.spark.Logging + +class FileClient(handler: FileClientHandler, connectTimeout: Int) extends Logging { + + private var channel: Channel = _ + private var bootstrap: Bootstrap = _ + private var group: EventLoopGroup = _ + private val sendTimeout = 60 + + def init(): Unit = { + group = new OioEventLoopGroup + bootstrap = new Bootstrap + bootstrap.group(group) + .channel(classOf[OioSocketChannel]) + .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) + .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Integer.valueOf(connectTimeout)) + .handler(new FileClientChannelInitializer(handler)) + } + + def connect(host: String, port: Int) { + try { + channel = bootstrap.connect(host, port).sync().channel() + } catch { + case e: InterruptedException => + logWarning("FileClient interrupted while trying to connect", e) + close() + } + } + + def waitForClose(): Unit = { + try { + channel.closeFuture.sync() + } catch { + case e: InterruptedException => + logWarning("FileClient interrupted", e) + } + } + + def sendRequest(file: String): Unit = { + try { + val bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS) + if (!bSent) { + throw new RuntimeException("Failed to send") + } + } catch { + case e: InterruptedException => + logError("Error", e) + } + } + + def close(): Unit = { + if (group != null) { + group.shutdownGracefully() + group = null + bootstrap = null + } + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala similarity index 57% rename from core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java rename to core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala index 264cf97d0209f..f4261c13f70a8 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java +++ b/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala @@ -15,25 +15,17 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.network.netty -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.string.StringEncoder; +import io.netty.channel.ChannelInitializer +import io.netty.channel.socket.SocketChannel +import io.netty.handler.codec.string.StringEncoder -class FileClientChannelInitializer extends ChannelInitializer { - private final FileClientHandler fhandler; +class FileClientChannelInitializer(handler: FileClientHandler) + extends ChannelInitializer[SocketChannel] { - FileClientChannelInitializer(FileClientHandler handler) { - fhandler = handler; - } - - @Override - public void initChannel(SocketChannel channel) { - // file no more than 2G - channel.pipeline() - .addLast("encoder", new StringEncoder()) - .addLast("handler", fhandler); + def initChannel(channel: SocketChannel) { + channel.pipeline.addLast("encoder", new StringEncoder).addLast("handler", handler) } } diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java b/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala similarity index 51% rename from core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java rename to core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala index 63d3d927255f9..017302ec7d33d 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java +++ b/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala @@ -15,41 +15,36 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.network.netty -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.buffer.ByteBuf +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} -import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockId -abstract class FileClientHandler extends SimpleChannelInboundHandler { - private FileHeader currentHeader = null; +abstract class FileClientHandler extends SimpleChannelInboundHandler[ByteBuf] { - private volatile boolean handlerCalled = false; + private var currentHeader: FileHeader = null - public boolean isComplete() { - return handlerCalled; - } + @volatile + private var handlerCalled: Boolean = false + + def isComplete: Boolean = handlerCalled + + def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) - public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); - public abstract void handleError(BlockId blockId); + def handleError(blockId: BlockId) - @Override - public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) { - // get header - if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { - currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); + override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { + if (currentHeader == null && in.readableBytes >= FileHeader.HEADER_SIZE) { + currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE)) } - // get file - if(in.readableBytes() >= currentHeader.fileLen()) { - handle(ctx, in, currentHeader); - handlerCalled = true; - currentHeader = null; - ctx.close(); + if (in.readableBytes >= currentHeader.fileLen) { + handle(ctx, in, currentHeader) + handlerCalled = true + currentHeader = null + ctx.close() } } - } - diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala index 136c1912045aa..607e560ff277f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala @@ -26,7 +26,7 @@ private[spark] class FileHeader ( val fileLen: Int, val blockId: BlockId) extends Logging { - lazy val buffer = { + lazy val buffer: ByteBuf = { val buf = Unpooled.buffer() buf.capacity(FileHeader.HEADER_SIZE) buf.writeInt(fileLen) @@ -62,11 +62,10 @@ private[spark] object FileHeader { new FileHeader(length, blockId) } - def main (args:Array[String]) { + def main(args:Array[String]) { val header = new FileHeader(25, TestBlockId("my_block")) val buf = header.buffer val newHeader = FileHeader.create(buf) System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen) } } - diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala new file mode 100644 index 0000000000000..dff77950659af --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.net.InetSocketAddress + +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.{ChannelFuture, ChannelOption, EventLoopGroup} +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.oio.OioServerSocketChannel + +import org.apache.spark.Logging + +/** + * Server that accept the path of a file an echo back its content. + */ +class FileServer(pResolver: PathResolver, private var port: Int) extends Logging { + + private val addr: InetSocketAddress = new InetSocketAddress(port) + private var bossGroup: EventLoopGroup = new OioEventLoopGroup + private var workerGroup: EventLoopGroup = new OioEventLoopGroup + + private var channelFuture: ChannelFuture = { + val bootstrap = new ServerBootstrap + bootstrap.group(bossGroup, workerGroup) + .channel(classOf[OioServerSocketChannel]) + .option(ChannelOption.SO_BACKLOG, java.lang.Integer.valueOf(100)) + .option(ChannelOption.SO_RCVBUF, java.lang.Integer.valueOf(1500)) + .childHandler(new FileServerChannelInitializer(pResolver)) + bootstrap.bind(addr) + } + + try { + val boundAddress = channelFuture.sync.channel.localAddress.asInstanceOf[InetSocketAddress] + port = boundAddress.getPort + } catch { + case ie: InterruptedException => + port = 0 + } + + /** Start the file server asynchronously in a new thread. */ + def start(): Unit = { + val blockingThread: Thread = new Thread { + override def run(): Unit = { + try { + channelFuture.channel.closeFuture.sync + logInfo("FileServer exiting") + } catch { + case e: InterruptedException => + logError("File server start got interrupted", e) + } + // NOTE: bootstrap is shutdown in stop() + } + } + blockingThread.setDaemon(true) + blockingThread.start() + } + + def getPort: Int = port + + def stop(): Unit = { + if (channelFuture != null) { + channelFuture.channel().close().awaitUninterruptibly() + channelFuture = null + } + if (bossGroup != null) { + bossGroup.shutdownGracefully() + bossGroup = null + } + if (workerGroup != null) { + workerGroup.shutdownGracefully() + workerGroup = null + } + } +} + diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala similarity index 54% rename from core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java rename to core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala index 46efec8f8d963..aaa2f913d0269 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java +++ b/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala @@ -15,27 +15,20 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.network.netty -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.DelimiterBasedFrameDecoder; -import io.netty.handler.codec.Delimiters; -import io.netty.handler.codec.string.StringDecoder; +import io.netty.channel.ChannelInitializer +import io.netty.channel.socket.SocketChannel +import io.netty.handler.codec.{DelimiterBasedFrameDecoder, Delimiters} +import io.netty.handler.codec.string.StringDecoder -class FileServerChannelInitializer extends ChannelInitializer { +class FileServerChannelInitializer(pResolver: PathResolver) + extends ChannelInitializer[SocketChannel] { - private final PathResolver pResolver; - - FileServerChannelInitializer(PathResolver pResolver) { - this.pResolver = pResolver; - } - - @Override - public void initChannel(SocketChannel channel) { - channel.pipeline() - .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter())) - .addLast("stringDecoder", new StringDecoder()) - .addLast("handler", new FileServerHandler(pResolver)); + override def initChannel(channel: SocketChannel): Unit = { + channel.pipeline + .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter : _*)) + .addLast("stringDecoder", new StringDecoder) + .addLast("handler", new FileServerHandler(pResolver)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala new file mode 100644 index 0000000000000..96f60b2883ad9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.FileInputStream + +import io.netty.channel.{DefaultFileRegion, ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.storage.{BlockId, FileSegment} + + +class FileServerHandler(pResolver: PathResolver) + extends SimpleChannelInboundHandler[String] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, blockIdString: String): Unit = { + val blockId: BlockId = BlockId(blockIdString) + val fileSegment: FileSegment = pResolver.getBlockLocation(blockId) + if (fileSegment == null) { + return + } + val file = fileSegment.file + if (file.exists) { + if (!file.isFile) { + ctx.write(new FileHeader(0, blockId).buffer) + ctx.flush() + return + } + val length: Long = fileSegment.length + if (length > Integer.MAX_VALUE || length <= 0) { + ctx.write(new FileHeader(0, blockId).buffer) + ctx.flush() + return + } + ctx.write(new FileHeader(length.toInt, blockId).buffer) + try { + val channel = new FileInputStream(file).getChannel + ctx.write(new DefaultFileRegion(channel, fileSegment.offset, fileSegment.length)) + } catch { + case e: Exception => + logError("Exception: ", e) + } + } else { + ctx.write(new FileHeader(0, blockId).buffer) + } + ctx.flush() + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + logError("Exception: ", cause) + ctx.close() + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala old mode 100755 new mode 100644 similarity index 80% rename from core/src/main/java/org/apache/spark/network/netty/PathResolver.java rename to core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala index 7ad8d03efbadc..0d7695072a7b1 --- a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java +++ b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.network.netty -import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.FileSegment; +import org.apache.spark.storage.{BlockId, FileSegment} -public interface PathResolver { +trait PathResolver { /** Get the file segment in which the given block resides. */ - FileSegment getBlockLocation(BlockId blockId); + def getBlockLocation(blockId: BlockId): FileSegment } diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala index 7ef7aecc6a9fb..95958e30f7eeb 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala @@ -32,7 +32,7 @@ private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) ext server.stop() } - def port: Int = server.getPort() + def port: Int = server.getPort } From db06a81fb7a413faa3fe0f8c35918f70454cb05d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 Aug 2014 11:54:09 -0700 Subject: [PATCH 42/83] [PySpark] [SPARK-2954] [SPARK-2948] [SPARK-2910] [SPARK-2101] Python 2.6 Fixes - Modify python/run-tests to test with Python 2.6 - Use unittest2 when running on Python 2.6. - Fix issue with namedtuple. - Skip TestOutputFormat.test_newhadoop on Python 2.6 until SPARK-2951 is fixed. - Fix MLlib _deserialize_double on Python 2.6. Closes #1868. Closes #1042. Author: Josh Rosen Closes #1874 from JoshRosen/python2.6 and squashes the following commits: 983d259 [Josh Rosen] [SPARK-2954] Fix MLlib _deserialize_double on Python 2.6. 5d18fd7 [Josh Rosen] [SPARK-2948] [SPARK-2910] [SPARK-2101] Python 2.6 fixes --- python/pyspark/mllib/_common.py | 11 ++++++++++- python/pyspark/mllib/tests.py | 7 ++++++- python/pyspark/serializers.py | 4 ++-- python/pyspark/tests.py | 13 ++++++++++--- python/run-tests | 8 ++++++++ 5 files changed, 36 insertions(+), 7 deletions(-) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index db341da85f865..bb60d3d0c8463 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -16,6 +16,7 @@ # import struct +import sys import numpy from numpy import ndarray, float64, int64, int32, array_equal, array from pyspark import SparkContext, RDD @@ -78,6 +79,14 @@ LABELED_POINT_MAGIC = 4 +# Workaround for SPARK-2954: before Python 2.7, struct.unpack couldn't unpack bytearray()s. +if sys.version_info[:2] <= (2, 6): + def _unpack(fmt, string): + return struct.unpack(fmt, buffer(string)) +else: + _unpack = struct.unpack + + def _deserialize_numpy_array(shape, ba, offset, dtype=float64): """ Deserialize a numpy array of the given type from an offset in @@ -191,7 +200,7 @@ def _deserialize_double(ba, offset=0): raise TypeError("_deserialize_double called on a %s; wanted bytearray" % type(ba)) if len(ba) - offset != 8: raise TypeError("_deserialize_double called on a %d-byte array; wanted 8 bytes." % nb) - return struct.unpack("d", ba[offset:])[0] + return _unpack("d", ba[offset:])[0] def _deserialize_double_vector(ba, offset=0): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 6f3ec8ac94bac..8a851bd35c0e8 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -19,8 +19,13 @@ Fuller unit tests for Python MLlib. """ +import sys from numpy import array, array_equal -import unittest + +if sys.version_info[:2] <= (2, 6): + import unittest2 as unittest +else: + import unittest from pyspark.mllib._common import _convert_vector, _serialize_double_vector, \ _deserialize_double_vector, _dot, _squared_distance diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index b35558db3e007..df90cafb245bf 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -314,8 +314,8 @@ def _copy_func(f): _old_namedtuple = _copy_func(collections.namedtuple) - def namedtuple(name, fields, verbose=False, rename=False): - cls = _old_namedtuple(name, fields, verbose, rename) + def namedtuple(*args, **kwargs): + cls = _old_namedtuple(*args, **kwargs) return _hack_namedtuple(cls) # replace namedtuple with new one diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 88a61176e51ab..22b51110ed671 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -29,9 +29,14 @@ import sys import tempfile import time -import unittest import zipfile +if sys.version_info[:2] <= (2, 6): + import unittest2 as unittest +else: + import unittest + + from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int @@ -605,6 +610,7 @@ def test_oldhadoop(self): conf=input_conf).collect()) self.assertEqual(old_dataset, dict_data) + @unittest.skipIf(sys.version_info[:2] <= (2, 6), "Skipped on 2.6 until SPARK-2951 is fixed") def test_newhadoop(self): basepath = self.tempdir.name # use custom ArrayWritable types and converters to handle arrays @@ -905,8 +911,9 @@ def createFileInZip(self, name, content): pattern = re.compile(r'^ *\|', re.MULTILINE) content = re.sub(pattern, '', content.strip()) path = os.path.join(self.programDir, name + ".zip") - with zipfile.ZipFile(path, 'w') as zip: - zip.writestr(name, content) + zip = zipfile.ZipFile(path, 'w') + zip.writestr(name, content) + zip.close() return path def test_single_script(self): diff --git a/python/run-tests b/python/run-tests index 48feba2f5bd63..1218edcbd7e08 100755 --- a/python/run-tests +++ b/python/run-tests @@ -48,6 +48,14 @@ function run_test() { echo "Running PySpark tests. Output is in python/unit-tests.log." +# Try to test with Python 2.6, since that's the minimum version that we support: +if [ $(which python2.6) ]; then + export PYSPARK_PYTHON="python2.6" +fi + +echo "Testing with Python version:" +$PYSPARK_PYTHON --version + run_test "pyspark/rdd.py" run_test "pyspark/context.py" run_test "pyspark/conf.py" From 37338666655909502e424b4639d680271d6d4c12 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 11 Aug 2014 15:25:21 -0700 Subject: [PATCH 43/83] [SPARK-2952] Enable logging actor messages at DEBUG level Example messages: ``` 14/08/09 21:37:01 DEBUG BlockManagerMasterActor: [actor] received message RegisterBlockManager(BlockManagerId(0, rxin-mbp, 58092, 0),278302556,Actor[akka.tcp://spark@rxin-mbp:58088/user/BlockManagerActor1#-63596539]) from Actor[akka.tcp://spark@rxin-mbp:58088/temp/$c] 14/08/09 21:37:01 DEBUG BlockManagerMasterActor: [actor] handled message (0.279 ms) RegisterBlockManager(BlockManagerId(0, rxin-mbp, 58092, 0),278302556,Actor[akka.tcp://spark@rxin-mbp:58088/user/BlockManagerActor1#-63596539]) from Actor[akka.tcp://spark@rxin-mbp:58088/temp/$c] ``` cc @mengxr @tdas @pwendell Author: Reynold Xin Closes #1870 from rxin/actorLogging and squashes the following commits: c531ee5 [Reynold Xin] Added license header for ActorLogReceive. f6b1ebe [Reynold Xin] [SPARK-2952] Enable logging actor messages at DEBUG level --- .../org/apache/spark/HeartbeatReceiver.scala | 7 +- .../org/apache/spark/MapOutputTracker.scala | 4 +- .../org/apache/spark/deploy/Client.scala | 8 ++- .../spark/deploy/client/AppClient.scala | 6 +- .../apache/spark/deploy/master/Master.scala | 6 +- .../apache/spark/deploy/worker/Worker.scala | 6 +- .../spark/deploy/worker/WorkerWatcher.scala | 8 ++- .../CoarseGrainedExecutorBackend.scala | 7 +- .../CoarseGrainedSchedulerBackend.scala | 9 ++- .../spark/scheduler/local/LocalBackend.scala | 8 +-- .../storage/BlockManagerMasterActor.scala | 11 ++-- .../storage/BlockManagerSlaveActor.scala | 5 +- .../apache/spark/util/ActorLogReceive.scala | 64 +++++++++++++++++++ 13 files changed, 111 insertions(+), 38 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 24ccce21b62ca..83ae57b7f1516 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -21,6 +21,7 @@ import akka.actor.Actor import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.util.ActorLogReceive /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -36,8 +37,10 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) extends Actor { - override def receive = { +private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) + extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { case Heartbeat(executorId, taskMetrics, blockManagerId) => val response = HeartbeatResponse( !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 894091761485d..51705c895a55c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -38,10 +38,10 @@ private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage /** Actor class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - def receive = { + override def receiveWithLogging = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = sender.path.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index c07003784e8ac..065ddda50e65e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -27,12 +27,14 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} /** * Proxy that relays messages to the driver. */ -private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with Logging { +private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) + extends Actor with ActorLogReceive with Logging { + var masterActor: ActorSelection = _ val timeout = AkkaUtils.askTimeout(conf) @@ -114,7 +116,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends } } - override def receive = { + override def receiveWithLogging = { case SubmitDriverResponse(success, driverId, message) => println(message) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index d38e9e79204c2..32790053a6be8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -30,7 +30,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.util.{ActorLogReceive, Utils, AkkaUtils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -56,7 +56,7 @@ private[spark] class AppClient( var registered = false var activeMasterUrl: String = null - class ClientActor extends Actor with Logging { + class ClientActor extends Actor with ActorLogReceive with Logging { var master: ActorSelection = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times var alreadyDead = false // To avoid calling listener.dead() multiple times @@ -119,7 +119,7 @@ private[spark] class AppClient( .contains(remoteUrl.hostPort) } - override def receive = { + override def receiveWithLogging = { case RegisteredApplication(appId_, masterUrl) => appId = appId_ registered = true diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index a70ecdb375373..cfa2c028a807b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -42,14 +42,14 @@ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} private[spark] class Master( host: String, port: Int, webUiPort: Int, val securityMgr: SecurityManager) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { import context.dispatcher // to use Akka's scheduler.schedule() @@ -167,7 +167,7 @@ private[spark] class Master( context.stop(leaderElectionAgent) } - override def receive = { + override def receiveWithLogging = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index bacb514ed6335..80fde7e4b2624 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -34,7 +34,7 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} /** * @param masterUrls Each url should look like spark://host:port. @@ -51,7 +51,7 @@ private[spark] class Worker( workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { import context.dispatcher Utils.checkHost(host, "Expected hostname") @@ -187,7 +187,7 @@ private[spark] class Worker( } } - override def receive = { + override def receiveWithLogging = { case RegisteredWorker(masterUrl, masterWebUiUrl) => logInfo("Successfully registered with master " + masterUrl) registered = true diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 530c147000904..6d0d0bbe5ecec 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -22,13 +22,15 @@ import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, Di import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat +import org.apache.spark.util.ActorLogReceive /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(workerUrl: String) extends Actor - with Logging { +private[spark] class WorkerWatcher(workerUrl: String) + extends Actor with ActorLogReceive with Logging { + override def preStart() { context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) @@ -48,7 +50,7 @@ private[spark] class WorkerWatcher(workerUrl: String) extends Actor def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receive = { + override def receiveWithLogging = { case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 1f46a0f176490..13af5b6f5812d 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -31,14 +31,15 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, executorId: String, hostPort: String, cores: Int, - sparkProperties: Seq[(String, String)]) extends Actor with ExecutorBackend with Logging { + sparkProperties: Seq[(String, String)]) + extends Actor with ActorLogReceive with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") @@ -52,7 +53,7 @@ private[spark] class CoarseGrainedExecutorBackend( context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } - override def receive = { + override def receiveWithLogging = { case RegisteredExecutor => logInfo("Successfully registered with driver") // Make this host instead of hostPort ? diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 33500d967ebb1..2a3711ae2a78c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -30,7 +30,7 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState} import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} import org.apache.spark.ui.JettyUtils /** @@ -61,7 +61,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) val createTime = System.currentTimeMillis() - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { + class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { + + override protected def log = CoarseGrainedSchedulerBackend.this.log + private val executorActor = new HashMap[String, ActorRef] private val executorAddress = new HashMap[String, Address] private val executorHost = new HashMap[String, String] @@ -79,7 +82,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } - def receive = { + def receiveWithLogging = { case RegisterExecutor(executorId, hostPort, cores) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorActor.contains(executorId)) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 3d1cf312ccc97..bec9502f20466 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -23,9 +23,9 @@ import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{Logging, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.{TaskMetrics, Executor, ExecutorBackend} +import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.ActorLogReceive private case class ReviveOffers() @@ -43,7 +43,7 @@ private case class StopExecutor() private[spark] class LocalActor( scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, - private val totalCores: Int) extends Actor with Logging { + private val totalCores: Int) extends Actor with ActorLogReceive with Logging { private var freeCores = totalCores @@ -53,7 +53,7 @@ private[spark] class LocalActor( val executor = new Executor( localExecutorId, localExecutorHostname, scheduler.conf.getAll, isLocal = true) - def receive = { + override def receiveWithLogging = { case ReviveOffers => reviveOffers() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index bd31e3c5a187f..3ab07703b6f85 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -31,7 +31,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} /** * BlockManagerMasterActor is an actor on the master node to track statuses of @@ -39,7 +39,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} */ private[spark] class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] @@ -55,8 +55,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000)) - val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", - 60000) + val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) var timeoutCheckingTask: Cancellable = null @@ -67,9 +66,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus super.preStart() } - def receive = { + override def receiveWithLogging = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - logInfo("received a register") register(blockManagerId, maxMemSize, slaveActor) sender ! true @@ -118,7 +116,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus sender ! true case StopBlockManagerMaster => - logInfo("Stopping BlockManagerMaster") sender ! true if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 6d4db064dff58..c194e0fed3367 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -23,6 +23,7 @@ import akka.actor.{ActorRef, Actor} import org.apache.spark.{Logging, MapOutputTracker} import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.ActorLogReceive /** * An actor to take commands from the master to execute options. For example, @@ -32,12 +33,12 @@ private[storage] class BlockManagerSlaveActor( blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { import context.dispatcher // Operations that involve removing blocks may be slow and should be done asynchronously - override def receive = { + override def receiveWithLogging = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, sender) { blockManager.removeBlock(blockId) diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala new file mode 100644 index 0000000000000..332d0cbb2dc0c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import akka.actor.Actor +import org.slf4j.Logger + +/** + * A trait to enable logging all Akka actor messages. Here's an example of using this: + * + * {{{ + * class BlockManagerMasterActor extends Actor with ActorLogReceive with Logging { + * ... + * override def receiveWithLogging = { + * case GetLocations(blockId) => + * sender ! getLocations(blockId) + * ... + * } + * ... + * } + * }}} + * + */ +private[spark] trait ActorLogReceive { + self: Actor => + + override def receive: Actor.Receive = new Actor.Receive { + + private val _receiveWithLogging = receiveWithLogging + + override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o) + + override def apply(o: Any): Unit = { + if (log.isDebugEnabled) { + log.debug(s"[actor] received message $o from ${self.sender}") + } + val start = System.nanoTime + _receiveWithLogging.apply(o) + val timeTaken = (System.nanoTime - start).toDouble / 1000000 + if (log.isDebugEnabled) { + log.debug(s"[actor] handled message ($timeTaken ms) $o from ${self.sender}") + } + } + } + + def receiveWithLogging: Actor.Receive + + protected def log: Logger +} From 7712e724ad69dd0b83754e938e9799d13a4d43b9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 11 Aug 2014 19:15:01 -0700 Subject: [PATCH 44/83] [SPARK-2931] In TaskSetManager, reset currentLocalityIndex after recomputing locality levels This addresses SPARK-2931, a bug where getAllowedLocalityLevel() could throw ArrayIndexOutOfBoundsException. The fix here is to reset currentLocalityIndex after recomputing the locality levels. Thanks to kayousterhout, mridulm, and lirui-intel for helping me to debug this. Author: Josh Rosen Closes #1896 from JoshRosen/SPARK-2931 and squashes the following commits: 48b60b5 [Josh Rosen] Move FakeRackUtil.cleanUp() info beforeEach(). 6fec474 [Josh Rosen] Set currentLocalityIndex after recomputing locality levels. 9384897 [Josh Rosen] Update SPARK-2931 test to reflect changes in 63bdb1f41b4895e3a9444f7938094438a94d3007. 9ecd455 [Josh Rosen] Apply @mridulm's patch for reproducing SPARK-2931. --- .../spark/scheduler/TaskSetManager.scala | 11 +++-- .../spark/scheduler/TaskSetManagerSuite.scala | 40 ++++++++++++++++++- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 20a4bd12f93f6..d9d53faf843ff 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -690,8 +690,7 @@ private[spark] class TaskSetManager( handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure) } // recalculate valid locality levels and waits when executor is lost - myLocalityLevels = computeValidLocalityLevels() - localityWaits = myLocalityLevels.map(getLocalityWait) + recomputeLocality() } /** @@ -775,9 +774,15 @@ private[spark] class TaskSetManager( levels.toArray } - def executorAdded() { + def recomputeLocality() { + val previousLocalityLevel = myLocalityLevels(currentLocalityIndex) myLocalityLevels = computeValidLocalityLevels() localityWaits = myLocalityLevels.map(getLocalityWait) + currentLocalityIndex = getLocalityIndex(previousLocalityLevel) + } + + def executorAdded() { + recomputeLocality() } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ffd23380a886f..93e8ddacf8865 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -154,6 +154,11 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val LOCALITY_WAIT = conf.getLong("spark.locality.wait", 3000) val MAX_TASK_FAILURES = 4 + override def beforeEach() { + super.beforeEach() + FakeRackUtil.cleanUp() + } + test("TaskSet with no preferences") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) @@ -471,7 +476,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("new executors get added and lost") { // Assign host2 to rack2 - FakeRackUtil.cleanUp() FakeRackUtil.assignHostToRack("host2", "rack2") sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc) @@ -504,7 +508,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { } test("test RACK_LOCAL tasks") { - FakeRackUtil.cleanUp() // Assign host1 to rack1 FakeRackUtil.assignHostToRack("host1", "rack1") // Assign host2 to rack1 @@ -607,6 +610,39 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execA", "host3", NO_PREF).get.index === 2) } + test("Ensure TaskSetManager is usable after addition of levels") { + // Regression test for SPARK-2931 + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(2, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host2", "execB.1"))) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + // Only ANY is valid + assert(manager.myLocalityLevels.sameElements(Array(ANY))) + // Add a new executor + sched.addExecutor("execA", "host1") + sched.addExecutor("execB.2", "host2") + manager.executorAdded() + assert(manager.pendingTasksWithNoPrefs.size === 0) + // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + assert(manager.resourceOffer("execA", "host1", ANY) !== None) + clock.advance(LOCALITY_WAIT * 4) + assert(manager.resourceOffer("execB.2", "host2", ANY) !== None) + sched.removeExecutor("execA") + sched.removeExecutor("execB.2") + manager.executorLost("execA", "host1") + manager.executorLost("execB.2", "host2") + clock.advance(LOCALITY_WAIT * 4) + sched.addExecutor("execC", "host3") + manager.executorAdded() + // Prior to the fix, this line resulted in an ArrayIndexOutOfBoundsException: + assert(manager.resourceOffer("execC", "host3", ANY) !== None) + } + + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) From 32638b5e74e02410831b391f555223f90c830498 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Mon, 11 Aug 2014 19:22:14 -0700 Subject: [PATCH 45/83] [SPARK-2515][mllib] Chi Squared test Author: Doris Xin Closes #1733 from dorx/chisquare and squashes the following commits: cafb3a7 [Doris Xin] fixed p-value for extreme case. d286783 [Doris Xin] Merge branch 'master' into chisquare e95e485 [Doris Xin] reviewer comments. 7dde711 [Doris Xin] ChiSqTestResult renaming and changed to Class 80d03e2 [Doris Xin] Reviewer comments. c39eeb5 [Doris Xin] units passed with updated API e90d90a [Doris Xin] Merge branch 'master' into chisquare 7eea80b [Doris Xin] WIP d64c2fb [Doris Xin] Merge branch 'master' into chisquare 5686082 [Doris Xin] facelift bc7eb2e [Doris Xin] unit passed; still need docs and some refactoring 50703a5 [Doris Xin] merge master 4e4e361 [Doris Xin] WIP e6b83f3 [Doris Xin] reviewer comments 3d61582 [Doris Xin] input names 706d436 [Doris Xin] Added API for RDD[Vector] 6598379 [Doris Xin] API and code structure. ff17423 [Doris Xin] WIP --- .../apache/spark/mllib/stat/Statistics.scala | 64 +++++ .../spark/mllib/stat/test/ChiSqTest.scala | 221 ++++++++++++++++++ .../spark/mllib/stat/test/TestResult.scala | 88 +++++++ .../mllib/stat/HypothesisTestSuite.scala | 139 +++++++++++ 4 files changed, 512 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index f416a9fbb323d..cf8679610e191 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -19,7 +19,9 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.correlation.Correlations +import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult} import org.apache.spark.rdd.RDD /** @@ -89,4 +91,66 @@ object Statistics { */ @Experimental def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) + + /** + * :: Experimental :: + * Conduct Pearson's chi-squared goodness of fit test of the observed data against the + * expected distribution. + * + * Note: the two input Vectors need to have the same size. + * `observed` cannot contain negative values. + * `expected` cannot contain nonpositive values. + * + * @param observed Vector containing the observed categorical counts/relative frequencies. + * @param expected Vector containing the expected categorical counts/relative frequencies. + * `expected` is rescaled if the `expected` sum differs from the `observed` sum. + * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, + * the method used, and the null hypothesis. + */ + @Experimental + def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { + ChiSqTest.chiSquared(observed, expected) + } + + /** + * :: Experimental :: + * Conduct Pearson's chi-squared goodness of fit test of the observed data against the uniform + * distribution, with each category having an expected frequency of `1 / observed.size`. + * + * Note: `observed` cannot contain negative values. + * + * @param observed Vector containing the observed categorical counts/relative frequencies. + * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, + * the method used, and the null hypothesis. + */ + @Experimental + def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) + + /** + * :: Experimental :: + * Conduct Pearson's independence test on the input contingency matrix, which cannot contain + * negative entries or columns or rows that sum up to 0. + * + * @param observed The contingency matrix (containing either counts or relative frequencies). + * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, + * the method used, and the null hypothesis. + */ + @Experimental + def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) + + /** + * :: Experimental :: + * Conduct Pearson's independence test for every feature against the label across the input RDD. + * For each feature, the (feature, label) pairs are converted into a contingency matrix for which + * the chi-squared statistic is computed. + * + * @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features. + * Real-valued features will be treated as categorical for each distinct value. + * @return an array containing the ChiSquaredTestResult for every feature against the label. + * The order of the elements in the returned array reflects the order of input features. + */ + @Experimental + def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { + ChiSqTest.chiSquaredFeatures(data) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala new file mode 100644 index 0000000000000..8f6752737402e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import breeze.linalg.{DenseMatrix => BDM} +import cern.jet.stat.Probability.chiSquareComplemented + +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD + +/** + * Conduct the chi-squared test for the input RDDs using the specified method. + * Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted + * on an input of type `Matrix` in which independence between columns is assessed. + * We also provide a method for computing the chi-squared statistic between each feature and the + * label for an input `RDD[LabeledPoint]`, return an `Array[ChiSquaredTestResult]` of size = + * number of features in the inpuy RDD. + * + * Supported methods for goodness of fit: `pearson` (default) + * Supported methods for independence: `pearson` (default) + * + * More information on Chi-squared test: http://en.wikipedia.org/wiki/Chi-squared_test + */ +private[stat] object ChiSqTest extends Logging { + + /** + * @param name String name for the method. + * @param chiSqFunc Function for computing the statistic given the observed and expected counts. + */ + case class Method(name: String, chiSqFunc: (Double, Double) => Double) + + // Pearson's chi-squared test: http://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test + val PEARSON = new Method("pearson", (observed: Double, expected: Double) => { + val dev = observed - expected + dev * dev / expected + }) + + // Null hypothesis for the two different types of chi-squared tests to be included in the result. + object NullHypothesis extends Enumeration { + type NullHypothesis = Value + val goodnessOfFit = Value("observed follows the same distribution as expected.") + val independence = Value("observations in each column are statistically independent.") + } + + // Method identification based on input methodName string + private def methodFromString(methodName: String): Method = { + methodName match { + case PEARSON.name => PEARSON + case _ => throw new IllegalArgumentException("Unrecognized method for Chi squared test.") + } + } + + /** + * Conduct Pearson's independence test for each feature against the label across the input RDD. + * The contingency table is constructed from the raw (feature, label) pairs and used to conduct + * the independence test. + * Returns an array containing the ChiSquaredTestResult for every feature against the label. + */ + def chiSquaredFeatures(data: RDD[LabeledPoint], + methodName: String = PEARSON.name): Array[ChiSqTestResult] = { + val numCols = data.first().features.size + val results = new Array[ChiSqTestResult](numCols) + var labels: Map[Double, Int] = null + // At most 100 columns at a time + val batchSize = 100 + var batch = 0 + while (batch * batchSize < numCols) { + // The following block of code can be cleaned up and made public as + // chiSquared(data: RDD[(V1, V2)]) + val startCol = batch * batchSize + val endCol = startCol + math.min(batchSize, numCols - startCol) + val pairCounts = data.flatMap { p => + // assume dense vectors + p.features.toArray.slice(startCol, endCol).zipWithIndex.map { case (feature, col) => + (col, feature, p.label) + } + }.countByValue() + + if (labels == null) { + // Do this only once for the first column since labels are invariant across features. + labels = + pairCounts.keys.filter(_._1 == startCol).map(_._3).toArray.distinct.zipWithIndex.toMap + } + val numLabels = labels.size + pairCounts.keys.groupBy(_._1).map { case (col, keys) => + val features = keys.map(_._2).toArray.distinct.zipWithIndex.toMap + val numRows = features.size + val contingency = new BDM(numRows, numLabels, new Array[Double](numRows * numLabels)) + keys.foreach { case (_, feature, label) => + val i = features(feature) + val j = labels(label) + contingency(i, j) += pairCounts((col, feature, label)) + } + results(col) = chiSquaredMatrix(Matrices.fromBreeze(contingency), methodName) + } + batch += 1 + } + results + } + + /* + * Pearon's goodness of fit test on the input observed and expected counts/relative frequencies. + * Uniform distribution is assumed when `expected` is not passed in. + */ + def chiSquared(observed: Vector, + expected: Vector = Vectors.dense(Array[Double]()), + methodName: String = PEARSON.name): ChiSqTestResult = { + + // Validate input arguments + val method = methodFromString(methodName) + if (expected.size != 0 && observed.size != expected.size) { + throw new IllegalArgumentException("observed and expected must be of the same size.") + } + val size = observed.size + if (size > 1000) { + logWarning("Chi-squared approximation may not be accurate due to low expected frequencies " + + s" as a result of a large number of categories: $size.") + } + val obsArr = observed.toArray + val expArr = if (expected.size == 0) Array.tabulate(size)(_ => 1.0 / size) else expected.toArray + if (!obsArr.forall(_ >= 0.0)) { + throw new IllegalArgumentException("Negative entries disallowed in the observed vector.") + } + if (expected.size != 0 && ! expArr.forall(_ >= 0.0)) { + throw new IllegalArgumentException("Negative entries disallowed in the expected vector.") + } + + // Determine the scaling factor for expected + val obsSum = obsArr.sum + val expSum = if (expected.size == 0.0) 1.0 else expArr.sum + val scale = if (math.abs(obsSum - expSum) < 1e-7) 1.0 else obsSum / expSum + + // compute chi-squared statistic + val statistic = obsArr.zip(expArr).foldLeft(0.0) { case (stat, (obs, exp)) => + if (exp == 0.0) { + if (obs == 0.0) { + throw new IllegalArgumentException("Chi-squared statistic undefined for input vectors due" + + " to 0.0 values in both observed and expected.") + } else { + return new ChiSqTestResult(0.0, size - 1, Double.PositiveInfinity, PEARSON.name, + NullHypothesis.goodnessOfFit.toString) + } + } + if (scale == 1.0) { + stat + method.chiSqFunc(obs, exp) + } else { + stat + method.chiSqFunc(obs, exp * scale) + } + } + val df = size - 1 + val pValue = chiSquareComplemented(df, statistic) + new ChiSqTestResult(pValue, df, statistic, PEARSON.name, NullHypothesis.goodnessOfFit.toString) + } + + /* + * Pearon's independence test on the input contingency matrix. + * TODO: optimize for SparseMatrix when it becomes supported. + */ + def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = { + val method = methodFromString(methodName) + val numRows = counts.numRows + val numCols = counts.numCols + + // get row and column sums + val colSums = new Array[Double](numCols) + val rowSums = new Array[Double](numRows) + val colMajorArr = counts.toArray + var i = 0 + while (i < colMajorArr.size) { + val elem = colMajorArr(i) + if (elem < 0.0) { + throw new IllegalArgumentException("Contingency table cannot contain negative entries.") + } + colSums(i / numRows) += elem + rowSums(i % numRows) += elem + i += 1 + } + val total = colSums.sum + + // second pass to collect statistic + var statistic = 0.0 + var j = 0 + while (j < colMajorArr.size) { + val col = j / numRows + val colSum = colSums(col) + if (colSum == 0.0) { + throw new IllegalArgumentException("Chi-squared statistic undefined for input matrix due to" + + s"0 sum in column [$col].") + } + val row = j % numRows + val rowSum = rowSums(row) + if (rowSum == 0.0) { + throw new IllegalArgumentException("Chi-squared statistic undefined for input matrix due to" + + s"0 sum in row [$row].") + } + val expected = colSum * rowSum / total + statistic += method.chiSqFunc(colMajorArr(j), expected) + j += 1 + } + val df = (numCols - 1) * (numRows - 1) + val pValue = chiSquareComplemented(df, statistic) + new ChiSqTestResult(pValue, df, statistic, methodName, NullHypothesis.independence.toString) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala new file mode 100644 index 0000000000000..2f278621335e1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Trait for hypothesis test results. + * @tparam DF Return type of `degreesOfFreedom`. + */ +@Experimental +trait TestResult[DF] { + + /** + * The probability of obtaining a test statistic result at least as extreme as the one that was + * actually observed, assuming that the null hypothesis is true. + */ + def pValue: Double + + /** + * Returns the degree(s) of freedom of the hypothesis test. + * Return type should be Number(e.g. Int, Double) or tuples of Numbers for toString compatibility. + */ + def degreesOfFreedom: DF + + /** + * Test statistic. + */ + def statistic: Double + + /** + * String explaining the hypothesis test result. + * Specific classes implementing this trait should override this method to output test-specific + * information. + */ + override def toString: String = { + + // String explaining what the p-value indicates. + val pValueExplain = if (pValue <= 0.01) { + "Very strong presumption against null hypothesis." + } else if (0.01 < pValue && pValue <= 0.05) { + "Strong presumption against null hypothesis." + } else if (0.05 < pValue && pValue <= 0.01) { + "Low presumption against null hypothesis." + } else { + "No presumption against null hypothesis." + } + + s"degrees of freedom = ${degreesOfFreedom.toString} \n" + + s"statistic = $statistic \n" + + s"pValue = $pValue \n" + pValueExplain + } +} + +/** + * :: Experimental :: + * Object containing the test results for the chi squared hypothesis test. + */ +@Experimental +class ChiSqTestResult(override val pValue: Double, + override val degreesOfFreedom: Int, + override val statistic: Double, + val method: String, + val nullHypothesis: String) extends TestResult[Int] { + + override def toString: String = { + "Chi squared test summary: \n" + + s"method: $method \n" + + s"null hypothesis: $nullHypothesis \n" + + super.toString + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala new file mode 100644 index 0000000000000..5bd0521298c14 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.test.ChiSqTest +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class HypothesisTestSuite extends FunSuite with LocalSparkContext { + + test("chi squared pearson goodness of fit") { + + val observed = new DenseVector(Array[Double](4, 6, 5)) + val pearson = Statistics.chiSqTest(observed) + + // Results validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))` + assert(pearson.statistic === 0.4) + assert(pearson.degreesOfFreedom === 2) + assert(pearson.pValue ~== 0.8187 relTol 1e-4) + assert(pearson.method === ChiSqTest.PEARSON.name) + assert(pearson.nullHypothesis === ChiSqTest.NullHypothesis.goodnessOfFit.toString) + + // different expected and observed sum + val observed1 = new DenseVector(Array[Double](21, 38, 43, 80)) + val expected1 = new DenseVector(Array[Double](3, 5, 7, 20)) + val pearson1 = Statistics.chiSqTest(observed1, expected1) + + // Results validated against the R command + // `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))` + assert(pearson1.statistic ~== 14.1429 relTol 1e-4) + assert(pearson1.degreesOfFreedom === 3) + assert(pearson1.pValue ~== 0.002717 relTol 1e-4) + assert(pearson1.method === ChiSqTest.PEARSON.name) + assert(pearson1.nullHypothesis === ChiSqTest.NullHypothesis.goodnessOfFit.toString) + + // Vectors with different sizes + val observed3 = new DenseVector(Array(1.0, 2.0, 3.0)) + val expected3 = new DenseVector(Array(1.0, 2.0, 3.0, 4.0)) + intercept[IllegalArgumentException](Statistics.chiSqTest(observed3, expected3)) + + // negative counts in observed + val negObs = new DenseVector(Array(1.0, 2.0, 3.0, -4.0)) + intercept[IllegalArgumentException](Statistics.chiSqTest(negObs, expected1)) + + // count = 0.0 in expected but not observed + val zeroExpected = new DenseVector(Array(1.0, 0.0, 3.0)) + val inf = Statistics.chiSqTest(observed, zeroExpected) + assert(inf.statistic === Double.PositiveInfinity) + assert(inf.degreesOfFreedom === 2) + assert(inf.pValue === 0.0) + assert(inf.method === ChiSqTest.PEARSON.name) + assert(inf.nullHypothesis === ChiSqTest.NullHypothesis.goodnessOfFit.toString) + + // 0.0 in expected and observed simultaneously + val zeroObserved = new DenseVector(Array(2.0, 0.0, 1.0)) + intercept[IllegalArgumentException](Statistics.chiSqTest(zeroObserved, zeroExpected)) + } + + test("chi squared pearson matrix independence") { + val data = Array(40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0) + // [[40.0, 56.0, 31.0, 30.0], + // [24.0, 32.0, 10.0, 15.0], + // [29.0, 42.0, 0.0, 12.0]] + val chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) + // Results validated against R command + // `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))` + assert(chi.statistic ~== 21.9958 relTol 1e-4) + assert(chi.degreesOfFreedom === 6) + assert(chi.pValue ~== 0.001213 relTol 1e-4) + assert(chi.method === ChiSqTest.PEARSON.name) + assert(chi.nullHypothesis === ChiSqTest.NullHypothesis.independence.toString) + + // Negative counts + val negCounts = Array(4.0, 5.0, 3.0, -3.0) + intercept[IllegalArgumentException](Statistics.chiSqTest(Matrices.dense(2, 2, negCounts))) + + // Row sum = 0.0 + val rowZero = Array(0.0, 1.0, 0.0, 2.0) + intercept[IllegalArgumentException](Statistics.chiSqTest(Matrices.dense(2, 2, rowZero))) + + // Column sum = 0.0 + val colZero = Array(0.0, 0.0, 2.0, 2.0) + // IllegalArgumentException thrown here since it's thrown on driver, not inside a task + intercept[IllegalArgumentException](Statistics.chiSqTest(Matrices.dense(2, 2, colZero))) + } + + test("chi squared pearson RDD[LabeledPoint]") { + // labels: 1.0 (2 / 6), 0.0 (4 / 6) + // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6) + // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6) + val data = Array(new LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), + new LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), + new LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), + new LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), + new LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), + new LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) + for (numParts <- List(2, 4, 6, 8)) { + val chi = Statistics.chiSqTest(sc.parallelize(data, numParts)) + val feature1 = chi(0) + assert(feature1.statistic === 0.75) + assert(feature1.degreesOfFreedom === 2) + assert(feature1.pValue ~== 0.6873 relTol 1e-4) + assert(feature1.method === ChiSqTest.PEARSON.name) + assert(feature1.nullHypothesis === ChiSqTest.NullHypothesis.independence.toString) + val feature2 = chi(1) + assert(feature2.statistic === 1.5) + assert(feature2.degreesOfFreedom === 3) + assert(feature2.pValue ~== 0.6823 relTol 1e-4) + assert(feature2.method === ChiSqTest.PEARSON.name) + assert(feature2.nullHypothesis === ChiSqTest.NullHypothesis.independence.toString) + } + + // Test that the right number of results is returned + val numCols = 321 + val sparseData = Array(new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), + new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((200, 1.0))))) + val chi = Statistics.chiSqTest(sc.parallelize(sparseData)) + assert(chi.size === numCols) + } +} From 6fab941b65f0cb6c9b32e0f8290d76889cda6a87 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 11 Aug 2014 19:49:29 -0700 Subject: [PATCH 46/83] [SPARK-2934][MLlib] Adding LogisticRegressionWithLBFGS Interface for training with LBFGS Optimizer which will converge faster than SGD. Author: DB Tsai Closes #1862 from dbtsai/dbtsai-lbfgs-lor and squashes the following commits: aa84b81 [DB Tsai] small change f852bcd [DB Tsai] Remove duplicate method f119fdc [DB Tsai] Formatting 97776aa [DB Tsai] address more feedback 85b4a91 [DB Tsai] address feedback 3cf50c2 [DB Tsai] LogisticRegressionWithLBFGS interface --- .../classification/LogisticRegression.scala | 51 ++++++++++- .../LogisticRegressionSuite.scala | 89 ++++++++++++++++++- 2 files changed, 136 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 2242329b7918e..31d474a20fa85 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -101,7 +101,7 @@ class LogisticRegressionWithSGD private ( } /** - * Top-level methods for calling Logistic Regression. + * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. * NOTE: Labels used in Logistic Regression should be {0, 1} */ object LogisticRegressionWithSGD { @@ -188,3 +188,52 @@ object LogisticRegressionWithSGD { train(input, numIterations, 1.0, 1.0) } } + +/** + * Train a classification model for Logistic Regression using Limited-memory BFGS. + * NOTE: Labels used in Logistic Regression should be {0, 1} + */ +class LogisticRegressionWithLBFGS private ( + private var convergenceTol: Double, + private var maxNumIterations: Int, + private var regParam: Double) + extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { + + /** + * Construct a LogisticRegression object with default parameters + */ + def this() = this(1E-4, 100, 0.0) + + private val gradient = new LogisticGradient() + private val updater = new SimpleUpdater() + // Have to return new LBFGS object every time since users can reset the parameters anytime. + override def optimizer = new LBFGS(gradient, updater) + .setNumCorrections(10) + .setConvergenceTol(convergenceTol) + .setMaxNumIterations(maxNumIterations) + .setRegParam(regParam) + + override protected val validators = List(DataValidators.binaryLabelValidator) + + /** + * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. + * Smaller value will lead to higher accuracy with the cost of more iterations. + */ + def setConvergenceTol(convergenceTol: Double): this.type = { + this.convergenceTol = convergenceTol + this + } + + /** + * Set the maximal number of iterations for L-BFGS. Default 100. + */ + def setNumIterations(numIterations: Int): this.type = { + this.maxNumIterations = numIterations + this + } + + override protected def createModel(weights: Vector, intercept: Double) = { + new LogisticRegressionModel(weights, intercept) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index da7c633bbd2af..2289c6cdc19de 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -67,7 +67,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match } // Test if we can correctly learn A, B where Y = logistic(A + B*X) - test("logistic regression") { + test("logistic regression with SGD") { val nPoints = 10000 val A = 2.0 val B = -1.5 @@ -94,7 +94,36 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } - test("logistic regression with initial weights") { + // Test if we can correctly learn A, B where Y = logistic(A + B*X) + test("logistic regression with LBFGS") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + + val model = lr.run(testRDD) + + // Test the weights + assert(model.weights(0) ~== -1.52 relTol 0.01) + assert(model.intercept ~== 2.00 relTol 0.01) + assert(model.weights(0) ~== model.weights(0) relTol 0.01) + assert(model.intercept ~== model.intercept relTol 0.01) + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("logistic regression with initial weights with SGD") { val nPoints = 10000 val A = 2.0 val B = -1.5 @@ -125,11 +154,42 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("logistic regression with initial weights with LBFGS") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val initialB = -1.0 + val initialWeights = Vectors.dense(initialB) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + // Use half as many iterations as the previous test. + val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + + val model = lr.run(testRDD, initialWeights) + + // Test the weights + assert(model.weights(0) ~== -1.50 relTol 0.02) + assert(model.intercept ~== 1.97 relTol 0.02) + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } } class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { - test("task size should be small in both training and prediction") { + test("task size should be small in both training and prediction using SGD optimizer") { val m = 4 val n = 200000 val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => @@ -139,6 +199,29 @@ class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkCont // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. val model = LogisticRegressionWithSGD.train(points, 2) + val predictions = model.predict(points.map(_.features)) + + // Materialize the RDDs + predictions.count() } + + test("task size should be small in both training and prediction using LBFGS optimizer") { + val m = 4 + val n = 200000 + val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble())))) + }.cache() + // If we serialize data directly in the task closure, the size of the serialized task would be + // greater than 1MB and hence Spark would throw an error. + val model = + (new LogisticRegressionWithLBFGS().setIntercept(true).setNumIterations(2)).run(points) + + val predictions = model.predict(points.map(_.features)) + + // Materialize the RDDs + predictions.count() + } + } From 490ecfa20327a636289321ea447722aa32b81657 Mon Sep 17 00:00:00 2001 From: Ahir Reddy Date: Mon, 11 Aug 2014 20:06:06 -0700 Subject: [PATCH 47/83] [SPARK-2844][SQL] Correctly set JVM HiveContext if it is passed into Python HiveContext constructor https://issues.apache.org/jira/browse/SPARK-2844 Author: Ahir Reddy Closes #1768 from ahirreddy/python-hive-context-fix and squashes the following commits: 7972d3b [Ahir Reddy] Correctly set JVM HiveContext if it is passed into Python HiveContext constructor --- python/pyspark/sql.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 950e275adbf01..36040463e62a9 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -912,6 +912,8 @@ def __init__(self, sparkContext, sqlContext=None): """Create a new SQLContext. @param sparkContext: The SparkContext to wrap. + @param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + SQLContext in the JVM, instead we make all calls to this object. >>> srdd = sqlCtx.inferSchema(rdd) >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL @@ -1315,6 +1317,18 @@ class HiveContext(SQLContext): It supports running both SQL and HiveQL commands. """ + def __init__(self, sparkContext, hiveContext=None): + """Create a new HiveContext. + + @param sparkContext: The SparkContext to wrap. + @param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new + HiveContext in the JVM, instead we make all calls to this object. + """ + SQLContext.__init__(self, sparkContext) + + if hiveContext: + self._scala_HiveContext = hiveContext + @property def _ssql_ctx(self): try: From 21a95ef051f7b23a80d147aadb00dfa4ebb169b0 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 11 Aug 2014 20:08:06 -0700 Subject: [PATCH 48/83] [SPARK-2590][SQL] Added option to handle incremental collection, disabled by default JIRA issue: [SPARK-2590](https://issues.apache.org/jira/browse/SPARK-2590) Author: Cheng Lian Closes #1853 from liancheng/inc-collect-option and squashes the following commits: cb3ea45 [Cheng Lian] Moved incremental collection option to Thrift server 43ce3aa [Cheng Lian] Changed incremental collect option name 623abde [Cheng Lian] Added option to handle incremental collection, disabled by default --- .../server/SparkSQLOperationManager.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index dee092159dd4c..f192f490ac3d0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -132,7 +132,16 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage logDebug(result.queryExecution.toString()) val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) - iter = result.queryExecution.toRdd.toLocalIterator + iter = { + val resultRdd = result.queryExecution.toRdd + val useIncrementalCollect = + hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean + if (useIncrementalCollect) { + resultRdd.toLocalIterator + } else { + resultRdd.collect().iterator + } + } dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray setHasResultSet(true) } catch { From e83fdcd421d132812411eb805565b76f087f1bc0 Mon Sep 17 00:00:00 2001 From: wangfei Date: Mon, 11 Aug 2014 20:10:13 -0700 Subject: [PATCH 49/83] [sql]use SparkSQLEnv.stop() in ShutdownHook Author: wangfei Closes #1852 from scwf/patch-3 and squashes the following commits: ae28c29 [wangfei] use SparkSQLEnv.stop() in ShutdownHook --- .../apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 6f7942aba314a..cadf7aaf42157 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -60,7 +60,7 @@ private[hive] object HiveThriftServer2 extends Logging { Runtime.getRuntime.addShutdownHook( new Thread() { override def run() { - SparkSQLEnv.sparkContext.stop() + SparkSQLEnv.stop() } } ) From 647aeba3a9e101d35083f7c4afbcfe7a33f7fc62 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 11 Aug 2014 20:11:29 -0700 Subject: [PATCH 50/83] [SQL] A tiny refactoring in HiveContext#analyze I should use `EliminateAnalysisOperators` in `analyze` instead of manually pattern matching. Author: Yin Huai Closes #1881 from yhuai/useEliminateAnalysisOperators and squashes the following commits: f3e1e7f [Yin Huai] Use EliminateAnalysisOperators. --- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 53f3dc11dbb9f..a8da676ffa0e0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -39,7 +39,8 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{OverrideFunctionRegistry, Analyzer, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators} +import org.apache.spark.sql.catalyst.analysis.{OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.ExtractPythonUdfs import org.apache.spark.sql.execution.QueryExecutionException @@ -119,10 +120,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * in the Hive metastore. */ def analyze(tableName: String) { - val relation = catalog.lookupRelation(None, tableName) match { - case LowerCaseSchema(r) => r - case o => o - } + val relation = EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) relation match { case relation: MetastoreRelation => { From c9c89c31b6114832fe282c21fecd663d8105b9bc Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 11 Aug 2014 20:15:01 -0700 Subject: [PATCH 51/83] [SPARK-2965][SQL] Fix HashOuterJoin output nullabilities. Output attributes of opposite side of `OuterJoin` should be nullable. Author: Takuya UESHIN Closes #1887 from ueshin/issues/SPARK-2965 and squashes the following commits: bcb2d37 [Takuya UESHIN] Fix HashOuterJoin output nullabilities. --- .../org/apache/spark/sql/execution/joins.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 51bb61530744c..ea075f8c65bff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -168,7 +168,18 @@ case class HashOuterJoin( override def requiredChildDistribution = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - def output = left.output ++ right.output + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. From c686b7dd4668b5e9fc3177f15edeae3446d2e634 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 11 Aug 2014 20:18:03 -0700 Subject: [PATCH 52/83] [SPARK-2968][SQL] Fix nullabilities of Explode. Output nullabilities of `Explode` could be detemined by `ArrayType.containsNull` or `MapType.valueContainsNull`. Author: Takuya UESHIN Closes #1888 from ueshin/issues/SPARK-2968 and squashes the following commits: d128c95 [Takuya UESHIN] Fix nullability of Explode. --- .../spark/sql/catalyst/expressions/generators.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3d41acb79e5fd..e99c5b452d183 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -86,19 +86,19 @@ case class Explode(attributeNames: Seq[String], child: Expression) (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) private lazy val elementTypes = child.dataType match { - case ArrayType(et, _) => et :: Nil - case MapType(kt,vt, _) => kt :: vt :: Nil + case ArrayType(et, containsNull) => (et, containsNull) :: Nil + case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil } // TODO: Move this pattern into Generator. protected def makeOutput() = if (attributeNames.size == elementTypes.size) { attributeNames.zip(elementTypes).map { - case (n, t) => AttributeReference(n, t, nullable = true)() + case (n, (t, nullable)) => AttributeReference(n, t, nullable)() } } else { elementTypes.zipWithIndex.map { - case (t, i) => AttributeReference(s"c_$i", t, nullable = true)() + case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)() } } From bad21ed085a505559dccc06223b486170371ddd2 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 11 Aug 2014 20:21:56 -0700 Subject: [PATCH 53/83] [SPARK-2650][SQL] Build column buffers in smaller batches Author: Michael Armbrust Closes #1880 from marmbrus/columnBatches and squashes the following commits: 0649987 [Michael Armbrust] add test 4756fad [Michael Armbrust] fix compilation 2314532 [Michael Armbrust] Build column buffers in smaller batches --- .../scala/org/apache/spark/sql/SQLConf.scala | 4 + .../org/apache/spark/sql/SQLContext.scala | 4 +- .../columnar/InMemoryColumnarTableScan.scala | 76 ++++++++++++------- .../apache/spark/sql/CachedTableSuite.scala | 12 ++- .../columnar/InMemoryColumnarQuerySuite.scala | 6 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 2 +- 7 files changed, 70 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 0fd7aaaa36eb8..35c51dec0bcf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -25,6 +25,7 @@ import java.util.Properties private[spark] object SQLConf { val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" + val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" @@ -71,6 +72,9 @@ trait SQLConf { /** When true tables cached using the in-memory columnar caching will be compressed. */ private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean + /** The number of rows that will be */ + private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "1000").toInt + /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 71d338d21d0f2..af9f7c62a1d25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -273,7 +273,7 @@ class SQLContext(@transient val sparkContext: SparkContext) currentTable.logicalPlan case _ => - InMemoryRelation(useCompression, executePlan(currentTable).executedPlan) + InMemoryRelation(useCompression, columnBatchSize, executePlan(currentTable).executedPlan) } catalog.registerTable(None, tableName, asInMemoryRelation) @@ -284,7 +284,7 @@ class SQLContext(@transient val sparkContext: SparkContext) table(tableName).queryExecution.analyzed match { // This is kind of a hack to make sure that if this was just an RDD registered as a table, // we reregister the RDD as a table. - case inMem @ InMemoryRelation(_, _, e: ExistingRdd) => + case inMem @ InMemoryRelation(_, _, _, e: ExistingRdd) => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) catalog.registerTable(None, tableName, SparkLogicalPlan(e)(self)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 88901debbb4e9..3364d0e18bcc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -28,13 +28,14 @@ import org.apache.spark.sql.Row import org.apache.spark.SparkConf object InMemoryRelation { - def apply(useCompression: Boolean, child: SparkPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, child)() + def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, batchSize, child)() } private[sql] case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, + batchSize: Int, child: SparkPlan) (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null) extends LogicalPlan with MultiInstanceRelation { @@ -43,22 +44,31 @@ private[sql] case class InMemoryRelation( // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { val output = child.output - val cached = child.execute().mapPartitions { iterator => - val columnBuilders = output.map { attribute => - ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name, useCompression) - }.toArray - - var row: Row = null - while (iterator.hasNext) { - row = iterator.next() - var i = 0 - while (i < row.length) { - columnBuilders(i).appendFrom(row, i) - i += 1 + val cached = child.execute().mapPartitions { baseIterator => + new Iterator[Array[ByteBuffer]] { + def next() = { + val columnBuilders = output.map { attribute => + ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name, useCompression) + }.toArray + + var row: Row = null + var rowCount = 0 + + while (baseIterator.hasNext && rowCount < batchSize) { + row = baseIterator.next() + var i = 0 + while (i < row.length) { + columnBuilders(i).appendFrom(row, i) + i += 1 + } + rowCount += 1 + } + + columnBuilders.map(_.build()) } - } - Iterator.single(columnBuilders.map(_.build())) + def hasNext = baseIterator.hasNext + } }.cache() cached.setName(child.toString) @@ -74,6 +84,7 @@ private[sql] case class InMemoryRelation( new InMemoryRelation( output.map(_.newInstance), useCompression, + batchSize, child)( _cachedColumnBuffers).asInstanceOf[this.type] } @@ -90,22 +101,31 @@ private[sql] case class InMemoryColumnarTableScan( override def execute() = { relation.cachedColumnBuffers.mapPartitions { iterator => - val columnBuffers = iterator.next() - assert(!iterator.hasNext) + // Find the ordinals of the requested columns. If none are requested, use the first. + val requestedColumns = + if (attributes.isEmpty) { + Seq(0) + } else { + attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) + } new Iterator[Row] { - // Find the ordinals of the requested columns. If none are requested, use the first. - val requestedColumns = - if (attributes.isEmpty) { - Seq(0) - } else { - attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) - } + private[this] var columnBuffers: Array[ByteBuffer] = null + private[this] var columnAccessors: Seq[ColumnAccessor] = null + nextBatch() + + private[this] val nextRow = new GenericMutableRow(columnAccessors.length) - val columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) - val nextRow = new GenericMutableRow(columnAccessors.length) + def nextBatch() = { + columnBuffers = iterator.next() + columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) + } override def next() = { + if (!columnAccessors.head.hasNext) { + nextBatch() + } + var i = 0 while (i < nextRow.length) { columnAccessors(i).extractTo(nextRow, i) @@ -114,7 +134,7 @@ private[sql] case class InMemoryColumnarTableScan( nextRow } - override def hasNext = columnAccessors.head.hasNext + override def hasNext = columnAccessors.head.hasNext || iterator.hasNext } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index fbf9bd9dbcdea..befef46d93973 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,9 +22,19 @@ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableSca import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +case class BigData(s: String) + class CachedTableSuite extends QueryTest { TestData // Load test tables. + test("too big for memory") { + val data = "*" * 10000 + sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData") + cacheTable("bigData") + assert(table("bigData").count() === 1000000L) + uncacheTable("bigData") + } + test("SPARK-1669: cacheTable should be idempotent") { assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) @@ -37,7 +47,7 @@ class CachedTableSuite extends QueryTest { cacheTable("testData") table("testData").queryExecution.analyzed match { - case InMemoryRelation(_, _, _: InMemoryColumnarTableScan) => + case InMemoryRelation(_, _, _, _: InMemoryColumnarTableScan) => fail("cacheTable is not idempotent") case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index b561b44ad7ee2..736c0f8571e9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -28,14 +28,14 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("simple columnar query") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, plan) + val scan = InMemoryRelation(useCompression = true, 5, plan) checkAnswer(scan, testData.collect().toSeq) } test("projection") { val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, plan) + val scan = InMemoryRelation(useCompression = true, 5, plan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -44,7 +44,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, plan) + val scan = InMemoryRelation(useCompression = true, 5, plan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 82e9c1a248626..3b371211e14cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -137,7 +137,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with castChildOutput(p, table, child) case p @ logical.InsertIntoTable( - InMemoryRelation(_, _, + InMemoryRelation(_, _, _, HiveTableScan(_, table, _)), _, child, _) => castChildOutput(p, table, child) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 85d2496a34cfb..5fcc1bd4b9adf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -45,7 +45,7 @@ private[hive] trait HiveStrategies { case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil case logical.InsertIntoTable( - InMemoryRelation(_, _, + InMemoryRelation(_, _, _, HiveTableScan(_, table, _)), partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil case _ => Nil From 5d54d71ddbac1fbb26925a8c9138bbb8c0e81db8 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 11 Aug 2014 20:45:14 -0700 Subject: [PATCH 54/83] [SQL] [SPARK-2826] Reduce the memory copy while building the hashmap for HashOuterJoin This is a follow up for #1147 , this PR will improve the performance about 10% - 15% in my local tests. ``` Before: LeftOuterJoin: took 16750 ms ([3000000] records) LeftOuterJoin: took 15179 ms ([3000000] records) RightOuterJoin: took 15515 ms ([3000000] records) RightOuterJoin: took 15276 ms ([3000000] records) FullOuterJoin: took 19150 ms ([6000000] records) FullOuterJoin: took 18935 ms ([6000000] records) After: LeftOuterJoin: took 15218 ms ([3000000] records) LeftOuterJoin: took 13503 ms ([3000000] records) RightOuterJoin: took 13663 ms ([3000000] records) RightOuterJoin: took 14025 ms ([3000000] records) FullOuterJoin: took 16624 ms ([6000000] records) FullOuterJoin: took 16578 ms ([6000000] records) ``` Besides the performance improvement, I also do some clean up as suggested in #1147 Author: Cheng Hao Closes #1765 from chenghao-intel/hash_outer_join_fixing and squashes the following commits: ab1f9e0 [Cheng Hao] Reduce the memory copy while building the hashmap --- .../apache/spark/sql/execution/joins.scala | 54 ++++++++++--------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index ea075f8c65bff..c86811e838bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.{HashMap => JavaHashMap} + import scala.collection.mutable.{ArrayBuffer, BitSet} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent._ @@ -136,14 +138,6 @@ trait HashJoin { } } -/** - * Constant Value for Binary Join Node - */ -object HashOuterJoin { - val DUMMY_LIST = Seq[Row](null) - val EMPTY_LIST = Seq[Row]() -} - /** * :: DeveloperApi :: * Performs a hash based outer join for two child relations by shuffling the data using @@ -181,6 +175,9 @@ case class HashOuterJoin( } } + @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) + @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. @@ -199,8 +196,8 @@ case class HashOuterJoin( joinedRow.copy } else { Nil - }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { - // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, // as we don't know whether we need to append it until finish iterating all of the // records in right side. // If we didn't get any proper row, then append a single row with empty right @@ -224,8 +221,8 @@ case class HashOuterJoin( joinedRow.copy } else { Nil - }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { - // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, // as we don't know whether we need to append it until finish iterating all of the // records in left side. // If we didn't get any proper row, then append a single row with empty left. @@ -259,10 +256,10 @@ case class HashOuterJoin( rightMatchedSet.add(idx) joinedRow.copy } - } ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // 2. For those unmatched records in left, append additional records with empty right. - // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. @@ -287,18 +284,22 @@ case class HashOuterJoin( } private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): Map[Row, ArrayBuffer[Row]] = { - // TODO: Use Spark's HashMap implementation. - val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]() + iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, ArrayBuffer[Row]] = { + val hashTable = new JavaHashMap[Row, ArrayBuffer[Row]]() while (iter.hasNext) { val currentRow = iter.next() val rowKey = keyGenerator(currentRow) - val existingMatchList = hashTable.getOrElseUpdate(rowKey, {new ArrayBuffer[Row]()}) + var existingMatchList = hashTable.get(rowKey) + if (existingMatchList == null) { + existingMatchList = new ArrayBuffer[Row]() + hashTable.put(rowKey, existingMatchList) + } + existingMatchList += currentRow.copy() } - - hashTable.toMap[Row, ArrayBuffer[Row]] + + hashTable } def execute() = { @@ -309,21 +310,22 @@ case class HashOuterJoin( // Build HashMap for current partition in right relation val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + import scala.collection.JavaConversions._ val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) joinType match { case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), - rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) } case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), - rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) } case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => fullOuterIterator(key, - leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), - rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) } case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") } From 9038d94e1e50e05de00fd51af4fd7b9280481cdc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 11 Aug 2014 22:33:45 -0700 Subject: [PATCH 55/83] [SPARK-2923][MLLIB] Implement some basic BLAS routines Having some basic BLAS operations implemented in MLlib can help simplify the current implementation and improve some performance. Tested on my local machine: ~~~ bin/spark-submit --class org.apache.spark.examples.mllib.BinaryClassification \ examples/target/scala-*/spark-examples-*.jar --algorithm LR --regType L2 \ --regParam 1.0 --numIterations 1000 ~/share/data/rcv1.binary/rcv1_train.binary ~~~ 1. before: ~1m 2. after: ~30s CC: jkbradley Author: Xiangrui Meng Closes #1849 from mengxr/ml-blas and squashes the following commits: ba583a2 [Xiangrui Meng] exclude Vector.copy a4d7d2f [Xiangrui Meng] Merge branch 'master' into ml-blas 6edeab9 [Xiangrui Meng] address comments 940bdeb [Xiangrui Meng] rename MLlibBLAS to BLAS c2a38bc [Xiangrui Meng] enhance dot tests 4cfaac4 [Xiangrui Meng] add apache header 48d01d2 [Xiangrui Meng] add tests for zeros and copy 3b882b1 [Xiangrui Meng] use blas.scal in gradient 735eb23 [Xiangrui Meng] remove d from BLAS routines d2d7d3c [Xiangrui Meng] update gradient and lbfgs 7f78186 [Xiangrui Meng] add zeros to Vectors; add dscal and dcopy to BLAS 14e6645 [Xiangrui Meng] add ddot cbb8273 [Xiangrui Meng] add daxpy test 07db0bb [Xiangrui Meng] Merge branch 'master' into ml-blas e8c326d [Xiangrui Meng] axpy --- .../org/apache/spark/mllib/linalg/BLAS.scala | 200 ++++++++++++++++++ .../apache/spark/mllib/linalg/Vectors.scala | 35 ++- .../spark/mllib/optimization/Gradient.scala | 60 ++---- .../spark/mllib/optimization/LBFGS.scala | 39 ++-- .../apache/spark/mllib/linalg/BLASSuite.scala | 129 +++++++++++ .../spark/mllib/linalg/VectorsSuite.scala | 30 +++ project/MimaExcludes.scala | 5 +- 7 files changed, 432 insertions(+), 66 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala new file mode 100644 index 0000000000000..70e23033c8754 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} + +/** + * BLAS routines for MLlib's vectors and matrices. + */ +private[mllib] object BLAS extends Serializable { + + @transient private var _f2jBLAS: NetlibBLAS = _ + + // For level-1 routines, we use Java implementation. + private def f2jBLAS: NetlibBLAS = { + if (_f2jBLAS == null) { + _f2jBLAS = new F2jBLAS + } + _f2jBLAS + } + + /** + * y += a * x + */ + def axpy(a: Double, x: Vector, y: Vector): Unit = { + require(x.size == y.size) + y match { + case dy: DenseVector => + x match { + case sx: SparseVector => + axpy(a, sx, dy) + case dx: DenseVector => + axpy(a, dx, dy) + case _ => + throw new UnsupportedOperationException( + s"axpy doesn't support x type ${x.getClass}.") + } + case _ => + throw new IllegalArgumentException( + s"axpy only supports adding to a dense vector but got type ${y.getClass}.") + } + } + + /** + * y += a * x + */ + private def axpy(a: Double, x: DenseVector, y: DenseVector): Unit = { + val n = x.size + f2jBLAS.daxpy(n, a, x.values, 1, y.values, 1) + } + + /** + * y += a * x + */ + private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = { + val nnz = x.indices.size + if (a == 1.0) { + var k = 0 + while (k < nnz) { + y.values(x.indices(k)) += x.values(k) + k += 1 + } + } else { + var k = 0 + while (k < nnz) { + y.values(x.indices(k)) += a * x.values(k) + k += 1 + } + } + } + + /** + * dot(x, y) + */ + def dot(x: Vector, y: Vector): Double = { + require(x.size == y.size) + (x, y) match { + case (dx: DenseVector, dy: DenseVector) => + dot(dx, dy) + case (sx: SparseVector, dy: DenseVector) => + dot(sx, dy) + case (dx: DenseVector, sy: SparseVector) => + dot(sy, dx) + case (sx: SparseVector, sy: SparseVector) => + dot(sx, sy) + case _ => + throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).") + } + } + + /** + * dot(x, y) + */ + private def dot(x: DenseVector, y: DenseVector): Double = { + val n = x.size + f2jBLAS.ddot(n, x.values, 1, y.values, 1) + } + + /** + * dot(x, y) + */ + private def dot(x: SparseVector, y: DenseVector): Double = { + val nnz = x.indices.size + var sum = 0.0 + var k = 0 + while (k < nnz) { + sum += x.values(k) * y.values(x.indices(k)) + k += 1 + } + sum + } + + /** + * dot(x, y) + */ + private def dot(x: SparseVector, y: SparseVector): Double = { + var kx = 0 + val nnzx = x.indices.size + var ky = 0 + val nnzy = y.indices.size + var sum = 0.0 + // y catching x + while (kx < nnzx && ky < nnzy) { + val ix = x.indices(kx) + while (ky < nnzy && y.indices(ky) < ix) { + ky += 1 + } + if (ky < nnzy && y.indices(ky) == ix) { + sum += x.values(kx) * y.values(ky) + ky += 1 + } + kx += 1 + } + sum + } + + /** + * y = x + */ + def copy(x: Vector, y: Vector): Unit = { + val n = y.size + require(x.size == n) + y match { + case dy: DenseVector => + x match { + case sx: SparseVector => + var i = 0 + var k = 0 + val nnz = sx.indices.size + while (k < nnz) { + val j = sx.indices(k) + while (i < j) { + dy.values(i) = 0.0 + i += 1 + } + dy.values(i) = sx.values(k) + i += 1 + k += 1 + } + while (i < n) { + dy.values(i) = 0.0 + i += 1 + } + case dx: DenseVector => + Array.copy(dx.values, 0, dy.values, 0, n) + } + case _ => + throw new IllegalArgumentException(s"y must be dense in copy but got ${y.getClass}") + } + } + + /** + * x = a * x + */ + def scal(a: Double, x: Vector): Unit = { + x match { + case sx: SparseVector => + f2jBLAS.dscal(sx.values.size, a, sx.values, 1) + case dx: DenseVector => + f2jBLAS.dscal(dx.values.size, a, dx.values, 1) + case _ => + throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 77b3e8c714997..a45781d12e41e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.linalg import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} -import java.util.Arrays +import java.util import scala.annotation.varargs import scala.collection.JavaConverters._ @@ -30,6 +30,8 @@ import org.apache.spark.SparkException /** * Represents a numeric vector, whose index type is Int and value type is Double. + * + * Note: Users should not implement this interface. */ trait Vector extends Serializable { @@ -46,12 +48,12 @@ trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { case v: Vector => - Arrays.equals(this.toArray, v.toArray) + util.Arrays.equals(this.toArray, v.toArray) case _ => false } } - override def hashCode(): Int = Arrays.hashCode(this.toArray) + override def hashCode(): Int = util.Arrays.hashCode(this.toArray) /** * Converts the instance to a breeze vector. @@ -63,6 +65,13 @@ trait Vector extends Serializable { * @param i index */ def apply(i: Int): Double = toBreeze(i) + + /** + * Makes a deep copy of this vector. + */ + def copy: Vector = { + throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") + } } /** @@ -127,6 +136,16 @@ object Vectors { }.toSeq) } + /** + * Creates a dense vector of all zeros. + * + * @param size vector size + * @return a zero vector + */ + def zeros(size: Int): Vector = { + new DenseVector(new Array[Double](size)) + } + /** * Parses a string resulted from `Vector#toString` into * an [[org.apache.spark.mllib.linalg.Vector]]. @@ -142,7 +161,7 @@ object Vectors { case Seq(size: Double, indices: Array[Double], values: Array[Double]) => Vectors.sparse(size.toInt, indices.map(_.toInt), values) case other => - throw new SparkException(s"Cannot parse $other.") + throw new SparkException(s"Cannot parse $other.") } } @@ -183,6 +202,10 @@ class DenseVector(val values: Array[Double]) extends Vector { private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) override def apply(i: Int) = values(i) + + override def copy: DenseVector = { + new DenseVector(values.clone()) + } } /** @@ -213,5 +236,9 @@ class SparseVector( data } + override def copy: SparseVector = { + new SparseVector(size, indices.clone(), values.clone()) + } + private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 9d82f011e674a..fdd67160114ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.optimization -import breeze.linalg.{axpy => brzAxpy} - import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} /** * :: DeveloperApi :: @@ -61,11 +60,10 @@ abstract class Gradient extends Serializable { @DeveloperApi class LogisticGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val margin: Double = -1.0 * brzWeights.dot(brzData) + val margin = -1.0 * dot(data, weights) val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - val gradient = brzData * gradientMultiplier + val gradient = data.copy + scal(gradientMultiplier, gradient) val loss = if (label > 0) { math.log1p(math.exp(margin)) // log1p is log(1+p) but more accurate for small p @@ -73,7 +71,7 @@ class LogisticGradient extends Gradient { math.log1p(math.exp(margin)) - margin } - (Vectors.fromBreeze(gradient), loss) + (gradient, loss) } override def compute( @@ -81,13 +79,9 @@ class LogisticGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val margin: Double = -1.0 * brzWeights.dot(brzData) + val margin = -1.0 * dot(data, weights) val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - - brzAxpy(gradientMultiplier, brzData, cumGradient.toBreeze) - + axpy(gradientMultiplier, data, cumGradient) if (label > 0) { math.log1p(math.exp(margin)) } else { @@ -106,13 +100,11 @@ class LogisticGradient extends Gradient { @DeveloperApi class LeastSquaresGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val diff = brzWeights.dot(brzData) - label + val diff = dot(data, weights) - label val loss = diff * diff - val gradient = brzData * (2.0 * diff) - - (Vectors.fromBreeze(gradient), loss) + val gradient = data.copy + scal(2.0 * diff, gradient) + (gradient, loss) } override def compute( @@ -120,12 +112,8 @@ class LeastSquaresGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val diff = brzWeights.dot(brzData) - label - - brzAxpy(2.0 * diff, brzData, cumGradient.toBreeze) - + val diff = dot(data, weights) - label + axpy(2.0 * diff, data, cumGradient) diff * diff } } @@ -139,18 +127,16 @@ class LeastSquaresGradient extends Gradient { @DeveloperApi class HingeGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val dotProduct = brzWeights.dot(brzData) - + val dotProduct = dot(data, weights) // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 - if (1.0 > labelScaled * dotProduct) { - (Vectors.fromBreeze(brzData * (-labelScaled)), 1.0 - labelScaled * dotProduct) + val gradient = data.copy + scal(-labelScaled, gradient) + (gradient, 1.0 - labelScaled * dotProduct) } else { - (Vectors.dense(new Array[Double](weights.size)), 0.0) + (Vectors.sparse(weights.size, Array.empty, Array.empty), 0.0) } } @@ -159,16 +145,12 @@ class HingeGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val dotProduct = brzWeights.dot(brzData) - + val dotProduct = dot(data, weights) // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 - if (1.0 > labelScaled * dotProduct) { - brzAxpy(-labelScaled, brzData, cumGradient.toBreeze) + axpy(-labelScaled, data, cumGradient) 1.0 - labelScaled * dotProduct } else { 0.0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 26a2b62e76ed0..033fe44f34f3c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV, axpy} +import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS.axpy import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: @@ -192,31 +193,29 @@ object LBFGS extends Logging { regParam: Double, numExamples: Long) extends DiffFunction[BDV[Double]] { - private var i = 0 - - override def calculate(weights: BDV[Double]) = { + override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { // Have a local copy to avoid the serialization of CostFun object which is not serializable. + val w = Vectors.fromBreeze(weights) + val n = w.size + val bcW = data.context.broadcast(w) val localGradient = gradient - val n = weights.length - val bcWeights = data.context.broadcast(weights) - val (gradientSum, lossSum) = data.treeAggregate((BDV.zeros[Double](n), 0.0))( + val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = localGradient.compute( - features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad)) + features, label, bcW.value, grad) (grad, loss + l) }, combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => - (grad1 += grad2, loss1 + loss2) + axpy(1.0, grad2, grad1) + (grad1, loss1 + loss2) }) /** * regVal is sum of weight squares if it's L2 updater; * for other updater, the same logic is followed. */ - val regVal = updater.compute( - Vectors.fromBreeze(weights), - Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 + val regVal = updater.compute(w, Vectors.zeros(n), 0, 1, regParam)._2 val loss = lossSum / numExamples + regVal /** @@ -236,17 +235,13 @@ object LBFGS extends Logging { */ // The following gradientTotal is actually the regularization part of gradient. // Will add the gradientSum computed from the data with weights in the next step. - val gradientTotal = weights - updater.compute( - Vectors.fromBreeze(weights), - Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze + val gradientTotal = w.copy + axpy(-1.0, updater.compute(w, Vectors.zeros(n), 1, 1, regParam)._1, gradientTotal) // gradientTotal = gradientSum / numExamples + gradientTotal axpy(1.0 / numExamples, gradientSum, gradientTotal) - i += 1 - - (loss, gradientTotal) + (loss, gradientTotal.toBreeze.asInstanceOf[BDV[Double]]) } } - } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala new file mode 100644 index 0000000000000..1952e6734ecf7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.linalg.BLAS._ + +class BLASSuite extends FunSuite { + + test("copy") { + val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0, 0.0) + val sy = Vectors.sparse(4, Array(0, 1, 3), Array(2.0, 1.0, 1.0)) + val dy = Array(2.0, 1.0, 0.0, 1.0) + + val dy1 = Vectors.dense(dy.clone()) + copy(sx, dy1) + assert(dy1 ~== dx absTol 1e-15) + + val dy2 = Vectors.dense(dy.clone()) + copy(dx, dy2) + assert(dy2 ~== dx absTol 1e-15) + + intercept[IllegalArgumentException] { + copy(sx, sy) + } + + intercept[IllegalArgumentException] { + copy(dx, sy) + } + + withClue("vector sizes must match") { + intercept[Exception] { + copy(sx, Vectors.dense(0.0, 1.0, 2.0)) + } + } + } + + test("scal") { + val a = 0.1 + val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0) + + scal(a, sx) + assert(sx ~== Vectors.sparse(3, Array(0, 2), Array(0.1, -0.2)) absTol 1e-15) + + scal(a, dx) + assert(dx ~== Vectors.dense(0.1, 0.0, -0.2) absTol 1e-15) + } + + test("axpy") { + val alpha = 0.1 + val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0) + val dy = Array(2.0, 1.0, 0.0) + val expected = Vectors.dense(2.1, 1.0, -0.2) + + val dy1 = Vectors.dense(dy.clone()) + axpy(alpha, sx, dy1) + assert(dy1 ~== expected absTol 1e-15) + + val dy2 = Vectors.dense(dy.clone()) + axpy(alpha, dx, dy2) + assert(dy2 ~== expected absTol 1e-15) + + val sy = Vectors.sparse(4, Array(0, 1), Array(2.0, 1.0)) + + intercept[IllegalArgumentException] { + axpy(alpha, sx, sy) + } + + intercept[IllegalArgumentException] { + axpy(alpha, dx, sy) + } + + withClue("vector sizes must match") { + intercept[Exception] { + axpy(alpha, sx, Vectors.dense(1.0, 2.0)) + } + } + } + + test("dot") { + val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0) + val sy = Vectors.sparse(3, Array(0, 1), Array(2.0, 1.0)) + val dy = Vectors.dense(2.0, 1.0, 0.0) + + assert(dot(sx, sy) ~== 2.0 absTol 1e-15) + assert(dot(sy, sx) ~== 2.0 absTol 1e-15) + assert(dot(sx, dy) ~== 2.0 absTol 1e-15) + assert(dot(dy, sx) ~== 2.0 absTol 1e-15) + assert(dot(dx, dy) ~== 2.0 absTol 1e-15) + assert(dot(dy, dx) ~== 2.0 absTol 1e-15) + + assert(dot(sx, sx) ~== 5.0 absTol 1e-15) + assert(dot(dx, dx) ~== 5.0 absTol 1e-15) + assert(dot(sx, dx) ~== 5.0 absTol 1e-15) + assert(dot(dx, sx) ~== 5.0 absTol 1e-15) + + val sx1 = Vectors.sparse(10, Array(0, 3, 5, 7, 8), Array(1.0, 2.0, 3.0, 4.0, 5.0)) + val sx2 = Vectors.sparse(10, Array(1, 3, 6, 7, 9), Array(1.0, 2.0, 3.0, 4.0, 5.0)) + assert(dot(sx1, sx2) ~== 20.0 absTol 1e-15) + assert(dot(sx2, sx1) ~== 20.0 absTol 1e-15) + + withClue("vector sizes must match") { + intercept[Exception] { + dot(sx, Vectors.dense(2.0, 1.0)) + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 7972ceea1fe8a..cd651fe2d2ddf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -125,4 +125,34 @@ class VectorsSuite extends FunSuite { } } } + + test("zeros") { + assert(Vectors.zeros(3) === Vectors.dense(0.0, 0.0, 0.0)) + } + + test("Vector.copy") { + val sv = Vectors.sparse(4, Array(0, 2), Array(1.0, 2.0)) + val svCopy = sv.copy + (sv, svCopy) match { + case (sv: SparseVector, svCopy: SparseVector) => + assert(sv.size === svCopy.size) + assert(sv.indices === svCopy.indices) + assert(sv.values === svCopy.values) + assert(!sv.indices.eq(svCopy.indices)) + assert(!sv.values.eq(svCopy.values)) + case _ => + throw new RuntimeException(s"copy returned ${svCopy.getClass} on ${sv.getClass}.") + } + + val dv = Vectors.dense(1.0, 0.0, 2.0) + val dvCopy = dv.copy + (dv, dvCopy) match { + case (dv: DenseVector, dvCopy: DenseVector) => + assert(dv.size === dvCopy.size) + assert(dv.values === dvCopy.values) + assert(!dv.values.eq(dvCopy.values)) + case _ => + throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.") + } + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b4653c72c10b5..6e72035f2c15b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -111,9 +111,12 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") ) ++ - Seq ( // package-private classes removed in MLlib + Seq( // package-private classes removed in MLlib ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") + ) ++ + Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") ) case v if v.startsWith("1.0") => Seq( From f0060b75ff67ab60babf54149a6860edc53cb6e9 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Tue, 12 Aug 2014 00:28:00 -0700 Subject: [PATCH 56/83] [MLlib] Correctly set vectorSize and alpha mengxr Correctly set vectorSize and alpha in Word2Vec training. Author: Liquan Pei Closes #1900 from Ishiihara/Word2Vec-bugfix and squashes the following commits: 85f64f2 [Liquan Pei] correctly set vectorSize and alpha --- .../apache/spark/mllib/feature/Word2Vec.scala | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 395037e1ec47c..ecd49ea2ff533 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -119,7 +119,6 @@ class Word2Vec extends Serializable with Logging { private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 private val MAX_SENTENCE_LENGTH = 1000 - private val layer1Size = vectorSize /** context words from [-window, window] */ private val window = 5 @@ -131,7 +130,6 @@ class Word2Vec extends Serializable with Logging { private var vocabSize = 0 private var vocab: Array[VocabWord] = null private var vocabHash = mutable.HashMap.empty[String, Int] - private var alpha = startingAlpha private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) @@ -287,9 +285,10 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) var syn0Global = - Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size) - var syn1Global = new Array[Float](vocabSize * layer1Size) + Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) + var syn1Global = new Array[Float](vocabSize * vectorSize) + var alpha = startingAlpha for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) @@ -317,24 +316,24 @@ class Word2Vec extends Serializable with Logging { val c = pos - window + a if (c >= 0 && c < sentence.size) { val lastWord = sentence(c) - val l1 = lastWord * layer1Size - val neu1e = new Array[Float](layer1Size) + val l1 = lastWord * vectorSize + val neu1e = new Array[Float](vectorSize) // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { - val l2 = bcVocab.value(word).point(d) * layer1Size + val l2 = bcVocab.value(word).point(d) * vectorSize // Propagate hidden -> output - var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1) + var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat - blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) - blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) + blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) } d += 1 } - blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1) + blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) } } a += 1 @@ -365,8 +364,8 @@ class Word2Vec extends Serializable with Logging { var i = 0 while (i < vocabSize) { val word = bcVocab.value(i).word - val vector = new Array[Float](layer1Size) - Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) + val vector = new Array[Float](vectorSize) + Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize) word2VecMap += word -> vector i += 1 } From 882da57a1c8c075a87909d516b169b624941a6ec Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 12 Aug 2014 16:26:01 -0700 Subject: [PATCH 57/83] fix flaky tests Python 2.6 does not handle float error well as 2.7+ Author: Davies Liu Closes #1910 from davies/fix_test and squashes the following commits: 7e51200 [Davies Liu] fix flaky tests --- python/pyspark/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 36040463e62a9..27f1d2ddf942a 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1094,7 +1094,7 @@ def applySchema(self, rdd, schema): ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + ... "float + 1.1 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1)] + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1...)] >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), From c235b83e2782cce0626ecc403c0a67e442be52c1 Mon Sep 17 00:00:00 2001 From: Ameet Talwalkar Date: Tue, 12 Aug 2014 17:15:21 -0700 Subject: [PATCH 58/83] SPARK-2830 [MLlib]: re-organize mllib documentation As per discussions with Xiangrui, I've reorganized and edited the mllib documentation. Author: Ameet Talwalkar Closes #1908 from atalwalkar/master and squashes the following commits: fe6938a [Ameet Talwalkar] made xiangruis suggested changes 840028b [Ameet Talwalkar] made xiangruis suggested changes 7ec366a [Ameet Talwalkar] reorganize and edit mllib documentation --- docs/mllib-basics.md | 117 +++++---------------- docs/mllib-classification-regression.md | 37 +++++++ docs/mllib-clustering.md | 15 +-- docs/mllib-collaborative-filtering.md | 21 ++-- docs/mllib-dimensionality-reduction.md | 44 ++++---- docs/mllib-feature-extraction.md | 12 +++ docs/mllib-guide.md | 30 +++--- docs/mllib-linear-methods.md | 134 ++++++++++++------------ docs/mllib-naive-bayes.md | 32 +++--- docs/mllib-stats.md | 95 +++++++++++++++++ 10 files changed, 317 insertions(+), 220 deletions(-) create mode 100644 docs/mllib-classification-regression.md create mode 100644 docs/mllib-feature-extraction.md create mode 100644 docs/mllib-stats.md diff --git a/docs/mllib-basics.md b/docs/mllib-basics.md index f9585251fafac..8752df412950a 100644 --- a/docs/mllib-basics.md +++ b/docs/mllib-basics.md @@ -9,17 +9,17 @@ displayTitle: MLlib - Basics MLlib supports local vectors and matrices stored on a single machine, as well as distributed matrices backed by one or more RDDs. -In the current implementation, local vectors and matrices are simple data models -to serve public interfaces. The underlying linear algebra operations are provided by +Local vectors and local matrices are simple data models +that serve as public interfaces. The underlying linear algebra operations are provided by [Breeze](http://www.scalanlp.org/) and [jblas](http://jblas.org/). -A training example used in supervised learning is called "labeled point" in MLlib. +A training example used in supervised learning is called a "labeled point" in MLlib. ## Local vector A local vector has integer-typed and 0-based indices and double-typed values, stored on a single machine. MLlib supports two types of local vectors: dense and sparse. A dense vector is backed by a double array representing its entry values, while a sparse vector is backed by two parallel -arrays: indices and values. For example, a vector $(1.0, 0.0, 3.0)$ can be represented in dense +arrays: indices and values. For example, a vector `(1.0, 0.0, 3.0)` can be represented in dense format as `[1.0, 0.0, 3.0]` or in sparse format as `(3, [0, 2], [1.0, 3.0])`, where `3` is the size of the vector. @@ -44,8 +44,7 @@ val sv1: Vector = Vectors.sparse(3, Array(0, 2), Array(1.0, 3.0)) val sv2: Vector = Vectors.sparse(3, Seq((0, 1.0), (2, 3.0))) {% endhighlight %} -***Note*** - +***Note:*** Scala imports `scala.collection.immutable.Vector` by default, so you have to import `org.apache.spark.mllib.linalg.Vector` explicitly to use MLlib's `Vector`. @@ -110,8 +109,8 @@ sv2 = sps.csc_matrix((np.array([1.0, 3.0]), np.array([0, 2]), np.array([0, 2])), A labeled point is a local vector, either dense or sparse, associated with a label/response. In MLlib, labeled points are used in supervised learning algorithms. We use a double to store a label, so we can use labeled points in both regression and classification. -For binary classification, label should be either $0$ (negative) or $1$ (positive). -For multiclass classification, labels should be class indices staring from zero: $0, 1, 2, \ldots$. +For binary classification, a label should be either `0` (negative) or `1` (positive). +For multiclass classification, labels should be class indices starting from zero: `0, 1, 2, ...`.
@@ -172,7 +171,7 @@ neg = LabeledPoint(0.0, SparseVector(3, [0, 2], [1.0, 3.0])) It is very common in practice to have sparse training data. MLlib supports reading training examples stored in `LIBSVM` format, which is the default format used by [`LIBSVM`](http://www.csie.ntu.edu.tw/~cjlin/libsvm/) and -[`LIBLINEAR`](http://www.csie.ntu.edu.tw/~cjlin/liblinear/). It is a text format. Each line +[`LIBLINEAR`](http://www.csie.ntu.edu.tw/~cjlin/liblinear/). It is a text format in which each line represents a labeled sparse feature vector using the following format: ~~~ @@ -226,7 +225,7 @@ examples = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") ## Local matrix A local matrix has integer-typed row and column indices and double-typed values, stored on a single -machine. MLlib supports dense matrix, whose entry values are stored in a single double array in +machine. MLlib supports dense matrices, whose entry values are stored in a single double array in column major. For example, the following matrix `\[ \begin{pmatrix} 1.0 & 2.0 \\ 3.0 & 4.0 \\ @@ -234,7 +233,6 @@ column major. For example, the following matrix `\[ \begin{pmatrix} \end{pmatrix} \]` is stored in a one-dimensional array `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]` with the matrix size `(3, 2)`. -We are going to add sparse matrix in the next release.
@@ -242,7 +240,7 @@ We are going to add sparse matrix in the next release. The base class of local matrices is [`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide one implementation: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix). -Sparse matrix will be added in the next release. We recommend using the factory methods implemented +We recommend using the factory methods implemented in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices) to create local matrices. @@ -259,7 +257,7 @@ val dm: Matrix = Matrices.dense(3, 2, Array(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) The base class of local matrices is [`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide one implementation: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html). -Sparse matrix will be added in the next release. We recommend using the factory methods implemented +We recommend using the factory methods implemented in [`Matrices`](api/java/org/apache/spark/mllib/linalg/Matrices.html) to create local matrices. @@ -279,28 +277,30 @@ Matrix dm = Matrices.dense(3, 2, new double[] {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); A distributed matrix has long-typed row and column indices and double-typed values, stored distributively in one or more RDDs. It is very important to choose the right format to store large and distributed matrices. Converting a distributed matrix to a different format may require a -global shuffle, which is quite expensive. We implemented three types of distributed matrices in -this release and will add more types in the future. +global shuffle, which is quite expensive. Three types of distributed matrices have been implemented +so far. The basic type is called `RowMatrix`. A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, e.g., a collection of feature vectors. It is backed by an RDD of its rows, where each row is a local vector. -We assume that the number of columns is not huge for a `RowMatrix`. +We assume that the number of columns is not huge for a `RowMatrix` so that a single +local vector can be reasonably communicated to the driver and can also be stored / +operated on using a single node. An `IndexedRowMatrix` is similar to a `RowMatrix` but with row indices, -which can be used for identifying rows and joins. -A `CoordinateMatrix` is a distributed matrix stored in [coordinate list (COO)](https://en.wikipedia.org/wiki/Sparse_matrix) format, +which can be used for identifying rows and executing joins. +A `CoordinateMatrix` is a distributed matrix stored in [coordinate list (COO)](https://en.wikipedia.org/wiki/Sparse_matrix#Coordinate_list_.28COO.29) format, backed by an RDD of its entries. ***Note*** The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. -It is always error-prone to have non-deterministic RDDs. +In general the use of non-deterministic RDDs can lead to errors. ### RowMatrix A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, backed by an RDD -of its rows, where each row is a local vector. This is similar to `data matrix` in the context of -multivariate statistics. Since each row is represented by a local vector, the number of columns is +of its rows, where each row is a local vector. +Since each row is represented by a local vector, the number of columns is limited by the integer range but it should be much smaller in practice.
@@ -344,70 +344,10 @@ long n = mat.numCols();
-#### Multivariate summary statistics - -We provide column summary statistics for `RowMatrix`. -If the number of columns is not large, say, smaller than 3000, you can also compute -the covariance matrix as a local matrix, which requires $\mathcal{O}(n^2)$ storage where $n$ is the -number of columns. The total CPU time is $\mathcal{O}(m n^2)$, where $m$ is the number of rows, -which could be faster if the rows are sparse. - -
-
- -[`RowMatrix#computeColumnSummaryStatistics`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) returns an instance of -[`MultivariateStatisticalSummary`](api/scala/index.html#org.apache.spark.mllib.stat.MultivariateStatisticalSummary), -which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the -total count. - -{% highlight scala %} -import org.apache.spark.mllib.linalg.Matrix -import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.stat.MultivariateStatisticalSummary - -val mat: RowMatrix = ... // a RowMatrix - -// Compute column summary statistics. -val summary: MultivariateStatisticalSummary = mat.computeColumnSummaryStatistics() -println(summary.mean) // a dense vector containing the mean value for each column -println(summary.variance) // column-wise variance -println(summary.numNonzeros) // number of nonzeros in each column - -// Compute the covariance matrix. -val cov: Matrix = mat.computeCovariance() -{% endhighlight %} -
- -
- -[`RowMatrix#computeColumnSummaryStatistics`](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html#computeColumnSummaryStatistics()) returns an instance of -[`MultivariateStatisticalSummary`](api/java/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.html), -which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the -total count. - -{% highlight java %} -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; -import org.apache.spark.mllib.stat.MultivariateStatisticalSummary; - -RowMatrix mat = ... // a RowMatrix - -// Compute column summary statistics. -MultivariateStatisticalSummary summary = mat.computeColumnSummaryStatistics(); -System.out.println(summary.mean()); // a dense vector containing the mean value for each column -System.out.println(summary.variance()); // column-wise variance -System.out.println(summary.numNonzeros()); // number of nonzeros in each column - -// Compute the covariance matrix. -Matrix cov = mat.computeCovariance(); -{% endhighlight %} -
-
- ### IndexedRowMatrix An `IndexedRowMatrix` is similar to a `RowMatrix` but with meaningful row indices. It is backed by -an RDD of indexed rows, which each row is represented by its index (long-typed) and a local vector. +an RDD of indexed rows, so that each row is represented by its index (long-typed) and a local vector.
@@ -467,7 +407,7 @@ RowMatrix rowMat = mat.toRowMatrix(); A `CoordinateMatrix` is a distributed matrix backed by an RDD of its entries. Each entry is a tuple of `(i: Long, j: Long, value: Double)`, where `i` is the row index, `j` is the column index, and -`value` is the entry value. A `CoordinateMatrix` should be used only in the case when both +`value` is the entry value. A `CoordinateMatrix` should be used only when both dimensions of the matrix are huge and the matrix is very sparse.
@@ -477,9 +417,9 @@ A [`CoordinateMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.CoordinateMatrix) can be created from an `RDD[MatrixEntry]` instance, where [`MatrixEntry`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.MatrixEntry) is a -wrapper over `(Long, Long, Double)`. A `CoordinateMatrix` can be converted to a `IndexedRowMatrix` -with sparse rows by calling `toIndexedRowMatrix`. In this release, we do not provide other -computation for `CoordinateMatrix`. +wrapper over `(Long, Long, Double)`. A `CoordinateMatrix` can be converted to an `IndexedRowMatrix` +with sparse rows by calling `toIndexedRowMatrix`. Other computations for +`CoordinateMatrix` are not currently supported. {% highlight scala %} import org.apache.spark.mllib.linalg.distributed.{CoordinateMatrix, MatrixEntry} @@ -503,8 +443,9 @@ A [`CoordinateMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.html) can be created from a `JavaRDD` instance, where [`MatrixEntry`](api/java/org/apache/spark/mllib/linalg/distributed/MatrixEntry.html) is a -wrapper over `(long, long, double)`. A `CoordinateMatrix` can be converted to a `IndexedRowMatrix` -with sparse rows by calling `toIndexedRowMatrix`. +wrapper over `(long, long, double)`. A `CoordinateMatrix` can be converted to an `IndexedRowMatrix` +with sparse rows by calling `toIndexedRowMatrix`. Other computations for +`CoordinateMatrix` are not currently supported. {% highlight java %} import org.apache.spark.api.java.JavaRDD; diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md new file mode 100644 index 0000000000000..719cc95767b00 --- /dev/null +++ b/docs/mllib-classification-regression.md @@ -0,0 +1,37 @@ +--- +layout: global +title: Classification and Regression - MLlib +displayTitle: MLlib - Classification and Regression +--- + +MLlib supports various methods for +[binary classification](http://en.wikipedia.org/wiki/Binary_classification), +[multiclass +classification](http://en.wikipedia.org/wiki/Multiclass_classification), and +[regression analysis](http://en.wikipedia.org/wiki/Regression_analysis). The table below outlines +the supported algorithms for each type of problem. + + + + + + + + + + + + + + + + +
Problem TypeSupported Methods
Binary Classificationlinear SVMs, logistic regression, decision trees, naive Bayes
Multiclass Classificationdecision trees, naive Bayes
Regressionlinear least squares, Lasso, ridge regression, decision trees
+ +More details for these methods can be found here: + +* [Linear models](mllib-linear-methods.html) + * [binary classification (SVMs, logistic regression)](mllib-linear-methods.html#binary-classification) + * [linear regression (least squares, Lasso, ridge)](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression) +* [Decision trees](mllib-decision-tree.html) +* [Naive Bayes](mllib-naive-bayes.html) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 561de48910132..dfd9cd572888c 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -38,7 +38,7 @@ a given dataset, the algorithm returns the best clustering result).
-Following code snippets can be executed in `spark-shell`. +The following code snippets can be executed in `spark-shell`. In the following example after loading and parsing data, we use the [`KMeans`](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) object to cluster the data @@ -70,7 +70,7 @@ All of MLlib's methods use Java-friendly types, so you can import and call them way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by calling `.rdd()` on your `JavaRDD` object. A standalone application example -that is equivalent to the provided example in Scala is given bellow: +that is equivalent to the provided example in Scala is given below: {% highlight java %} import org.apache.spark.api.java.*; @@ -113,14 +113,15 @@ public class KMeansExample { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
-Following examples can be tested in the PySpark shell. +The following examples can be tested in the PySpark shell. In the following example after loading and parsing data, we use the KMeans object to cluster the data into two clusters. The number of desired clusters is passed to the algorithm. We then compute diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 0d28b5f7c89b3..ab10b2f01f87b 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -14,13 +14,13 @@ is commonly used for recommender systems. These techniques aim to fill in the missing entries of a user-item association matrix. MLlib currently supports model-based collaborative filtering, in which users and products are described by a small set of latent factors that can be used to predict missing entries. -In particular, we implement the [alternating least squares +MLlib uses the [alternating least squares (ALS)](http://dl.acm.org/citation.cfm?id=1608614) algorithm to learn these latent factors. The implementation in MLlib has the following parameters: * *numBlocks* is the number of blocks used to parallelize computation (set to -1 to auto-configure). -* *rank* is the number of latent factors in our model. +* *rank* is the number of latent factors in the model. * *iterations* is the number of iterations to run. * *lambda* specifies the regularization parameter in ALS. * *implicitPrefs* specifies whether to use the *explicit feedback* ALS variant or one adapted for @@ -86,8 +86,8 @@ val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => println("Mean Squared Error = " + MSE) {% endhighlight %} -If the rating matrix is derived from other source of information (i.e., it is inferred from -other signals), you can use the trainImplicit method to get better results. +If the rating matrix is derived from another source of information (e.g., it is inferred from +other signals), you can use the `trainImplicit` method to get better results. {% highlight scala %} val alpha = 0.01 @@ -174,10 +174,11 @@ public class CollaborativeFiltering { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
@@ -219,5 +220,5 @@ model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01) ## Tutorial -[AMP Camp](http://ampcamp.berkeley.edu/) provides a hands-on tutorial for -[personalized movie recommendation with MLlib](http://ampcamp.berkeley.edu/big-data-mini-course/movie-recommendation-with-mllib.html). +The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for +[personalized movie recommendation with MLlib](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 8e434998c15ea..065d646496131 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -9,9 +9,9 @@ displayTitle: MLlib - Dimensionality Reduction [Dimensionality reduction](http://en.wikipedia.org/wiki/Dimensionality_reduction) is the process of reducing the number of variables under consideration. -It is used to extract latent features from raw and noisy features, +It can be used to extract latent features from raw and noisy features or compress data while maintaining the structure. -In this release, we provide preliminary support for dimensionality reduction on tall-and-skinny matrices. +MLlib provides support for dimensionality reduction on tall-and-skinny matrices. ## Singular value decomposition (SVD) @@ -30,17 +30,17 @@ where * $V$ is an orthonormal matrix, whose columns are called right singular vectors. For large matrices, usually we don't need the complete factorization but only the top singular -values and its associated singular vectors. This can save storage, and more importantly, de-noise +values and its associated singular vectors. This can save storage, de-noise and recover the low-rank structure of the matrix. -If we keep the top $k$ singular values, then the dimensions of the return will be: +If we keep the top $k$ singular values, then the dimensions of the resulting low-rank matrix will be: * `$U$`: `$m \times k$`, * `$\Sigma$`: `$k \times k$`, * `$V$`: `$n \times k$`. -In this release, we provide SVD computation to row-oriented matrices that have only a few columns, -say, less than $1000$, but many rows, which we call *tall-and-skinny*. +MLlib provides SVD functionality to row-oriented matrices that have only a few columns, +say, less than $1000$, but many rows, i.e., *tall-and-skinny* matrices.
@@ -58,15 +58,10 @@ val s: Vector = svd.s // The singular values are stored in a local dense vector. val V: Matrix = svd.V // The V factor is a local dense matrix. {% endhighlight %} -Same code applies to `IndexedRowMatrix`. -The only difference that the `U` matrix becomes an `IndexedRowMatrix`. +The same code applies to `IndexedRowMatrix` if `U` is defined as an +`IndexedRowMatrix`.
-In order to run the following standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. - {% highlight java %} import java.util.LinkedList; @@ -104,8 +99,16 @@ public class SVD { } } {% endhighlight %} -Same code applies to `IndexedRowMatrix`. -The only difference that the `U` matrix becomes an `IndexedRowMatrix`. + +The same code applies to `IndexedRowMatrix` if `U` is defined as an +`IndexedRowMatrix`. + +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. +
@@ -116,7 +119,7 @@ statistical method to find a rotation such that the first coordinate has the lar possible, and each succeeding coordinate in turn has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. -In this release, we implement PCA for tall-and-skinny matrices stored in row-oriented format. +MLlib supports PCA for tall-and-skinny matrices stored in row-oriented format.
@@ -180,9 +183,10 @@ public class PCA { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md new file mode 100644 index 0000000000000..21453cb9cd8c9 --- /dev/null +++ b/docs/mllib-feature-extraction.md @@ -0,0 +1,12 @@ +--- +layout: global +title: Feature Extraction - MLlib +displayTitle: MLlib - Feature Extraction +--- + +* Table of contents +{:toc} + +## Word2Vec + +## TFIDF diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 95ee6bc96801f..23d5a0c4607af 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -3,18 +3,19 @@ layout: global title: Machine Learning Library (MLlib) --- -MLlib is a Spark implementation of some common machine learning algorithms and utilities, +MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, including classification, regression, clustering, collaborative -filtering, dimensionality reduction, as well as underlying optimization primitives: +filtering, dimensionality reduction, as well as underlying optimization primitives, as outlined below: -* [Basics](mllib-basics.html) - * data types +* [Data types](mllib-basics.html) +* [Basic statistics](mllib-stats.html) + * data generators + * stratified sampling * summary statistics -* Classification and regression - * [linear support vector machine (SVM)](mllib-linear-methods.html#linear-support-vector-machine-svm) - * [logistic regression](mllib-linear-methods.html#logistic-regression) - * [linear least squares, Lasso, and ridge regression](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression) - * [decision tree](mllib-decision-tree.html) + * hypothesis testing +* [Classification and regression](mllib-classification-regression.html) + * [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html) + * [decision trees](mllib-decision-tree.html) * [naive Bayes](mllib-naive-bayes.html) * [Collaborative filtering](mllib-collaborative-filtering.html) * alternating least squares (ALS) @@ -23,17 +24,18 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [Dimensionality reduction](mllib-dimensionality-reduction.html) * singular value decomposition (SVD) * principal component analysis (PCA) -* [Optimization](mllib-optimization.html) +* [Feature extraction and transformation](mllib-feature-extraction.html) +* [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) -MLlib is a new component under active development. +MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, -and we will provide migration guide between releases. +and the migration guide below will explain all changes between releases. # Dependencies -MLlib uses linear algebra packages [Breeze](http://www.scalanlp.org/), which depends on +MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), which depends on [netlib-java](https://github.com/fommil/netlib-java), and [jblas](https://github.com/mikiobraun/jblas). `netlib-java` and `jblas` depend on native Fortran routines. @@ -56,7 +58,7 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few breaking changes. If your data is sparse, please store it in a sparse format instead of dense to -take advantage of sparsity in both storage and computation. +take advantage of sparsity in both storage and computation. Details are described below.
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 254201147edc1..e504cd7f0f578 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -33,24 +33,24 @@ the task of finding a minimizer of a convex function `$f$` that depends on a var Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where the objective function is of the form `\begin{equation} - f(\wv) := - \frac1n \sum_{i=1}^n L(\wv;\x_i,y_i) + - \lambda\, R(\wv_i) + f(\wv) := \lambda\, R(\wv) + + \frac1n \sum_{i=1}^n L(\wv;\x_i,y_i) \label{eq:regPrimal} \ . \end{equation}` Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and `$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. -Several MLlib's classification and regression algorithms fall into this category, +Several of MLlib's classification and regression algorithms fall into this category, and are discussed here. The objective function `$f$` has two parts: -the loss that measures the error of the model on the training data, -and the regularizer that measures the complexity of the model. -The loss function `$L(\wv;.)$` must be a convex function in `$\wv$`. -The fixed regularization parameter `$\lambda \ge 0$` (`regParam` in the code) defines the trade-off -between the two goals of small loss and small model complexity. +the regularizer that controls the complexity of the model, +and the loss that measures the error of the model on the training data. +The loss function `$L(\wv;.)$` is typically a convex function in `$\wv$`. The +fixed regularization parameter `$\lambda \ge 0$` (`regParam` in the code) +defines the trade-off between the two goals of minimizing the loss (i.e., +training error) and minimizing model complexity (i.e., to avoid overfitting). ### Loss functions @@ -80,10 +80,10 @@ methods MLlib supports: ### Regularizers -The purpose of the [regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to -encourage simple models, by punishing the complexity of the model `$\wv$`, in order to e.g. avoid -over-fitting. -We support the following regularizers in MLlib: +The purpose of the +[regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to +encourage simple models and avoid overfitting. We support the following +regularizers in MLlib: @@ -106,27 +106,28 @@ Here `$\mathrm{sign}(\wv)$` is the vector consisting of the signs (`$\pm1$`) of of `$\wv$`. L2-regularized problems are generally easier to solve than L1-regularized due to smoothness. -However, L1 regularization can help promote sparsity in weights, leading to simpler models, which is -also used for feature selection. It is not recommended to train models without any regularization, +However, L1 regularization can help promote sparsity in weights leading to smaller and more interpretable models, the latter of which can be useful for feature selection. +It is not recommended to train models without any regularization, especially when the number of training examples is small. ## Binary classification -[Binary classification](http://en.wikipedia.org/wiki/Binary_classification) is to divide items into -two categories: positive and negative. MLlib supports two linear methods for binary classification: -linear support vector machine (SVM) and logistic regression. The training data set is represented -by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib. Note that, in the mathematical -formulation, a training label $y$ is either $+1$ (positive) or $-1$ (negative), which is convenient -for the formulation. *However*, the negative label is represented by $0$ in MLlib instead of $-1$, -to be consistent with multiclass labeling. +[Binary classification](http://en.wikipedia.org/wiki/Binary_classification) +aims to divide items into two categories: positive and negative. MLlib +supports two linear methods for binary classification: linear support vector +machines (SVMs) and logistic regression. For both methods, MLlib supports +L1 and L2 regularized variants. The training data set is represented by an RDD +of [LabeledPoint](mllib-data-types.html) in MLlib. Note that, in the +mathematical formulation in this guide, a training label $y$ is denoted as +either $+1$ (positive) or $-1$ (negative), which is convenient for the +formulation. *However*, the negative label is represented by $0$ in MLlib +instead of $-1$, to be consistent with multiclass labeling. -### Linear support vector machine (SVM) +### Linear support vector machines (SVMs) The [linear SVM](http://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM) -has become a standard choice for large-scale classification tasks. -The name "linear SVM" is actually ambiguous. -By "linear SVM", we mean specifically the linear method with the loss function in formulation -`$\eqref{eq:regPrimal}$` given by the hinge loss +is a standard method for large-scale classification tasks. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss function in the formulation given by the hinge loss: + `\[ L(\wv;\x,y) := \max \{0, 1-y \wv^T \x \}. \]` @@ -134,39 +135,44 @@ By default, linear SVMs are trained with an L2 regularization. We also support alternative L1 regularization. In this case, the problem becomes a [linear program](http://en.wikipedia.org/wiki/Linear_programming). -Linear SVM algorithm outputs a SVM model, which makes predictions based on the value of $\wv^T \x$. -By the default, if $\wv^T \x \geq 0$, the outcome is positive, or negative otherwise. -However, quite often in practice, the default threshold $0$ is not a good choice. -The threshold should be determined via model evaluation. +The linear SVMs algorithm outputs an SVM model. Given a new data point, +denoted by $\x$, the model makes predictions based on the value of $\wv^T \x$. +By the default, if $\wv^T \x \geq 0$ then the outcome is positive, and negative +otherwise. ### Logistic regression [Logistic regression](http://en.wikipedia.org/wiki/Logistic_regression) is widely used to predict a -binary response. It is a linear method with the loss function in formulation -`$\eqref{eq:regPrimal}$` given by the logistic loss +binary response. +It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss +function in the formulation given by the logistic loss: `\[ L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)). \]` -Logistic regression algorithm outputs a logistic regression model, which makes predictions by +The logistic regression algorithm outputs a logistic regression model. Given a +new data point, denoted by $\x$, the model makes predictions by applying the logistic function `\[ \mathrm{f}(z) = \frac{1}{1 + e^{-z}} \]` where $z = \wv^T \x$. -By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or negative otherwise. -For the same reason mentioned above, quite often in practice, this default threshold is not a good choice. -The threshold should be determined via model evaluation. +By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or +negative otherwise, though unlike linear SVMs, the raw output of the logistic regression +model, $\mathrm{f}(z)$, has a probabilistic interpretation (i.e., the probability +that $\x$ is positive). ### Evaluation metrics -MLlib supports common evaluation metrics for binary classification (not available in Python). This +MLlib supports common evaluation metrics for binary classification (not available in PySpark). +This includes precision, recall, [F-measure](http://en.wikipedia.org/wiki/F1_score), [receiver operating characteristic (ROC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic), precision-recall curve, and [area under the curves (AUC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve). -Among the metrics, area under ROC is commonly used to compare models and precision/recall/F-measure -can help determine the threshold to use. +AUC is commonly used to compare the performance of various models while +precision/recall/F-measure can help determine the appropriate threshold to use +for prediction purposes. ### Examples @@ -233,8 +239,7 @@ svmAlg.optimizer. val modelL1 = svmAlg.run(training) {% endhighlight %} -Similarly, you can use replace `SVMWithSGD` by -[`LogisticRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD). +[`LogisticRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD) can be used in a similar fashion as `SVMWithSGD`. @@ -318,10 +323,11 @@ svmAlg.optimizer() final SVMModel modelL1 = svmAlg.run(training.rdd()); {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
@@ -354,24 +360,22 @@ print("Training Error = " + str(trainErr)) ## Linear least squares, Lasso, and ridge regression -Linear least squares is a family of linear methods with the loss function in formulation -`$\eqref{eq:regPrimal}$` given by the squared loss +Linear least squares is the most common formulation for regression problems. +It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss +function in the formulation given by the squared loss: `\[ L(\wv;\x,y) := \frac{1}{2} (\wv^T \x - y)^2. \]` -Depending on the regularization type, we call the method -[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or simply -[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) if there -is no regularization, [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) if L2 -regularization is used, and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) if L1 -regularization is used. This average loss $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$ is also +Various related regression methods are derived by using different types of regularization: +[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or +[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses + no regularization; [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) uses L2 +regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) uses L1 +regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_error). -Note that the squared loss is sensitive to outliers. -Regularization or a robust alternative (e.g., $\ell_1$ regression) is usually necessary in practice. - ### Examples
@@ -379,7 +383,7 @@ Regularization or a robust alternative (e.g., $\ell_1$ regression) is usually ne
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. The example then uses LinearRegressionWithSGD to build a simple linear model to predict label -values. We compute the Mean Squared Error at the end to evaluate +values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). {% highlight scala %} @@ -407,9 +411,8 @@ val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean() println("training Mean Squared Error = " + MSE) {% endhighlight %} -Similarly you can use [`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) -and [`LassoWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD). +and [`LassoWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD) can be used in a similar fashion as `LinearRegressionWithSGD`.
@@ -479,16 +482,17 @@ public class LinearRegression { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. The example then uses LinearRegressionWithSGD to build a simple linear model to predict label -values. We compute the Mean Squared Error at the end to evaluate +values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). {% highlight python %} diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index b1650c83c98b9..86d94aebd9442 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -4,23 +4,23 @@ title: Naive Bayes - MLlib displayTitle: MLlib - Naive Bayes --- -Naive Bayes is a simple multiclass classification algorithm with the assumption of independence -between every pair of features. Naive Bayes can be trained very efficiently. Within a single pass to -the training data, it computes the conditional probability distribution of each feature given label, -and then it applies Bayes' theorem to compute the conditional probability distribution of label -given an observation and use it for prediction. For more details, please visit the Wikipedia page -[Naive Bayes classifier](http://en.wikipedia.org/wiki/Naive_Bayes_classifier). - -In MLlib, we implemented multinomial naive Bayes, which is typically used for document -classification. Within that context, each observation is a document, each feature represents a term, -whose value is the frequency of the term. For its formulation, please visit the Wikipedia page -[Multinomial Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) -or the section -[Naive Bayes text classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) -from the book Introduction to Information -Retrieval. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by +[Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) is a simple +multiclass classification algorithm with the assumption of independence between +every pair of features. Naive Bayes can be trained very efficiently. Within a +single pass to the training data, it computes the conditional probability +distribution of each feature given label, and then it applies Bayes' theorem to +compute the conditional probability distribution of label given an observation +and use it for prediction. + +MLlib supports [multinomial naive +Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes), +which is typically used for [document +classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Within that context, each observation is a document and each +feature represents a term whose value is the frequency of the term. +[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature -vectors are usually sparse. Please supply sparse vectors as input to take advantage of +vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of sparsity. Since the training data is only used once, it is not necessary to cache it. ## Examples diff --git a/docs/mllib-stats.md b/docs/mllib-stats.md new file mode 100644 index 0000000000000..ca9ef46c15186 --- /dev/null +++ b/docs/mllib-stats.md @@ -0,0 +1,95 @@ +--- +layout: global +title: Statistics Functionality - MLlib +displayTitle: MLlib - Statistics Functionality +--- + +* Table of contents +{:toc} + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + +## Data Generators + +## Stratified Sampling + +## Summary Statistics + +### Multivariate summary statistics + +We provide column summary statistics for `RowMatrix` (note: this functionality is not currently supported in `IndexedRowMatrix` or `CoordinateMatrix`). +If the number of columns is not large, e.g., on the order of thousands, then the +covariance matrix can also be computed as a local matrix, which requires $\mathcal{O}(n^2)$ storage where $n$ is the +number of columns. The total CPU time is $\mathcal{O}(m n^2)$, where $m$ is the number of rows, +and is faster if the rows are sparse. + +
+
+ +[`computeColumnSummaryStatistics()`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) returns an instance of +[`MultivariateStatisticalSummary`](api/scala/index.html#org.apache.spark.mllib.stat.MultivariateStatisticalSummary), +which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the +total count. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary + +val mat: RowMatrix = ... // a RowMatrix + +// Compute column summary statistics. +val summary: MultivariateStatisticalSummary = mat.computeColumnSummaryStatistics() +println(summary.mean) // a dense vector containing the mean value for each column +println(summary.variance) // column-wise variance +println(summary.numNonzeros) // number of nonzeros in each column + +// Compute the covariance matrix. +val cov: Matrix = mat.computeCovariance() +{% endhighlight %} +
+ +
+ +[`RowMatrix#computeColumnSummaryStatistics`](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html#computeColumnSummaryStatistics()) returns an instance of +[`MultivariateStatisticalSummary`](api/java/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.html), +which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the +total count. + +{% highlight java %} +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary; + +RowMatrix mat = ... // a RowMatrix + +// Compute column summary statistics. +MultivariateStatisticalSummary summary = mat.computeColumnSummaryStatistics(); +System.out.println(summary.mean()); // a dense vector containing the mean value for each column +System.out.println(summary.variance()); // column-wise variance +System.out.println(summary.numNonzeros()); // number of nonzeros in each column + +// Compute the covariance matrix. +Matrix cov = mat.computeCovariance(); +{% endhighlight %} +
+
+ + +## Hypothesis Testing From 676f98289dad61c091bb45bd35a2b9613b22d64a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 12 Aug 2014 22:50:29 -0700 Subject: [PATCH 59/83] [SPARK-2953] Allow using short names for io compression codecs Instead of requiring "org.apache.spark.io.LZ4CompressionCodec", it is easier for users if Spark just accepts "lz4", "lzf", "snappy". Author: Reynold Xin Closes #1873 from rxin/compressionCodecShortForm and squashes the following commits: 9f50962 [Reynold Xin] Specify short-form compression codec names first. 63f78ee [Reynold Xin] Updated configuration documentation. 47b3848 [Reynold Xin] [SPARK-2953] Allow using short names for io compression codecs --- .../org/apache/spark/io/CompressionCodec.scala | 11 +++++++++-- .../spark/io/CompressionCodecSuite.scala | 18 ++++++++++++++++++ docs/configuration.md | 8 +++++--- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 1b66218d86dd9..ef9c43ecf14f6 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -46,17 +46,24 @@ trait CompressionCodec { private[spark] object CompressionCodec { + + private val shortCompressionCodecNames = Map( + "lz4" -> classOf[LZ4CompressionCodec].getName, + "lzf" -> classOf[LZFCompressionCodec].getName, + "snappy" -> classOf[SnappyCompressionCodec].getName) + def createCodec(conf: SparkConf): CompressionCodec = { createCodec(conf, conf.get("spark.io.compression.codec", DEFAULT_COMPRESSION_CODEC)) } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { - val ctor = Class.forName(codecName, true, Utils.getContextOrSparkClassLoader) + val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) + val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader) .getConstructor(classOf[SparkConf]) ctor.newInstance(conf).asInstanceOf[CompressionCodec] } - val DEFAULT_COMPRESSION_CODEC = classOf[SnappyCompressionCodec].getName + val DEFAULT_COMPRESSION_CODEC = "snappy" } diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 3f882a724b047..25be7f25c21bb 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -56,15 +56,33 @@ class CompressionCodecSuite extends FunSuite { testCodec(codec) } + test("lz4 compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "lz4") + assert(codec.getClass === classOf[LZ4CompressionCodec]) + testCodec(codec) + } + test("lzf compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName) assert(codec.getClass === classOf[LZFCompressionCodec]) testCodec(codec) } + test("lzf compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "lzf") + assert(codec.getClass === classOf[LZFCompressionCodec]) + testCodec(codec) + } + test("snappy compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) assert(codec.getClass === classOf[SnappyCompressionCodec]) testCodec(codec) } + + test("snappy compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "snappy") + assert(codec.getClass === classOf[SnappyCompressionCodec]) + testCodec(codec) + } } diff --git a/docs/configuration.md b/docs/configuration.md index 617a72a021f6e..8136bd62ab6af 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -373,10 +373,12 @@ Apart from these, the following properties are also available, and may be useful
- + From 246cb3f158686348a698d1c0da3001c314727129 Mon Sep 17 00:00:00 2001 From: Raymond Liu Date: Tue, 12 Aug 2014 23:19:35 -0700 Subject: [PATCH 60/83] Use transferTo when copy merge files in ExternalSorter Since this is a file to file copy, using transferTo should be faster. Author: Raymond Liu Closes #1884 from colorant/externalSorter and squashes the following commits: 6e42f3c [Raymond Liu] More code into copyStream bfb496b [Raymond Liu] Use transferTo when copy merge files in ExternalSorter --- .../scala/org/apache/spark/util/Utils.scala | 29 ++++++++++++++----- .../util/collection/ExternalSorter.scala | 7 ++--- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c60be4f8a11d2..8cac5da644fa9 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -284,17 +284,32 @@ private[spark] object Utils extends Logging { /** Copy all data from an InputStream to an OutputStream */ def copyStream(in: InputStream, out: OutputStream, - closeStreams: Boolean = false) + closeStreams: Boolean = false): Long = { + var count = 0L try { - val buf = new Array[Byte](8192) - var n = 0 - while (n != -1) { - n = in.read(buf) - if (n != -1) { - out.write(buf, 0, n) + if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]) { + // When both streams are File stream, use transferTo to improve copy performance. + val inChannel = in.asInstanceOf[FileInputStream].getChannel() + val outChannel = out.asInstanceOf[FileOutputStream].getChannel() + val size = inChannel.size() + + // In case transferTo method transferred less data than we have required. + while (count < size) { + count += inChannel.transferTo(count, size - count, outChannel) + } + } else { + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = in.read(buf) + if (n != -1) { + out.write(buf, 0, n) + count += n + } } } + count } finally { if (closeStreams) { try { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b73d5e0cf1714..5d8a648d9551e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -745,12 +745,11 @@ private[spark] class ExternalSorter[K, V, C]( try { out = new FileOutputStream(outputFile) for (i <- 0 until numPartitions) { - val file = partitionWriters(i).fileSegment().file - in = new FileInputStream(file) - org.apache.spark.util.Utils.copyStream(in, out) + in = new FileInputStream(partitionWriters(i).fileSegment().file) + val size = org.apache.spark.util.Utils.copyStream(in, out, false) in.close() in = null - lengths(i) = file.length() + lengths(i) = size offsets(i + 1) = offsets(i) + lengths(i) } } finally { From 2bd812639c3d8c62a725fb7577365ef0816f2898 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 12 Aug 2014 23:43:36 -0700 Subject: [PATCH 61/83] [SPARK-1777 (partial)] bugfix: make size of requested memory correctly Author: Zhang, Liye Closes #1892 from liyezhang556520/lazy_memory_request and squashes the following commits: 335ab61 [Zhang, Liye] [SPARK-1777 (partial)] bugfix: make size of requested memory correctly --- .../src/main/scala/org/apache/spark/storage/MemoryStore.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 28f675c2bbb1e..0a09c24d61879 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -238,7 +238,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // If our vector's size has exceeded the threshold, request more memory val currentSize = vector.estimateSize() if (currentSize >= memoryThreshold) { - val amountToRequest = (currentSize * (memoryGrowthFactor - 1)).toLong + val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong // Hold the accounting lock, in case another thread concurrently puts a block that // takes up the unrolling space we just ensured here accountingLock.synchronized { @@ -254,7 +254,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } // New threshold is currentSize * memoryGrowthFactor - memoryThreshold = currentSize + amountToRequest + memoryThreshold += amountToRequest } } elementsUnrolled += 1 From fe4735958e62b1b32a01960503876000f3d2e520 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Tue, 12 Aug 2014 23:47:42 -0700 Subject: [PATCH 62/83] [SPARK-2993] [MLLib] colStats (wrapper around MultivariateStatisticalSummary) in Statistics For both Scala and Python. The ser/de util functions were moved out of `PythonMLLibAPI` and into their own object to avoid creating the `PythonMLLibAPI` object inside of `MultivariateStatisticalSummarySerialized`, which is then referenced inside of a method in `PythonMLLibAPI`. `MultivariateStatisticalSummarySerialized` was created to serialize the `Vector` fields in `MultivariateStatisticalSummary`. Author: Doris Xin Closes #1911 from dorx/colStats and squashes the following commits: 77b9924 [Doris Xin] developerAPI tag de9cbbe [Doris Xin] reviewer comments and moved more ser/de 459faba [Doris Xin] colStats in Statistics for both Scala and Python --- .../mllib/api/python/PythonMLLibAPI.scala | 532 ++++++++++-------- .../MatrixFactorizationModel.scala | 7 +- .../apache/spark/mllib/stat/Statistics.scala | 13 + .../api/python/PythonMLLibAPISuite.scala | 17 +- python/pyspark/mllib/stat.py | 66 ++- 5 files changed, 374 insertions(+), 261 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index ba7ccd8ce4b8b..18dc087856785 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -34,7 +34,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -48,182 +48,7 @@ import org.apache.spark.util.Utils */ @DeveloperApi class PythonMLLibAPI extends Serializable { - private val DENSE_VECTOR_MAGIC: Byte = 1 - private val SPARSE_VECTOR_MAGIC: Byte = 2 - private val DENSE_MATRIX_MAGIC: Byte = 3 - private val LABELED_POINT_MAGIC: Byte = 4 - - private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = { - require(bytes.length - offset >= 5, "Byte array too short") - val magic = bytes(offset) - if (magic == DENSE_VECTOR_MAGIC) { - deserializeDenseVector(bytes, offset) - } else if (magic == SPARSE_VECTOR_MAGIC) { - deserializeSparseVector(bytes, offset) - } else { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - } - - private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = { - require(bytes.length - offset == 8, "Wrong size byte array for Double") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - bb.getDouble - } - private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = { - val packetLength = bytes.length - offset - require(packetLength >= 5, "Byte array too short") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic) - val length = bb.getInt() - require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength) - val db = bb.asDoubleBuffer() - val ans = new Array[Double](length.toInt) - db.get(ans) - Vectors.dense(ans) - } - - private def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = { - val packetLength = bytes.length - offset - require(packetLength >= 9, "Byte array too short") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic) - val size = bb.getInt() - val nonZeros = bb.getInt() - require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength) - val ib = bb.asIntBuffer() - val indices = new Array[Int](nonZeros) - ib.get(indices) - bb.position(bb.position() + 4 * nonZeros) - val db = bb.asDoubleBuffer() - val values = new Array[Double](nonZeros) - db.get(values) - Vectors.sparse(size, indices, values) - } - - /** - * Returns an 8-byte array for the input Double. - * - * Note: we currently do not use a magic byte for double for storage efficiency. - * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety. - * The corresponding deserializer, deserializeDouble, needs to be modified as well if the - * serialization scheme changes. - */ - private[python] def serializeDouble(double: Double): Array[Byte] = { - val bytes = new Array[Byte](8) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.putDouble(double) - bytes - } - - private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = { - val len = doubles.length - val bytes = new Array[Byte](5 + 8 * len) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(DENSE_VECTOR_MAGIC) - bb.putInt(len) - val db = bb.asDoubleBuffer() - db.put(doubles) - bytes - } - - private def serializeSparseVector(vector: SparseVector): Array[Byte] = { - val nonZeros = vector.indices.length - val bytes = new Array[Byte](9 + 12 * nonZeros) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(SPARSE_VECTOR_MAGIC) - bb.putInt(vector.size) - bb.putInt(nonZeros) - val ib = bb.asIntBuffer() - ib.put(vector.indices) - bb.position(bb.position() + 4 * nonZeros) - val db = bb.asDoubleBuffer() - db.put(vector.values) - bytes - } - - private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match { - case s: SparseVector => - serializeSparseVector(s) - case _ => - serializeDenseVector(vector.toArray) - } - - private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { - val packetLength = bytes.length - if (packetLength < 9) { - throw new IllegalArgumentException("Byte array too short.") - } - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - if (magic != DENSE_MATRIX_MAGIC) { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - val rows = bb.getInt() - val cols = bb.getInt() - if (packetLength != 9 + 8 * rows * cols) { - throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") - } - val db = bb.asDoubleBuffer() - val ans = new Array[Array[Double]](rows.toInt) - for (i <- 0 until rows.toInt) { - ans(i) = new Array[Double](cols.toInt) - db.get(ans(i)) - } - ans - } - - private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { - val rows = doubles.length - var cols = 0 - if (rows > 0) { - cols = doubles(0).length - } - val bytes = new Array[Byte](9 + 8 * rows * cols) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(DENSE_MATRIX_MAGIC) - bb.putInt(rows) - bb.putInt(cols) - val db = bb.asDoubleBuffer() - for (i <- 0 until rows) { - db.put(doubles(i)) - } - bytes - } - - private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = { - val fb = serializeDoubleVector(p.features) - val bytes = new Array[Byte](1 + 8 + fb.length) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(LABELED_POINT_MAGIC) - bb.putDouble(p.label) - bb.put(fb) - bytes - } - - private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = { - require(bytes.length >= 9, "Byte array too short") - val magic = bytes(0) - if (magic != LABELED_POINT_MAGIC) { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - val labelBytes = ByteBuffer.wrap(bytes, 1, 8) - labelBytes.order(ByteOrder.nativeOrder()) - val label = labelBytes.asDoubleBuffer().get(0) - LabeledPoint(label, deserializeDoubleVector(bytes, 9)) - } /** * Loads and serializes labeled points saved with `RDD#saveAsTextFile`. @@ -236,17 +61,17 @@ class PythonMLLibAPI extends Serializable { jsc: JavaSparkContext, path: String, minPartitions: Int): JavaRDD[Array[Byte]] = - MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(serializeLabeledPoint) + MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(SerDe.serializeLabeledPoint) private def trainRegressionModel( trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel, dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) - val initialWeights = deserializeDoubleVector(initialWeightsBA) + val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) + val initialWeights = SerDe.deserializeDoubleVector(initialWeightsBA) val model = trainFunc(data, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleVector(model.weights)) + ret.add(SerDe.serializeDoubleVector(model.weights)) ret.add(model.intercept: java.lang.Double) ret } @@ -405,12 +230,12 @@ class PythonMLLibAPI extends Serializable { def trainNaiveBayes( dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double): java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) + val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) val model = NaiveBayes.train(data, lambda) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleVector(Vectors.dense(model.labels))) - ret.add(serializeDoubleVector(Vectors.dense(model.pi))) - ret.add(serializeDoubleMatrix(model.theta)) + ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.labels))) + ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.pi))) + ret.add(SerDe.serializeDoubleMatrix(model.theta)) ret } @@ -423,52 +248,13 @@ class PythonMLLibAPI extends Serializable { maxIterations: Int, runs: Int, initializationMode: String): java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(bytes => deserializeDoubleVector(bytes)) + val data = dataBytesJRDD.rdd.map(bytes => SerDe.deserializeDoubleVector(bytes)) val model = KMeans.train(data, k, maxIterations, runs, initializationMode) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleMatrix(model.clusterCenters.map(_.toArray))) + ret.add(SerDe.serializeDoubleMatrix(model.clusterCenters.map(_.toArray))) ret } - /** Unpack a Rating object from an array of bytes */ - private def unpackRating(ratingBytes: Array[Byte]): Rating = { - val bb = ByteBuffer.wrap(ratingBytes) - bb.order(ByteOrder.nativeOrder()) - val user = bb.getInt() - val product = bb.getInt() - val rating = bb.getDouble() - new Rating(user, product, rating) - } - - /** Unpack a tuple of Ints from an array of bytes */ - private[spark] def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = { - val bb = ByteBuffer.wrap(tupleBytes) - bb.order(ByteOrder.nativeOrder()) - val v1 = bb.getInt() - val v2 = bb.getInt() - (v1, v2) - } - - /** - * Serialize a Rating object into an array of bytes. - * It can be deserialized using RatingDeserializer(). - * - * @param rate the Rating object to serialize - * @return - */ - private[spark] def serializeRating(rate: Rating): Array[Byte] = { - val len = 3 - val bytes = new Array[Byte](4 + 8 * len) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.putInt(len) - val db = bb.asDoubleBuffer() - db.put(rate.user.toDouble) - db.put(rate.product.toDouble) - db.put(rate.rating) - bytes - } - /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -481,7 +267,7 @@ class PythonMLLibAPI extends Serializable { iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { - val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating) ALS.train(ratings, rank, iterations, lambda, blocks) } @@ -498,7 +284,7 @@ class PythonMLLibAPI extends Serializable { lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { - val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating) ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } @@ -519,7 +305,7 @@ class PythonMLLibAPI extends Serializable { maxDepth: Int, maxBins: Int): DecisionTreeModel = { - val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) + val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) val algo = Algo.fromString(algoStr) val impurity = Impurities.fromString(impurityStr) @@ -545,7 +331,7 @@ class PythonMLLibAPI extends Serializable { def predictDecisionTreeModel( model: DecisionTreeModel, featuresBytes: Array[Byte]): Double = { - val features: Vector = deserializeDoubleVector(featuresBytes) + val features: Vector = SerDe.deserializeDoubleVector(featuresBytes) model.predict(features) } @@ -559,8 +345,17 @@ class PythonMLLibAPI extends Serializable { def predictDecisionTreeModel( model: DecisionTreeModel, dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { - val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) - model.predict(data).map(serializeDouble) + val data = dataJRDD.rdd.map(xBytes => SerDe.deserializeDoubleVector(xBytes)) + model.predict(data).map(SerDe.serializeDouble) + } + + /** + * Java stub for mllib Statistics.colStats(X: RDD[Vector]). + * TODO figure out return type. + */ + def colStats(X: JavaRDD[Array[Byte]]): MultivariateStatisticalSummarySerialized = { + val cStats = Statistics.colStats(X.rdd.map(SerDe.deserializeDoubleVector(_))) + new MultivariateStatisticalSummarySerialized(cStats) } /** @@ -569,17 +364,17 @@ class PythonMLLibAPI extends Serializable { * pyspark. */ def corr(X: JavaRDD[Array[Byte]], method: String): Array[Byte] = { - val inputMatrix = X.rdd.map(deserializeDoubleVector(_)) + val inputMatrix = X.rdd.map(SerDe.deserializeDoubleVector(_)) val result = Statistics.corr(inputMatrix, getCorrNameOrDefault(method)) - serializeDoubleMatrix(to2dArray(result)) + SerDe.serializeDoubleMatrix(SerDe.to2dArray(result)) } /** * Java stub for mllib Statistics.corr(x: RDD[Double], y: RDD[Double], method: String). */ def corr(x: JavaRDD[Array[Byte]], y: JavaRDD[Array[Byte]], method: String): Double = { - val xDeser = x.rdd.map(deserializeDouble(_)) - val yDeser = y.rdd.map(deserializeDouble(_)) + val xDeser = x.rdd.map(SerDe.deserializeDouble(_)) + val yDeser = y.rdd.map(SerDe.deserializeDouble(_)) Statistics.corr(xDeser, yDeser, getCorrNameOrDefault(method)) } @@ -588,12 +383,6 @@ class PythonMLLibAPI extends Serializable { if (method == null) CorrelationNames.defaultCorrName else method } - // Reformat a Matrix into Array[Array[Double]] for serialization - private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = { - val values = matrix.toArray - Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows)) - } - // Used by the *RDD methods to get default seed if not passed in from pyspark private def getSeedOrDefault(seed: java.lang.Long): Long = { if (seed == null) Utils.random.nextLong else seed @@ -621,7 +410,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.uniformRDD(jsc.sc, size, parts, s).map(serializeDouble) + RG.uniformRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble) } /** @@ -633,7 +422,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.normalRDD(jsc.sc, size, parts, s).map(serializeDouble) + RG.normalRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble) } /** @@ -646,7 +435,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.poissonRDD(jsc.sc, mean, size, parts, s).map(serializeDouble) + RG.poissonRDD(jsc.sc, mean, size, parts, s).map(SerDe.serializeDouble) } /** @@ -659,7 +448,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) } /** @@ -672,7 +461,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) } /** @@ -686,7 +475,256 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector) + RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) + } + +} + +/** + * :: DeveloperApi :: + * MultivariateStatisticalSummary with Vector fields serialized. + */ +@DeveloperApi +class MultivariateStatisticalSummarySerialized(val summary: MultivariateStatisticalSummary) + extends Serializable { + + def mean: Array[Byte] = SerDe.serializeDoubleVector(summary.mean) + + def variance: Array[Byte] = SerDe.serializeDoubleVector(summary.variance) + + def count: Long = summary.count + + def numNonzeros: Array[Byte] = SerDe.serializeDoubleVector(summary.numNonzeros) + + def max: Array[Byte] = SerDe.serializeDoubleVector(summary.max) + + def min: Array[Byte] = SerDe.serializeDoubleVector(summary.min) +} + +/** + * SerDe utility functions for PythonMLLibAPI. + */ +private[spark] object SerDe extends Serializable { + private val DENSE_VECTOR_MAGIC: Byte = 1 + private val SPARSE_VECTOR_MAGIC: Byte = 2 + private val DENSE_MATRIX_MAGIC: Byte = 3 + private val LABELED_POINT_MAGIC: Byte = 4 + + private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = { + require(bytes.length - offset >= 5, "Byte array too short") + val magic = bytes(offset) + if (magic == DENSE_VECTOR_MAGIC) { + deserializeDenseVector(bytes, offset) + } else if (magic == SPARSE_VECTOR_MAGIC) { + deserializeSparseVector(bytes, offset) + } else { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } } + private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = { + require(bytes.length - offset == 8, "Wrong size byte array for Double") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + bb.getDouble + } + + private[python] def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = { + val packetLength = bytes.length - offset + require(packetLength >= 5, "Byte array too short") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic) + val length = bb.getInt() + require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength) + val db = bb.asDoubleBuffer() + val ans = new Array[Double](length.toInt) + db.get(ans) + Vectors.dense(ans) + } + + private[python] def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = { + val packetLength = bytes.length - offset + require(packetLength >= 9, "Byte array too short") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic) + val size = bb.getInt() + val nonZeros = bb.getInt() + require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength) + val ib = bb.asIntBuffer() + val indices = new Array[Int](nonZeros) + ib.get(indices) + bb.position(bb.position() + 4 * nonZeros) + val db = bb.asDoubleBuffer() + val values = new Array[Double](nonZeros) + db.get(values) + Vectors.sparse(size, indices, values) + } + + /** + * Returns an 8-byte array for the input Double. + * + * Note: we currently do not use a magic byte for double for storage efficiency. + * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety. + * The corresponding deserializer, deserializeDouble, needs to be modified as well if the + * serialization scheme changes. + */ + private[python] def serializeDouble(double: Double): Array[Byte] = { + val bytes = new Array[Byte](8) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putDouble(double) + bytes + } + + private[python] def serializeDenseVector(doubles: Array[Double]): Array[Byte] = { + val len = doubles.length + val bytes = new Array[Byte](5 + 8 * len) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(DENSE_VECTOR_MAGIC) + bb.putInt(len) + val db = bb.asDoubleBuffer() + db.put(doubles) + bytes + } + + private[python] def serializeSparseVector(vector: SparseVector): Array[Byte] = { + val nonZeros = vector.indices.length + val bytes = new Array[Byte](9 + 12 * nonZeros) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(SPARSE_VECTOR_MAGIC) + bb.putInt(vector.size) + bb.putInt(nonZeros) + val ib = bb.asIntBuffer() + ib.put(vector.indices) + bb.position(bb.position() + 4 * nonZeros) + val db = bb.asDoubleBuffer() + db.put(vector.values) + bytes + } + + private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match { + case s: SparseVector => + serializeSparseVector(s) + case _ => + serializeDenseVector(vector.toArray) + } + + private[python] def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { + val packetLength = bytes.length + if (packetLength < 9) { + throw new IllegalArgumentException("Byte array too short.") + } + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + if (magic != DENSE_MATRIX_MAGIC) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val rows = bb.getInt() + val cols = bb.getInt() + if (packetLength != 9 + 8 * rows * cols) { + throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") + } + val db = bb.asDoubleBuffer() + val ans = new Array[Array[Double]](rows.toInt) + for (i <- 0 until rows.toInt) { + ans(i) = new Array[Double](cols.toInt) + db.get(ans(i)) + } + ans + } + + private[python] def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { + val rows = doubles.length + var cols = 0 + if (rows > 0) { + cols = doubles(0).length + } + val bytes = new Array[Byte](9 + 8 * rows * cols) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(DENSE_MATRIX_MAGIC) + bb.putInt(rows) + bb.putInt(cols) + val db = bb.asDoubleBuffer() + for (i <- 0 until rows) { + db.put(doubles(i)) + } + bytes + } + + private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = { + val fb = serializeDoubleVector(p.features) + val bytes = new Array[Byte](1 + 8 + fb.length) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(LABELED_POINT_MAGIC) + bb.putDouble(p.label) + bb.put(fb) + bytes + } + + private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = { + require(bytes.length >= 9, "Byte array too short") + val magic = bytes(0) + if (magic != LABELED_POINT_MAGIC) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val labelBytes = ByteBuffer.wrap(bytes, 1, 8) + labelBytes.order(ByteOrder.nativeOrder()) + val label = labelBytes.asDoubleBuffer().get(0) + LabeledPoint(label, deserializeDoubleVector(bytes, 9)) + } + + // Reformat a Matrix into Array[Array[Double]] for serialization + private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = { + val values = matrix.toArray + Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows)) + } + + + /** Unpack a Rating object from an array of bytes */ + private[python] def unpackRating(ratingBytes: Array[Byte]): Rating = { + val bb = ByteBuffer.wrap(ratingBytes) + bb.order(ByteOrder.nativeOrder()) + val user = bb.getInt() + val product = bb.getInt() + val rating = bb.getDouble() + new Rating(user, product, rating) + } + + /** Unpack a tuple of Ints from an array of bytes */ + def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = { + val bb = ByteBuffer.wrap(tupleBytes) + bb.order(ByteOrder.nativeOrder()) + val v1 = bb.getInt() + val v2 = bb.getInt() + (v1, v2) + } + + /** + * Serialize a Rating object into an array of bytes. + * It can be deserialized using RatingDeserializer(). + * + * @param rate the Rating object to serialize + * @return + */ + def serializeRating(rate: Rating): Array[Byte] = { + val len = 3 + val bytes = new Array[Byte](4 + 8 * len) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putInt(len) + val db = bb.asDoubleBuffer() + db.put(rate.user.toDouble) + db.put(rate.product.toDouble) + db.put(rate.rating) + bytes + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index a1a76fcbe9f9c..478c6485052b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.api.python.PythonMLLibAPI +import org.apache.spark.mllib.api.python.SerDe /** * Model representing the result of matrix factorization. @@ -117,9 +117,8 @@ class MatrixFactorizationModel private[mllib] ( */ @DeveloperApi def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { - val pythonAPI = new PythonMLLibAPI() - val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes)) - predict(usersProducts).map(rate => pythonAPI.serializeRating(rate)) + val usersProducts = usersProductsJRDD.rdd.map(xBytes => SerDe.unpackTuple(xBytes)) + predict(usersProducts).map(rate => SerDe.serializeRating(rate)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index cf8679610e191..3cf1028fbc725 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.correlation.Correlations @@ -30,6 +31,18 @@ import org.apache.spark.rdd.RDD @Experimental object Statistics { + /** + * :: Experimental :: + * Computes column-wise summary statistics for the input RDD[Vector]. + * + * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. + * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. + */ + @Experimental + def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { + new RowMatrix(X).computeColumnSummaryStatistics() + } + /** * :: Experimental :: * Compute the Pearson correlation matrix for the input RDD of Vectors. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index bd413a80f5107..092d67bbc5238 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint class PythonMLLibAPISuite extends FunSuite { - val py = new PythonMLLibAPI test("vector serialization") { val vectors = Seq( @@ -34,8 +33,8 @@ class PythonMLLibAPISuite extends FunSuite { Vectors.sparse(1, Array.empty[Int], Array.empty[Double]), Vectors.sparse(2, Array(1), Array(-2.0))) vectors.foreach { v => - val bytes = py.serializeDoubleVector(v) - val u = py.deserializeDoubleVector(bytes) + val bytes = SerDe.serializeDoubleVector(v) + val u = SerDe.deserializeDoubleVector(bytes) assert(u.getClass === v.getClass) assert(u === v) } @@ -50,8 +49,8 @@ class PythonMLLibAPISuite extends FunSuite { LabeledPoint(1.0, Vectors.sparse(1, Array.empty[Int], Array.empty[Double])), LabeledPoint(-0.5, Vectors.sparse(2, Array(1), Array(-2.0)))) points.foreach { p => - val bytes = py.serializeLabeledPoint(p) - val q = py.deserializeLabeledPoint(bytes) + val bytes = SerDe.serializeLabeledPoint(p) + val q = SerDe.deserializeLabeledPoint(bytes) assert(q.label === p.label) assert(q.features.getClass === p.features.getClass) assert(q.features === p.features) @@ -60,8 +59,8 @@ class PythonMLLibAPISuite extends FunSuite { test("double serialization") { for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) { - val bytes = py.serializeDouble(x) - val deser = py.deserializeDouble(bytes) + val bytes = SerDe.serializeDouble(x) + val deser = SerDe.deserializeDouble(bytes) // We use `equals` here for comparison because we cannot use `==` for NaN assert(x.equals(deser)) } @@ -70,14 +69,14 @@ class PythonMLLibAPISuite extends FunSuite { test("matrix to 2D array") { val values = Array[Double](0, 1.2, 3, 4.56, 7, 8) val matrix = Matrices.dense(2, 3, values) - val arr = py.to2dArray(matrix) + val arr = SerDe.to2dArray(matrix) val expected = Array(Array[Double](0, 3, 7), Array[Double](1.2, 4.56, 8)) assert(arr === expected) // Test conversion for empty matrix val empty = Array[Double]() val emptyMatrix = Matrices.dense(0, 0, empty) - val empty2D = py.to2dArray(emptyMatrix) + val empty2D = SerDe.to2dArray(emptyMatrix) assert(empty2D === Array[Array[Double]]()) } } diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 982906b9d09f0..a73abc5ff90df 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -22,11 +22,75 @@ from pyspark.mllib._common import \ _get_unmangled_double_vector_rdd, _get_unmangled_rdd, \ _serialize_double, _serialize_double_vector, \ - _deserialize_double, _deserialize_double_matrix + _deserialize_double, _deserialize_double_matrix, _deserialize_double_vector + + +class MultivariateStatisticalSummary(object): + + """ + Trait for multivariate statistical summary of a data matrix. + """ + + def __init__(self, sc, java_summary): + """ + :param sc: Spark context + :param java_summary: Handle to Java summary object + """ + self._sc = sc + self._java_summary = java_summary + + def __del__(self): + self._sc._gateway.detach(self._java_summary) + + def mean(self): + return _deserialize_double_vector(self._java_summary.mean()) + + def variance(self): + return _deserialize_double_vector(self._java_summary.variance()) + + def count(self): + return self._java_summary.count() + + def numNonzeros(self): + return _deserialize_double_vector(self._java_summary.numNonzeros()) + + def max(self): + return _deserialize_double_vector(self._java_summary.max()) + + def min(self): + return _deserialize_double_vector(self._java_summary.min()) class Statistics(object): + @staticmethod + def colStats(X): + """ + Computes column-wise summary statistics for the input RDD[Vector]. + + >>> from linalg import Vectors + >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]), + ... Vectors.dense([4, 5, 0, 3]), + ... Vectors.dense([6, 7, 0, 8])]) + >>> cStats = Statistics.colStats(rdd) + >>> cStats.mean() + array([ 4., 4., 0., 3.]) + >>> cStats.variance() + array([ 4., 13., 0., 25.]) + >>> cStats.count() + 3L + >>> cStats.numNonzeros() + array([ 3., 2., 0., 3.]) + >>> cStats.max() + array([ 6., 7., 0., 8.]) + >>> cStats.min() + array([ 2., 0., 0., -2.]) + """ + sc = X.ctx + Xser = _get_unmangled_double_vector_rdd(X) + cStats = sc._jvm.PythonMLLibAPI().colStats(Xser._jrdd) + return MultivariateStatisticalSummary(sc, cStats) + @staticmethod def corr(x, y=None, method=None): """ From 869f06c759c29b09c8dc72e0e4034c03f908ba30 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 13 Aug 2014 14:42:57 -0700 Subject: [PATCH 63/83] [SPARK-2963] [SQL] There no documentation about building to use HiveServer and CLI for SparkSQL Author: Kousuke Saruta Closes #1885 from sarutak/SPARK-2963 and squashes the following commits: ed53329 [Kousuke Saruta] Modified description and notaton of proper noun 07c59fc [Kousuke Saruta] Added a description about how to build to use HiveServer and CLI for SparkSQL to building-with-maven.md 6e6645a [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2963 c88fa93 [Kousuke Saruta] Added a description about building to use HiveServer and CLI for SparkSQL --- README.md | 9 +++++++++ docs/building-with-maven.md | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/README.md b/README.md index f87e07aa5cc90..a1a48f5bd0819 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,15 @@ If your project is built with Maven, add this to your POM file's ` +## A Note About Thrift JDBC server and CLI for Spark SQL + +Spark SQL supports Thrift JDBC server and CLI. +See sql-programming-guide.md for more information about those features. +You can use those features by setting `-Phive-thriftserver` when building Spark as follows. + + $ sbt/sbt -Phive-thriftserver assembly + + ## Configuration Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html) diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index 672d0ef114f6d..4d87ab92cec5b 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -96,6 +96,15 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package {% endhighlight %} +# Building Thrift JDBC server and CLI for Spark SQL + +Spark SQL supports Thrift JDBC server and CLI. +See sql-programming-guide.md for more information about those features. +You can use those features by setting `-Phive-thriftserver` when building Spark as follows. +{% highlight bash %} +mvn -Phive-thriftserver assembly +{% endhighlight %} + # Spark Tests in Maven Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). From c974a716e17c9fe2628b1ba1d4309ead1bd855ad Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 13 Aug 2014 14:56:11 -0700 Subject: [PATCH 64/83] [SPARK-3013] [SQL] [PySpark] convert array into list because Pyrolite does not support array from Python 2.6 Author: Davies Liu Closes #1928 from davies/fix_array and squashes the following commits: 858e6c5 [Davies Liu] convert array into list --- python/pyspark/sql.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 27f1d2ddf942a..46540ca3f1e8a 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -498,10 +498,7 @@ def _infer_schema(row): def _create_converter(obj, dataType): """Create an converter to drop the names of fields in obj """ - if not _has_struct(dataType): - return lambda x: x - - elif isinstance(dataType, ArrayType): + if isinstance(dataType, ArrayType): conv = _create_converter(obj[0], dataType.elementType) return lambda row: map(conv, row) @@ -510,6 +507,9 @@ def _create_converter(obj, dataType): conv = _create_converter(value, dataType.valueType) return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + elif not isinstance(dataType, StructType): + return lambda x: x + # dataType must be StructType names = [f.name for f in dataType.fields] @@ -529,8 +529,7 @@ def _create_converter(obj, dataType): elif hasattr(obj, "__dict__"): # object conv = lambda o: [o.__dict__.get(n, None) for n in names] - nested = any(_has_struct(f.dataType) for f in dataType.fields) - if not nested: + if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields): return conv row = conv(obj) @@ -1037,7 +1036,8 @@ def inferSchema(self, rdd): raise ValueError("The first row in RDD is empty, " "can not infer schema") if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated") + warnings.warn("Using RDD of dict to inferSchema is deprecated," + "please use pyspark.Row instead") schema = _infer_schema(first) rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) From 434bea1c002b597cff9db899da101490e1f1e9ed Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 13 Aug 2014 14:57:12 -0700 Subject: [PATCH 65/83] [SPARK-2983] [PySpark] improve performance of sortByKey() 1. skip partitionBy() when numOfPartition is 1 2. use bisect_left (O(lg(N))) instread of loop (O(N)) in rangePartitioner Author: Davies Liu Closes #1898 from davies/sort and squashes the following commits: 0a9608b [Davies Liu] Merge branch 'master' into sort 1cf9565 [Davies Liu] improve performance of sortByKey() --- python/pyspark/rdd.py | 47 ++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 756e8f35fb03d..3934bdda0a466 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -30,6 +30,7 @@ from threading import Thread import warnings import heapq +import bisect from random import Random from math import sqrt, log @@ -574,6 +575,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): # noqa >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + >>> sc.parallelize(tmp).sortByKey(True, 1).collect() + [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] >>> sc.parallelize(tmp).sortByKey(True, 2).collect() [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)] @@ -584,42 +587,40 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - bounds = list() + if numPartitions == 1: + if self.getNumPartitions() > 1: + self = self.coalesce(1) + + def sort(iterator): + return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) + + return self.mapPartitions(sort) # first compute the boundary of each part via sampling: we want to partition # the key-space into bins such that the bins have roughly the same # number of (key, value) pairs falling into them - if numPartitions > 1: - rddSize = self.count() - # constant from Spark's RangePartitioner - maxSampleSize = numPartitions * 20.0 - fraction = min(maxSampleSize / max(rddSize, 1), 1.0) - - samples = self.sample(False, fraction, 1).map( - lambda (k, v): k).collect() - samples = sorted(samples, reverse=(not ascending), key=keyfunc) - - # we have numPartitions many parts but one of the them has - # an implicit boundary - for i in range(0, numPartitions - 1): - index = (len(samples) - 1) * (i + 1) / numPartitions - bounds.append(samples[index]) + rddSize = self.count() + maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner + fraction = min(maxSampleSize / max(rddSize, 1), 1.0) + samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() + samples = sorted(samples, reverse=(not ascending), key=keyfunc) + + # we have numPartitions many parts but one of the them has + # an implicit boundary + bounds = [samples[len(samples) * (i + 1) / numPartitions] + for i in range(0, numPartitions - 1)] def rangePartitionFunc(k): - p = 0 - while p < len(bounds) and keyfunc(k) > bounds[p]: - p += 1 + p = bisect.bisect_left(bounds, keyfunc(k)) if ascending: return p else: return numPartitions - 1 - p def mapFunc(iterator): - yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) + return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) - return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc) - .mapPartitions(mapFunc, preservesPartitioning=True) - .flatMap(lambda x: x, preservesPartitioning=True)) + return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True) def sortBy(self, keyfunc, ascending=True, numPartitions=None): """ From 7ecb867c4cd6916b6cb12f2ece1a4c88591ad5b5 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 13 Aug 2014 16:20:49 -0700 Subject: [PATCH 66/83] [MLLIB] use Iterator.fill instead of Array.fill Iterator.fill uses less memory Author: Xiangrui Meng Closes #1930 from mengxr/rand-gen-iter and squashes the following commits: 24178ca [Xiangrui Meng] use Iterator.fill instead of Array.fill --- .../scala/org/apache/spark/mllib/rdd/RandomRDD.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index c8db3910c6eab..910eff9540a47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -105,16 +105,16 @@ private[mllib] object RandomRDD { def getPointIterator[T: ClassTag](partition: RandomRDDPartition[T]): Iterator[T] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) - Array.fill(partition.size)(generator.nextValue()).toIterator + Iterator.fill(partition.size)(generator.nextValue()) } // The RNG has to be reset every time the iterator is requested to guarantee same data // every time the content of the RDD is examined. - def getVectorIterator(partition: RandomRDDPartition[Double], - vectorSize: Int): Iterator[Vector] = { + def getVectorIterator( + partition: RandomRDDPartition[Double], + vectorSize: Int): Iterator[Vector] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) - Array.fill(partition.size)(new DenseVector( - (0 until vectorSize).map { _ => generator.nextValue() }.toArray)).toIterator + Iterator.fill(partition.size)(new DenseVector(Array.fill(vectorSize)(generator.nextValue()))) } } From bdc7a1a4749301f8d18617c130c7766684aa8789 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 13 Aug 2014 16:27:50 -0700 Subject: [PATCH 67/83] [SPARK-3004][SQL] Added null checking when retrieving row set JIRA issue: [SPARK-3004](https://issues.apache.org/jira/browse/SPARK-3004) HiveThriftServer2 throws exception when the result set contains `NULL`. Should check `isNullAt` in `SparkSQLOperationManager.getNextRowSet`. Note that simply using `row.addColumnValue(null)` doesn't work, since Hive set the column type of a null `ColumnValue` to String by default. Author: Cheng Lian Closes #1920 from liancheng/spark-3004 and squashes the following commits: 1b1db1c [Cheng Lian] Adding NULL column values in the Hive way 2217722 [Cheng Lian] Fixed SPARK-3004: added null checking when retrieving row set --- .../server/SparkSQLOperationManager.scala | 93 +++++++++++++------ .../data/files/small_kv_with_null.txt | 10 ++ .../thriftserver/HiveThriftServer2Suite.scala | 26 +++++- 3 files changed, 96 insertions(+), 33 deletions(-) create mode 100644 sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index f192f490ac3d0..9338e8121b0fe 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -73,35 +73,10 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage var curCol = 0 while (curCol < sparkRow.length) { - dataTypes(curCol) match { - case StringType => - row.addString(sparkRow(curCol).asInstanceOf[String]) - case IntegerType => - row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol))) - case BooleanType => - row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol))) - case DoubleType => - row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol))) - case FloatType => - row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol))) - case DecimalType => - val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal - row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) - case LongType => - row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol))) - case ByteType => - row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol))) - case ShortType => - row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol))) - case TimestampType => - row.addColumnValue( - ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp])) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = result - .queryExecution - .asInstanceOf[HiveContext#QueryExecution] - .toHiveString((sparkRow.get(curCol), dataTypes(curCol))) - row.addColumnValue(ColumnValue.stringValue(hiveString)) + if (sparkRow.isNullAt(curCol)) { + addNullColumnValue(sparkRow, row, curCol) + } else { + addNonNullColumnValue(sparkRow, row, curCol) } curCol += 1 } @@ -112,6 +87,66 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } } + def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to.addString(from(ordinal).asInstanceOf[String]) + case IntegerType => + to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal))) + case BooleanType => + to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal))) + case DoubleType => + to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) + case FloatType => + to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) + case DecimalType => + val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal + to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) + case LongType => + to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal))) + case ByteType => + to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) + case ShortType => + to.addColumnValue(ColumnValue.intValue(from.getShort(ordinal))) + case TimestampType => + to.addColumnValue( + ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = result + .queryExecution + .asInstanceOf[HiveContext#QueryExecution] + .toHiveString((from.get(ordinal), dataTypes(ordinal))) + to.addColumnValue(ColumnValue.stringValue(hiveString)) + } + } + + def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to.addString(null) + case IntegerType => + to.addColumnValue(ColumnValue.intValue(null)) + case BooleanType => + to.addColumnValue(ColumnValue.booleanValue(null)) + case DoubleType => + to.addColumnValue(ColumnValue.doubleValue(null)) + case FloatType => + to.addColumnValue(ColumnValue.floatValue(null)) + case DecimalType => + to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) + case LongType => + to.addColumnValue(ColumnValue.longValue(null)) + case ByteType => + to.addColumnValue(ColumnValue.byteValue(null)) + case ShortType => + to.addColumnValue(ColumnValue.intValue(null)) + case TimestampType => + to.addColumnValue(ColumnValue.timestampValue(null)) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + to.addColumnValue(ColumnValue.stringValue(null: String)) + } + } + def getResultSetSchema: TableSchema = { logWarning(s"Result Schema: ${result.queryExecution.analyzed.output}") if (result.queryExecution.analyzed.output.size == 0) { diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt new file mode 100644 index 0000000000000..ae08c640e6c13 --- /dev/null +++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt @@ -0,0 +1,10 @@ +238val_238 + +311val_311 +val_27 +val_165 +val_409 +255val_255 +278val_278 +98val_98 +val_484 diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index 78bffa2607349..aedef6ce1f5f2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -113,22 +113,40 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt val stmt = createStatement() stmt.execute("DROP TABLE IF EXISTS test") stmt.execute("DROP TABLE IF EXISTS test_cached") - stmt.execute("CREATE TABLE test(key int, val string)") + stmt.execute("CREATE TABLE test(key INT, val STRING)") stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test") - stmt.execute("CREATE TABLE test_cached as select * from test limit 4") + stmt.execute("CREATE TABLE test_cached AS SELECT * FROM test LIMIT 4") stmt.execute("CACHE TABLE test_cached") - var rs = stmt.executeQuery("select count(*) from test") + var rs = stmt.executeQuery("SELECT COUNT(*) FROM test") rs.next() assert(rs.getInt(1) === 5) - rs = stmt.executeQuery("select count(*) from test_cached") + rs = stmt.executeQuery("SELECT COUNT(*) FROM test_cached") rs.next() assert(rs.getInt(1) === 4) stmt.close() } + test("SPARK-3004 regression: result set containing NULL") { + Thread.sleep(5 * 1000) + val dataFilePath = getDataFile("data/files/small_kv_with_null.txt") + val stmt = createStatement() + stmt.execute("DROP TABLE IF EXISTS test_null") + stmt.execute("CREATE TABLE test_null(key INT, val STRING)") + stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null") + + val rs = stmt.executeQuery("SELECT * FROM test_null WHERE key IS NULL") + var count = 0 + while (rs.next()) { + count += 1 + } + assert(count === 5) + + stmt.close() + } + def getConnection: Connection = { val connectURI = s"jdbc:hive2://localhost:$PORT/" DriverManager.getConnection(connectURI, System.getProperty("user.name"), "") From 13f54e2b97744beab45e1bdbcdf8d215ca481b78 Mon Sep 17 00:00:00 2001 From: tianyi Date: Wed, 13 Aug 2014 16:50:02 -0700 Subject: [PATCH 68/83] [SPARK-2817] [SQL] add "show create table" support In spark sql component, the "show create table" syntax had been disabled. We thought it is a useful funciton to describe a hive table. Author: tianyi Author: tianyi Author: tianyi Closes #1760 from tianyi/spark-2817 and squashes the following commits: 7d28b15 [tianyi] [SPARK-2817] fix too short prefix problem cbffe8b [tianyi] [SPARK-2817] fix the case problem 565ec14 [tianyi] [SPARK-2817] fix the case problem 60d48a9 [tianyi] [SPARK-2817] use system temporary folder instead of temporary files in the source tree, and also clean some empty line dbe1031 [tianyi] [SPARK-2817] move some code out of function rewritePaths, as it may be called multiple times 9b2ba11 [tianyi] [SPARK-2817] fix the line length problem 9f97586 [tianyi] [SPARK-2817] remove test.tmp.dir from pom.xml bfc2999 [tianyi] [SPARK-2817] add "File.separator" support, create a "testTmpDir" outside the rewritePaths bde800a [tianyi] [SPARK-2817] add "${system:test.tmp.dir}" support add "last_modified_by" to nonDeterministicLineIndicators in HiveComparisonTest bb82726 [tianyi] [SPARK-2817] remove test which requires a system from the whitelist. bbf6b42 [tianyi] [SPARK-2817] add a systemProperties named "test.tmp.dir" to pass the test which contains "${system:test.tmp.dir}" a337bd6 [tianyi] [SPARK-2817] add "show create table" support a03db77 [tianyi] [SPARK-2817] add "show create table" support --- .../execution/HiveCompatibilitySuite.scala | 8 +++++++ .../org/apache/spark/sql/hive/HiveQl.scala | 1 + .../org/apache/spark/sql/hive/TestHive.scala | 8 +++++++ ...e_alter-0-813886d6cf0875c62e89cd1d06b8b0b4 | 0 ...e_alter-1-2a91d52719cf4552ebeb867204552a26 | 18 +++++++++++++++ ..._alter-10-259d978ed9543204c8b9c25b6e25b0de | 0 ...e_alter-2-928cc85c025440b731e5ee33e437e404 | 0 ...e_alter-3-2a91d52719cf4552ebeb867204552a26 | 22 +++++++++++++++++++ ...e_alter-4-c2cb6a7d942d4dddd1aababccb1239f9 | 0 ...e_alter-5-2a91d52719cf4552ebeb867204552a26 | 21 ++++++++++++++++++ ...le_alter-6-fdd1bd7f9acf0b2c8c9b7503d4046cb | 0 ...e_alter-7-2a91d52719cf4552ebeb867204552a26 | 21 ++++++++++++++++++ ...e_alter-8-22ab6ed5b15a018756f454dd2294847e | 0 ...e_alter-9-2a91d52719cf4552ebeb867204552a26 | 21 ++++++++++++++++++ ...b_table-0-67509558a4b2d39b25787cca33f52635 | 0 ...b_table-1-549981e00a3d95f03dd5a9ef6044aa20 | 2 ++ ...db_table-2-34ae7e611d0aedbc62b6e420347abee | 0 ...b_table-3-7a9e67189d3d4151f23b12c22bde06b5 | 0 ...b_table-4-b585371b624cbab2616a49f553a870a0 | 13 +++++++++++ ...b_table-5-964757b7e7f2a69fe36132c1a5712199 | 0 ...b_table-6-ac09cf81e7e734cf10406f30b9fa566e | 0 ...limited-0-97228478b9925f06726ceebb6571bf34 | 0 ...limited-1-2a91d52719cf4552ebeb867204552a26 | 17 ++++++++++++++ ...limited-2-259d978ed9543204c8b9c25b6e25b0de | 0 ...itioned-0-4be9a3b1ff0840786a1f001cba170a0c | 0 ...itioned-1-2a91d52719cf4552ebeb867204552a26 | 16 ++++++++++++++ ...itioned-2-259d978ed9543204c8b9c25b6e25b0de | 0 ...e_serde-0-33f15d91810b75ee05c7b9dea0abb01c | 0 ...e_serde-1-2a91d52719cf4552ebeb867204552a26 | 15 +++++++++++++ ...e_serde-2-259d978ed9543204c8b9c25b6e25b0de | 0 ...e_serde-3-fd12b3e0fe30f5d71c67676791b4a33b | 0 ...e_serde-4-2a91d52719cf4552ebeb867204552a26 | 14 ++++++++++++ ...e_serde-5-259d978ed9543204c8b9c25b6e25b0de | 0 ...le_view-0-ecef6821e4e9212e553ca38142fd0250 | 0 ...le_view-1-1e931ea3fa6065107859ffbb29bb0ed7 | 1 + ...le_view-2-ed97e9e56d95c5b3db57485cba5ad17f | 0 .../hive/execution/HiveComparisonTest.scala | 1 + 37 files changed, 199 insertions(+) create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-0-813886d6cf0875c62e89cd1d06b8b0b4 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-10-259d978ed9543204c8b9c25b6e25b0de create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-2-928cc85c025440b731e5ee33e437e404 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-4-c2cb6a7d942d4dddd1aababccb1239f9 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-6-fdd1bd7f9acf0b2c8c9b7503d4046cb create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-8-22ab6ed5b15a018756f454dd2294847e create mode 100644 sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-0-67509558a4b2d39b25787cca33f52635 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-2-34ae7e611d0aedbc62b6e420347abee create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-3-7a9e67189d3d4151f23b12c22bde06b5 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-5-964757b7e7f2a69fe36132c1a5712199 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_db_table-6-ac09cf81e7e734cf10406f30b9fa566e create mode 100644 sql/hive/src/test/resources/golden/show_create_table_delimited-0-97228478b9925f06726ceebb6571bf34 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_delimited-2-259d978ed9543204c8b9c25b6e25b0de create mode 100644 sql/hive/src/test/resources/golden/show_create_table_partitioned-0-4be9a3b1ff0840786a1f001cba170a0c create mode 100644 sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_partitioned-2-259d978ed9543204c8b9c25b6e25b0de create mode 100644 sql/hive/src/test/resources/golden/show_create_table_serde-0-33f15d91810b75ee05c7b9dea0abb01c create mode 100644 sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_serde-2-259d978ed9543204c8b9c25b6e25b0de create mode 100644 sql/hive/src/test/resources/golden/show_create_table_serde-3-fd12b3e0fe30f5d71c67676791b4a33b create mode 100644 sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_serde-5-259d978ed9543204c8b9c25b6e25b0de create mode 100644 sql/hive/src/test/resources/golden/show_create_table_view-0-ecef6821e4e9212e553ca38142fd0250 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 create mode 100644 sql/hive/src/test/resources/golden/show_create_table_view-2-ed97e9e56d95c5b3db57485cba5ad17f diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 4fef071161719..210753efe7678 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -635,6 +635,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "serde_regex", "serde_reported_schema", "set_variable_sub", + "show_create_table_partitioned", + "show_create_table_delimited", + "show_create_table_alter", + "show_create_table_view", + "show_create_table_serde", + "show_create_table_db_table", + "show_create_table_does_not_exist", + "show_create_table_index", "show_describe_func_quotes", "show_functions", "show_partitions", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 05b2f5f6cd3f7..1d9ba1b24a7a4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -53,6 +53,7 @@ private[hive] object HiveQl { protected val nativeCommands = Seq( "TOK_DESCFUNCTION", "TOK_DESCDATABASE", + "TOK_SHOW_CREATETABLE", "TOK_SHOW_TABLESTATUS", "TOK_SHOWDATABASES", "TOK_SHOWFUNCTIONS", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index d890df866fbe5..a013f3f7a805f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -70,6 +70,13 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { setConf("hive.metastore.warehouse.dir", warehousePath) } + val testTempDir = File.createTempFile("testTempFiles", "spark.hive.tmp") + testTempDir.delete() + testTempDir.mkdir() + + // For some hive test case which contain ${system:test.tmp.dir} + System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) + configure() // Must be called before initializing the catalog below. /** The location of the compiled hive distribution */ @@ -109,6 +116,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { hiveFilesTemp.mkdir() hiveFilesTemp.deleteOnExit() + val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) } else { diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-0-813886d6cf0875c62e89cd1d06b8b0b4 b/sql/hive/src/test/resources/golden/show_create_table_alter-0-813886d6cf0875c62e89cd1d06b8b0b4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..3c1fc128bedce --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,18 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132100') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-10-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_alter-10-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-2-928cc85c025440b731e5ee33e437e404 b/sql/hive/src/test/resources/golden/show_create_table_alter-2-928cc85c025440b731e5ee33e437e404 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..2ece813dd7d56 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,22 @@ +CREATE TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'temporary table' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'EXTERNAL'='FALSE', + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132100', + 'transient_lastDdlTime'='1407132100') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-4-c2cb6a7d942d4dddd1aababccb1239f9 b/sql/hive/src/test/resources/golden/show_create_table_alter-4-c2cb6a7d942d4dddd1aababccb1239f9 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..2af657bd29506 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,21 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'changed comment' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132100', + 'transient_lastDdlTime'='1407132100') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-6-fdd1bd7f9acf0b2c8c9b7503d4046cb b/sql/hive/src/test/resources/golden/show_create_table_alter-6-fdd1bd7f9acf0b2c8c9b7503d4046cb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..f793ffb7a0bfd --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,21 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'changed comment' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132101', + 'transient_lastDdlTime'='1407132101') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-8-22ab6ed5b15a018756f454dd2294847e b/sql/hive/src/test/resources/golden/show_create_table_alter-8-22ab6ed5b15a018756f454dd2294847e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..c65aff26a7fc1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,21 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'changed comment' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED BY + 'org.apache.hadoop.hive.ql.metadata.DefaultStorageHandler' +WITH SERDEPROPERTIES ( + 'serialization.format'='1') +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132101', + 'transient_lastDdlTime'='1407132101') diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-0-67509558a4b2d39b25787cca33f52635 b/sql/hive/src/test/resources/golden/show_create_table_db_table-0-67509558a4b2d39b25787cca33f52635 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 b/sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 new file mode 100644 index 0000000000000..707b2ae3ed1df --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 @@ -0,0 +1,2 @@ +default +tmp_feng diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-2-34ae7e611d0aedbc62b6e420347abee b/sql/hive/src/test/resources/golden/show_create_table_db_table-2-34ae7e611d0aedbc62b6e420347abee new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-3-7a9e67189d3d4151f23b12c22bde06b5 b/sql/hive/src/test/resources/golden/show_create_table_db_table-3-7a9e67189d3d4151f23b12c22bde06b5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 new file mode 100644 index 0000000000000..b5a18368ed85e --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 @@ -0,0 +1,13 @@ +CREATE TABLE tmp_feng.tmp_showcrt( + key string, + value int) +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_feng.db/tmp_showcrt' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132107') diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-5-964757b7e7f2a69fe36132c1a5712199 b/sql/hive/src/test/resources/golden/show_create_table_db_table-5-964757b7e7f2a69fe36132c1a5712199 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-6-ac09cf81e7e734cf10406f30b9fa566e b/sql/hive/src/test/resources/golden/show_create_table_db_table-6-ac09cf81e7e734cf10406f30b9fa566e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-0-97228478b9925f06726ceebb6571bf34 b/sql/hive/src/test/resources/golden/show_create_table_delimited-0-97228478b9925f06726ceebb6571bf34 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..d36ad25dc8273 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,17 @@ +CREATE TABLE tmp_showcrt1( + key int, + value string, + newvalue bigint) +ROW FORMAT DELIMITED + FIELDS TERMINATED BY ',' + COLLECTION ITEMS TERMINATED BY '|' + MAP KEYS TERMINATED BY '%' + LINES TERMINATED BY '\n' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132730') diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-2-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_delimited-2-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_partitioned-0-4be9a3b1ff0840786a1f001cba170a0c b/sql/hive/src/test/resources/golden/show_create_table_partitioned-0-4be9a3b1ff0840786a1f001cba170a0c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..9e572c0d7df6a --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,16 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key string, + newvalue boolean COMMENT 'a new value') +COMMENT 'temporary table' +PARTITIONED BY ( + value bigint COMMENT 'some value') +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132112') diff --git a/sql/hive/src/test/resources/golden/show_create_table_partitioned-2-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_partitioned-2-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-0-33f15d91810b75ee05c7b9dea0abb01c b/sql/hive/src/test/resources/golden/show_create_table_serde-0-33f15d91810b75ee05c7b9dea0abb01c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..69a38e1a7b20a --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,15 @@ +CREATE TABLE tmp_showcrt1( + key int, + value string, + newvalue bigint) +COMMENT 'temporary table' +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.hive.ql.io.RCFileInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.RCFileOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132115') diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-2-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_serde-2-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-3-fd12b3e0fe30f5d71c67676791b4a33b b/sql/hive/src/test/resources/golden/show_create_table_serde-3-fd12b3e0fe30f5d71c67676791b4a33b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..b4e693dc622fb --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,14 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key string, + value boolean) +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' +STORED BY + 'org.apache.hadoop.hive.ql.metadata.DefaultStorageHandler' +WITH SERDEPROPERTIES ( + 'serialization.format'='$', + 'field.delim'=',') +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132115') diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-5-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_serde-5-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_view-0-ecef6821e4e9212e553ca38142fd0250 b/sql/hive/src/test/resources/golden/show_create_table_view-0-ecef6821e4e9212e553ca38142fd0250 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 b/sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 new file mode 100644 index 0000000000000..be3fb3ce30960 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 @@ -0,0 +1 @@ +CREATE VIEW tmp_copy_src AS SELECT `src`.`key`, `src`.`value` FROM `default`.`src` diff --git a/sql/hive/src/test/resources/golden/show_create_table_view-2-ed97e9e56d95c5b3db57485cba5ad17f b/sql/hive/src/test/resources/golden/show_create_table_view-2-ed97e9e56d95c5b3db57485cba5ad17f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 0ebaf6ffd5458..502ce8fb297e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -161,6 +161,7 @@ abstract class HiveComparisonTest "transient_lastDdlTime", "grantTime", "lastUpdateTime", + "last_modified_by", "last_modified_time", "Owner:", // The following are hive specific schema parameters which we do not need to match exactly. From 9256d4a9c8c9ddb9ae6bbe3c3b99b03fb66b946b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 13 Aug 2014 17:35:38 -0700 Subject: [PATCH 69/83] [SPARK-2994][SQL] Support for udfs that take complex types Author: Michael Armbrust Closes #1915 from marmbrus/arrayUDF and squashes the following commits: a1c503d [Michael Armbrust] Support for udfs that take complex types --- .../spark/sql/hive/HiveInspectors.scala | 14 ++++++- .../org/apache/spark/sql/hive/hiveUdfs.scala | 41 +++++++++++-------- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 354fcd53f303b..943bbaa8ce25e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -71,6 +71,9 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) + + // Hive seems to return this for struct types? + case c: Class[_] if c == classOf[java.lang.Object] => NullType } /** Converts hive types to native catalyst types. */ @@ -147,7 +150,10 @@ private[hive] trait HiveInspectors { case t: java.sql.Timestamp => t case s: Seq[_] => seqAsJavaList(s.map(wrap)) case m: Map[_,_] => - mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) + // Some UDFs seem to assume we pass in a HashMap. + val hashMap = new java.util.HashMap[AnyRef, AnyRef]() + hashMap.putAll(m.map { case (k, v) => wrap(k) -> wrap(v) }) + hashMap case null => null } @@ -214,6 +220,12 @@ private[hive] trait HiveInspectors { import TypeInfoFactory._ def toTypeInfo: TypeInfo = dt match { + case ArrayType(elemType, _) => + getListTypeInfo(elemType.toTypeInfo) + case StructType(fields) => + getStructTypeInfo(fields.map(_.name), fields.map(_.dataType.toTypeInfo)) + case MapType(keyType, valueType, _) => + getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo case BooleanType => booleanTypeInfo case ByteType => byteTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 179aac5cbd5cd..c6497a15efa0c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -55,7 +55,10 @@ private[hive] abstract class HiveFunctionRegistry HiveSimpleUdf( functionClassName, - children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) } + children.zip(expectedDataTypes).map { + case (e, NullType) => e + case (e, t) => Cast(e, t) + } ) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUdf(functionClassName, children) @@ -115,22 +118,26 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ c.getParameterTypes.size == 1 && primitiveClasses.contains(c.getParameterTypes.head) } - val constructor = matchingConstructor.getOrElse( - sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}.")) - - (a: Any) => { - logDebug( - s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.") - // We must make sure that primitives get boxed java style. - if (a == null) { - null - } else { - constructor.newInstance(a match { - case i: Int => i: java.lang.Integer - case bd: BigDecimal => new HiveDecimal(bd.underlying()) - case other: AnyRef => other - }).asInstanceOf[AnyRef] - } + matchingConstructor match { + case Some(constructor) => + (a: Any) => { + logDebug( + s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} $constructor.") + // We must make sure that primitives get boxed java style. + if (a == null) { + null + } else { + constructor.newInstance(a match { + case i: Int => i: java.lang.Integer + case bd: BigDecimal => new HiveDecimal(bd.underlying()) + case other: AnyRef => other + }).asInstanceOf[AnyRef] + } + } + case None => + (a: Any) => a match { + case wrapper => wrap(wrapper) + } } } From 376a82e196e102ef49b9722e8be0b01ac5890a8b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 13 Aug 2014 17:37:55 -0700 Subject: [PATCH 70/83] [SPARK-2650][SQL] More precise initial buffer size estimation for in-memory column buffer This is a follow up of #1880. Since the row number within a single batch is known, we can estimate a much more precise initial buffer size when building an in-memory column buffer. Author: Cheng Lian Closes #1901 from liancheng/precise-init-buffer-size and squashes the following commits: d5501fa [Cheng Lian] More precise initial buffer size estimation for in-memory column buffer --- .../sql/columnar/InMemoryColumnarTableScan.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 3364d0e18bcc9..e63b4903041f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -20,12 +20,11 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{SparkPlan, LeafNode} -import org.apache.spark.sql.Row -import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.{LeafNode, SparkPlan} object InMemoryRelation { def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation = @@ -48,7 +47,9 @@ private[sql] case class InMemoryRelation( new Iterator[Array[ByteBuffer]] { def next() = { val columnBuilders = output.map { attribute => - ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name, useCompression) + val columnType = ColumnType(attribute.dataType) + val initialBufferSize = columnType.defaultSize * batchSize + ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression) }.toArray var row: Row = null From 9fde1ff5fc114b5edb755ed40944607419b62184 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 13 Aug 2014 17:40:59 -0700 Subject: [PATCH 71/83] [SPARK-2935][SQL]Fix parquet predicate push down bug Author: Michael Armbrust Closes #1863 from marmbrus/parquetPredicates and squashes the following commits: 10ad202 [Michael Armbrust] left <=> right f249158 [Michael Armbrust] quiet parquet tests. 802da5b [Michael Armbrust] Add test case. eab2eda [Michael Armbrust] Fix parquet predicate push down bug --- .../scala/org/apache/spark/sql/parquet/ParquetFilters.scala | 5 +++-- sql/core/src/test/resources/log4j.properties | 3 +++ .../org/apache/spark/sql/parquet/ParquetQuerySuite.scala | 5 ++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index cc575bedd8fcb..2298a9b933df5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -201,8 +201,9 @@ object ParquetFilters { (leftFilter, rightFilter) match { case (None, Some(filter)) => Some(filter) case (Some(filter), None) => Some(filter) - case (_, _) => - Some(new AndFilter(leftFilter.get, rightFilter.get)) + case (Some(leftF), Some(rightF)) => + Some(new AndFilter(leftF, rightF)) + case _ => None } } case p @ EqualTo(left: Literal, right: NamedExpression) if !right.nullable => diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index dffd15a61838b..c7e0ff1cf6494 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -36,6 +36,9 @@ log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c{1}: %m%n log4j.appender.FA.Threshold = INFO # Some packages are noisy for no good reason. +log4j.additivity.parquet.hadoop.ParquetRecordReader=false +log4j.logger.parquet.hadoop.ParquetRecordReader=OFF + log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 9933575038bd3..502f6702e394e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -381,11 +381,14 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val predicate5 = new GreaterThan(attribute1, attribute2) val badfilter = ParquetFilters.createFilter(predicate5) assert(badfilter.isDefined === false) + + val predicate6 = And(GreaterThan(attribute1, attribute2), GreaterThan(attribute1, attribute2)) + val badfilter2 = ParquetFilters.createFilter(predicate6) + assert(badfilter2.isDefined === false) } test("test filter by predicate pushdown") { for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) { - println(s"testing field $myval") val query1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") assert( query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], From 905dc4b405e679feb145f5e6b35e952db2442e0d Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 13 Aug 2014 17:42:38 -0700 Subject: [PATCH 72/83] [SPARK-2970] [SQL] spark-sql script ends with IOException when EventLogging is enabled Author: Kousuke Saruta Closes #1891 from sarutak/SPARK-2970 and squashes the following commits: 4a2d2fe [Kousuke Saruta] Modified comment style 8bd833c [Kousuke Saruta] Modified style 6c0997c [Kousuke Saruta] Modified the timing of shutdown hook execution. It should be executed before shutdown hook of o.a.h.f.FileSystem --- .../sql/hive/thriftserver/SparkSQLCLIDriver.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 4d0c506c5a397..4ed0f58ebc531 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -26,6 +26,8 @@ import jline.{ConsoleReader, History} import org.apache.commons.lang.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} import org.apache.hadoop.hive.common.LogUtils.LogInitializationException import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} @@ -116,13 +118,17 @@ private[hive] object SparkSQLCLIDriver { SessionState.start(sessionState) // Clean up after we exit - Runtime.getRuntime.addShutdownHook( + /** + * This should be executed before shutdown hook of + * FileSystem to avoid race condition of FileSystem operation + */ + ShutdownHookManager.get.addShutdownHook( new Thread() { override def run() { SparkSQLEnv.stop() } } - ) + , FileSystem.SHUTDOWN_HOOK_PRIORITY - 1) // "-h" option has been passed, so connect to Hive thrift server. if (sessionState.getHost != null) { From 63d6777737ca8559d4344d1661500b8ad868bb47 Mon Sep 17 00:00:00 2001 From: guowei Date: Wed, 13 Aug 2014 17:45:24 -0700 Subject: [PATCH 73/83] [SPARK-2986] [SQL] fixed: setting properties does not effect it seems that set command does not run by SparkSQLDriver. it runs on hive api. user can not change reduce number by setting spark.sql.shuffle.partitions but i think setting hive properties seems just a role to spark sql. Author: guowei Closes #1904 from guowei2/temp-branch and squashes the following commits: 7d47dde [guowei] fixed: setting properties like spark.sql.shuffle.partitions does not effective --- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 4ed0f58ebc531..c16a7d3661c66 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket @@ -284,7 +284,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf) if (proc != null) { - if (proc.isInstanceOf[Driver]) { + if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor]) { val driver = new SparkSQLDriver driver.init() From 0c7b452904fe6b5a966a66b956369123d8a9dd4b Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 13 Aug 2014 18:08:38 -0700 Subject: [PATCH 74/83] SPARK-3020: Print completed indices rather than tasks in web UI Author: Patrick Wendell Closes #1933 from pwendell/speculation and squashes the following commits: 33a3473 [Patrick Wendell] Use OpenHashSet 8ce2ff0 [Patrick Wendell] SPARK-3020: Print completed indices rather than tasks in web UI --- .../scala/org/apache/spark/ui/jobs/JobProgressListener.scala | 1 + core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala | 2 +- core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index a57a354620163..a3e9566832d06 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -153,6 +153,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) = taskEnd.reason match { case org.apache.spark.Success => + stageData.completedIndices.add(info.index) stageData.numCompleteTasks += 1 (None, Option(taskEnd.taskMetrics)) case e: ExceptionFailure => // Handle ExceptionFailure because we might have metrics diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 3dcfaf76e4aba..15998404ed612 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -168,7 +168,7 @@ private[ui] class StageTableBase( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 85db15472a00c..a336bf7e1ed02 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -19,6 +19,7 @@ package org.apache.spark.ui.jobs import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} +import org.apache.spark.util.collection.OpenHashSet import scala.collection.mutable.HashMap @@ -38,6 +39,7 @@ private[jobs] object UIData { class StageUIData { var numActiveTasks: Int = _ var numCompleteTasks: Int = _ + var completedIndices = new OpenHashSet[Int]() var numFailedTasks: Int = _ var executorRunTime: Long = _ From 9497b12d429cf9d075807896637e40e205175203 Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Wed, 13 Aug 2014 22:17:07 -0700 Subject: [PATCH 75/83] [SPARK-3006] Failed to execute spark-shell in Windows OS Modified the order of the options and arguments in spark-shell.cmd Author: Masayoshi TSUZUKI Closes #1918 from tsudukim/feature/SPARK-3006 and squashes the following commits: 8bba494 [Masayoshi TSUZUKI] [SPARK-3006] Failed to execute spark-shell in Windows OS 1a32410 [Masayoshi TSUZUKI] [SPARK-3006] Failed to execute spark-shell in Windows OS --- bin/spark-shell.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index b56d69801171c..2ee60b4e2a2b3 100755 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -19,4 +19,4 @@ rem set SPARK_HOME=%~dp0.. -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell --class org.apache.spark.repl.Main %* +cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell From e4245656438d00714ebd59e89c4de3fdaae83494 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 13 Aug 2014 23:24:23 -0700 Subject: [PATCH 76/83] [Docs] Add missing tags (minor) These configs looked inconsistent from the rest. Author: Andrew Or Closes #1936 from andrewor14/docs-code and squashes the following commits: 15f578a [Andrew Or] Add tag --- docs/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 8136bd62ab6af..c8336b39133de 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -562,7 +562,7 @@ Apart from these, the following properties are also available, and may be useful - + - + + + + + +
spark.io.compression.codecorg.apache.spark.io.
SnappyCompressionCodec
snappy - The codec used to compress internal data such as RDD partitions and shuffle outputs. - By default, Spark provides three codecs: org.apache.spark.io.LZ4CompressionCodec, + The codec used to compress internal data such as RDD partitions and shuffle outputs. By default, + Spark provides three codecs: lz4, lzf, and snappy. You + can also use fully qualified class names to specify the codec, e.g. + org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, and org.apache.spark.io.SnappyCompressionCodec. {submissionTime} {formattedDuration} - {makeProgressBar(stageData.numActiveTasks, stageData.numCompleteTasks, + {makeProgressBar(stageData.numActiveTasks, stageData.completedIndices.size, stageData.numFailedTasks, s.numTasks)} {inputReadWithUnit}
spark.hadoop.validateOutputSpecsspark.hadoop.validateOutputSpecs true If set to true, validates the output specification (e.g. checking if the output directory already exists) used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing @@ -570,7 +570,7 @@ Apart from these, the following properties are also available, and may be useful previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand.
spark.executor.heartbeatIntervalspark.executor.heartbeatInterval 10000 Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let the driver know that the executor is still alive and update it with metrics for in-progress From 69a57a18ee35af1cc5a00b67a80837ea317cd330 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 13 Aug 2014 23:53:44 -0700 Subject: [PATCH 77/83] [SPARK-2995][MLLIB] add ALS.setIntermediateRDDStorageLevel As mentioned in SPARK-2465, using `MEMORY_AND_DISK_SER` for user/product in/out links together with `spark.rdd.compress=true` can help reduce the space requirement by a lot, at the cost of speed. It might be useful to add this option so people can run ALS on much bigger datasets. Another option for the method name is `setIntermediateRDDStorageLevel`. Author: Xiangrui Meng Closes #1913 from mengxr/als-storagelevel and squashes the following commits: d942017 [Xiangrui Meng] rename to setIntermediateRDDStorageLevel 7550029 [Xiangrui Meng] add ALS.setIntermediateDataStorageLevel --- .../spark/mllib/recommendation/ALS.scala | 45 ++++++++++++------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 8ebc7e27ed4dd..84d192db53e26 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -111,11 +111,17 @@ class ALS private ( */ def this() = this(-1, -1, 10, 10, 0.01, false, 1.0) + /** If true, do alternating nonnegative least squares. */ + private var nonnegative = false + + /** storage level for user/product in/out links */ + private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK + /** * Set the number of blocks for both user blocks and product blocks to parallelize the computation * into; pass -1 for an auto-configured number of blocks. Default: -1. */ - def setBlocks(numBlocks: Int): ALS = { + def setBlocks(numBlocks: Int): this.type = { this.numUserBlocks = numBlocks this.numProductBlocks = numBlocks this @@ -124,7 +130,7 @@ class ALS private ( /** * Set the number of user blocks to parallelize the computation. */ - def setUserBlocks(numUserBlocks: Int): ALS = { + def setUserBlocks(numUserBlocks: Int): this.type = { this.numUserBlocks = numUserBlocks this } @@ -132,31 +138,31 @@ class ALS private ( /** * Set the number of product blocks to parallelize the computation. */ - def setProductBlocks(numProductBlocks: Int): ALS = { + def setProductBlocks(numProductBlocks: Int): this.type = { this.numProductBlocks = numProductBlocks this } /** Set the rank of the feature matrices computed (number of features). Default: 10. */ - def setRank(rank: Int): ALS = { + def setRank(rank: Int): this.type = { this.rank = rank this } /** Set the number of iterations to run. Default: 10. */ - def setIterations(iterations: Int): ALS = { + def setIterations(iterations: Int): this.type = { this.iterations = iterations this } /** Set the regularization parameter, lambda. Default: 0.01. */ - def setLambda(lambda: Double): ALS = { + def setLambda(lambda: Double): this.type = { this.lambda = lambda this } /** Sets whether to use implicit preference. Default: false. */ - def setImplicitPrefs(implicitPrefs: Boolean): ALS = { + def setImplicitPrefs(implicitPrefs: Boolean): this.type = { this.implicitPrefs = implicitPrefs this } @@ -166,29 +172,38 @@ class ALS private ( * Sets the constant used in computing confidence in implicit ALS. Default: 1.0. */ @Experimental - def setAlpha(alpha: Double): ALS = { + def setAlpha(alpha: Double): this.type = { this.alpha = alpha this } /** Sets a random seed to have deterministic results. */ - def setSeed(seed: Long): ALS = { + def setSeed(seed: Long): this.type = { this.seed = seed this } - /** If true, do alternating nonnegative least squares. */ - private var nonnegative = false - /** * Set whether the least-squares problems solved at each iteration should have * nonnegativity constraints. */ - def setNonnegative(b: Boolean): ALS = { + def setNonnegative(b: Boolean): this.type = { this.nonnegative = b this } + /** + * :: DeveloperApi :: + * Sets storage level for intermediate RDDs (user/product in/out links). The default value is + * `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g., `MEMORY_AND_DISK_SER` and + * set `spark.rdd.compress` to `true` to reduce the space requirement, at the cost of speed. + */ + @DeveloperApi + def setIntermediateRDDStorageLevel(storageLevel: StorageLevel): this.type = { + this.intermediateRDDStorageLevel = storageLevel + this + } + /** * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. * Returns a MatrixFactorizationModel with feature vectors for each user and product. @@ -441,8 +456,8 @@ class ALS private ( }, preservesPartitioning = true) val inLinks = links.mapValues(_._1) val outLinks = links.mapValues(_._2) - inLinks.persist(StorageLevel.MEMORY_AND_DISK) - outLinks.persist(StorageLevel.MEMORY_AND_DISK) + inLinks.persist(intermediateRDDStorageLevel) + outLinks.persist(intermediateRDDStorageLevel) (inLinks, outLinks) } From d069c5d9d2f6ce06389ca2ddf0b3ae4db72c5797 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 14 Aug 2014 01:37:38 -0700 Subject: [PATCH 78/83] [SPARK-3029] Disable local execution of Spark jobs by default Currently, local execution of Spark jobs is only used by take(), and it can be problematic as it can load a significant amount of data onto the driver. The worst case scenarios occur if the RDD is cached (guaranteed to load whole partition), has very large elements, or the partition is just large and we apply a filter with high selectivity or computational overhead. Additionally, jobs that run locally in this manner do not show up in the web UI, and are thus harder to track or understand what is occurring. This PR adds a flag to disable local execution, which is turned OFF by default, with the intention of perhaps eventually removing this functionality altogether. Removing it now is a tougher proposition since it is part of the public runJob API. An alternative solution would be to limit the flag to take()/first() to avoid impacting any external users of this API, but such usage (or, at least, reliance upon the feature) is hopefully minimal. Author: Aaron Davidson Closes #1321 from aarondav/allowlocal and squashes the following commits: 136b253 [Aaron Davidson] Fix DAGSchedulerSuite 5599d55 [Aaron Davidson] [RFC] Disable local execution of Spark jobs by default --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 7 ++++++- .../org/apache/spark/scheduler/DAGSchedulerSuite.scala | 4 +++- docs/configuration.md | 9 +++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 430e45ada5808..36bbaaa3f1c85 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -121,6 +121,9 @@ class DAGScheduler( private[scheduler] var eventProcessActor: ActorRef = _ + /** If enabled, we may run certain actions like take() and first() locally. */ + private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) + private def initializeEventProcessActor() { // blocking the thread until supervisor is started, which ensures eventProcessActor is // not null before any job is submitted @@ -732,7 +735,9 @@ class DAGScheduler( logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) - if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { + val shouldRunLocally = + localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 + if (shouldRunLocally) { // Compute very short actions like first() or take() with no parent stages locally. listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties)) runLocally(job) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8c1b0fed11f72..bd829752eb401 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -141,7 +141,9 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } before { - sc = new SparkContext("local", "DAGSchedulerSuite") + // Enable local execution for this test + val conf = new SparkConf().set("spark.localExecution.enabled", "true") + sc = new SparkContext("local", "DAGSchedulerSuite", conf) sparkListener.successfulStages.clear() sparkListener.failedStages.clear() failure = null diff --git a/docs/configuration.md b/docs/configuration.md index c8336b39133de..c408c468dcd94 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -846,6 +846,15 @@ Apart from these, the following properties are also available, and may be useful (in milliseconds).
spark.localExecution.enabledfalse + Enables Spark to run certain jobs, such as first() or take() on the driver, without sending + tasks to the cluster. This can make certain jobs execute very quickly, but may require + shipping a whole partition of data to the driver. +
#### Security From 6b8de0e36c7548046c3b8a57f2c8e7e788dde8cc Mon Sep 17 00:00:00 2001 From: Graham Dennis Date: Thu, 14 Aug 2014 02:24:18 -0700 Subject: [PATCH 79/83] SPARK-2893: Do not swallow Exceptions when running a custom kryo registrator The previous behaviour of swallowing ClassNotFound exceptions when running a custom Kryo registrator could lead to difficult to debug problems later on at serialisation / deserialisation time, see SPARK-2878. Instead it is better to fail fast. Added test case. Author: Graham Dennis Closes #1827 from GrahamDennis/feature/spark-2893 and squashes the following commits: fbe4cb6 [Graham Dennis] [SPARK-2878]: Update the test case to match the updated exception message 65e53c5 [Graham Dennis] [SPARK-2893]: Improve message when a spark.kryo.registrator fails. f480d85 [Graham Dennis] [SPARK-2893] Fix typo. b59d2c2 [Graham Dennis] SPARK-2893: Do not swallow Exceptions when running a custom spark.kryo.registrator --- .../org/apache/spark/serializer/KryoSerializer.scala | 11 ++++++----- .../apache/spark/serializer/KryoSerializerSuite.scala | 10 ++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 407cb9db6ee9a..85944eabcfefc 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -79,15 +79,16 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) // Allow the user to register their own classes by setting spark.kryo.registrator - try { - for (regCls <- registrator) { - logDebug("Running user registrator: " + regCls) + for (regCls <- registrator) { + logDebug("Running user registrator: " + regCls) + try { val reg = Class.forName(regCls, true, classLoader).newInstance() .asInstanceOf[KryoRegistrator] reg.registerClasses(kryo) + } catch { + case e: Exception => + throw new SparkException(s"Failed to invoke $regCls", e) } - } catch { - case e: Exception => logError("Failed to run spark.kryo.registrator", e) } // Register Chill's classes; we do this after our ranges and the user's own classes to let diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 789b773bae316..3bf9efebb39d2 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -207,6 +207,16 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x assert(10 + control.sum === result) } + + test("kryo with nonexistent custom registrator should fail") { + import org.apache.spark.{SparkConf, SparkException} + + val conf = new SparkConf(false) + conf.set("spark.kryo.registrator", "this.class.does.not.exist") + + val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance()) + assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist")) + } } class KryoSerializerResizableOutputSuite extends FunSuite { From 078f3fbda860e2f5de34153c55dfc3fecb4256e9 Mon Sep 17 00:00:00 2001 From: Chia-Yung Su Date: Thu, 14 Aug 2014 10:43:08 -0700 Subject: [PATCH 80/83] [SPARK-3011][SQL] _temporary directory should be filtered out by sqlContext.parquetFile Author: Chia-Yung Su Closes #1924 from joesu/bugfix-spark3011 and squashes the following commits: c7e44f2 [Chia-Yung Su] match syntax f8fc32a [Chia-Yung Su] filter out tmp dir --- .../main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 2867dc0a8b1f9..37091bcf73dd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -375,7 +375,8 @@ private[parquet] object ParquetTypesConverter extends Logging { val children = fs.listStatus(path).filterNot { status => val name = status.getPath.getName - name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME + name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME || + name == FileOutputCommitter.TEMP_DIR_NAME } // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row From add75d4831fdc35712bf8b737574ea0bc677c37c Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 14 Aug 2014 10:46:33 -0700 Subject: [PATCH 81/83] [SPARK-2927][SQL] Add a conf to configure if we always read Binary columns stored in Parquet as String columns This PR adds a new conf flag `spark.sql.parquet.binaryAsString`. When it is `true`, if there is no parquet metadata file available to provide the schema of the data, we will always treat binary fields stored in parquet as string fields. This conf is used to provide a way to read string fields generated without UTF8 decoration. JIRA: https://issues.apache.org/jira/browse/SPARK-2927 Author: Yin Huai Closes #1855 from yhuai/parquetBinaryAsString and squashes the following commits: 689ffa9 [Yin Huai] Add missing "=". 80827de [Yin Huai] Unit test. 1765ca4 [Yin Huai] Use .toBoolean. 9d3f199 [Yin Huai] Merge remote-tracking branch 'upstream/master' into parquetBinaryAsString 5d436a1 [Yin Huai] The initial support of adding a conf to treat binary columns stored in Parquet as string columns. --- .../scala/org/apache/spark/sql/SQLConf.scala | 10 +++- .../spark/sql/parquet/ParquetRelation.scala | 6 ++- .../sql/parquet/ParquetTableSupport.scala | 3 +- .../spark/sql/parquet/ParquetTypes.scala | 36 +++++++------ .../spark/sql/parquet/ParquetQuerySuite.scala | 54 +++++++++++++++++-- 5 files changed, 87 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 35c51dec0bcf5..90de11182e605 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -31,6 +31,7 @@ private[spark] object SQLConf { val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val CODEGEN_ENABLED = "spark.sql.codegen" val DIALECT = "spark.sql.dialect" + val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -87,8 +88,7 @@ trait SQLConf { * * Defaults to false as this feature is currently experimental. */ - private[spark] def codegenEnabled: Boolean = - if (getConf(CODEGEN_ENABLED, "false") == "true") true else false + private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to @@ -108,6 +108,12 @@ trait SQLConf { private[spark] def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, (autoBroadcastJoinThreshold + 1).toString).toLong + /** + * When set to true, we always treat byte arrays in Parquet files as strings. + */ + private[spark] def isParquetBinaryAsString: Boolean = + getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index b3bae5db0edbc..053b2a154389c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -60,7 +60,11 @@ private[sql] case class ParquetRelation( .getSchema /** Attributes */ - override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf) + override val output = + ParquetTypesConverter.readSchemaFromFile( + new Path(path), + conf, + sqlContext.isParquetBinaryAsString) override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 6d4ce32ac5bfa..6a657c20fe46c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -80,9 +80,10 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { } } // if both unavailable, fall back to deducing the schema from the given Parquet schema + // TODO: Why it can be null? if (schema == null) { log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes(parquetSchema) + schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false) } log.debug(s"list of attributes that will be read: $schema") new RowRecordMaterializer(parquetSchema, schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 37091bcf73dd6..b0579f76da073 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -43,10 +43,13 @@ private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = classOf[PrimitiveType] isAssignableFrom ctype.getClass - def toPrimitiveDataType(parquetType: ParquetPrimitiveType): DataType = + def toPrimitiveDataType( + parquetType: ParquetPrimitiveType, + binayAsString: Boolean): DataType = parquetType.getPrimitiveTypeName match { case ParquetPrimitiveTypeName.BINARY - if parquetType.getOriginalType == ParquetOriginalType.UTF8 => StringType + if (parquetType.getOriginalType == ParquetOriginalType.UTF8 || + binayAsString) => StringType case ParquetPrimitiveTypeName.BINARY => BinaryType case ParquetPrimitiveTypeName.BOOLEAN => BooleanType case ParquetPrimitiveTypeName.DOUBLE => DoubleType @@ -85,7 +88,7 @@ private[parquet] object ParquetTypesConverter extends Logging { * @param parquetType The type to convert. * @return The corresponding Catalyst type. */ - def toDataType(parquetType: ParquetType): DataType = { + def toDataType(parquetType: ParquetType, isBinaryAsString: Boolean): DataType = { def correspondsToMap(groupType: ParquetGroupType): Boolean = { if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { false @@ -107,7 +110,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType) + toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString) } else { val groupType = parquetType.asGroupType() parquetType.getOriginalType match { @@ -116,7 +119,7 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetOriginalType.LIST => { // TODO: check enums! assert(groupType.getFieldCount == 1) val field = groupType.getFields.apply(0) - ArrayType(toDataType(field), containsNull = false) + ArrayType(toDataType(field, isBinaryAsString), containsNull = false) } case ParquetOriginalType.MAP => { assert( @@ -126,9 +129,9 @@ private[parquet] object ParquetTypesConverter extends Logging { assert( keyValueGroup.getFieldCount == 2, "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") - val keyType = toDataType(keyValueGroup.getFields.apply(0)) + val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val valueType = toDataType(keyValueGroup.getFields.apply(1)) + val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true // at here. @@ -138,22 +141,22 @@ private[parquet] object ParquetTypesConverter extends Logging { // Note: the order of these checks is important! if (correspondsToMap(groupType)) { // MapType val keyValueGroup = groupType.getFields.apply(0).asGroupType() - val keyType = toDataType(keyValueGroup.getFields.apply(0)) + val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val valueType = toDataType(keyValueGroup.getFields.apply(1)) + val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true // at here. MapType(keyType, valueType) } else if (correspondsToArray(groupType)) { // ArrayType - val elementType = toDataType(groupType.getFields.apply(0)) + val elementType = toDataType(groupType.getFields.apply(0), isBinaryAsString) ArrayType(elementType, containsNull = false) } else { // everything else: StructType val fields = groupType .getFields .map(ptype => new StructField( ptype.getName, - toDataType(ptype), + toDataType(ptype, isBinaryAsString), ptype.getRepetition != Repetition.REQUIRED)) StructType(fields) } @@ -276,7 +279,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } } - def convertToAttributes(parquetSchema: ParquetType): Seq[Attribute] = { + def convertToAttributes(parquetSchema: ParquetType, isBinaryAsString: Boolean): Seq[Attribute] = { parquetSchema .asGroupType() .getFields @@ -284,7 +287,7 @@ private[parquet] object ParquetTypesConverter extends Logging { field => new AttributeReference( field.getName, - toDataType(field), + toDataType(field, isBinaryAsString), field.getRepetition != Repetition.REQUIRED)()) } @@ -404,7 +407,10 @@ private[parquet] object ParquetTypesConverter extends Logging { * @param conf The Hadoop configuration to use. * @return A list of attributes that make up the schema. */ - def readSchemaFromFile(origPath: Path, conf: Option[Configuration]): Seq[Attribute] = { + def readSchemaFromFile( + origPath: Path, + conf: Option[Configuration], + isBinaryAsString: Boolean): Seq[Attribute] = { val keyValueMetadata: java.util.Map[String, String] = readMetaData(origPath, conf) .getFileMetaData @@ -413,7 +419,7 @@ private[parquet] object ParquetTypesConverter extends Logging { convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) } else { val attributes = convertToAttributes( - readMetaData(origPath, conf).getFileMetaData.getSchema) + readMetaData(origPath, conf).getFileMetaData.getSchema, isBinaryAsString) log.info(s"Falling back to schema conversion from Parquet types; result: $attributes") attributes } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 502f6702e394e..172dcd6aa0ee3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -21,8 +21,6 @@ import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil -import parquet.schema.MessageTypeParser - import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job @@ -33,7 +31,6 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} import org.apache.spark.sql.catalyst.util.getTempFilePath -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils @@ -138,6 +135,57 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } } + test("Treat binary as string") { + val oldIsParquetBinaryAsString = TestSQLContext.isParquetBinaryAsString + + // Create the test file. + val file = getTempFilePath("parquet") + val path = file.toString + val range = (0 to 255) + val rowRDD = TestSQLContext.sparkContext.parallelize(range) + .map(i => org.apache.spark.sql.Row(i, s"val_$i".getBytes)) + // We need to ask Parquet to store the String column as a Binary column. + val schema = StructType( + StructField("c1", IntegerType, false) :: + StructField("c2", BinaryType, false) :: Nil) + val schemaRDD1 = applySchema(rowRDD, schema) + schemaRDD1.saveAsParquetFile(path) + val resultWithBinary = parquetFile(path).collect + range.foreach { + i => + assert(resultWithBinary(i).getInt(0) === i) + assert(resultWithBinary(i)(1) === s"val_$i".getBytes) + } + + TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") + // This ParquetRelation always use Parquet types to derive output. + val parquetRelation = new ParquetRelation( + path.toString, + Some(TestSQLContext.sparkContext.hadoopConfiguration), + TestSQLContext) { + override val output = + ParquetTypesConverter.convertToAttributes( + ParquetTypesConverter.readMetaData(new Path(path), conf).getFileMetaData.getSchema, + TestSQLContext.isParquetBinaryAsString) + } + val schemaRDD = new SchemaRDD(TestSQLContext, parquetRelation) + val resultWithString = schemaRDD.collect + range.foreach { + i => + assert(resultWithString(i).getInt(0) === i) + assert(resultWithString(i)(1) === s"val_$i") + } + + schemaRDD.registerTempTable("tmp") + checkAnswer( + sql("SELECT c1, c2 FROM tmp WHERE c2 = 'val_5' OR c2 = 'val_7'"), + (5, "val_5") :: + (7, "val_7") :: Nil) + + // Set it back. + TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString) + } + test("Read/Write All Types with non-primitive type") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) From fde692b361773110c262abe219e7c8128bd76419 Mon Sep 17 00:00:00 2001 From: Ahir Reddy Date: Thu, 14 Aug 2014 10:48:52 -0700 Subject: [PATCH 82/83] [SQL] Python JsonRDD UTF8 Encoding Fix Only encode unicode objects to UTF-8, and not strings Author: Ahir Reddy Closes #1914 from ahirreddy/json-rdd-unicode-fix1 and squashes the following commits: ca4e9ba [Ahir Reddy] Encoding Fix --- python/pyspark/sql.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 46540ca3f1e8a..95086a2258222 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1267,7 +1267,9 @@ def func(iterator): for x in iterator: if not isinstance(x, basestring): x = unicode(x) - yield x.encode("utf-8") + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x keyed = rdd.mapPartitions(func) keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) From 267fdffe2743bc2dc706c8ac8af0ae33a358a5d3 Mon Sep 17 00:00:00 2001 From: wangfei Date: Thu, 14 Aug 2014 10:55:51 -0700 Subject: [PATCH 83/83] [SPARK-2925] [sql]fix spark-sql and start-thriftserver shell bugs when set --driver-java-options https://issues.apache.org/jira/browse/SPARK-2925 Run cmd like this will get the error bin/spark-sql --driver-java-options '-Xdebug -Xnoagent -Xrunjdwp:transport=dt_socket,address=8788,server=y,suspend=y' Error: Unrecognized option '-Xnoagent'. Run with --help for usage help or --verbose for debug output Author: wangfei Author: wangfei Closes #1851 from scwf/patch-2 and squashes the following commits: 516554d [wangfei] quote variables to fix this issue 8bd40f2 [wangfei] quote variables to fix this problem e6d79e3 [wangfei] fix start-thriftserver bug when set driver-java-options 948395d [wangfei] fix spark-sql error when set --driver-java-options --- bin/spark-sql | 18 +++++++++--------- sbin/start-thriftserver.sh | 8 ++++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/bin/spark-sql b/bin/spark-sql index 7813ccc361415..564f1f419060f 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -65,30 +65,30 @@ while (($#)); do case $1 in -d | --define | --database | -f | -h | --hiveconf | --hivevar | -i | -p) ensure_arg_number $# 2 - CLI_ARGS+=($1); shift - CLI_ARGS+=($1); shift + CLI_ARGS+=("$1"); shift + CLI_ARGS+=("$1"); shift ;; -e) ensure_arg_number $# 2 - CLI_ARGS+=($1); shift - CLI_ARGS+=(\"$1\"); shift + CLI_ARGS+=("$1"); shift + CLI_ARGS+=("$1"); shift ;; -s | --silent) - CLI_ARGS+=($1); shift + CLI_ARGS+=("$1"); shift ;; -v | --verbose) # Both SparkSubmit and SparkSQLCLIDriver recognizes -v | --verbose - CLI_ARGS+=($1) - SUBMISSION_ARGS+=($1); shift + CLI_ARGS+=("$1") + SUBMISSION_ARGS+=("$1"); shift ;; *) - SUBMISSION_ARGS+=($1); shift + SUBMISSION_ARGS+=("$1"); shift ;; esac done -eval exec "$FWDIR"/bin/spark-submit --class $CLASS ${SUBMISSION_ARGS[*]} spark-internal ${CLI_ARGS[*]} +exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_ARGS[@]}" spark-internal "${CLI_ARGS[@]}" diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index 603f50ae13240..2c4452473ccbc 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -65,14 +65,14 @@ while (($#)); do case $1 in --hiveconf) ensure_arg_number $# 2 - THRIFT_SERVER_ARGS+=($1); shift - THRIFT_SERVER_ARGS+=($1); shift + THRIFT_SERVER_ARGS+=("$1"); shift + THRIFT_SERVER_ARGS+=("$1"); shift ;; *) - SUBMISSION_ARGS+=($1); shift + SUBMISSION_ARGS+=("$1"); shift ;; esac done -eval exec "$FWDIR"/bin/spark-submit --class $CLASS ${SUBMISSION_ARGS[*]} spark-internal ${THRIFT_SERVER_ARGS[*]} +exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_ARGS[@]}" spark-internal "${THRIFT_SERVER_ARGS[@]}"