diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 0746e43bab..d9d7b06cdb 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -169,4 +169,8 @@ object FileCommitProtocol extends Logging { ctor.newInstance(jobId, outputPath) } } + + def getStagingDir(path: String, jobId: String): Path = { + new Path(path, ".spark-staging-" + jobId) + } } diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 11ce608f52..30f9a650a6 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -41,13 +41,28 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil * @param jobId the job's or stage's id * @param path the job's output path, or null if committer acts as a noop * @param dynamicPartitionOverwrite If true, Spark will overwrite partition directories at runtime - * dynamically, i.e., we first write files under a staging - * directory with partition path, e.g. - * /path/to/staging/a=1/b=1/xxx.parquet. When committing the job, - * we first clean up the corresponding partition directories at - * destination path, e.g. /path/to/destination/a=1/b=1, and move - * files from staging directory to the corresponding partition - * directories under destination path. + * dynamically. Suppose final path is /path/to/outputPath, output + * path of [[FileOutputCommitter]] is an intermediate path, e.g. + * /path/to/outputPath/.spark-staging-{jobId}, which is a staging + * directory. Task attempts firstly write files under the + * intermediate path, e.g. + * /path/to/outputPath/.spark-staging-{jobId}/_temporary/ + * {appAttemptId}/_temporary/{taskAttemptId}/a=1/b=1/xxx.parquet. + * + * 1. When [[FileOutputCommitter]] algorithm version set to 1, + * we firstly move task attempt output files to + * /path/to/outputPath/.spark-staging-{jobId}/_temporary/ + * {appAttemptId}/{taskId}/a=1/b=1, + * then move them to + * /path/to/outputPath/.spark-staging-{jobId}/a=1/b=1. + * 2. When [[FileOutputCommitter]] algorithm version set to 2, + * committing tasks directly move task attempt output files to + * /path/to/outputPath/.spark-staging-{jobId}/a=1/b=1. + * + * At the end of committing job, we move output files from + * intermediate path to final path, e.g., move files from + * /path/to/outputPath/.spark-staging-{jobId}/a=1/b=1 + * to /path/to/outputPath/a=1/b=1 */ class HadoopMapReduceCommitProtocol( jobId: String, @@ -89,7 +104,7 @@ class HadoopMapReduceCommitProtocol( * The staging directory of this write job. Spark uses it to deal with files with absolute output * path, or writing data into partitioned directory with dynamicPartitionOverwrite=true. */ - private def stagingDir = new Path(path, ".spark-staging-" + jobId) + protected def stagingDir = getStagingDir(path, jobId) protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.getConstructor().newInstance() @@ -106,13 +121,13 @@ class HadoopMapReduceCommitProtocol( val filename = getFilename(taskContext, ext) val stagingDir: Path = committer match { - case _ if dynamicPartitionOverwrite => - assert(dir.isDefined, - "The dataset to be written must be partitioned when dynamicPartitionOverwrite is true.") - partitionPaths += dir.get - this.stagingDir // For FileOutputCommitter it has its own staging path called "work path". case f: FileOutputCommitter => + if (dynamicPartitionOverwrite) { + assert(dir.isDefined, + "The dataset to be written must be partitioned when dynamicPartitionOverwrite is true.") + partitionPaths += dir.get + } new Path(Option(f.getWorkPath).map(_.toString).getOrElse(path)) case _ => new Path(path) } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 147731a0fb..c607fb28b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -169,7 +169,7 @@ private[spark] class StorageStatus( .getOrElse((0L, 0L)) case _ if !level.useOffHeap => (_nonRddStorageInfo.onHeapUsage, _nonRddStorageInfo.diskUsage) - case _ if level.useOffHeap => + case _ => (_nonRddStorageInfo.offHeapUsage, _nonRddStorageInfo.diskUsage) } val newMem = math.max(oldMem + changeInMem, 0L) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 13f7cb4533..103965e486 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -757,7 +757,7 @@ private[spark] object JsonProtocol { def taskResourceRequestMapFromJson(json: JValue): Map[String, TaskResourceRequest] = { val jsonFields = json.asInstanceOf[JObject].obj - jsonFields.map { case JField(k, v) => + jsonFields.collect { case JField(k, v) => val req = taskResourceRequestFromJson(v) (k, req) }.toMap @@ -765,7 +765,7 @@ private[spark] object JsonProtocol { def executorResourceRequestMapFromJson(json: JValue): Map[String, ExecutorResourceRequest] = { val jsonFields = json.asInstanceOf[JObject].obj - jsonFields.map { case JField(k, v) => + jsonFields.collect { case JField(k, v) => val req = executorResourceRequestFromJson(v) (k, req) }.toMap @@ -1229,7 +1229,7 @@ private[spark] object JsonProtocol { def resourcesMapFromJson(json: JValue): Map[String, ResourceInformation] = { val jsonFields = json.asInstanceOf[JObject].obj - jsonFields.map { case JField(k, v) => + jsonFields.collect { case JField(k, v) => val resourceInfo = ResourceInformation.parseJson(v) (k, resourceInfo) }.toMap @@ -1241,7 +1241,7 @@ private[spark] object JsonProtocol { def mapFromJson(json: JValue): Map[String, String] = { val jsonFields = json.asInstanceOf[JObject].obj - jsonFields.map { case JField(k, JString(v)) => (k, v) }.toMap + jsonFields.collect { case JField(k, JString(v)) => (k, v) }.toMap } def propertiesFromJson(json: JValue): Properties = { diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index c3adc696a5..c155d4ea3f 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -26,6 +26,7 @@ FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" cd "$FWDIR" export PATH=/home/anaconda/envs/py36/bin:$PATH +export LANG="en_US.UTF-8" PYTHON_VERSION_CHECK=$(python3 -c 'import sys; print(sys.version_info < (3, 6, 0))') if [[ "$PYTHON_VERSION_CHECK" == "True" ]]; then diff --git a/docs/_config.yml b/docs/_config.yml index cd341063a1..026b3dd804 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -26,15 +26,20 @@ SCALA_VERSION: "2.12.10" MESOS_VERSION: 1.0.0 SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark -# Before a new release, we should apply a new `apiKey` for the new Spark documentation -# on https://docsearch.algolia.com/. Otherwise, after release, the search results are always based -# on the latest documentation(https://spark.apache.org/docs/latest/) even when visiting the -# documentation of previous releases. +# Before a new release, we should: +# 1. update the `version` array for the new Spark documentation +# on https://github.com/algolia/docsearch-configs/blob/master/configs/apache_spark.json. +# 2. update the value of `facetFilters.version` in `algoliaOptions` on the new release branch. +# Otherwise, after release, the search results are always based on the latest documentation +# (https://spark.apache.org/docs/latest/) even when visiting the documentation of previous releases. DOCSEARCH_SCRIPT: | docsearch({ apiKey: 'b18ca3732c502995563043aa17bc6ecb', indexName: 'apache_spark', inputSelector: '#docsearch-input', enhancedSearchInput: true, + algoliaOptions: { + 'facetFilters': ["version:latest"] + }, debug: false // Set debug to true if you want to inspect the dropdown }); diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 368f177cda..b6c1b011f0 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -302,6 +302,8 @@ private[spark] object BLAS extends Serializable { j += 1 prevCol = col } + case _ => + throw new IllegalArgumentException(s"spr doesn't support vector type ${v.getClass}.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index dbbfd8f329..c5b28c95eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -286,6 +286,7 @@ private[ml] object RFormulaParser extends RegexParsers { private val pow: Parser[Term] = term ~ "^" ~ "^[1-9]\\d*".r ^^ { case base ~ "^" ~ degree => power(base, degree.toInt) + case t => throw new IllegalArgumentException(s"Invalid term: $t") } | term private val interaction: Parser[Term] = pow * (":" ^^^ { interact _ }) @@ -298,7 +299,10 @@ private[ml] object RFormulaParser extends RegexParsers { private val expr = (sum | term) private val formula: Parser[ParsedRFormula] = - (label ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.asTerms.terms) } + (label ~ "~" ~ expr) ^^ { + case r ~ "~" ~ t => ParsedRFormula(r, t.asTerms.terms) + case t => throw new IllegalArgumentException(s"Invalid term: $t") + } def parse(value: String): ParsedRFormula = parseAll(formula, value) match { case Success(result, _) => result diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 7434b1adb2..92dee46ad0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -314,6 +314,8 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { case SparseVector(size, indices, values) => val newValues = transformSparseWithScale(scale, indices, values.clone()) Vectors.sparse(size, indices, newValues) + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } case (false, false) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala index 0bee643412..8f03a29eb9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonMatrixConverter.scala @@ -74,6 +74,8 @@ private[ml] object JsonMatrixConverter { ("values" -> values.toSeq) ~ ("isTransposed" -> isTransposed) compact(render(jValue)) + case _ => + throw new IllegalArgumentException(s"Unknown matrix type ${m.getClass}.") } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala index 781e69f8d6..1b949d75ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/JsonVectorConverter.scala @@ -57,6 +57,8 @@ private[ml] object JsonVectorConverter { case DenseVector(values) => val jValue = ("type" -> 1) ~ ("values" -> values.toSeq) compact(render(jValue)) + case _ => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala index 37f173bc20..35bbaf5aa1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala @@ -45,6 +45,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { row.setNullAt(2) row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala index 3d72512563..0fe1ed231a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala @@ -200,6 +200,9 @@ private[ml] class BlockHingeAggregator( case sm: SparseMatrix if !fitIntercept => val gradSumVec = new DenseVector(gradientSumArray) BLAS.gemv(1.0, sm.transpose, vec, 1.0, gradSumVec) + + case m => + throw new IllegalArgumentException(s"Unknown matrix type ${m.getClass}.") } if (fitIntercept) gradientSumArray(numFeatures) += vec.values.sum diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala index 2496c789f8..5a516940b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala @@ -504,6 +504,9 @@ private[ml] class BlockLogisticAggregator( case sm: SparseMatrix if !fitIntercept => val gradSumVec = new DenseVector(gradientSumArray) BLAS.gemv(1.0, sm.transpose, vec, 1.0, gradSumVec) + + case m => + throw new IllegalArgumentException(s"Unknown matrix type ${m.getClass}.") } if (fitIntercept) gradientSumArray(numFeatures) += vec.values.sum diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index d4b39e11fd..2215c2b071 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -192,6 +192,8 @@ private[spark] object Instrumentation { case Failure(NonFatal(e)) => instr.logFailure(e) throw e + case Failure(e) => + throw e case Success(result) => instr.logSuccess() result diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 8f9d6d07a4..12a5a0f2b2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -167,6 +167,8 @@ class StandardScalerModel @Since("1.3.0") ( val newValues = NewStandardScalerModel .transformSparseWithScale(localScale, indices, values.clone()) Vectors.sparse(size, indices, newValues) + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } case _ => vector diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index da486010cf..bd60364326 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -285,6 +285,8 @@ private[spark] object BLAS extends Serializable with Logging { j += 1 prevCol = col } + case _ => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 2fe415f140..9ed9dd0c88 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -289,6 +289,8 @@ class VectorUDT extends UserDefinedType[Vector] { row.setNullAt(2) row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index ad79230c75..da5d165069 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -145,6 +145,8 @@ class IndexedRowMatrix @Since("1.0.0") ( .map { case (values, blockColumn) => ((blockRow.toInt, blockColumn), (rowInBlock.toInt, values.zipWithIndex)) } + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } }.groupByKey(GridPartitioner(numRowBlocks, numColBlocks, rows.getNumPartitions)).map { case ((blockRow, blockColumn), itr) => @@ -187,6 +189,8 @@ class IndexedRowMatrix @Since("1.0.0") ( Iterator.tabulate(indices.length)(i => MatrixEntry(rowIndex, indices(i), values(i))) case DenseVector(values) => Iterator.tabulate(values.length)(i => MatrixEntry(rowIndex, i, values(i))) + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } } new CoordinateMatrix(entries, numRows(), numCols()) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 07b9d91c1f..c618b71ddc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -748,6 +748,8 @@ class RowMatrix @Since("1.0.0") ( } buf }.flatten + case v => + throw new IllegalArgumentException(s"Unknown vector type ${v.getClass}.") } } }.reduceByKey(_ + _).map { case ((i, j), sim) => diff --git a/pom.xml b/pom.xml index 0ab5a8c5b3..e5b1f30edd 100644 --- a/pom.xml +++ b/pom.xml @@ -3264,7 +3264,7 @@ scala-2.13 - 2.13.3 + 2.13.4 2.13 diff --git a/python/docs/source/reference/pyspark.mllib.rst b/python/docs/source/reference/pyspark.mllib.rst index acc834c065..df5ea017d0 100644 --- a/python/docs/source/reference/pyspark.mllib.rst +++ b/python/docs/source/reference/pyspark.mllib.rst @@ -216,6 +216,8 @@ Statistics ChiSqTestResult MultivariateGaussian KernelDensity + ChiSqTestResult + KolmogorovSmirnovTestResult Tree @@ -250,4 +252,3 @@ Utilities Loader MLUtils Saveable - diff --git a/python/mypy.ini b/python/mypy.ini index 4a5368a519..5103452a05 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -16,10 +16,97 @@ ; [mypy] +strict_optional = True +no_implicit_optional = True +disallow_untyped_defs = True + +; Allow untyped def in internal modules and tests + +[mypy-pyspark.daemon] +disallow_untyped_defs = False + +[mypy-pyspark.find_spark_home] +disallow_untyped_defs = False + +[mypy-pyspark._globals] +disallow_untyped_defs = False + +[mypy-pyspark.install] +disallow_untyped_defs = False + +[mypy-pyspark.java_gateway] +disallow_untyped_defs = False + +[mypy-pyspark.join] +disallow_untyped_defs = False + +[mypy-pyspark.ml.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.mllib.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.rddsampler] +disallow_untyped_defs = False + +[mypy-pyspark.resource.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.serializers] +disallow_untyped_defs = False + +[mypy-pyspark.shuffle] +disallow_untyped_defs = False + +[mypy-pyspark.streaming.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.streaming.util] +disallow_untyped_defs = False + +[mypy-pyspark.sql.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.serializers] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.types] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.typehints] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas.utils] +disallow_untyped_defs = False + +[mypy-pyspark.sql.pandas._typing.protocols.*] +disallow_untyped_defs = False + +[mypy-pyspark.sql.utils] +disallow_untyped_defs = False + +[mypy-pyspark.tests.*] +disallow_untyped_defs = False + +[mypy-pyspark.testing.*] +disallow_untyped_defs = False + +[mypy-pyspark.traceback_utils] +disallow_untyped_defs = False + +[mypy-pyspark.util] +disallow_untyped_defs = False + +[mypy-pyspark.worker] +disallow_untyped_defs = False + +; Ignore errors in embedded third party code [mypy-pyspark.cloudpickle.*] ignore_errors = True +; Ignore missing imports for external untyped packages + [mypy-py4j.*] ignore_missing_imports = True diff --git a/python/pyspark/broadcast.pyi b/python/pyspark/broadcast.pyi index 4b019a509a..944cb06d41 100644 --- a/python/pyspark/broadcast.pyi +++ b/python/pyspark/broadcast.pyi @@ -17,7 +17,7 @@ # under the License. import threading -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar T = TypeVar("T") @@ -32,14 +32,14 @@ class Broadcast(Generic[T]): path: Optional[Any] = ..., sock_file: Optional[Any] = ..., ) -> None: ... - def dump(self, value: Any, f: Any) -> None: ... - def load_from_path(self, path: Any): ... - def load(self, file: Any): ... + def dump(self, value: T, f: Any) -> None: ... + def load_from_path(self, path: Any) -> T: ... + def load(self, file: Any) -> T: ... @property def value(self) -> T: ... def unpersist(self, blocking: bool = ...) -> None: ... def destroy(self, blocking: bool = ...) -> None: ... - def __reduce__(self): ... + def __reduce__(self) -> Tuple[Callable[[int], T], Tuple[int]]: ... class BroadcastPickleRegistry(threading.local): def __init__(self) -> None: ... diff --git a/python/pyspark/context.pyi b/python/pyspark/context.pyi index 2789a38b3b..640a69cad0 100644 --- a/python/pyspark/context.pyi +++ b/python/pyspark/context.pyi @@ -16,7 +16,19 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + NoReturn, + Optional, + Tuple, + Type, + TypeVar, +) +from types import TracebackType from py4j.java_gateway import JavaGateway, JavaObject # type: ignore[import] @@ -51,9 +63,14 @@ class SparkContext: jsc: Optional[JavaObject] = ..., profiler_cls: type = ..., ) -> None: ... - def __getnewargs__(self): ... - def __enter__(self): ... - def __exit__(self, type, value, trace): ... + def __getnewargs__(self) -> NoReturn: ... + def __enter__(self) -> SparkContext: ... + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + trace: Optional[TracebackType], + ) -> None: ... @classmethod def getOrCreate(cls, conf: Optional[SparkConf] = ...) -> SparkContext: ... def setLogLevel(self, logLevel: str) -> None: ... diff --git a/python/pyspark/ml/classification.pyi b/python/pyspark/ml/classification.pyi index 4bde851bb1..c44176a13a 100644 --- a/python/pyspark/ml/classification.pyi +++ b/python/pyspark/ml/classification.pyi @@ -107,7 +107,7 @@ class _JavaProbabilisticClassifier( class _JavaProbabilisticClassificationModel( ProbabilisticClassificationModel, _JavaClassificationModel[T] ): - def predictProbability(self, value: Any): ... + def predictProbability(self, value: Vector) -> Vector: ... class _ClassificationSummary(JavaWrapper): @property @@ -543,7 +543,7 @@ class RandomForestClassificationModel( @property def trees(self) -> List[DecisionTreeClassificationModel]: ... def summary(self) -> RandomForestClassificationTrainingSummary: ... - def evaluate(self, dataset) -> RandomForestClassificationSummary: ... + def evaluate(self, dataset: DataFrame) -> RandomForestClassificationSummary: ... class RandomForestClassificationSummary(_ClassificationSummary): ... class RandomForestClassificationTrainingSummary( @@ -891,7 +891,7 @@ class FMClassifier( solver: str = ..., thresholds: Optional[Any] = ..., seed: Optional[Any] = ..., - ): ... + ) -> FMClassifier: ... def setFactorSize(self, value: int) -> FMClassifier: ... def setFitLinear(self, value: bool) -> FMClassifier: ... def setMiniBatchFraction(self, value: float) -> FMClassifier: ... diff --git a/python/pyspark/ml/common.pyi b/python/pyspark/ml/common.pyi index 7bf0ed6183..a38fc5734f 100644 --- a/python/pyspark/ml/common.pyi +++ b/python/pyspark/ml/common.pyi @@ -16,5 +16,11 @@ # specific language governing permissions and limitations # under the License. -def callJavaFunc(sc, func, *args): ... -def inherit_doc(cls): ... +from typing import Any, TypeVar + +import pyspark.context + +C = TypeVar("C", bound=type) + +def callJavaFunc(sc: pyspark.context.SparkContext, func: Any, *args: Any) -> Any: ... +def inherit_doc(cls: C) -> C: ... diff --git a/python/pyspark/ml/evaluation.pyi b/python/pyspark/ml/evaluation.pyi index ea0a9f045c..55a3ae2774 100644 --- a/python/pyspark/ml/evaluation.pyi +++ b/python/pyspark/ml/evaluation.pyi @@ -39,9 +39,12 @@ from pyspark.ml.param.shared import ( HasWeightCol, ) from pyspark.ml.util import JavaMLReadable, JavaMLWritable +from pyspark.sql.dataframe import DataFrame class Evaluator(Params, metaclass=abc.ABCMeta): - def evaluate(self, dataset, params: Optional[ParamMap] = ...) -> float: ... + def evaluate( + self, dataset: DataFrame, params: Optional[ParamMap] = ... + ) -> float: ... def isLargerBetter(self) -> bool: ... class JavaEvaluator(JavaParams, Evaluator, metaclass=abc.ABCMeta): @@ -75,16 +78,15 @@ class BinaryClassificationEvaluator( def setLabelCol(self, value: str) -> BinaryClassificationEvaluator: ... def setRawPredictionCol(self, value: str) -> BinaryClassificationEvaluator: ... def setWeightCol(self, value: str) -> BinaryClassificationEvaluator: ... - -def setParams( - self, - *, - rawPredictionCol: str = ..., - labelCol: str = ..., - metricName: BinaryClassificationEvaluatorMetricType = ..., - weightCol: Optional[str] = ..., - numBins: int = ... -) -> BinaryClassificationEvaluator: ... + def setParams( + self, + *, + rawPredictionCol: str = ..., + labelCol: str = ..., + metricName: BinaryClassificationEvaluatorMetricType = ..., + weightCol: Optional[str] = ..., + numBins: int = ... + ) -> BinaryClassificationEvaluator: ... class RegressionEvaluator( JavaEvaluator, diff --git a/python/pyspark/ml/feature.pyi b/python/pyspark/ml/feature.pyi index f5b12a5b2f..4999defdf8 100644 --- a/python/pyspark/ml/feature.pyi +++ b/python/pyspark/ml/feature.pyi @@ -100,9 +100,9 @@ class _LSHParams(HasInputCol, HasOutputCol): def getNumHashTables(self) -> int: ... class _LSH(Generic[JM], JavaEstimator[JM], _LSHParams, JavaMLReadable, JavaMLWritable): - def setNumHashTables(self: P, value) -> P: ... - def setInputCol(self: P, value) -> P: ... - def setOutputCol(self: P, value) -> P: ... + def setNumHashTables(self: P, value: int) -> P: ... + def setInputCol(self: P, value: str) -> P: ... + def setOutputCol(self: P, value: str) -> P: ... class _LSHModel(JavaModel, _LSHParams): def setInputCol(self: P, value: str) -> P: ... @@ -1518,7 +1518,7 @@ class ChiSqSelector( fpr: float = ..., fdr: float = ..., fwe: float = ... - ): ... + ) -> ChiSqSelector: ... def setSelectorType(self, value: str) -> ChiSqSelector: ... def setNumTopFeatures(self, value: int) -> ChiSqSelector: ... def setPercentile(self, value: float) -> ChiSqSelector: ... @@ -1602,7 +1602,10 @@ class _VarianceThresholdSelectorParams(HasFeaturesCol, HasOutputCol): def getVarianceThreshold(self) -> float: ... class VarianceThresholdSelector( - JavaEstimator, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable + JavaEstimator[VarianceThresholdSelectorModel], + _VarianceThresholdSelectorParams, + JavaMLReadable[VarianceThresholdSelector], + JavaMLWritable, ): def __init__( self, @@ -1615,13 +1618,16 @@ class VarianceThresholdSelector( featuresCol: str = ..., outputCol: Optional[str] = ..., varianceThreshold: float = ..., - ): ... + ) -> VarianceThresholdSelector: ... def setVarianceThreshold(self, value: float) -> VarianceThresholdSelector: ... def setFeaturesCol(self, value: str) -> VarianceThresholdSelector: ... def setOutputCol(self, value: str) -> VarianceThresholdSelector: ... class VarianceThresholdSelectorModel( - JavaModel, _VarianceThresholdSelectorParams, JavaMLReadable, JavaMLWritable + JavaModel, + _VarianceThresholdSelectorParams, + JavaMLReadable[VarianceThresholdSelectorModel], + JavaMLWritable, ): def setFeaturesCol(self, value: str) -> VarianceThresholdSelectorModel: ... def setOutputCol(self, value: str) -> VarianceThresholdSelectorModel: ... diff --git a/python/pyspark/ml/linalg/__init__.pyi b/python/pyspark/ml/linalg/__init__.pyi index a576b30aec..b4fba8823b 100644 --- a/python/pyspark/ml/linalg/__init__.pyi +++ b/python/pyspark/ml/linalg/__init__.pyi @@ -17,7 +17,7 @@ # under the License. from typing import overload -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, NoReturn, Optional, Tuple, Type, Union from pyspark.ml import linalg as newlinalg # noqa: F401 from pyspark.sql.types import StructType, UserDefinedType @@ -45,7 +45,7 @@ class MatrixUDT(UserDefinedType): @classmethod def scalaUDT(cls) -> str: ... def serialize( - self, obj + self, obj: Matrix ) -> Tuple[ int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool ]: ... @@ -64,9 +64,7 @@ class DenseVector(Vector): def __init__(self, __arr: bytes) -> None: ... @overload def __init__(self, __arr: Iterable[float]) -> None: ... - @staticmethod - def parse(s) -> DenseVector: ... - def __reduce__(self) -> Tuple[type, bytes]: ... + def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... def dot(self, other: Iterable[float]) -> float64: ... @@ -112,16 +110,14 @@ class SparseVector(Vector): def __init__(self, size: int, __map: Dict[int, float]) -> None: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... - def __reduce__(self): ... - @staticmethod - def parse(s: str) -> SparseVector: ... + def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ... def dot(self, other: Iterable[float]) -> float64: ... def squared_distance(self, other: Iterable[float]) -> float64: ... def toArray(self) -> ndarray: ... def __len__(self) -> int: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... def __getitem__(self, index: int) -> float64: ... - def __ne__(self, other) -> bool: ... + def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... class Vectors: @@ -144,13 +140,13 @@ class Vectors: def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ... @overload @staticmethod - def dense(self, *elements: float) -> DenseVector: ... + def dense(*elements: float) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: bytes) -> DenseVector: ... + def dense(__arr: bytes) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: Iterable[float]) -> DenseVector: ... + def dense(__arr: Iterable[float]) -> DenseVector: ... @staticmethod def stringify(vector: Vector) -> str: ... @staticmethod @@ -158,8 +154,6 @@ class Vectors: @staticmethod def norm(vector: Vector, p: Union[float, str]) -> float64: ... @staticmethod - def parse(s: str) -> Vector: ... - @staticmethod def zeros(size: int) -> DenseVector: ... class Matrix: @@ -170,7 +164,7 @@ class Matrix: def __init__( self, numRows: int, numCols: int, isTransposed: bool = ... ) -> None: ... - def toArray(self): ... + def toArray(self) -> NoReturn: ... class DenseMatrix(Matrix): values: Any @@ -186,11 +180,11 @@ class DenseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ... + def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ... def toArray(self) -> ndarray: ... def toSparse(self) -> SparseMatrix: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class SparseMatrix(Matrix): colPtrs: ndarray @@ -216,11 +210,13 @@ class SparseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ... + def __reduce__( + self, + ) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... def toArray(self) -> ndarray: ... def toDense(self) -> DenseMatrix: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class Matrices: @overload diff --git a/python/pyspark/ml/pipeline.pyi b/python/pyspark/ml/pipeline.pyi index 44680586d7..f47e9e012a 100644 --- a/python/pyspark/ml/pipeline.pyi +++ b/python/pyspark/ml/pipeline.pyi @@ -51,7 +51,7 @@ class PipelineWriter(MLWriter): def __init__(self, instance: Pipeline) -> None: ... def saveImpl(self, path: str) -> None: ... -class PipelineReader(MLReader): +class PipelineReader(MLReader[Pipeline]): cls: Type[Pipeline] def __init__(self, cls: Type[Pipeline]) -> None: ... def load(self, path: str) -> Pipeline: ... @@ -61,7 +61,7 @@ class PipelineModelWriter(MLWriter): def __init__(self, instance: PipelineModel) -> None: ... def saveImpl(self, path: str) -> None: ... -class PipelineModelReader(MLReader): +class PipelineModelReader(MLReader[PipelineModel]): cls: Type[PipelineModel] def __init__(self, cls: Type[PipelineModel]) -> None: ... def load(self, path: str) -> PipelineModel: ... diff --git a/python/pyspark/ml/regression.pyi b/python/pyspark/ml/regression.pyi index 5cb0e7a509..b8f1e61859 100644 --- a/python/pyspark/ml/regression.pyi +++ b/python/pyspark/ml/regression.pyi @@ -414,7 +414,7 @@ class RandomForestRegressionModel( _TreeEnsembleModel, _RandomForestRegressorParams, JavaMLWritable, - JavaMLReadable, + JavaMLReadable[RandomForestRegressionModel], ): @property def trees(self) -> List[DecisionTreeRegressionModel]: ... @@ -749,10 +749,10 @@ class _FactorizationMachinesParams( initStd: Param[float] solver: Param[str] def __init__(self, *args: Any): ... - def getFactorSize(self): ... - def getFitLinear(self): ... - def getMiniBatchFraction(self): ... - def getInitStd(self): ... + def getFactorSize(self) -> int: ... + def getFitLinear(self) -> bool: ... + def getMiniBatchFraction(self) -> float: ... + def getInitStd(self) -> float: ... class FMRegressor( _JavaRegressor[FMRegressionModel], diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index bbca216cce..bd43e91afd 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -88,20 +88,26 @@ class LogisticRegressionModel(LinearClassificationModel): Classification model trained using Multinomial/Binary Logistic Regression. - :param weights: - Weights computed for every feature. - :param intercept: - Intercept computed for this model. (Only used in Binary Logistic - Regression. In Multinomial Logistic Regression, the intercepts will - not bea single value, so the intercepts will be part of the - weights.) - :param numFeatures: - The dimension of the features. - :param numClasses: - The number of possible outcomes for k classes classification problem - in Multinomial Logistic Regression. By default, it is binary - logistic regression so numClasses will be set to 2. + .. versionadded:: 0.9.0 + Parameters + ---------- + weights : :py:class:`pyspark.mllib.linalg.Vector` + Weights computed for every feature. + intercept : float + Intercept computed for this model. (Only used in Binary Logistic + Regression. In Multinomial Logistic Regression, the intercepts will + not be a single value, so the intercepts will be part of the + weights.) + numFeatures : int + The dimension of the features. + numClasses : int + The number of possible outcomes for k classes classification problem + in Multinomial Logistic Regression. By default, it is binary + logistic regression so numClasses will be set to 2. + + Examples + -------- >>> from pyspark.mllib.linalg import SparseVector >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), @@ -159,8 +165,6 @@ class LogisticRegressionModel(LinearClassificationModel): 1 >>> mcm.predict([0.0, 0.0, 0.3]) 2 - - .. versionadded:: 0.9.0 """ def __init__(self, weights, intercept, numFeatures, numClasses): super(LogisticRegressionModel, self).__init__(weights, intercept) @@ -263,54 +267,60 @@ def __repr__(self): class LogisticRegressionWithSGD(object): """ + Train a classification model for Binary Logistic Regression using Stochastic Gradient Descent. + .. versionadded:: 0.9.0 - .. note:: Deprecated in 2.0.0. Use ml.classification.LogisticRegression or - LogisticRegressionWithLBFGS. + .. deprecated:: 2.0.0 + Use ml.classification.LogisticRegression or LogisticRegressionWithLBFGS. """ @classmethod - @since('0.9.0') def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.01, regType="l2", intercept=False, validateData=True, convergenceTol=0.001): """ Train a logistic regression model on the given data. - :param data: - The training data, an RDD of LabeledPoint. - :param iterations: - The number of iterations. - (default: 100) - :param step: - The step parameter used in SGD. - (default: 1.0) - :param miniBatchFraction: - Fraction of data to be used for each SGD iteration. - (default: 1.0) - :param initialWeights: - The initial weights. - (default: None) - :param regParam: - The regularizer parameter. - (default: 0.01) - :param regType: - The type of regularizer used for training our model. - Supported values: + .. versionadded:: 0.9.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + The training data, an RDD of :py:class:`pyspark.mllib.regression.LabeledPoint`. + iterations : int, optional + The number of iterations. + (default: 100) + step : float, optional + The step parameter used in SGD. + (default: 1.0) + miniBatchFraction : float, optional + Fraction of data to be used for each SGD iteration. + (default: 1.0) + initialWeights : :py:class:`pyspark.mllib.linalg.Vector` or convertible, optional + The initial weights. + (default: None) + regParam : float, optional + The regularizer parameter. + (default: 0.01) + regType : str, optional + The type of regularizer used for training our model. + Supported values: - "l1" for using L1 regularization - "l2" for using L2 regularization (default) - None for no regularization - :param intercept: - Boolean parameter which indicates the use or not of the - augmented representation for training data (i.e., whether bias - features are activated or not). - (default: False) - :param validateData: - Boolean parameter which indicates if the algorithm should - validate data before training. - (default: True) - :param convergenceTol: - A condition which decides iteration termination. - (default: 0.001) + + intercept : bool, optional + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e., whether bias + features are activated or not). + (default: False) + validateData : bool, optional + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + convergenceTol : float, optional + A condition which decides iteration termination. + (default: 0.001) """ warnings.warn( "Deprecated in 2.0.0. Use ml.classification.LogisticRegression or " @@ -326,55 +336,65 @@ def train(rdd, i): class LogisticRegressionWithLBFGS(object): """ + Train a classification model for Multinomial/Binary Logistic Regression + using Limited-memory BFGS. + + Standard feature scaling and L2 regularization are used by default. .. versionadded:: 1.2.0 """ @classmethod - @since('1.2.0') def train(cls, data, iterations=100, initialWeights=None, regParam=0.0, regType="l2", intercept=False, corrections=10, tolerance=1e-6, validateData=True, numClasses=2): """ Train a logistic regression model on the given data. - :param data: - The training data, an RDD of LabeledPoint. - :param iterations: - The number of iterations. - (default: 100) - :param initialWeights: - The initial weights. - (default: None) - :param regParam: - The regularizer parameter. - (default: 0.0) - :param regType: - The type of regularizer used for training our model. - Supported values: + .. versionadded:: 1.2.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + The training data, an RDD of :py:class:`pyspark.mllib.regression.LabeledPoint`. + iterations : int, optional + The number of iterations. + (default: 100) + initialWeights : :py:class:`pyspark.mllib.linalg.Vector` or convertible, optional + The initial weights. + (default: None) + regParam : float, optional + The regularizer parameter. + (default: 0.01) + regType : str, optional + The type of regularizer used for training our model. + Supported values: - "l1" for using L1 regularization - "l2" for using L2 regularization (default) - None for no regularization - :param intercept: - Boolean parameter which indicates the use or not of the - augmented representation for training data (i.e., whether bias - features are activated or not). - (default: False) - :param corrections: - The number of corrections used in the LBFGS update. - If a known updater is used for binary classification, - it calls the ml implementation and this parameter will - have no effect. (default: 10) - :param tolerance: - The convergence tolerance of iterations for L-BFGS. - (default: 1e-6) - :param validateData: - Boolean parameter which indicates if the algorithm should - validate data before training. - (default: True) - :param numClasses: - The number of classes (i.e., outcomes) a label can take in - Multinomial Logistic Regression. - (default: 2) + intercept : bool, optional + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e., whether bias + features are activated or not). + (default: False) + corrections : int, optional + The number of corrections used in the LBFGS update. + If a known updater is used for binary classification, + it calls the ml implementation and this parameter will + have no effect. (default: 10) + tolerance : float, optional + The convergence tolerance of iterations for L-BFGS. + (default: 1e-6) + validateData : bool, optional + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + numClasses : int, optional + The number of classes (i.e., outcomes) a label can take in + Multinomial Logistic Regression. + (default: 2) + + Examples + -------- >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), ... LabeledPoint(1.0, [1.0, 0.0]), @@ -406,11 +426,17 @@ class SVMModel(LinearClassificationModel): """ Model for Support Vector Machines (SVMs). - :param weights: - Weights computed for every feature. - :param intercept: - Intercept computed for this model. + .. versionadded:: 0.9.0 + + Parameters + ---------- + weights : :py:class:`pyspark.mllib.linalg.Vector` + Weights computed for every feature. + intercept : float + Intercept computed for this model. + Examples + -------- >>> from pyspark.mllib.linalg import SparseVector >>> data = [ ... LabeledPoint(0.0, [0.0]), @@ -451,8 +477,6 @@ class SVMModel(LinearClassificationModel): ... rmtree(path) ... except: ... pass - - .. versionadded:: 0.9.0 """ def __init__(self, weights, intercept): super(SVMModel, self).__init__(weights, intercept) @@ -501,53 +525,59 @@ def load(cls, sc, path): class SVMWithSGD(object): """ + Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. + .. versionadded:: 0.9.0 """ @classmethod - @since('0.9.0') def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False, validateData=True, convergenceTol=0.001): """ Train a support vector machine on the given data. - :param data: - The training data, an RDD of LabeledPoint. - :param iterations: - The number of iterations. - (default: 100) - :param step: - The step parameter used in SGD. - (default: 1.0) - :param regParam: - The regularizer parameter. - (default: 0.01) - :param miniBatchFraction: - Fraction of data to be used for each SGD iteration. - (default: 1.0) - :param initialWeights: - The initial weights. - (default: None) - :param regType: - The type of regularizer used for training our model. - Allowed values: + .. versionadded:: 0.9.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + The training data, an RDD of :py:class:`pyspark.mllib.regression.LabeledPoint`. + iterations : int, optional + The number of iterations. + (default: 100) + step : float, optional + The step parameter used in SGD. + (default: 1.0) + regParam : float, optional + The regularizer parameter. + (default: 0.01) + miniBatchFraction : float, optional + Fraction of data to be used for each SGD iteration. + (default: 1.0) + initialWeights : :py:class:`pyspark.mllib.linalg.Vector` or convertible, optional + The initial weights. + (default: None) + regType : str, optional + The type of regularizer used for training our model. + Allowed values: - "l1" for using L1 regularization - "l2" for using L2 regularization (default) - None for no regularization - :param intercept: - Boolean parameter which indicates the use or not of the - augmented representation for training data (i.e. whether bias - features are activated or not). - (default: False) - :param validateData: - Boolean parameter which indicates if the algorithm should - validate data before training. - (default: True) - :param convergenceTol: - A condition which decides iteration termination. - (default: 0.001) + + intercept : bool, optional + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + validateData : bool, optional + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + convergenceTol : float, optional + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), @@ -563,14 +593,20 @@ class NaiveBayesModel(Saveable, Loader): """ Model for Naive Bayes classifiers. - :param labels: - List of labels. - :param pi: - Log of class priors, whose dimension is C, number of labels. - :param theta: - Log of class conditional probabilities, whose dimension is C-by-D, - where D is number of features. + .. versionadded:: 0.9.0 + Parameters + ---------- + labels : :py:class:`numpy.ndarray` + List of labels. + pi : :py:class:`numpy.ndarray` + Log of class priors, whose dimension is C, number of labels. + theta : :py:class:`numpy.ndarray` + Log of class conditional probabilities, whose dimension is C-by-D, + where D is number of features. + + Examples + -------- >>> from pyspark.mllib.linalg import SparseVector >>> data = [ ... LabeledPoint(0.0, [0.0, 0.0]), @@ -605,8 +641,6 @@ class NaiveBayesModel(Saveable, Loader): ... rmtree(path) ... except OSError: ... pass - - .. versionadded:: 0.9.0 """ def __init__(self, labels, pi, theta): self.labels = labels @@ -652,11 +686,12 @@ def load(cls, sc, path): class NaiveBayes(object): """ + Train a Multinomial Naive Bayes model. + .. versionadded:: 0.9.0 """ @classmethod - @since('0.9.0') def train(cls, data, lambda_=1.0): """ Train a Naive Bayes model given an RDD of (label, features) @@ -669,11 +704,15 @@ def train(cls, data, lambda_=1.0): it can also be used as `Bernoulli NB `_. The input feature values must be nonnegative. - :param data: - RDD of LabeledPoint. - :param lambda_: - The smoothing parameter. - (default: 1.0) + .. versionadded:: 0.9.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + The training data, an RDD of :py:class:`pyspark.mllib.regression.LabeledPoint`. + lambda\\_ : float, optional + The smoothing parameter. + (default: 1.0) """ first = data.first() if not isinstance(first, LabeledPoint): @@ -694,23 +733,25 @@ class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): of features must be constant. An initial weight vector must be provided. - :param stepSize: - Step size for each iteration of gradient descent. - (default: 0.1) - :param numIterations: - Number of iterations run for each batch of data. - (default: 50) - :param miniBatchFraction: - Fraction of each batch of data to use for updates. - (default: 1.0) - :param regParam: - L2 Regularization parameter. - (default: 0.0) - :param convergenceTol: - Value used to determine when to terminate iterations. - (default: 0.001) - .. versionadded:: 1.5.0 + + Parameters + ---------- + stepSize : float, optional + Step size for each iteration of gradient descent. + (default: 0.1) + numIterations : int, optional + Number of iterations run for each batch of data. + (default: 50) + miniBatchFraction : float, optional + Fraction of each batch of data to use for updates. + (default: 1.0) + regParam : float, optional + L2 Regularization parameter. + (default: 0.0) + convergenceTol : float, optional + Value used to determine when to terminate iterations. + (default: 0.001) """ def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.0, convergenceTol=0.001): diff --git a/python/pyspark/mllib/classification.pyi b/python/pyspark/mllib/classification.pyi index c51882c87b..967b0a9f28 100644 --- a/python/pyspark/mllib/classification.pyi +++ b/python/pyspark/mllib/classification.pyi @@ -118,7 +118,7 @@ class NaiveBayesModel(Saveable, Loader[NaiveBayesModel]): labels: ndarray pi: ndarray theta: ndarray - def __init__(self, labels, pi, theta) -> None: ... + def __init__(self, labels: ndarray, pi: ndarray, theta: ndarray) -> None: ... @overload def predict(self, x: VectorLike) -> float64: ... @overload diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index b99a4150c3..e1a009643c 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -41,6 +41,10 @@ class BisectingKMeansModel(JavaModelWrapper): """ A clustering model derived from the bisecting k-means method. + .. versionadded:: 2.0.0 + + Examples + -------- >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4, 2) >>> bskm = BisectingKMeans() >>> model = bskm.train(sc.parallelize(data, 2), k=4) @@ -51,8 +55,6 @@ class BisectingKMeansModel(JavaModelWrapper): 4 >>> model.computeCost(p) 0.0 - - .. versionadded:: 2.0.0 """ def __init__(self, java_model): @@ -72,17 +74,25 @@ def k(self): """Get the number of clusters""" return self.call("k") - @since('2.0.0') def predict(self, x): """ Find the cluster that each of the points belongs to in this model. - :param x: - A data point (or RDD of points) to determine cluster index. - :return: - Predicted cluster index or an RDD of predicted cluster indices - if the input is an RDD. + .. versionadded:: 2.0.0 + + Parameters + ---------- + x : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + A data point (or RDD of points) to determine cluster index. + :py:class:`pyspark.mllib.linalg.Vector` can be replaced with equivalent + objects (list, tuple, numpy.ndarray). + + Returns + ------- + int or :py:class:`pyspark.RDD` of int + Predicted cluster index or an RDD of predicted cluster indices + if the input is an RDD. """ if isinstance(x, RDD): vecs = x.map(_convert_to_vector) @@ -91,15 +101,20 @@ def predict(self, x): x = _convert_to_vector(x) return self.call("predict", x) - @since('2.0.0') def computeCost(self, x): """ Return the Bisecting K-means cost (sum of squared distances of points to their nearest center) for this model on the given data. If provided with an RDD of points returns the sum. - :param point: - A data point (or RDD of points) to compute the cost(s). + .. versionadded:: 2.0.0 + + Parameters + ---------- + point : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + A data point (or RDD of points) to compute the cost(s). + :py:class:`pyspark.mllib.linalg.Vector` can be replaced with equivalent + objects (list, tuple, numpy.ndarray). """ if isinstance(x, RDD): vecs = x.map(_convert_to_vector) @@ -122,37 +137,43 @@ class BisectingKMeans(object): clusters on the bottom level would result more than `k` leaf clusters, larger clusters get higher priority. - Based on - `Steinbach, Karypis, and Kumar, A comparison of document clustering - techniques, KDD Workshop on Text Mining, 2000 - `_. - .. versionadded:: 2.0.0 + + Notes + ----- + See the original paper [1]_ + + .. [1] Steinbach, M. et al. “A Comparison of Document Clustering Techniques.” (2000). + KDD Workshop on Text Mining, 2000 + http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf """ @classmethod - @since('2.0.0') def train(self, rdd, k=4, maxIterations=20, minDivisibleClusterSize=1.0, seed=-1888008604): """ Runs the bisecting k-means algorithm return the model. - :param rdd: - Training points as an `RDD` of `Vector` or convertible - sequence types. - :param k: - The desired number of leaf clusters. The actual number could - be smaller if there are no divisible leaf clusters. - (default: 4) - :param maxIterations: - Maximum number of iterations allowed to split clusters. - (default: 20) - :param minDivisibleClusterSize: - Minimum number of points (if >= 1.0) or the minimum proportion - of points (if < 1.0) of a divisible cluster. - (default: 1) - :param seed: - Random seed value for cluster initialization. - (default: -1888008604 from classOf[BisectingKMeans].getName.##) + .. versionadded:: 2.0.0 + + Parameters + ---------- + rdd : :py:class:`pyspark.RDD` + Training points as an `RDD` of `Vector` or convertible + sequence types. + k : int, optional + The desired number of leaf clusters. The actual number could + be smaller if there are no divisible leaf clusters. + (default: 4) + maxIterations : int, optional + Maximum number of iterations allowed to split clusters. + (default: 20) + minDivisibleClusterSize : float, optional + Minimum number of points (if >= 1.0) or the minimum proportion + of points (if < 1.0) of a divisible cluster. + (default: 1) + seed : int, optional + Random seed value for cluster initialization. + (default: -1888008604 from classOf[BisectingKMeans].getName.##) """ java_model = callMLlibFunc( "trainBisectingKMeans", rdd.map(_convert_to_vector), @@ -165,6 +186,10 @@ class KMeansModel(Saveable, Loader): """A clustering model derived from the k-means method. + .. versionadded:: 0.9.0 + + Examples + -------- >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4, 2) >>> model = KMeans.train( ... sc.parallelize(data), 2, maxIterations=10, initializationMode="random", @@ -213,8 +238,6 @@ class KMeansModel(Saveable, Loader): ... initialModel = KMeansModel([(-1000.0,-1000.0),(5.0,5.0),(1000.0,1000.0)])) >>> model.clusterCenters [array([-1000., -1000.]), array([ 5., 5.]), array([ 1000., 1000.])] - - .. versionadded:: 0.9.0 """ def __init__(self, centers): @@ -232,17 +255,25 @@ def k(self): """Total number of clusters.""" return len(self.centers) - @since('0.9.0') def predict(self, x): """ Find the cluster that each of the points belongs to in this model. - :param x: - A data point (or RDD of points) to determine cluster index. - :return: - Predicted cluster index or an RDD of predicted cluster indices - if the input is an RDD. + .. versionadded:: 0.9.0 + + Parameters + ---------- + x : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + A data point (or RDD of points) to determine cluster index. + :py:class:`pyspark.mllib.linalg.Vector` can be replaced with equivalent + objects (list, tuple, numpy.ndarray). + + Returns + ------- + int or :py:class:`pyspark.RDD` of int + Predicted cluster index or an RDD of predicted cluster indices + if the input is an RDD. """ best = 0 best_distance = float("inf") @@ -257,15 +288,18 @@ def predict(self, x): best_distance = distance return best - @since('1.4.0') def computeCost(self, rdd): """ Return the K-means cost (sum of squared distances of points to their nearest center) for this model on the given data. - :param rdd: - The RDD of points to compute the cost on. + .. versionadded:: 1.4.0 + + Parameters + ---------- + rdd : ::py:class:`pyspark.RDD` + The RDD of points to compute the cost on. """ cost = callMLlibFunc("computeCostKmeansModel", rdd.map(_convert_to_vector), [_convert_to_vector(c) for c in self.centers]) @@ -292,46 +326,51 @@ def load(cls, sc, path): class KMeans(object): """ + K-means clustering. + .. versionadded:: 0.9.0 """ @classmethod - @since('0.9.0') def train(cls, rdd, k, maxIterations=100, initializationMode="k-means||", seed=None, initializationSteps=2, epsilon=1e-4, initialModel=None): """ Train a k-means clustering model. - :param rdd: - Training points as an `RDD` of `Vector` or convertible - sequence types. - :param k: - Number of clusters to create. - :param maxIterations: - Maximum number of iterations allowed. - (default: 100) - :param initializationMode: - The initialization algorithm. This can be either "random" or - "k-means||". - (default: "k-means||") - :param seed: - Random seed value for cluster initialization. Set as None to - generate seed based on system time. - (default: None) - :param initializationSteps: - Number of steps for the k-means|| initialization mode. - This is an advanced setting -- the default of 2 is almost - always enough. - (default: 2) - :param epsilon: - Distance threshold within which a center will be considered to - have converged. If all centers move less than this Euclidean - distance, iterations are stopped. - (default: 1e-4) - :param initialModel: - Initial cluster centers can be provided as a KMeansModel object - rather than using the random or k-means|| initializationModel. - (default: None) + .. versionadded:: 0.9.0 + + Parameters + ---------- + rdd : ::py:class:`pyspark.RDD` + Training points as an `RDD` of :py:class:`pyspark.mllib.linalg.Vector` + or convertible sequence types. + k : int + Number of clusters to create. + maxIterations : int, optional + Maximum number of iterations allowed. + (default: 100) + initializationMode : str, optional + The initialization algorithm. This can be either "random" or + "k-means||". + (default: "k-means||") + seed : int, optional + Random seed value for cluster initialization. Set as None to + generate seed based on system time. + (default: None) + initializationSteps : + Number of steps for the k-means|| initialization mode. + This is an advanced setting -- the default of 2 is almost + always enough. + (default: 2) + epsilon : float, optional + Distance threshold within which a center will be considered to + have converged. If all centers move less than this Euclidean + distance, iterations are stopped. + (default: 1e-4) + initialModel : :py:class:`KMeansModel`, optional + Initial cluster centers can be provided as a KMeansModel object + rather than using the random or k-means|| initializationModel. + (default: None) """ clusterInitialModel = [] if initialModel is not None: @@ -352,6 +391,10 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ A clustering model derived from the Gaussian Mixture Model method. + .. versionadded:: 1.3.0 + + Examples + -------- >>> from pyspark.mllib.linalg import Vectors, DenseMatrix >>> from numpy.testing import assert_equal >>> from shutil import rmtree @@ -410,8 +453,6 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): True >>> labels[2]==labels[3]==labels[4] True - - .. versionadded:: 1.3.0 """ @property @@ -440,17 +481,23 @@ def k(self): """Number of gaussians in mixture.""" return len(self.weights) - @since('1.3.0') def predict(self, x): """ Find the cluster to which the point 'x' or each point in RDD 'x' has maximum membership in this model. - :param x: - A feature vector or an RDD of vectors representing data points. - :return: - Predicted cluster label or an RDD of predicted cluster labels - if the input is an RDD. + .. versionadded:: 1.3.0 + + Parameters + ---------- + x : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + A feature vector or an RDD of vectors representing data points. + + Returns + ------- + numpy.float64 or :py:class:`pyspark.RDD` of int + Predicted cluster label or an RDD of predicted cluster labels + if the input is an RDD. """ if isinstance(x, RDD): cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) @@ -459,16 +506,22 @@ def predict(self, x): z = self.predictSoft(x) return z.argmax() - @since('1.3.0') def predictSoft(self, x): """ Find the membership of point 'x' or each point in RDD 'x' to all mixture components. - :param x: - A feature vector or an RDD of vectors representing data points. - :return: - The membership value to all mixture components for vector 'x' - or each vector in RDD 'x'. + .. versionadded:: 1.3.0 + + Parameters + ---------- + x : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + A feature vector or an RDD of vectors representing data points. + + Returns + ------- + numpy.ndarray or :py:class:`pyspark.RDD` + The membership value to all mixture components for vector 'x' + or each vector in RDD 'x'. """ if isinstance(x, RDD): means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) @@ -479,14 +532,16 @@ def predictSoft(self, x): return self.call("predictSoft", _convert_to_vector(x)).toArray() @classmethod - @since('1.5.0') def load(cls, sc, path): """Load the GaussianMixtureModel from disk. - :param sc: - SparkContext. - :param path: - Path to where the model is stored. + .. versionadded:: 1.5.0 + + Parameters + ---------- + sc : :py:class:`SparkContext` + path : str + Path to where the model is stored. """ model = cls._load_java(sc, path) wrapper = sc._jvm.org.apache.spark.mllib.api.python.GaussianMixtureModelWrapper(model) @@ -499,32 +554,36 @@ class GaussianMixture(object): .. versionadded:: 1.3.0 """ + @classmethod - @since('1.3.0') def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initialModel=None): """ Train a Gaussian Mixture clustering model. - :param rdd: - Training points as an `RDD` of `Vector` or convertible - sequence types. - :param k: - Number of independent Gaussians in the mixture model. - :param convergenceTol: - Maximum change in log-likelihood at which convergence is - considered to have occurred. - (default: 1e-3) - :param maxIterations: - Maximum number of iterations allowed. - (default: 100) - :param seed: - Random seed for initial Gaussian distribution. Set as None to - generate seed based on system time. - (default: None) - :param initialModel: - Initial GMM starting point, bypassing the random - initialization. - (default: None) + .. versionadded:: 1.3.0 + + Parameters + ---------- + rdd : ::py:class:`pyspark.RDD` + Training points as an `RDD` of :py:class:`pyspark.mllib.linalg.Vector` + or convertible sequence types. + k : int + Number of independent Gaussians in the mixture model. + convergenceTol : float, optional + Maximum change in log-likelihood at which convergence is + considered to have occurred. + (default: 1e-3) + maxIterations : int, optional + Maximum number of iterations allowed. + (default: 100) + seed : int, optional + Random seed for initial Gaussian distribution. Set as None to + generate seed based on system time. + (default: None) + initialModel : GaussianMixtureModel, optional + Initial GMM starting point, bypassing the random + initialization. + (default: None) """ initialModelWeights = None initialModelMu = None @@ -545,8 +604,12 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ - Model produced by [[PowerIterationClustering]]. + Model produced by :py:class:`PowerIterationClustering`. + .. versionadded:: 1.5.0 + + Examples + -------- >>> import math >>> def genCircle(r, n): ... points = [] @@ -589,8 +652,6 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): ... rmtree(path) ... except OSError: ... pass - - .. versionadded:: 1.5.0 """ @property @@ -623,37 +684,48 @@ def load(cls, sc, path): class PowerIterationClustering(object): """ - Power Iteration Clustering (PIC), a scalable graph clustering algorithm - developed by [[http://www.cs.cmu.edu/~frank/papers/icml2010-pic-final.pdf Lin and Cohen]]. - From the abstract: PIC finds a very low-dimensional embedding of a - dataset using truncated power iteration on a normalized pair-wise - similarity matrix of the data. + Power Iteration Clustering (PIC), a scalable graph clustering algorithm. + + + Developed by Lin and Cohen [1]_. From the abstract: + + "PIC finds a very low-dimensional embedding of a + dataset using truncated power iteration on a normalized pair-wise + similarity matrix of the data." .. versionadded:: 1.5.0 + + .. [1] Lin, Frank & Cohen, William. (2010). Power Iteration Clustering. + http://www.cs.cmu.edu/~frank/papers/icml2010-pic-final.pdf """ @classmethod - @since('1.5.0') def train(cls, rdd, k, maxIterations=100, initMode="random"): r""" - :param rdd: - An RDD of (i, j, s\ :sub:`ij`\) tuples representing the - affinity matrix, which is the matrix A in the PIC paper. The - similarity s\ :sub:`ij`\ must be nonnegative. This is a symmetric - matrix and hence s\ :sub:`ij`\ = s\ :sub:`ji`\ For any (i, j) with - nonzero similarity, there should be either (i, j, s\ :sub:`ij`\) or - (j, i, s\ :sub:`ji`\) in the input. Tuples with i = j are ignored, - because it is assumed s\ :sub:`ij`\ = 0.0. - :param k: - Number of clusters. - :param maxIterations: - Maximum number of iterations of the PIC algorithm. - (default: 100) - :param initMode: - Initialization mode. This can be either "random" to use - a random vector as vertex properties, or "degree" to use - normalized sum similarities. - (default: "random") + Train PowerIterationClusteringModel + + .. versionadded:: 1.5.0 + + Parameters + ---------- + rdd : :py:class:`pyspark.RDD` + An RDD of (i, j, s\ :sub:`ij`\) tuples representing the + affinity matrix, which is the matrix A in the PIC paper. The + similarity s\ :sub:`ij`\ must be nonnegative. This is a symmetric + matrix and hence s\ :sub:`ij`\ = s\ :sub:`ji`\ For any (i, j) with + nonzero similarity, there should be either (i, j, s\ :sub:`ij`\) or + (j, i, s\ :sub:`ji`\) in the input. Tuples with i = j are ignored, + because it is assumed s\ :sub:`ij`\ = 0.0. + k : int + Number of clusters. + maxIterations : int, optional + Maximum number of iterations of the PIC algorithm. + (default: 100) + initMode : str, optional + Initialization mode. This can be either "random" to use + a random vector as vertex properties, or "degree" to use + normalized sum similarities. + (default: "random") """ model = callMLlibFunc("trainPowerIterationClusteringModel", rdd.map(_convert_to_vector), int(k), int(maxIterations), initMode) @@ -673,29 +745,37 @@ class StreamingKMeansModel(KMeansModel): The update formula for each centroid is given by - * c_t+1 = ((c_t * n_t * a) + (x_t * m_t)) / (n_t + m_t) - * n_t+1 = n_t * a + m_t + - c_t+1 = ((c_t * n_t * a) + (x_t * m_t)) / (n_t + m_t) + - n_t+1 = n_t * a + m_t where - * c_t: Centroid at the n_th iteration. - * n_t: Number of samples (or) weights associated with the centroid - at the n_th iteration. - * x_t: Centroid of the new data closest to c_t. - * m_t: Number of samples (or) weights of the new data closest to c_t - * c_t+1: New centroid. - * n_t+1: New number of weights. - * a: Decay Factor, which gives the forgetfulness. + - c_t: Centroid at the n_th iteration. + - n_t: Number of samples (or) weights associated with the centroid + at the n_th iteration. + - x_t: Centroid of the new data closest to c_t. + - m_t: Number of samples (or) weights of the new data closest to c_t + - c_t+1: New centroid. + - n_t+1: New number of weights. + - a: Decay Factor, which gives the forgetfulness. - .. note:: If a is set to 1, it is the weighted mean of the previous - and new data. If it set to zero, the old centroids are completely - forgotten. - - :param clusterCenters: - Initial cluster centers. - :param clusterWeights: - List of weights assigned to each cluster. + .. versionadded:: 1.5.0 + Parameters + ---------- + clusterCenters : list of :py:class:`pyspark.mllib.linalg.Vector` or covertible + Initial cluster centers. + clusterWeights : :py:class:`pyspark.mllib.linalg.Vector` or covertible + List of weights assigned to each cluster. + + Notes + ----- + If a is set to 1, it is the weighted mean of the previous + and new data. If it set to zero, the old centroids are completely + forgotten. + + Examples + -------- >>> initCenters = [[0.0, 0.0], [1.0, 1.0]] >>> initWeights = [1.0, 1.0] >>> stkm = StreamingKMeansModel(initCenters, initWeights) @@ -723,8 +803,6 @@ class StreamingKMeansModel(KMeansModel): 0 >>> stkm.predict([1.5, 1.5]) 1 - - .. versionadded:: 1.5.0 """ def __init__(self, clusterCenters, clusterWeights): super(StreamingKMeansModel, self).__init__(centers=clusterCenters) @@ -740,14 +818,18 @@ def clusterWeights(self): def update(self, data, decayFactor, timeUnit): """Update the centroids, according to data - :param data: - RDD with new data for the model update. - :param decayFactor: - Forgetfulness of the previous centroids. - :param timeUnit: - Can be "batches" or "points". If points, then the decay factor - is raised to the power of number of new points and if batches, - then decay factor will be used as is. + .. versionadded:: 1.5.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + RDD with new data for the model update. + decayFactor : float + Forgetfulness of the previous centroids. + timeUnit : str + Can be "batches" or "points". If points, then the decay factor + is raised to the power of number of new points and if batches, + then decay factor will be used as is. """ if not isinstance(data, RDD): raise TypeError("Data should be of an RDD, got %s." % type(data)) @@ -772,19 +854,21 @@ class StreamingKMeans(object): More details on how the centroids are updated are provided under the docs of StreamingKMeansModel. - :param k: - Number of clusters. - (default: 2) - :param decayFactor: - Forgetfulness of the previous centroids. - (default: 1.0) - :param timeUnit: - Can be "batches" or "points". If points, then the decay factor is - raised to the power of number of new points and if batches, then - decay factor will be used as is. - (default: "batches") - .. versionadded:: 1.5.0 + + Parameters + ---------- + k : int, optional + Number of clusters. + (default: 2) + decayFactor : float, optional + Forgetfulness of the previous centroids. + (default: 1.0) + timeUnit : str, optional + Can be "batches" or "points". If points, then the decay factor is + raised to the power of number of new points and if batches, then + decay factor will be used as is. + (default: "batches") """ def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"): self._k = k @@ -887,13 +971,23 @@ class LDAModel(JavaModelWrapper, JavaSaveable, Loader): Latent Dirichlet Allocation (LDA), a topic model designed for text documents. Terminology + - "word" = "term": an element of the vocabulary - "token": instance of a term appearing in a document - "topic": multinomial distribution over words representing some concept - References: - - Original LDA paper (journal version): - Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + .. versionadded:: 1.5.0 + + Notes + ----- + See the original LDA paper (journal version) [1]_ + + .. [1] Blei, D. et al. "Latent Dirichlet Allocation." + J. Mach. Learn. Res. 3 (2003): 993-1022. + https://www.jmlr.org/papers/v3/blei03a + + Examples + -------- >>> from pyspark.mllib.linalg import Vectors >>> from numpy.testing import assert_almost_equal, assert_equal >>> data = [ @@ -925,8 +1019,6 @@ class LDAModel(JavaModelWrapper, JavaSaveable, Loader): ... rmtree(path) ... except OSError: ... pass - - .. versionadded:: 1.5.0 """ @since('1.5.0') @@ -939,19 +1031,24 @@ def vocabSize(self): """Vocabulary size (number of terms or terms in the vocabulary)""" return self.call("vocabSize") - @since('1.6.0') def describeTopics(self, maxTermsPerTopic=None): """Return the topics described by weighted terms. - WARNING: If vocabSize and k are large, this can return a large object! - - :param maxTermsPerTopic: - Maximum number of terms to collect for each topic. - (default: vocabulary size) - :return: - Array over topics. Each topic is represented as a pair of - matching arrays: (term indices, term weights in topic). - Each topic's terms are sorted in order of decreasing weight. + .. versionadded:: 1.6.0 + .. warning:: If vocabSize and k are large, this can return a large object! + + Parameters + ---------- + maxTermsPerTopic : int, optional + Maximum number of terms to collect for each topic. + (default: vocabulary size) + + Returns + ------- + list + Array over topics. Each topic is represented as a pair of + matching arrays: (term indices, term weights in topic). + Each topic's terms are sorted in order of decreasing weight. """ if maxTermsPerTopic is None: topics = self.call("describeTopics") @@ -960,14 +1057,16 @@ def describeTopics(self, maxTermsPerTopic=None): return topics @classmethod - @since('1.5.0') def load(cls, sc, path): """Load the LDAModel from disk. - :param sc: - SparkContext. - :param path: - Path to where the model is stored. + .. versionadded:: 1.5.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + path : str + Path to where the model is stored. """ if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) @@ -979,47 +1078,52 @@ def load(cls, sc, path): class LDA(object): """ + Train Latent Dirichlet Allocation (LDA) model. + .. versionadded:: 1.5.0 """ @classmethod - @since('1.5.0') def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"): """Train a LDA model. - :param rdd: - RDD of documents, which are tuples of document IDs and term - (word) count vectors. The term count vectors are "bags of - words" with a fixed-size vocabulary (where the vocabulary size - is the length of the vector). Document IDs must be unique - and >= 0. - :param k: - Number of topics to infer, i.e., the number of soft cluster - centers. - (default: 10) - :param maxIterations: - Maximum number of iterations allowed. - (default: 20) - :param docConcentration: - Concentration parameter (commonly named "alpha") for the prior - placed on documents' distributions over topics ("theta"). - (default: -1.0) - :param topicConcentration: - Concentration parameter (commonly named "beta" or "eta") for - the prior placed on topics' distributions over terms. - (default: -1.0) - :param seed: - Random seed for cluster initialization. Set as None to generate - seed based on system time. - (default: None) - :param checkpointInterval: - Period (in iterations) between checkpoints. - (default: 10) - :param optimizer: - LDAOptimizer used to perform the actual calculation. Currently - "em", "online" are supported. - (default: "em") + .. versionadded:: 1.5.0 + + Parameters + ---------- + rdd : :py:class:`pyspark.RDD` + RDD of documents, which are tuples of document IDs and term + (word) count vectors. The term count vectors are "bags of + words" with a fixed-size vocabulary (where the vocabulary size + is the length of the vector). Document IDs must be unique + and >= 0. + k : int, optional + Number of topics to infer, i.e., the number of soft cluster + centers. + (default: 10) + maxIterations : int, optional + Maximum number of iterations allowed. + (default: 20) + docConcentration : float, optional + Concentration parameter (commonly named "alpha") for the prior + placed on documents' distributions over topics ("theta"). + (default: -1.0) + topicConcentration : float, optional + Concentration parameter (commonly named "beta" or "eta") for + the prior placed on topics' distributions over terms. + (default: -1.0) + seed : int, optional + Random seed for cluster initialization. Set as None to generate + seed based on system time. + (default: None) + checkpointInterval : int, optional + Period (in iterations) between checkpoints. + (default: 10) + optimizer : str, optional + LDAOptimizer used to perform the actual calculation. Currently + "em", "online" are supported. + (default: "em") """ model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations, docConcentration, topicConcentration, seed, diff --git a/python/pyspark/mllib/clustering.pyi b/python/pyspark/mllib/clustering.pyi index 1c3eba17e2..b4f349612f 100644 --- a/python/pyspark/mllib/clustering.pyi +++ b/python/pyspark/mllib/clustering.pyi @@ -63,7 +63,7 @@ class BisectingKMeans: class KMeansModel(Saveable, Loader[KMeansModel]): centers: List[ndarray] - def __init__(self, centers: List[ndarray]) -> None: ... + def __init__(self, centers: List[VectorLike]) -> None: ... @property def clusterCenters(self) -> List[ndarray]: ... @property @@ -144,7 +144,9 @@ class PowerIterationClustering: class Assignment(NamedTuple("Assignment", [("id", int), ("cluster", int)])): ... class StreamingKMeansModel(KMeansModel): - def __init__(self, clusterCenters, clusterWeights) -> None: ... + def __init__( + self, clusterCenters: List[VectorLike], clusterWeights: VectorLike + ) -> None: ... @property def clusterWeights(self) -> List[float64]: ... centers: ndarray diff --git a/python/pyspark/mllib/common.pyi b/python/pyspark/mllib/common.pyi index 1df308b91b..daba212d93 100644 --- a/python/pyspark/mllib/common.pyi +++ b/python/pyspark/mllib/common.pyi @@ -16,12 +16,20 @@ # specific language governing permissions and limitations # under the License. -def callJavaFunc(sc, func, *args): ... -def callMLlibFunc(name, *args): ... +from typing import Any, TypeVar + +import pyspark.context + +from py4j.java_gateway import JavaObject + +C = TypeVar("C", bound=type) + +def callJavaFunc(sc: pyspark.context.SparkContext, func: Any, *args: Any) -> Any: ... +def callMLlibFunc(name: str, *args: Any) -> Any: ... class JavaModelWrapper: - def __init__(self, java_model) -> None: ... - def __del__(self): ... - def call(self, name, *a): ... + def __init__(self, java_model: JavaObject) -> None: ... + def __del__(self) -> None: ... + def call(self, name: str, *a: Any) -> Any: ... -def inherit_doc(cls): ... +def inherit_doc(cls: C) -> C: ... diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index f3be827fb6..198a979177 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -30,8 +30,15 @@ class BinaryClassificationMetrics(JavaModelWrapper): """ Evaluator for binary classification. - :param scoreAndLabels: an RDD of score, label and optional weight. + .. versionadded:: 1.4.0 + + Parameters + ---------- + scoreAndLabels : :py:class:`pyspark.RDD` + an RDD of score, label and optional weight. + Examples + -------- >>> scoreAndLabels = sc.parallelize([ ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2) >>> metrics = BinaryClassificationMetrics(scoreAndLabels) @@ -48,8 +55,6 @@ class BinaryClassificationMetrics(JavaModelWrapper): 0.79... >>> metrics.areaUnderPR 0.88... - - .. versionadded:: 1.4.0 """ def __init__(self, scoreAndLabels): @@ -95,8 +100,15 @@ class RegressionMetrics(JavaModelWrapper): """ Evaluator for regression. - :param predictionAndObservations: an RDD of prediction, observation and optional weight. + .. versionadded:: 1.4.0 + + Parameters + ---------- + predictionAndObservations : :py:class:`pyspark.RDD` + an RDD of prediction, observation and optional weight. + Examples + -------- >>> predictionAndObservations = sc.parallelize([ ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)]) >>> metrics = RegressionMetrics(predictionAndObservations) @@ -115,8 +127,6 @@ class RegressionMetrics(JavaModelWrapper): >>> metrics = RegressionMetrics(predictionAndObservationsWithOptWeight) >>> metrics.rootMeanSquaredError 0.68... - - .. versionadded:: 1.4.0 """ def __init__(self, predictionAndObservations): @@ -182,9 +192,15 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. - :param predictionAndLabels: an RDD of prediction, label, optional weight - and optional probability. + .. versionadded:: 1.4.0 + + Parameters + ---------- + predictionAndLabels : :py:class:`pyspark.RDD` + an RDD of prediction, label, optional weight and optional probability. + Examples + -------- >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) >>> metrics = MulticlassMetrics(predictionAndLabels) @@ -246,8 +262,6 @@ class MulticlassMetrics(JavaModelWrapper): >>> metrics = MulticlassMetrics(predictionAndLabelsWithProbabilities) >>> metrics.logLoss() 0.9682... - - .. versionadded:: 1.4.0 """ def __init__(self, predictionAndLabels): @@ -377,9 +391,15 @@ class RankingMetrics(JavaModelWrapper): """ Evaluator for ranking algorithms. - :param predictionAndLabels: an RDD of (predicted ranking, - ground truth set) pairs. + .. versionadded:: 1.4.0 + Parameters + ---------- + predictionAndLabels : :py:class:`pyspark.RDD` + an RDD of (predicted ranking, ground truth set) pairs. + + Examples + -------- >>> predictionAndLabels = sc.parallelize([ ... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]), ... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]), @@ -407,8 +427,6 @@ class RankingMetrics(JavaModelWrapper): 0.35... >>> metrics.recallAt(15) 0.66... - - .. versionadded:: 1.4.0 """ def __init__(self, predictionAndLabels): @@ -484,10 +502,16 @@ class MultilabelMetrics(JavaModelWrapper): """ Evaluator for multilabel classification. - :param predictionAndLabels: an RDD of (predictions, labels) pairs, - both are non-null Arrays, each with - unique elements. + .. versionadded:: 1.4.0 + + Parameters + ---------- + predictionAndLabels : :py:class:`pyspark.RDD` + an RDD of (predictions, labels) pairs, + both are non-null Arrays, each with unique elements. + Examples + -------- >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]), ... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]), ... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])]) @@ -516,8 +540,6 @@ class MultilabelMetrics(JavaModelWrapper): 0.28... >>> metrics.accuracy 0.54... - - .. versionadded:: 1.4.0 """ def __init__(self, predictionAndLabels): diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index d95f9197ea..1d37ab8156 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -41,7 +41,10 @@ def transform(self, vector): """ Applies transformation on a vector. - :param vector: vector to be transformed. + Parameters + ---------- + vector : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + vector or convertible or RDD to be transformed. """ raise NotImplementedError @@ -56,8 +59,15 @@ class Normalizer(VectorTransformer): For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization. - :param p: Normalization in L^p^ space, p = 2 by default. + .. versionadded:: 1.2.0 + + Parameters + ---------- + p : float, optional + Normalization in L^p^ space, p = 2 by default. + Examples + -------- >>> from pyspark.mllib.linalg import Vectors >>> v = Vectors.dense(range(3)) >>> nor = Normalizer(1) @@ -71,21 +81,27 @@ class Normalizer(VectorTransformer): >>> nor2 = Normalizer(float("inf")) >>> nor2.transform(v) DenseVector([0.0, 0.5, 1.0]) - - .. versionadded:: 1.2.0 """ def __init__(self, p=2.0): assert p >= 1.0, "p should be greater than 1.0" self.p = float(p) - @since('1.2.0') def transform(self, vector): """ Applies unit length normalization on a vector. - :param vector: vector or RDD of vector to be normalized. - :return: normalized vector. If the norm of the input is zero, it - will return the input vector. + .. versionadded:: 1.2.0 + + Parameters + ---------- + vector : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + vector or RDD of vector to be normalized. + + Returns + ------- + :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + normalized vector(s). If the norm of the input is zero, it + will return the input vector. """ if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) @@ -103,11 +119,16 @@ def transform(self, vector): """ Applies transformation on a vector or an RDD[Vector]. - .. note:: In Python, transform cannot currently be used within - an RDD transformation or action. - Call transform directly on the RDD instead. + Parameters + ---------- + vector : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + Input vector(s) to be transformed. - :param vector: Vector or RDD of Vector to be transformed. + Notes + ----- + In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. """ if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) @@ -123,19 +144,29 @@ class StandardScalerModel(JavaVectorTransformer): .. versionadded:: 1.2.0 """ - @since('1.2.0') def transform(self, vector): """ Applies standardization transformation on a vector. - .. note:: In Python, transform cannot currently be used within - an RDD transformation or action. - Call transform directly on the RDD instead. + .. versionadded:: 1.2.0 + + Parameters + ---------- + vector : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + Input vector(s) to be standardized. - :param vector: Vector or RDD of Vector to be standardized. - :return: Standardized vector. If the variance of a column is - zero, it will return default `0.0` for the column with - zero variance. + Returns + ------- + :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + Standardized vector(s). If the variance of a column is + zero, it will return default `0.0` for the column with + zero variance. + + Notes + ----- + In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. """ return JavaVectorTransformer.transform(self, vector) @@ -196,12 +227,20 @@ class StandardScaler(object): variance using column summary statistics on the samples in the training set. - :param withMean: False by default. Centers the data with mean - before scaling. It will build a dense output, so take - care when applying to sparse input. - :param withStd: True by default. Scales the data to unit - standard deviation. + .. versionadded:: 1.2.0 + Parameters + ---------- + withMean : bool, optional + False by default. Centers the data with mean + before scaling. It will build a dense output, so take + care when applying to sparse input. + withStd : bool, optional + True by default. Scales the data to unit + standard deviation. + + Examples + -------- >>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])] >>> dataset = sc.parallelize(vs) >>> standardizer = StandardScaler(True, True) @@ -218,8 +257,6 @@ class StandardScaler(object): True >>> model.withMean True - - .. versionadded:: 1.2.0 """ def __init__(self, withMean=False, withStd=True): if not (withMean or withStd): @@ -227,15 +264,22 @@ def __init__(self, withMean=False, withStd=True): self.withMean = withMean self.withStd = withStd - @since('1.2.0') def fit(self, dataset): """ Computes the mean and variance and stores as a model to be used for later scaling. - :param dataset: The data used to compute the mean and variance - to build the transformation model. - :return: a StandardScalarModel + .. versionadded:: 1.2.0 + + Parameters + ---------- + dataset : :py:class:`pyspark.RDD` + The data used to compute the mean and variance + to build the transformation model. + + Returns + ------- + :py:class:`StandardScalerModel` """ dataset = dataset.map(_convert_to_vector) jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, dataset) @@ -249,13 +293,21 @@ class ChiSqSelectorModel(JavaVectorTransformer): .. versionadded:: 1.4.0 """ - @since('1.4.0') def transform(self, vector): """ Applies transformation on a vector. - :param vector: Vector or RDD of Vector to be transformed. - :return: transformed vector. + .. versionadded:: 1.4.0 + + Examples + -------- + vector : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + Input vector(s) to be transformed. + + Returns + ------- + :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + transformed vector(s). """ return JavaVectorTransformer.transform(self, vector) @@ -284,6 +336,10 @@ class ChiSqSelector(object): By default, the selection method is `numTopFeatures`, with the default number of top features set to 50. + .. versionadded:: 1.4.0 + + Examples + -------- >>> from pyspark.mllib.linalg import SparseVector, DenseVector >>> from pyspark.mllib.regression import LabeledPoint >>> data = sc.parallelize([ @@ -306,8 +362,6 @@ class ChiSqSelector(object): >>> model = ChiSqSelector(selectorType="percentile", percentile=0.34).fit(data) >>> model.transform(DenseVector([7.0, 9.0, 5.0])) DenseVector([7.0]) - - .. versionadded:: 1.4.0 """ def __init__(self, numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1, fpr=0.05, fdr=0.05, fwe=0.05): @@ -372,15 +426,18 @@ def setSelectorType(self, selectorType): self.selectorType = str(selectorType) return self - @since('1.4.0') def fit(self, data): """ Returns a ChiSquared feature selector. - :param data: an `RDD[LabeledPoint]` containing the labeled dataset - with categorical features. Real-valued features will be - treated as categorical for each distinct value. - Apply feature discretizer before using this function. + .. versionadded:: 1.4.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` of :py:class:`pyspark.mllib.regression.LabeledPoint` + containing the labeled dataset with categorical features. + Real-valued features will be treated as categorical for each + distinct value. Apply feature discretizer before using this function. """ jmodel = callMLlibFunc("fitChiSqSelector", self.selectorType, self.numTopFeatures, self.percentile, self.fpr, self.fdr, self.fwe, data) @@ -399,6 +456,10 @@ class PCA(object): """ A feature transformer that projects vectors to a low-dimensional space using PCA. + .. versionadded:: 1.5.0 + + Examples + -------- >>> data = [Vectors.sparse(5, [(1, 1.0), (3, 7.0)]), ... Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]), ... Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0])] @@ -408,20 +469,26 @@ class PCA(object): 1.648... >>> pcArray[1] -4.013... - - .. versionadded:: 1.5.0 """ def __init__(self, k): """ - :param k: number of principal components. + Parameters + ---------- + k : int + number of principal components. """ self.k = int(k) - @since('1.5.0') def fit(self, data): """ Computes a [[PCAModel]] that contains the principal components of the input vectors. - :param data: source vectors + + .. versionadded:: 1.5.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + source vectors """ jmodel = callMLlibFunc("fitPCA", self.k, data) return PCAModel(jmodel) @@ -432,16 +499,23 @@ class HashingTF(object): Maps a sequence of terms to their term frequencies using the hashing trick. - .. note:: The terms must be hashable (can not be dict/set/list...). + .. versionadded:: 1.2.0 + + Parameters + ---------- + numFeatures : int, optional + number of features (default: 2^20) - :param numFeatures: number of features (default: 2^20) + Notes + ----- + The terms must be hashable (can not be dict/set/list...). + Examples + -------- >>> htf = HashingTF(100) >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) SparseVector(100, {...}) - - .. versionadded:: 1.2.0 """ def __init__(self, numFeatures=1 << 20): self.numFeatures = numFeatures @@ -485,7 +559,7 @@ class IDFModel(JavaVectorTransformer): .. versionadded:: 1.2.0 """ - @since('1.2.0') + def transform(self, x): """ Transforms term frequency (TF) vectors to TF-IDF vectors. @@ -494,13 +568,24 @@ def transform(self, x): the terms which occur in fewer than `minDocFreq` documents will have an entry of 0. - .. note:: In Python, transform cannot currently be used within - an RDD transformation or action. - Call transform directly on the RDD instead. + .. versionadded:: 1.2.0 + + Parameters + ---------- + x : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + an RDD of term frequency vectors or a term frequency + vector - :param x: an RDD of term frequency vectors or a term frequency - vector - :return: an RDD of TF-IDF vectors or a TF-IDF vector + Returns + ------- + :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + an RDD of TF-IDF vectors or a TF-IDF vector + + Notes + ----- + In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. """ return JavaVectorTransformer.transform(self, x) @@ -539,9 +624,15 @@ class IDF(object): `minDocFreq`). For terms that are not in at least `minDocFreq` documents, the IDF is found as 0, resulting in TF-IDFs of 0. - :param minDocFreq: minimum of documents in which a term - should appear for filtering + .. versionadded:: 1.2.0 + + Parameters + ---------- + minDocFreq : int + minimum of documents in which a term should appear for filtering + Examples + -------- >>> n = 4 >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)), ... Vectors.dense([0.0, 1.0, 2.0, 3.0]), @@ -560,18 +651,20 @@ class IDF(object): DenseVector([0.0, 0.0, 1.3863, 0.863]) >>> model.transform(Vectors.sparse(n, (1, 3), (1.0, 2.0))) SparseVector(4, {1: 0.0, 3: 0.5754}) - - .. versionadded:: 1.2.0 """ def __init__(self, minDocFreq=0): self.minDocFreq = minDocFreq - @since('1.2.0') def fit(self, dataset): """ Computes the inverse document frequency. - :param dataset: an RDD of term frequency vectors + .. versionadded:: 1.2.0 + + Parameters + ---------- + dataset : :py:class:`pyspark.RDD` + an RDD of term frequency vectors """ if not isinstance(dataset, RDD): raise TypeError("dataset should be an RDD of term frequency vectors") @@ -582,34 +675,55 @@ def fit(self, dataset): class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader): """ class for Word2Vec model - - .. versionadded:: 1.2.0 """ - @since('1.2.0') + def transform(self, word): """ Transforms a word to its vector representation - .. note:: Local use only + .. versionadded:: 1.2.0 + + Parameters + ---------- + word : str + a word - :param word: a word - :return: vector representation of word(s) + Returns + ------- + :py:class:`pyspark.mllib.linalg.Vector` + vector representation of word(s) + + Notes + ----- + Local use only """ try: return self.call("transform", word) except Py4JJavaError: raise ValueError("%s not found" % word) - @since('1.2.0') def findSynonyms(self, word, num): """ Find synonyms of a word - :param word: a word or a vector representation of word - :param num: number of synonyms to find - :return: array of (word, cosineSimilarity) + .. versionadded:: 1.2.0 + + Parameters + ---------- + + word : str or :py:class:`pyspark.mllib.linalg.Vector` + a word or a vector representation of word + num : int + number of synonyms to find + + Returns + ------- + :py:class:`collections.abc.Iterable` + array of (word, cosineSimilarity) - .. note:: Local use only + Notes + ----- + Local use only """ if not isinstance(word, str): word = _convert_to_vector(word) @@ -653,6 +767,10 @@ class Word2Vec(object): and Distributed Representations of Words and Phrases and their Compositionality. + .. versionadded:: 1.2.0 + + Examples + -------- >>> sentence = "a b " * 100 + "a c " * 10 >>> localDoc = [sentence, sentence] >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) @@ -686,9 +804,6 @@ class Word2Vec(object): ... rmtree(path) ... except OSError: ... pass - - .. versionadded:: 1.2.0 - """ def __init__(self): """ @@ -761,13 +876,20 @@ def setWindowSize(self, windowSize): self.windowSize = windowSize return self - @since('1.2.0') def fit(self, data): """ Computes the vector representation of each word in vocabulary. - :param data: training data. RDD of list of string - :return: Word2VecModel instance + .. versionadded:: 1.2.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + training data. RDD of list of string + + Returns + ------- + :py:class:`Word2VecModel` """ if not isinstance(data, RDD): raise TypeError("data should be an RDD of list of string") @@ -783,6 +905,10 @@ class ElementwiseProduct(VectorTransformer): Scales each column of the vector, with the supplied weight vector. i.e the elementwise product. + .. versionadded:: 1.5.0 + + Examples + -------- >>> weight = Vectors.dense([1.0, 2.0, 3.0]) >>> eprod = ElementwiseProduct(weight) >>> a = Vectors.dense([2.0, 1.0, 3.0]) @@ -792,8 +918,6 @@ class ElementwiseProduct(VectorTransformer): >>> rdd = sc.parallelize([a, b]) >>> eprod.transform(rdd).collect() [DenseVector([2.0, 2.0, 9.0]), DenseVector([9.0, 6.0, 12.0])] - - .. versionadded:: 1.5.0 """ def __init__(self, scalingVector): self.scalingVector = _convert_to_vector(scalingVector) diff --git a/python/pyspark/mllib/feature.pyi b/python/pyspark/mllib/feature.pyi index 9ccec36abd..24a46f6bee 100644 --- a/python/pyspark/mllib/feature.pyi +++ b/python/pyspark/mllib/feature.pyi @@ -17,7 +17,7 @@ # under the License. from typing import overload -from typing import Iterable, Hashable, List, Tuple +from typing import Iterable, Hashable, List, Tuple, Union from pyspark.mllib._typing import VectorLike from pyspark.context import SparkContext @@ -135,7 +135,7 @@ class IDF: class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader[Word2VecModel]): def transform(self, word: str) -> Vector: ... # type: ignore - def findSynonyms(self, word: str, num: int) -> Iterable[Tuple[str, float]]: ... + def findSynonyms(self, word: Union[str, VectorLike], num: int) -> Iterable[Tuple[str, float]]: ... def getVectors(self) -> JavaMap: ... @classmethod def load(cls, sc: SparkContext, path: str) -> Word2VecModel: ... diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index cbbd7b351b..1f87a15cb1 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -32,6 +32,10 @@ class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader): A FP-Growth model for mining frequent itemsets using the Parallel FP-Growth algorithm. + .. versionadded:: 1.4.0 + + Examples + -------- >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] >>> rdd = sc.parallelize(data, 2) >>> model = FPGrowth.train(rdd, 0.6, 2) @@ -42,8 +46,6 @@ class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> sameModel = FPGrowthModel.load(sc, model_path) >>> sorted(model.freqItemsets().collect()) == sorted(sameModel.freqItemsets().collect()) True - - .. versionadded:: 1.4.0 """ @since("1.4.0") @@ -72,20 +74,23 @@ class FPGrowth(object): """ @classmethod - @since("1.4.0") def train(cls, data, minSupport=0.3, numPartitions=-1): """ Computes an FP-Growth model that contains frequent itemsets. - :param data: - The input data set, each element contains a transaction. - :param minSupport: - The minimal support level. - (default: 0.3) - :param numPartitions: - The number of partitions used by parallel FP-growth. A value - of -1 will use the same number as input data. - (default: -1) + .. versionadded:: 1.4.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + The input data set, each element contains a transaction. + minSupport : float, optional + The minimal support level. + (default: 0.3) + numPartitions : int, optional + The number of partitions used by parallel FP-growth. A value + of -1 will use the same number as input data. + (default: -1) """ model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) return FPGrowthModel(model) @@ -103,6 +108,10 @@ class PrefixSpanModel(JavaModelWrapper): """ Model fitted by PrefixSpan + .. versionadded:: 1.6.0 + + Examples + -------- >>> data = [ ... [["a", "b"], ["c"]], ... [["a"], ["c", "b"], ["a", "b"]], @@ -112,8 +121,6 @@ class PrefixSpanModel(JavaModelWrapper): >>> model = PrefixSpan.train(rdd) >>> sorted(model.freqSequences().collect()) [FreqSequence(sequence=[['a']], freq=3), FreqSequence(sequence=[['a'], ['a']], freq=1), ... - - .. versionadded:: 1.6.0 """ @since("1.6.0") @@ -125,38 +132,45 @@ def freqSequences(self): class PrefixSpan(object): """ A parallel PrefixSpan algorithm to mine frequent sequential patterns. - The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: - Mining Sequential Patterns Efficiently by Prefix-Projected Pattern Growth - ([[https://doi.org/10.1109/ICDE.2001.914830]]). + The PrefixSpan algorithm is described in Jian Pei et al (2001) [1]_ .. versionadded:: 1.6.0 + + .. [1] Jian Pei et al., + "PrefixSpan,: mining sequential patterns efficiently by prefix-projected pattern growth," + Proceedings 17th International Conference on Data Engineering, Heidelberg, + Germany, 2001, pp. 215-224, + doi: https://doi.org/10.1109/ICDE.2001.914830 """ @classmethod - @since("1.6.0") def train(cls, data, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000): """ Finds the complete set of frequent sequential patterns in the input sequences of itemsets. - :param data: - The input data set, each element contains a sequence of - itemsets. - :param minSupport: - The minimal support level of the sequential pattern, any - pattern that appears more than (minSupport * - size-of-the-dataset) times will be output. - (default: 0.1) - :param maxPatternLength: - The maximal length of the sequential pattern, any pattern - that appears less than maxPatternLength will be output. - (default: 10) - :param maxLocalProjDBSize: - The maximum number of items (including delimiters used in the - internal storage format) allowed in a projected database before - local processing. If a projected database exceeds this size, - another iteration of distributed prefix growth is run. - (default: 32000000) + .. versionadded:: 1.6.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + The input data set, each element contains a sequence of + itemsets. + minSupport : float, optional + The minimal support level of the sequential pattern, any + pattern that appears more than (minSupport * + size-of-the-dataset) times will be output. + (default: 0.1) + maxPatternLength : int, optional + The maximal length of the sequential pattern, any pattern + that appears less than maxPatternLength will be output. + (default: 10) + maxLocalProjDBSize : int, optional + The maximum number of items (including delimiters used in the + internal storage format) allowed in a projected database before + local processing. If a projected database exceeds this size, + another iteration of distributed prefix growth is run. + (default: 32000000) """ model = callMLlibFunc("trainPrefixSpanModel", data, minSupport, maxPatternLength, maxLocalProjDBSize) diff --git a/python/pyspark/mllib/fpm.pyi b/python/pyspark/mllib/fpm.pyi index 880baae1a9..c5a6b5f680 100644 --- a/python/pyspark/mllib/fpm.pyi +++ b/python/pyspark/mllib/fpm.pyi @@ -37,8 +37,8 @@ class FPGrowth: cls, data: RDD[List[T]], minSupport: float = ..., numPartitions: int = ... ) -> FPGrowthModel[T]: ... class FreqItemset(Generic[T]): - items = ... # List[T] - freq = ... # int + items: List[T] + freq: int class PrefixSpanModel(JavaModelWrapper, Generic[T]): def freqSequences(self) -> RDD[PrefixSpan.FreqSequence[T]]: ... diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index c1402fb98a..f20004ab70 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -71,6 +71,8 @@ def _vector_size(v): """ Returns the size of the vector. + Examples + -------- >>> _vector_size([1., 2., 3.]) 3 >>> _vector_size((1., 2., 3.)) @@ -231,7 +233,9 @@ def toArray(self): """ Convert the vector into an numpy.ndarray - :return: numpy.ndarray + Returns + ------- + :py:class:`numpy.ndarray` """ raise NotImplementedError @@ -240,7 +244,9 @@ def asML(self): Convert this vector to the new mllib-local representation. This does NOT copy the data; it copies references. - :return: :py:class:`pyspark.ml.linalg.Vector` + Returns + ------- + :py:class:`pyspark.ml.linalg.Vector` """ raise NotImplementedError @@ -251,6 +257,8 @@ class DenseVector(Vector): storage and arithmetics will be delegated to the underlying numpy array. + Examples + -------- >>> v = Vectors.dense([1.0, 2.0]) >>> u = Vectors.dense([3.0, 4.0]) >>> v + u @@ -282,6 +290,8 @@ def parse(s): """ Parse string representation back into the DenseVector. + Examples + -------- >>> DenseVector.parse(' [ 0.0,1.0,2.0, 3.0]') DenseVector([0.0, 1.0, 2.0, 3.0]) """ @@ -312,6 +322,8 @@ def norm(self, p): """ Calculates the norm of a DenseVector. + Examples + -------- >>> a = DenseVector([0, -1, 2, -3]) >>> a.norm(2) 3.7... @@ -327,6 +339,8 @@ def dot(self, other): and a target NumPy array that is either 1- or 2-dimensional. Equivalent to calling numpy.dot of the two vectors. + Examples + -------- >>> dense = DenseVector(array.array('d', [1., 2.])) >>> dense.dot(dense) 5.0 @@ -367,6 +381,8 @@ def squared_distance(self, other): """ Squared distance of two Vectors. + Examples + -------- >>> dense1 = DenseVector(array.array('d', [1., 2.])) >>> dense1.squared_distance(dense1) 0.0 @@ -412,9 +428,11 @@ def asML(self): Convert this vector to the new mllib-local representation. This does NOT copy the data; it copies references. - :return: :py:class:`pyspark.ml.linalg.DenseVector` - .. versionadded:: 2.0.0 + + Returns + ------- + :py:class:`pyspark.ml.linalg.DenseVector` """ return newlinalg.DenseVector(self.array) @@ -501,12 +519,18 @@ def __init__(self, size, *args): (index, value) pairs, or two separate arrays of indices and values (sorted by index). - :param size: Size of the vector. - :param args: Active entries, as a dictionary {index: value, ...}, - a list of tuples [(index, value), ...], or a list of strictly - increasing indices and a list of corresponding values [index, ...], - [value, ...]. Inactive entries are treated as zeros. - + Parameters + ---------- + size : int + Size of the vector. + args + Active entries, as a dictionary {index: value, ...}, + a list of tuples [(index, value), ...], or a list of strictly + increasing indices and a list of corresponding values [index, ...], + [value, ...]. Inactive entries are treated as zeros. + + Examples + -------- >>> SparseVector(4, {1: 1.0, 3: 5.5}) SparseVector(4, {1: 1.0, 3: 5.5}) >>> SparseVector(4, [(1, 1.0), (3, 5.5)]) @@ -556,6 +580,8 @@ def norm(self, p): """ Calculates the norm of a SparseVector. + Examples + -------- >>> a = SparseVector(4, [0, 1], [3., -4.]) >>> a.norm(1) 7.0 @@ -574,6 +600,8 @@ def parse(s): """ Parse string representation back into the SparseVector. + Examples + -------- >>> SparseVector.parse(' (4, [0,1 ],[ 4.0,5.0] )') SparseVector(4, {0: 4.0, 1: 5.0}) """ @@ -622,6 +650,8 @@ def dot(self, other): """ Dot product with a SparseVector or 1- or 2-dimensional Numpy array. + Examples + -------- >>> a = SparseVector(4, [1, 3], [3.0, 4.0]) >>> a.dot(a) 25.0 @@ -678,6 +708,8 @@ def squared_distance(self, other): """ Squared distance from a SparseVector or 1-dimensional NumPy array. + Examples + -------- >>> a = SparseVector(4, [1, 3], [3.0, 4.0]) >>> a.squared_distance(a) 0.0 @@ -754,9 +786,11 @@ def asML(self): Convert this vector to the new mllib-local representation. This does NOT copy the data; it copies references. - :return: :py:class:`pyspark.ml.linalg.SparseVector` - .. versionadded:: 2.0.0 + + Returns + ------- + :py:class:`pyspark.ml.linalg.SparseVector` """ return newlinalg.SparseVector(self.size, self.indices, self.values) @@ -828,10 +862,12 @@ class Vectors(object): """ Factory methods for working with vectors. - .. note:: Dense vectors are simply represented as NumPy array objects, - so there is no need to covert them for use in MLlib. For sparse vectors, - the factory methods in this class create an MLlib-compatible type, or users - can pass in SciPy's `scipy.sparse` column vectors. + Notes + ----- + Dense vectors are simply represented as NumPy array objects, + so there is no need to covert them for use in MLlib. For sparse vectors, + the factory methods in this class create an MLlib-compatible type, or users + can pass in SciPy's `scipy.sparse` column vectors. """ @staticmethod @@ -841,10 +877,16 @@ def sparse(size, *args): (index, value) pairs, or two separate arrays of indices and values (sorted by index). - :param size: Size of the vector. - :param args: Non-zero entries, as a dictionary, list of tuples, - or two sorted lists containing indices and values. + Parameters + ---------- + size : int + Size of the vector. + args + Non-zero entries, as a dictionary, list of tuples, + or two sorted lists containing indices and values. + Examples + -------- >>> Vectors.sparse(4, {1: 1.0, 3: 5.5}) SparseVector(4, {1: 1.0, 3: 5.5}) >>> Vectors.sparse(4, [(1, 1.0), (3, 5.5)]) @@ -859,6 +901,8 @@ def dense(*elements): """ Create a dense vector of 64-bit floats from a Python list or numbers. + Examples + -------- >>> Vectors.dense([1, 2, 3]) DenseVector([1.0, 2.0, 3.0]) >>> Vectors.dense(1.0, 2.0) @@ -875,10 +919,15 @@ def fromML(vec): Convert a vector from the new mllib-local representation. This does NOT copy the data; it copies references. - :param vec: a :py:class:`pyspark.ml.linalg.Vector` - :return: a :py:class:`pyspark.mllib.linalg.Vector` - .. versionadded:: 2.0.0 + + Parameters + ---------- + vec : :py:class:`pyspark.ml.linalg.Vector` + + Returns + ------- + :py:class:`pyspark.mllib.linalg.Vector` """ if isinstance(vec, newlinalg.DenseVector): return DenseVector(vec.array) @@ -893,6 +942,8 @@ def stringify(vector): Converts a vector into a string, which can be recognized by Vectors.parse(). + Examples + -------- >>> Vectors.stringify(Vectors.sparse(2, [1], [1.0])) '(2,[1],[1.0])' >>> Vectors.stringify(Vectors.dense([0.0, 1.0])) @@ -907,6 +958,8 @@ def squared_distance(v1, v2): a and b can be of type SparseVector, DenseVector, np.ndarray or array.array. + Examples + -------- >>> a = Vectors.sparse(4, [(0, 1), (3, 4)]) >>> b = Vectors.dense([2, 5, 4, 1]) >>> a.squared_distance(b) @@ -926,6 +979,8 @@ def norm(vector, p): def parse(s): """Parse a string representation back into the Vector. + Examples + -------- >>> Vectors.parse('[2,1,2 ]') DenseVector([2.0, 1.0, 2.0]) >>> Vectors.parse(' ( 100, [0], [2])') @@ -1023,6 +1078,8 @@ def __str__(self): """ Pretty printing of a DenseMatrix + Examples + -------- >>> dm = DenseMatrix(2, 2, range(4)) >>> print(dm) DenseMatrix([[ 0., 2.], @@ -1044,6 +1101,8 @@ def __repr__(self): """ Representation of a DenseMatrix + Examples + -------- >>> dm = DenseMatrix(2, 2, range(4)) >>> dm DenseMatrix(2, 2, [0.0, 1.0, 2.0, 3.0], False) @@ -1067,6 +1126,8 @@ def toArray(self): """ Return an numpy.ndarray + Examples + -------- >>> m = DenseMatrix(2, 2, range(4)) >>> m.toArray() array([[ 0., 2.], @@ -1098,9 +1159,11 @@ def asML(self): Convert this matrix to the new mllib-local representation. This does NOT copy the data; it copies references. - :return: :py:class:`pyspark.ml.linalg.DenseMatrix` - .. versionadded:: 2.0.0 + + Returns + ------- + :py:class:`pyspark.ml.linalg.DenseMatrix` """ return newlinalg.DenseMatrix(self.numRows, self.numCols, self.values, self.isTransposed) @@ -1154,6 +1217,8 @@ def __str__(self): """ Pretty printing of a SparseMatrix + Examples + -------- >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) >>> print(sm1) 2 X 2 CSCMatrix @@ -1200,6 +1265,8 @@ def __repr__(self): """ Representation of a SparseMatrix + Examples + -------- >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) >>> sm1 SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2.0, 3.0, 4.0], False) @@ -1281,9 +1348,11 @@ def asML(self): Convert this matrix to the new mllib-local representation. This does NOT copy the data; it copies references. - :return: :py:class:`pyspark.ml.linalg.SparseMatrix` - .. versionadded:: 2.0.0 + + Returns + ------- + :py:class:`pyspark.ml.linalg.SparseMatrix` """ return newlinalg.SparseMatrix(self.numRows, self.numCols, self.colPtrs, self.rowIndices, self.values, self.isTransposed) @@ -1314,10 +1383,15 @@ def fromML(mat): Convert a matrix from the new mllib-local representation. This does NOT copy the data; it copies references. - :param mat: a :py:class:`pyspark.ml.linalg.Matrix` - :return: a :py:class:`pyspark.mllib.linalg.Matrix` - .. versionadded:: 2.0.0 + + Parameters + ---------- + mat : :py:class:`pyspark.ml.linalg.Matrix` + + Returns + ------- + :py:class:`pyspark.mllib.linalg.Matrix` """ if isinstance(mat, newlinalg.DenseMatrix): return DenseMatrix(mat.numRows, mat.numCols, mat.values, mat.isTransposed) diff --git a/python/pyspark/mllib/linalg/__init__.pyi b/python/pyspark/mllib/linalg/__init__.pyi index c0719c535c..60d16b26f3 100644 --- a/python/pyspark/mllib/linalg/__init__.pyi +++ b/python/pyspark/mllib/linalg/__init__.pyi @@ -17,7 +17,18 @@ # under the License. from typing import overload -from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Dict, + Generic, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from pyspark.ml import linalg as newlinalg from pyspark.sql.types import StructType, UserDefinedType from numpy import float64, ndarray # type: ignore[import] @@ -46,7 +57,7 @@ class MatrixUDT(UserDefinedType): @classmethod def scalaUDT(cls) -> str: ... def serialize( - self, obj + self, obj: Matrix ) -> Tuple[ int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool ]: ... @@ -67,8 +78,8 @@ class DenseVector(Vector): @overload def __init__(self, __arr: Iterable[float]) -> None: ... @staticmethod - def parse(s) -> DenseVector: ... - def __reduce__(self) -> Tuple[type, bytes]: ... + def parse(s: str) -> DenseVector: ... + def __reduce__(self) -> Tuple[Type[DenseVector], bytes]: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... def dot(self, other: Iterable[float]) -> float64: ... @@ -115,7 +126,7 @@ class SparseVector(Vector): def __init__(self, size: int, __map: Dict[int, float]) -> None: ... def numNonzeros(self) -> int: ... def norm(self, p: Union[float, str]) -> float64: ... - def __reduce__(self): ... + def __reduce__(self) -> Tuple[Type[SparseVector], Tuple[int, bytes, bytes]]: ... @staticmethod def parse(s: str) -> SparseVector: ... def dot(self, other: Iterable[float]) -> float64: ... @@ -123,9 +134,9 @@ class SparseVector(Vector): def toArray(self) -> ndarray: ... def asML(self) -> newlinalg.SparseVector: ... def __len__(self) -> int: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... def __getitem__(self, index: int) -> float64: ... - def __ne__(self, other) -> bool: ... + def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... class Vectors: @@ -148,13 +159,13 @@ class Vectors: def sparse(size: int, __map: Dict[int, float]) -> SparseVector: ... @overload @staticmethod - def dense(self, *elements: float) -> DenseVector: ... + def dense(*elements: float) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: bytes) -> DenseVector: ... + def dense(__arr: bytes) -> DenseVector: ... @overload @staticmethod - def dense(self, __arr: Iterable[float]) -> DenseVector: ... + def dense(__arr: Iterable[float]) -> DenseVector: ... @staticmethod def fromML(vec: newlinalg.DenseVector) -> DenseVector: ... @staticmethod @@ -176,8 +187,8 @@ class Matrix: def __init__( self, numRows: int, numCols: int, isTransposed: bool = ... ) -> None: ... - def toArray(self): ... - def asML(self): ... + def toArray(self) -> ndarray: ... + def asML(self) -> newlinalg.Matrix: ... class DenseMatrix(Matrix): values: Any @@ -193,12 +204,12 @@ class DenseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, int]]: ... + def __reduce__(self) -> Tuple[Type[DenseMatrix], Tuple[int, int, bytes, int]]: ... def toArray(self) -> ndarray: ... def toSparse(self) -> SparseMatrix: ... def asML(self) -> newlinalg.DenseMatrix: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class SparseMatrix(Matrix): colPtrs: ndarray @@ -224,12 +235,14 @@ class SparseMatrix(Matrix): values: Iterable[float], isTransposed: bool = ..., ) -> None: ... - def __reduce__(self) -> Tuple[type, Tuple[int, int, bytes, bytes, bytes, int]]: ... + def __reduce__( + self, + ) -> Tuple[Type[SparseMatrix], Tuple[int, int, bytes, bytes, bytes, int]]: ... def __getitem__(self, indices: Tuple[int, int]) -> float64: ... def toArray(self) -> ndarray: ... def toDense(self) -> DenseMatrix: ... def asML(self) -> newlinalg.SparseMatrix: ... - def __eq__(self, other) -> bool: ... + def __eq__(self, other: Any) -> bool: ... class Matrices: @overload diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 603d31d3d7..f0e889b15b 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -55,16 +55,22 @@ class RowMatrix(DistributedMatrix): Represents a row-oriented distributed Matrix with no meaningful row indices. - :param rows: An RDD or DataFrame of vectors. If a DataFrame is provided, it must have a single - vector typed column. - :param numRows: Number of rows in the matrix. A non-positive - value means unknown, at which point the number - of rows will be determined by the number of - records in the `rows` RDD. - :param numCols: Number of columns in the matrix. A non-positive - value means unknown, at which point the number - of columns will be determined by the size of - the first row. + + Parameters + ---------- + rows : :py:class:`pyspark.RDD` or :py:class:`pyspark.sql.DataFrame` + An RDD or DataFrame of vectors. If a DataFrame is provided, it must have a single + vector typed column. + numRows : int, optional + Number of rows in the matrix. A non-positive + value means unknown, at which point the number + of rows will be determined by the number of + records in the `rows` RDD. + numCols : int, optional + Number of columns in the matrix. A non-positive + value means unknown, at which point the number + of columns will be determined by the size of + the first row. """ def __init__(self, rows, numRows=0, numCols=0): """ @@ -77,6 +83,8 @@ def __init__(self, rows, numRows=0, numCols=0): object, in which case we can wrap it directly. This assists in clean matrix conversions. + Examples + -------- >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6]]) >>> mat = RowMatrix(rows) @@ -108,6 +116,8 @@ def rows(self): """ Rows of the RowMatrix stored as an RDD of vectors. + Examples + -------- >>> mat = RowMatrix(sc.parallelize([[1, 2, 3], [4, 5, 6]])) >>> rows = mat.rows >>> rows.first() @@ -119,6 +129,8 @@ def numRows(self): """ Get or compute the number of rows. + Examples + -------- >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6], ... [7, 8, 9], [10, 11, 12]]) @@ -136,6 +148,8 @@ def numCols(self): """ Get or compute the number of cols. + Examples + -------- >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6], ... [7, 8, 9], [10, 11, 12]]) @@ -149,14 +163,19 @@ def numCols(self): """ return self._java_matrix_wrapper.call("numCols") - @since('2.0.0') def computeColumnSummaryStatistics(self): """ Computes column-wise summary statistics. - :return: :class:`MultivariateStatisticalSummary` object - containing column-wise summary statistics. + .. versionadded:: 2.0.0 + + Returns + ------- + :py:class:`MultivariateStatisticalSummary` + object containing column-wise summary statistics. + Examples + -------- >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6]]) >>> mat = RowMatrix(rows) @@ -167,14 +186,19 @@ def computeColumnSummaryStatistics(self): java_col_stats = self._java_matrix_wrapper.call("computeColumnSummaryStatistics") return MultivariateStatisticalSummary(java_col_stats) - @since('2.0.0') def computeCovariance(self): """ Computes the covariance matrix, treating each row as an observation. - .. note:: This cannot be computed on matrices with more than 65535 columns. + .. versionadded:: 2.0.0 + + Notes + ----- + This cannot be computed on matrices with more than 65535 columns. + Examples + -------- >>> rows = sc.parallelize([[1, 2], [2, 1]]) >>> mat = RowMatrix(rows) @@ -183,13 +207,18 @@ def computeCovariance(self): """ return self._java_matrix_wrapper.call("computeCovariance") - @since('2.0.0') def computeGramianMatrix(self): """ Computes the Gramian matrix `A^T A`. - .. note:: This cannot be computed on matrices with more than 65535 columns. + .. versionadded:: 2.0.0 + Notes + ----- + This cannot be computed on matrices with more than 65535 columns. + + Examples + -------- >>> rows = sc.parallelize([[1, 2, 3], [4, 5, 6]]) >>> mat = RowMatrix(rows) @@ -220,11 +249,12 @@ def columnSimilarities(self, threshold=0.0): similarity threshold. To describe the guarantee, we set some notation: - * Let A be the smallest in magnitude non-zero element of - this matrix. - * Let B be the largest in magnitude non-zero element of - this matrix. - * Let L be the maximum number of non-zeros per row. + + - Let A be the smallest in magnitude non-zero element of + this matrix. + - Let B be the largest in magnitude non-zero element of + this matrix. + - Let L be the maximum number of non-zeros per row. For example, for {0,1} matrices: A=B=1. Another example, for the Netflix matrix: A=1, B=5 @@ -236,20 +266,31 @@ def columnSimilarities(self, threshold=0.0): The shuffle size is bounded by the *smaller* of the following two expressions: - * O(n log(n) L / (threshold * A)) - * O(m L^2^) + - O(n log(n) L / (threshold * A)) + - O(m L^2^) The latter is the cost of the brute-force approach, so for non-zero thresholds, the cost is always cheaper than the brute-force approach. - :param: threshold: Set to 0 for deterministic guaranteed - correctness. Similarities above this - threshold are estimated with the cost vs - estimate quality trade-off described above. - :return: An n x n sparse upper-triangular CoordinateMatrix of - cosine similarities between columns of this matrix. + .. versionadded:: 2.0.0 + + Parameters + ---------- + threshold : float, optional + Set to 0 for deterministic guaranteed + correctness. Similarities above this + threshold are estimated with the cost vs + estimate quality trade-off described above. + Returns + ------- + :py:class:`CoordinateMatrix` + An n x n sparse upper-triangular CoordinateMatrix of + cosine similarities between columns of this matrix. + + Examples + -------- >>> rows = sc.parallelize([[1, 2], [1, 5]]) >>> mat = RowMatrix(rows) @@ -260,23 +301,32 @@ def columnSimilarities(self, threshold=0.0): java_sims_mat = self._java_matrix_wrapper.call("columnSimilarities", float(threshold)) return CoordinateMatrix(java_sims_mat) - @since('2.0.0') def tallSkinnyQR(self, computeQ=False): """ Compute the QR decomposition of this RowMatrix. The implementation is designed to optimize the QR decomposition - (factorization) for the RowMatrix of a tall and skinny shape. + (factorization) for the RowMatrix of a tall and skinny shape [1]_. - Reference: - Paul G. Constantine, David F. Gleich. "Tall and skinny QR - factorizations in MapReduce architectures" - ([[https://doi.org/10.1145/1996092.1996103]]) + .. [1] Paul G. Constantine, David F. Gleich. "Tall and skinny QR + factorizations in MapReduce architectures" + https://doi.org/10.1145/1996092.1996103 - :param: computeQ: whether to computeQ - :return: QRDecomposition(Q: RowMatrix, R: Matrix), where - Q = None if computeQ = false. + .. versionadded:: 2.0.0 + Parameters + ---------- + computeQ : bool, optional + whether to computeQ + + Returns + ------- + :py:class:`pyspark.mllib.linalg.QRDecomposition` + QRDecomposition(Q: RowMatrix, R: Matrix), where + Q = None if computeQ = false. + + Examples + -------- >>> rows = sc.parallelize([[3, -6], [4, -8], [0, 1]]) >>> mat = RowMatrix(rows) >>> decomp = mat.tallSkinnyQR(True) @@ -301,7 +351,6 @@ def tallSkinnyQR(self, computeQ=False): R = decomp.call("R") return QRDecomposition(Q, R) - @since('2.2.0') def computeSVD(self, k, computeU=False, rCond=1e-9): """ Computes the singular value decomposition of the RowMatrix. @@ -309,27 +358,39 @@ def computeSVD(self, k, computeU=False, rCond=1e-9): The given row matrix A of dimension (m X n) is decomposed into U * s * V'T where - * U: (m X k) (left singular vectors) is a RowMatrix whose - columns are the eigenvectors of (A X A') - * s: DenseVector consisting of square root of the eigenvalues - (singular values) in descending order. - * v: (n X k) (right singular vectors) is a Matrix whose columns - are the eigenvectors of (A' X A) + - U: (m X k) (left singular vectors) is a RowMatrix whose + columns are the eigenvectors of (A X A') + - s: DenseVector consisting of square root of the eigenvalues + (singular values) in descending order. + - v: (n X k) (right singular vectors) is a Matrix whose columns + are the eigenvectors of (A' X A) For more specific details on implementation, please refer the Scala documentation. - :param k: Number of leading singular values to keep (`0 < k <= n`). - It might return less than k if there are numerically zero singular values - or there are not enough Ritz values converged before the maximum number of - Arnoldi update iterations is reached (in case that matrix A is ill-conditioned). - :param computeU: Whether or not to compute U. If set to be - True, then U is computed by A * V * s^-1 - :param rCond: Reciprocal condition number. All singular values - smaller than rCond * s[0] are treated as zero - where s[0] is the largest singular value. - :returns: :py:class:`SingularValueDecomposition` - + .. versionadded:: 2.2.0 + + Parameters + ---------- + k : int + Number of leading singular values to keep (`0 < k <= n`). + It might return less than k if there are numerically zero singular values + or there are not enough Ritz values converged before the maximum number of + Arnoldi update iterations is reached (in case that matrix A is ill-conditioned). + computeU : bool, optional + Whether or not to compute U. If set to be + True, then U is computed by A * V * s^-1 + rCond : float, optional + Reciprocal condition number. All singular values + smaller than rCond * s[0] are treated as zero + where s[0] is the largest singular value. + + Returns + ------- + :py:class:`SingularValueDecomposition` + + Examples + -------- >>> rows = sc.parallelize([[3, 1, 1], [-1, 3, 1]]) >>> rm = RowMatrix(rows) @@ -345,16 +406,27 @@ def computeSVD(self, k, computeU=False, rCond=1e-9): "computeSVD", int(k), bool(computeU), float(rCond)) return SingularValueDecomposition(j_model) - @since('2.2.0') def computePrincipalComponents(self, k): """ Computes the k principal components of the given row matrix - .. note:: This cannot be computed on matrices with more than 65535 columns. + .. versionadded:: 2.2.0 + + Notes + ----- + This cannot be computed on matrices with more than 65535 columns. - :param k: Number of principal components to keep. - :returns: :py:class:`pyspark.mllib.linalg.DenseMatrix` + Parameters + ---------- + k : int + Number of principal components to keep. + Returns + ------- + :py:class:`pyspark.mllib.linalg.DenseMatrix` + + Examples + -------- >>> rows = sc.parallelize([[1, 2, 3], [2, 4, 5], [3, 6, 1]]) >>> rm = RowMatrix(rows) @@ -370,15 +442,24 @@ def computePrincipalComponents(self, k): """ return self._java_matrix_wrapper.call("computePrincipalComponents", k) - @since('2.2.0') def multiply(self, matrix): """ Multiply this matrix by a local dense matrix on the right. - :param matrix: a local dense matrix whose number of rows must match the number of columns - of this matrix - :returns: :py:class:`RowMatrix` + .. versionadded:: 2.2.0 + + Parameters + ---------- + matrix : :py:class:`pyspark.mllib.linalg.Matrix` + a local dense matrix whose number of rows must match the number of columns + of this matrix + Returns + ------- + :py:class:`RowMatrix` + + Examples + -------- >>> rm = RowMatrix(sc.parallelize([[0, 1], [2, 3]])) >>> rm.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect() [DenseVector([2.0, 3.0]), DenseVector([6.0, 11.0])] @@ -438,8 +519,12 @@ class IndexedRow(object): Just a wrapper over a (int, vector) tuple. - :param index: The index for the given row. - :param vector: The row in the matrix at the given index. + Parameters + ---------- + index : int + The index for the given row. + vector : :py:class:`pyspark.mllib.linalg.Vector` or convertible + The row in the matrix at the given index. """ def __init__(self, index, vector): self.index = int(index) @@ -462,16 +547,21 @@ class IndexedRowMatrix(DistributedMatrix): """ Represents a row-oriented distributed Matrix with indexed rows. - :param rows: An RDD of IndexedRows or (int, vector) tuples or a DataFrame consisting of a - int typed column of indices and a vector typed column. - :param numRows: Number of rows in the matrix. A non-positive - value means unknown, at which point the number - of rows will be determined by the max row - index plus one. - :param numCols: Number of columns in the matrix. A non-positive - value means unknown, at which point the number - of columns will be determined by the size of - the first row. + Parameters + ---------- + rows : :py:class:`pyspark.RDD` + An RDD of IndexedRows or (int, vector) tuples or a DataFrame consisting of a + int typed column of indices and a vector typed column. + numRows : int, optional + Number of rows in the matrix. A non-positive + value means unknown, at which point the number + of rows will be determined by the max row + index plus one. + numCols : int, optional + Number of columns in the matrix. A non-positive + value means unknown, at which point the number + of columns will be determined by the size of + the first row. """ def __init__(self, rows, numRows=0, numCols=0): """ @@ -484,6 +574,8 @@ def __init__(self, rows, numRows=0, numCols=0): object, in which case we can wrap it directly. This assists in clean matrix conversions. + Examples + -------- >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), ... IndexedRow(1, [4, 5, 6])]) >>> mat = IndexedRowMatrix(rows) @@ -524,6 +616,8 @@ def rows(self): """ Rows of the IndexedRowMatrix stored as an RDD of IndexedRows. + Examples + -------- >>> mat = IndexedRowMatrix(sc.parallelize([IndexedRow(0, [1, 2, 3]), ... IndexedRow(1, [4, 5, 6])])) >>> rows = mat.rows @@ -542,6 +636,8 @@ def numRows(self): """ Get or compute the number of rows. + Examples + -------- >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), ... IndexedRow(1, [4, 5, 6]), ... IndexedRow(2, [7, 8, 9]), @@ -561,6 +657,8 @@ def numCols(self): """ Get or compute the number of cols. + Examples + -------- >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), ... IndexedRow(1, [4, 5, 6]), ... IndexedRow(2, [7, 8, 9]), @@ -580,6 +678,8 @@ def columnSimilarities(self): """ Compute all cosine similarities between columns. + Examples + -------- >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), ... IndexedRow(6, [4, 5, 6])]) >>> mat = IndexedRowMatrix(rows) @@ -590,13 +690,18 @@ def columnSimilarities(self): java_coordinate_matrix = self._java_matrix_wrapper.call("columnSimilarities") return CoordinateMatrix(java_coordinate_matrix) - @since('2.0.0') def computeGramianMatrix(self): """ Computes the Gramian matrix `A^T A`. - .. note:: This cannot be computed on matrices with more than 65535 columns. + .. versionadded:: 2.0.0 + + Notes + ----- + This cannot be computed on matrices with more than 65535 columns. + Examples + -------- >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), ... IndexedRow(1, [4, 5, 6])]) >>> mat = IndexedRowMatrix(rows) @@ -610,6 +715,8 @@ def toRowMatrix(self): """ Convert this matrix to a RowMatrix. + Examples + -------- >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), ... IndexedRow(6, [4, 5, 6])]) >>> mat = IndexedRowMatrix(rows).toRowMatrix() @@ -623,6 +730,8 @@ def toCoordinateMatrix(self): """ Convert this matrix to a CoordinateMatrix. + Examples + -------- >>> rows = sc.parallelize([IndexedRow(0, [1, 0]), ... IndexedRow(6, [0, 5])]) >>> mat = IndexedRowMatrix(rows).toCoordinateMatrix() @@ -636,13 +745,19 @@ def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): """ Convert this matrix to a BlockMatrix. - :param rowsPerBlock: Number of rows that make up each block. - The blocks forming the final rows are not - required to have the given number of rows. - :param colsPerBlock: Number of columns that make up each block. - The blocks forming the final columns are not - required to have the given number of columns. - + Parameters + ---------- + rowsPerBlock : int, optional + Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + colsPerBlock : int, optional + Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + + Examples + -------- >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), ... IndexedRow(6, [4, 5, 6])]) >>> mat = IndexedRowMatrix(rows).toBlockMatrix() @@ -661,7 +776,6 @@ def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): colsPerBlock) return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) - @since('2.2.0') def computeSVD(self, k, computeU=False, rCond=1e-9): """ Computes the singular value decomposition of the IndexedRowMatrix. @@ -679,17 +793,29 @@ def computeSVD(self, k, computeU=False, rCond=1e-9): For more specific details on implementation, please refer the scala documentation. - :param k: Number of leading singular values to keep (`0 < k <= n`). - It might return less than k if there are numerically zero singular values - or there are not enough Ritz values converged before the maximum number of - Arnoldi update iterations is reached (in case that matrix A is ill-conditioned). - :param computeU: Whether or not to compute U. If set to be - True, then U is computed by A * V * s^-1 - :param rCond: Reciprocal condition number. All singular values - smaller than rCond * s[0] are treated as zero - where s[0] is the largest singular value. - :returns: SingularValueDecomposition object - + .. versionadded:: 2.2.0 + + Parameters + ---------- + k : int + Number of leading singular values to keep (`0 < k <= n`). + It might return less than k if there are numerically zero singular values + or there are not enough Ritz values converged before the maximum number of + Arnoldi update iterations is reached (in case that matrix A is ill-conditioned). + computeU : bool, optional + Whether or not to compute U. If set to be + True, then U is computed by A * V * s^-1 + rCond : float, optional + Reciprocal condition number. All singular values + smaller than rCond * s[0] are treated as zero + where s[0] is the largest singular value. + + Returns + ------- + :py:class:`SingularValueDecomposition` + + Examples + -------- >>> rows = [(0, (3, 1, 1)), (1, (-1, 3, 1))] >>> irm = IndexedRowMatrix(sc.parallelize(rows)) >>> svd_model = irm.computeSVD(2, True) @@ -705,15 +831,24 @@ def computeSVD(self, k, computeU=False, rCond=1e-9): "computeSVD", int(k), bool(computeU), float(rCond)) return SingularValueDecomposition(j_model) - @since('2.2.0') def multiply(self, matrix): """ Multiply this matrix by a local dense matrix on the right. - :param matrix: a local dense matrix whose number of rows must match the number of columns - of this matrix - :returns: :py:class:`IndexedRowMatrix` + .. versionadded:: 2.2.0 + + Parameters + ---------- + matrix : :py:class:`pyspark.mllib.linalg.Matrix` + a local dense matrix whose number of rows must match the number of columns + of this matrix + Returns + ------- + :py:class:`IndexedRowMatrix` + + Examples + -------- >>> mat = IndexedRowMatrix(sc.parallelize([(0, (0, 1)), (1, (2, 3))])) >>> mat.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect() [IndexedRow(0, [2.0,3.0]), IndexedRow(1, [6.0,11.0])] @@ -730,9 +865,14 @@ class MatrixEntry(object): Just a wrapper over a (int, int, float) tuple. - :param i: The row index of the matrix. - :param j: The column index of the matrix. - :param value: The (i, j)th entry of the matrix, as a float. + Parameters + ---------- + i : int + The row index of the matrix. + j : int + The column index of the matrix. + value : float + The (i, j)th entry of the matrix, as a float. """ def __init__(self, i, j, value): self.i = int(i) @@ -756,16 +896,21 @@ class CoordinateMatrix(DistributedMatrix): """ Represents a matrix in coordinate format. - :param entries: An RDD of MatrixEntry inputs or - (int, int, float) tuples. - :param numRows: Number of rows in the matrix. A non-positive - value means unknown, at which point the number - of rows will be determined by the max row - index plus one. - :param numCols: Number of columns in the matrix. A non-positive - value means unknown, at which point the number - of columns will be determined by the max row - index plus one. + Parameters + ---------- + entries : :py:class:`pyspark.RDD` + An RDD of MatrixEntry inputs or + (int, int, float) tuples. + numRows : int, optional + Number of rows in the matrix. A non-positive + value means unknown, at which point the number + of rows will be determined by the max row + index plus one. + numCols : int, optional + Number of columns in the matrix. A non-positive + value means unknown, at which point the number + of columns will be determined by the max row + index plus one. """ def __init__(self, entries, numRows=0, numCols=0): """ @@ -778,6 +923,8 @@ def __init__(self, entries, numRows=0, numCols=0): object, in which case we can wrap it directly. This assists in clean matrix conversions. + Examples + -------- >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), ... MatrixEntry(6, 4, 2.1)]) >>> mat = CoordinateMatrix(entries) @@ -817,6 +964,8 @@ def entries(self): Entries of the CoordinateMatrix stored as an RDD of MatrixEntries. + Examples + -------- >>> mat = CoordinateMatrix(sc.parallelize([MatrixEntry(0, 0, 1.2), ... MatrixEntry(6, 4, 2.1)])) >>> entries = mat.entries @@ -835,6 +984,8 @@ def numRows(self): """ Get or compute the number of rows. + Examples + -------- >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), ... MatrixEntry(1, 0, 2), ... MatrixEntry(2, 1, 3.7)]) @@ -853,6 +1004,8 @@ def numCols(self): """ Get or compute the number of cols. + Examples + -------- >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), ... MatrixEntry(1, 0, 2), ... MatrixEntry(2, 1, 3.7)]) @@ -867,11 +1020,14 @@ def numCols(self): """ return self._java_matrix_wrapper.call("numCols") - @since('2.0.0') def transpose(self): """ Transpose this CoordinateMatrix. + .. versionadded:: 2.0.0 + + Examples + -------- >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), ... MatrixEntry(1, 0, 2), ... MatrixEntry(2, 1, 3.7)]) @@ -891,6 +1047,8 @@ def toRowMatrix(self): """ Convert this matrix to a RowMatrix. + Examples + -------- >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), ... MatrixEntry(6, 4, 2.1)]) >>> mat = CoordinateMatrix(entries).toRowMatrix() @@ -915,6 +1073,8 @@ def toIndexedRowMatrix(self): """ Convert this matrix to an IndexedRowMatrix. + Examples + -------- >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), ... MatrixEntry(6, 4, 2.1)]) >>> mat = CoordinateMatrix(entries).toIndexedRowMatrix() @@ -938,13 +1098,19 @@ def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): """ Convert this matrix to a BlockMatrix. - :param rowsPerBlock: Number of rows that make up each block. - The blocks forming the final rows are not - required to have the given number of rows. - :param colsPerBlock: Number of columns that make up each block. - The blocks forming the final columns are not - required to have the given number of columns. - + Parameters + ---------- + rowsPerBlock : int, optional + Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + colsPerBlock : int, optional + Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + + Examples + -------- >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), ... MatrixEntry(6, 4, 2.1)]) >>> mat = CoordinateMatrix(entries).toBlockMatrix() @@ -983,26 +1149,33 @@ class BlockMatrix(DistributedMatrix): """ Represents a distributed matrix in blocks of local matrices. - :param blocks: An RDD of sub-matrix blocks - ((blockRowIndex, blockColIndex), sub-matrix) that - form this distributed matrix. If multiple blocks - with the same index exist, the results for - operations like add and multiply will be - unpredictable. - :param rowsPerBlock: Number of rows that make up each block. - The blocks forming the final rows are not - required to have the given number of rows. - :param colsPerBlock: Number of columns that make up each block. - The blocks forming the final columns are not - required to have the given number of columns. - :param numRows: Number of rows of this matrix. If the supplied - value is less than or equal to zero, the number - of rows will be calculated when `numRows` is - invoked. - :param numCols: Number of columns of this matrix. If the supplied - value is less than or equal to zero, the number - of columns will be calculated when `numCols` is - invoked. + Parameters + ---------- + blocks : :py:class:`pyspark.RDD` + An RDD of sub-matrix blocks + ((blockRowIndex, blockColIndex), sub-matrix) that + form this distributed matrix. If multiple blocks + with the same index exist, the results for + operations like add and multiply will be + unpredictable. + rowsPerBlock : int + Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + colsPerBlock : int + Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + numRows : int, optional + Number of rows of this matrix. If the supplied + value is less than or equal to zero, the number + of rows will be calculated when `numRows` is + invoked. + numCols : int, optional + Number of columns of this matrix. If the supplied + value is less than or equal to zero, the number + of columns will be calculated when `numCols` is + invoked. """ def __init__(self, blocks, rowsPerBlock, colsPerBlock, numRows=0, numCols=0): """ @@ -1015,6 +1188,8 @@ def __init__(self, blocks, rowsPerBlock, colsPerBlock, numRows=0, numCols=0): object, in which case we can wrap it directly. This assists in clean matrix conversions. + Examples + -------- >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) >>> mat = BlockMatrix(blocks, 3, 2) @@ -1058,6 +1233,8 @@ def blocks(self): ((blockRowIndex, blockColIndex), sub-matrix) that form this distributed matrix. + Examples + -------- >>> mat = BlockMatrix( ... sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]), 3, 2) @@ -1079,6 +1256,8 @@ def rowsPerBlock(self): """ Number of rows that make up each block. + Examples + -------- >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) >>> mat = BlockMatrix(blocks, 3, 2) @@ -1092,6 +1271,8 @@ def colsPerBlock(self): """ Number of columns that make up each block. + Examples + -------- >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) >>> mat = BlockMatrix(blocks, 3, 2) @@ -1105,6 +1286,8 @@ def numRowBlocks(self): """ Number of rows of blocks in the BlockMatrix. + Examples + -------- >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) >>> mat = BlockMatrix(blocks, 3, 2) @@ -1118,6 +1301,8 @@ def numColBlocks(self): """ Number of columns of blocks in the BlockMatrix. + Examples + -------- >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) >>> mat = BlockMatrix(blocks, 3, 2) @@ -1130,6 +1315,8 @@ def numRows(self): """ Get or compute the number of rows. + Examples + -------- >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) @@ -1147,6 +1334,8 @@ def numCols(self): """ Get or compute the number of cols. + Examples + -------- >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) @@ -1197,6 +1386,8 @@ def add(self, other): two dense sub matrix blocks are added, the output block will also be a DenseMatrix. + Examples + -------- >>> dm1 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) >>> dm2 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]) >>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [7, 11, 12]) @@ -1220,7 +1411,6 @@ def add(self, other): java_block_matrix = self._java_matrix_wrapper.call("add", other_java_block_matrix) return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock) - @since('2.0.0') def subtract(self, other): """ Subtracts the given block matrix `other` from this block matrix: @@ -1232,6 +1422,10 @@ def subtract(self, other): If two dense sub matrix blocks are subtracted, the output block will also be a DenseMatrix. + .. versionadded:: 2.0.0 + + Examples + -------- >>> dm1 = Matrices.dense(3, 2, [3, 1, 5, 4, 6, 2]) >>> dm2 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]) >>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [1, 2, 3]) @@ -1265,6 +1459,8 @@ def multiply(self, other): This may cause some performance issues until support for multiplying two sparse matrices is added. + Examples + -------- >>> dm1 = Matrices.dense(2, 3, [1, 2, 3, 4, 5, 6]) >>> dm2 = Matrices.dense(2, 3, [7, 8, 9, 10, 11, 12]) >>> dm3 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) @@ -1290,12 +1486,15 @@ def multiply(self, other): java_block_matrix = self._java_matrix_wrapper.call("multiply", other_java_block_matrix) return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock) - @since('2.0.0') def transpose(self): """ Transpose this BlockMatrix. Returns a new BlockMatrix instance sharing the same underlying data. Is a lazy operation. + .. versionadded:: 2.0.0 + + Examples + -------- >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) >>> mat = BlockMatrix(blocks, 3, 2) @@ -1311,6 +1510,8 @@ def toLocalMatrix(self): """ Collect the distributed matrix on the driver as a DenseMatrix. + Examples + -------- >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) >>> mat = BlockMatrix(blocks, 3, 2).toLocalMatrix() @@ -1333,6 +1534,8 @@ def toIndexedRowMatrix(self): """ Convert this matrix to an IndexedRowMatrix. + Examples + -------- >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) >>> mat = BlockMatrix(blocks, 3, 2).toIndexedRowMatrix() @@ -1356,6 +1559,8 @@ def toCoordinateMatrix(self): """ Convert this matrix to a CoordinateMatrix. + Examples + -------- >>> blocks = sc.parallelize([((0, 0), Matrices.dense(1, 2, [1, 2])), ... ((1, 0), Matrices.dense(1, 2, [7, 8]))]) >>> mat = BlockMatrix(blocks, 1, 2).toCoordinateMatrix() diff --git a/python/pyspark/mllib/linalg/distributed.pyi b/python/pyspark/mllib/linalg/distributed.pyi index 238c4ea32e..7ec2d60c5a 100644 --- a/python/pyspark/mllib/linalg/distributed.pyi +++ b/python/pyspark/mllib/linalg/distributed.pyi @@ -22,6 +22,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.mllib.common import JavaModelWrapper from pyspark.mllib.linalg import Vector, Matrix, QRDecomposition from pyspark.mllib.stat import MultivariateStatisticalSummary +import pyspark.sql.dataframe from numpy import ndarray # noqa: F401 VectorLike = Union[Vector, Sequence[Union[float, int]]] @@ -35,7 +36,10 @@ class DistributedMatrix: class RowMatrix(DistributedMatrix): def __init__( - self, rows: RDD[Vector], numRows: int = ..., numCols: int = ... + self, + rows: Union[RDD[Vector], pyspark.sql.dataframe.DataFrame], + numRows: int = ..., + numCols: int = ..., ) -> None: ... @property def rows(self) -> RDD[Vector]: ... diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 6106c58584..a33dfe26fb 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -22,7 +22,6 @@ import sys from functools import wraps -from pyspark import since from pyspark.mllib.common import callMLlibFunc @@ -46,7 +45,6 @@ class RandomRDDs(object): """ @staticmethod - @since("1.1.0") def uniformRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the @@ -56,12 +54,26 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): to U(a, b), use ``RandomRDDs.uniformRDD(sc, n, p, seed).map(lambda v: a + (b - a) * v)`` - :param sc: SparkContext used to create the RDD. - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`. - + .. versionadded:: 1.1.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + used to create the RDD. + size : int + Size of the RDD. + numPartitions : int, optional + Number of partitions in the RDD (default: `sc.defaultParallelism`). + seed : int, optional + Random seed (default: a random long integer). + + Returns + ------- + :py:class:`pyspark.RDD` + RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`. + + Examples + -------- >>> x = RandomRDDs.uniformRDD(sc, 100).collect() >>> len(x) 100 @@ -76,7 +88,6 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed) @staticmethod - @since("1.1.0") def normalRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the standard normal @@ -86,12 +97,26 @@ def normalRDD(sc, size, numPartitions=None, seed=None): to some other normal N(mean, sigma^2), use ``RandomRDDs.normal(sc, n, p, seed).map(lambda v: mean + sigma * v)`` - :param sc: SparkContext used to create the RDD. - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0). - + .. versionadded:: 1.1.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + used to create the RDD. + size : int + Size of the RDD. + numPartitions : int, optional + Number of partitions in the RDD (default: `sc.defaultParallelism`). + seed : int, optional + Random seed (default: a random long integer). + + Returns + ------- + :py:class:`pyspark.RDD` + RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0). + + Examples + -------- >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1) >>> stats = x.stats() >>> stats.count() @@ -104,20 +129,34 @@ def normalRDD(sc, size, numPartitions=None, seed=None): return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) @staticmethod - @since("1.3.0") def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the log normal distribution with the input mean and standard distribution. - :param sc: SparkContext used to create the RDD. - :param mean: mean for the log Normal distribution - :param std: std for the log Normal distribution - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ log N(mean, std). - + .. versionadded:: 1.3.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + used to create the RDD. + mean : float + mean for the log Normal distribution + std : float + std for the log Normal distribution + size : int + Size of the RDD. + numPartitions : int, optional + Number of partitions in the RDD (default: `sc.defaultParallelism`). + seed : int, optional + Random seed (default: a random long integer). + + Returns + ------- + RDD of float comprised of i.i.d. samples ~ log N(mean, std). + + Examples + -------- >>> from math import sqrt, exp >>> mean = 0.0 >>> std = 1.0 @@ -137,19 +176,33 @@ def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): size, numPartitions, seed) @staticmethod - @since("1.1.0") def poissonRDD(sc, mean, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. - :param sc: SparkContext used to create the RDD. - :param mean: Mean, or lambda, for the Poisson distribution. - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ Pois(mean). - + .. versionadded:: 1.1.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + SparkContext used to create the RDD. + mean : float + Mean, or lambda, for the Poisson distribution. + size : int + Size of the RDD. + numPartitions : int, optional + Number of partitions in the RDD (default: `sc.defaultParallelism`). + seed : int, optional + Random seed (default: a random long integer). + + Returns + ------- + :py:class:`pyspark.RDD` + RDD of float comprised of i.i.d. samples ~ Pois(mean). + + Examples + -------- >>> mean = 100.0 >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2) >>> stats = x.stats() @@ -164,19 +217,33 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) @staticmethod - @since("1.3.0") def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Exponential distribution with the input mean. - :param sc: SparkContext used to create the RDD. - :param mean: Mean, or 1 / lambda, for the Exponential distribution. - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ Exp(mean). - + .. versionadded:: 1.3.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + SparkContext used to create the RDD. + mean : float + Mean, or 1 / lambda, for the Exponential distribution. + size : int + Size of the RDD. + numPartitions : int, optional + Number of partitions in the RDD (default: `sc.defaultParallelism`). + seed : int, optional + Random seed (default: a random long integer). + + Returns + ------- + :py:class:`pyspark.RDD` + RDD of float comprised of i.i.d. samples ~ Exp(mean). + + Examples + -------- >>> mean = 2.0 >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2) >>> stats = x.stats() @@ -191,20 +258,35 @@ def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed) @staticmethod - @since("1.3.0") def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Gamma distribution with the input shape and scale. - :param sc: SparkContext used to create the RDD. - :param shape: shape (> 0) parameter for the Gamma distribution - :param scale: scale (> 0) parameter for the Gamma distribution - :param size: Size of the RDD. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of float comprised of i.i.d. samples ~ Gamma(shape, scale). - + .. versionadded:: 1.3.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + SparkContext used to create the RDD. + shape : float + shape (> 0) parameter for the Gamma distribution + scale : float + scale (> 0) parameter for the Gamma distribution + size : int + Size of the RDD. + numPartitions : int, optional + Number of partitions in the RDD (default: `sc.defaultParallelism`). + seed : int, optional + Random seed (default: a random long integer). + + Returns + ------- + :py:class:`pyspark.RDD` + RDD of float comprised of i.i.d. samples ~ Gamma(shape, scale). + + Examples + -------- >>> from math import sqrt >>> shape = 1.0 >>> scale = 2.0 @@ -224,19 +306,33 @@ def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): @staticmethod @toArray - @since("1.1.0") def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn from the uniform distribution U(0.0, 1.0). - :param sc: SparkContext used to create the RDD. - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD. - :param seed: Seed for the RNG that generates the seed for the generator in each partition. - :return: RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. - + .. versionadded:: 1.1.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + SparkContext used to create the RDD. + numRows : int + Number of Vectors in the RDD. + numCols : int + Number of elements in each Vector. + numPartitions : int, optional + Number of partitions in the RDD. + seed : int, optional + Seed for the RNG that generates the seed for the generator in each partition. + + Returns + ------- + :py:class:`pyspark.RDD` + RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. + + Examples + -------- >>> import numpy as np >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) >>> mat.shape @@ -250,19 +346,33 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray - @since("1.1.0") def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn from the standard normal distribution. - :param sc: SparkContext used to create the RDD. - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. - + .. versionadded:: 1.1.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + SparkContext used to create the RDD. + numRows : int + Number of Vectors in the RDD. + numCols : int + Number of elements in each Vector. + numPartitions : int, optional + Number of partitions in the RDD (default: `sc.defaultParallelism`). + seed : int, optional + Random seed (default: a random long integer). + + Returns + ------- + :py:class:`pyspark.RDD` + RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + + Examples + -------- >>> import numpy as np >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1).collect()) >>> mat.shape @@ -276,21 +386,37 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray - @since("1.3.0") def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn from the log normal distribution. - :param sc: SparkContext used to create the RDD. - :param mean: Mean of the log normal distribution - :param std: Standard Deviation of the log normal distribution - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ log `N(mean, std)`. - + .. versionadded:: 1.3.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + SparkContext used to create the RDD. + mean : float + Mean of the log normal distribution + std : float + Standard Deviation of the log normal distribution + numRows : int + Number of Vectors in the RDD. + numCols : int + Number of elements in each Vector. + numPartitions : int, optional + Number of partitions in the RDD (default: `sc.defaultParallelism`). + seed : int, optional + Random seed (default: a random long integer). + + Returns + ------- + :py:class:`pyspark.RDD` + RDD of Vector with vectors containing i.i.d. samples ~ log `N(mean, std)`. + + Examples + -------- >>> import numpy as np >>> from math import sqrt, exp >>> mean = 0.0 @@ -311,20 +437,35 @@ def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed @staticmethod @toArray - @since("1.1.0") def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn from the Poisson distribution with the input mean. - :param sc: SparkContext used to create the RDD. - :param mean: Mean, or lambda, for the Poisson distribution. - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean). - + .. versionadded:: 1.1.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + SparkContext used to create the RDD. + mean : float + Mean, or lambda, for the Poisson distribution. + numRows : float + Number of Vectors in the RDD. + numCols : int + Number of elements in each Vector. + numPartitions : int, optional + Number of partitions in the RDD (default: `sc.defaultParallelism`) + seed : int, optional + Random seed (default: a random long integer). + + Returns + ------- + :py:class:`pyspark.RDD` + RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean). + + Examples + -------- >>> import numpy as np >>> mean = 100.0 >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1) @@ -342,20 +483,35 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray - @since("1.3.0") def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn from the Exponential distribution with the input mean. - :param sc: SparkContext used to create the RDD. - :param mean: Mean, or 1 / lambda, for the Exponential distribution. - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ Exp(mean). - + .. versionadded:: 1.3.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + SparkContext used to create the RDD. + mean : float + Mean, or 1 / lambda, for the Exponential distribution. + numRows : int + Number of Vectors in the RDD. + numCols : int + Number of elements in each Vector. + numPartitions : int, optional + Number of partitions in the RDD (default: `sc.defaultParallelism`) + seed : int, optional + Random seed (default: a random long integer). + + Returns + ------- + :py:class:`pyspark.RDD` + RDD of Vector with vectors containing i.i.d. samples ~ Exp(mean). + + Examples + -------- >>> import numpy as np >>> mean = 0.5 >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1) @@ -373,21 +529,37 @@ def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=No @staticmethod @toArray - @since("1.3.0") def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn from the Gamma distribution. - :param sc: SparkContext used to create the RDD. - :param shape: Shape (> 0) of the Gamma distribution - :param scale: Scale (> 0) of the Gamma distribution - :param numRows: Number of Vectors in the RDD. - :param numCols: Number of elements in each Vector. - :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). - :param seed: Random seed (default: a random long integer). - :return: RDD of Vector with vectors containing i.i.d. samples ~ Gamma(shape, scale). - + .. versionadded:: 1.3.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + SparkContext used to create the RDD. + shape : float + Shape (> 0) of the Gamma distribution + scale : float + Scale (> 0) of the Gamma distribution + numRows : int + Number of Vectors in the RDD. + numCols : int + Number of elements in each Vector. + numPartitions : int, optional + Number of partitions in the RDD (default: `sc.defaultParallelism`). + seed : int, optional, + Random seed (default: a random long integer). + + Returns + ------- + :py:class:`pyspark.RDD` + RDD of Vector with vectors containing i.i.d. samples ~ Gamma(shape, scale). + + Examples + -------- >>> import numpy as np >>> from math import sqrt >>> shape = 1.0 diff --git a/python/pyspark/mllib/random.pyi b/python/pyspark/mllib/random.pyi index dc5f470161..ec83170625 100644 --- a/python/pyspark/mllib/random.pyi +++ b/python/pyspark/mllib/random.pyi @@ -90,7 +90,7 @@ class RandomRDDs: def logNormalVectorRDD( sc: SparkContext, mean: float, - std, + std: float, numRows: int, numCols: int, numPartitions: Optional[int] = ..., diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 3dd7cb200c..7a5fb6e6ee 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -32,13 +32,15 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])): """ Represents a (user, product, rating) tuple. + .. versionadded:: 1.2.0 + + Examples + -------- >>> r = Rating(1, 2, 5.0) >>> (r.user, r.product, r.rating) (1, 2, 5.0) >>> (r[0], r[1], r[2]) (1, 2, 5.0) - - .. versionadded:: 1.2.0 """ def __reduce__(self): @@ -51,6 +53,10 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): """A matrix factorisation model trained by regularized alternating least-squares. + .. versionadded:: 0.9.0 + + Examples + -------- >>> r1 = (1, 1, 1.0) >>> r2 = (1, 2, 2.0) >>> r3 = (2, 1, 2.0) @@ -126,8 +132,6 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): ... rmtree(path) ... except OSError: ... pass - - .. versionadded:: 0.9.0 """ @since("0.9.0") def predict(self, user, product): @@ -237,7 +241,6 @@ def _prepare(cls, ratings): return ratings @classmethod - @since("0.9.0") def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, seed=None): """ @@ -247,35 +250,38 @@ def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative features). To solve for these features, ALS is run iteratively with a configurable level of parallelism. - :param ratings: - RDD of `Rating` or (userID, productID, rating) tuple. - :param rank: - Number of features to use (also referred to as the number of latent factors). - :param iterations: - Number of iterations of ALS. - (default: 5) - :param lambda_: - Regularization parameter. - (default: 0.01) - :param blocks: - Number of blocks used to parallelize the computation. A value - of -1 will use an auto-configured number of blocks. - (default: -1) - :param nonnegative: - A value of True will solve least-squares with nonnegativity - constraints. - (default: False) - :param seed: - Random seed for initial matrix factorization model. A value - of None will use system time as the seed. - (default: None) + .. versionadded:: 0.9.0 + + Parameters + ---------- + ratings : :py:class:`pyspark.RDD` + RDD of `Rating` or (userID, productID, rating) tuple. + rank : int + Number of features to use (also referred to as the number of latent factors). + iterations : int, optional + Number of iterations of ALS. + (default: 5) + lambda\\_ : float, optional + Regularization parameter. + (default: 0.01) + blocks : int, optional + Number of blocks used to parallelize the computation. A value + of -1 will use an auto-configured number of blocks. + (default: -1) + nonnegative : bool, optional + A value of True will solve least-squares with nonnegativity + constraints. + (default: False) + seed : bool, optional + Random seed for initial matrix factorization model. A value + of None will use system time as the seed. + (default: None) """ model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, nonnegative, seed) return MatrixFactorizationModel(model) @classmethod - @since("0.9.0") def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01, nonnegative=False, seed=None): """ @@ -285,31 +291,35 @@ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alp given rank (number of features). To solve for these features, ALS is run iteratively with a configurable level of parallelism. - :param ratings: - RDD of `Rating` or (userID, productID, rating) tuple. - :param rank: - Number of features to use (also referred to as the number of latent factors). - :param iterations: - Number of iterations of ALS. - (default: 5) - :param lambda_: - Regularization parameter. - (default: 0.01) - :param blocks: - Number of blocks used to parallelize the computation. A value - of -1 will use an auto-configured number of blocks. - (default: -1) - :param alpha: - A constant used in computing confidence. - (default: 0.01) - :param nonnegative: - A value of True will solve least-squares with nonnegativity - constraints. - (default: False) - :param seed: - Random seed for initial matrix factorization model. A value - of None will use system time as the seed. - (default: None) + .. versionadded:: 0.9.0 + + Parameters + ---------- + ratings : :py:class:`pyspark.RDD` + RDD of `Rating` or (userID, productID, rating) tuple. + rank : int + Number of features to use (also referred to as the number of latent factors). + iterations : int, optional + Number of iterations of ALS. + (default: 5) + lambda\\_ : float, optional + Regularization parameter. + (default: 0.01) + blocks : int, optional + Number of blocks used to parallelize the computation. A value + of -1 will use an auto-configured number of blocks. + (default: -1) + alpha : float, optional + A constant used in computing confidence. + (default: 0.01) + nonnegative : bool, optional + A value of True will solve least-squares with nonnegativity + constraints. + (default: False) + seed : int, optional + Random seed for initial matrix factorization model. A value + of None will use system time as the seed. + (default: None) """ model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, alpha, nonnegative, seed) diff --git a/python/pyspark/mllib/recommendation.pyi b/python/pyspark/mllib/recommendation.pyi index e2f1549420..4fea0acf3c 100644 --- a/python/pyspark/mllib/recommendation.pyi +++ b/python/pyspark/mllib/recommendation.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import array from collections import namedtuple @@ -27,7 +27,7 @@ from pyspark.mllib.common import JavaModelWrapper from pyspark.mllib.util import JavaLoader, JavaSaveable class Rating(namedtuple("Rating", ["user", "product", "rating"])): - def __reduce__(self): ... + def __reduce__(self) -> Tuple[Type[Rating], Tuple[int, int, float]]: ... class MatrixFactorizationModel( JavaModelWrapper, JavaSaveable, JavaLoader[MatrixFactorizationModel] diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 77bca86ac1..e549b0ac43 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -39,15 +39,19 @@ class LabeledPoint(object): """ Class that represents the features and labels of a data point. - :param label: - Label for this data point. - :param features: - Vector of features for this point (NumPy array, list, - pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix). - - .. note:: 'label' and 'features' are accessible as class attributes. - .. versionadded:: 1.0.0 + + Parameters + ---------- + label : int + Label for this data point. + features : :py:class:`pyspark.mllib.linalg.Vector` or convertible + Vector of features for this point (NumPy array, list, + pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix). + + Notes + ----- + 'label' and 'features' are accessible as class attributes. """ def __init__(self, label, features): @@ -69,12 +73,14 @@ class LinearModel(object): """ A linear model that has a vector of coefficients and an intercept. - :param weights: - Weights computed for every feature. - :param intercept: - Intercept computed for this model. - .. versionadded:: 0.9.0 + + Parameters + ---------- + weights : :py:class:`pyspark.mllib.linalg.Vector` + Weights computed for every feature. + intercept : float + Intercept computed for this model. """ def __init__(self, weights, intercept): @@ -102,14 +108,16 @@ class LinearRegressionModelBase(LinearModel): """A linear regression model. + .. versionadded:: 0.9.0 + + Examples + -------- >>> from pyspark.mllib.linalg import SparseVector >>> lrmb = LinearRegressionModelBase(np.array([1.0, 2.0]), 0.1) >>> abs(lrmb.predict(np.array([-1.03, 7.777])) - 14.624) < 1e-6 True >>> abs(lrmb.predict(SparseVector(2, {0: -1.03, 1: 7.777})) - 14.624) < 1e-6 True - - .. versionadded:: 0.9.0 """ @since("0.9.0") @@ -129,6 +137,10 @@ class LinearRegressionModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit. + .. versionadded:: 0.9.0 + + Examples + -------- >>> from pyspark.mllib.linalg import SparseVector >>> from pyspark.mllib.regression import LabeledPoint >>> data = [ @@ -181,8 +193,6 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True - - .. versionadded:: 0.9.0 """ @since("1.4.0") def save(self, sc, path): @@ -224,11 +234,13 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): class LinearRegressionWithSGD(object): """ + Train a linear regression model with no regularization using Stochastic Gradient Descent. + .. versionadded:: 0.9.0 - .. note:: Deprecated in 2.0.0. Use ml.regression.LinearRegression. + .. deprecated:: 2.0.0 + Use :py:class:`pyspark.ml.regression.LinearRegression`. """ @classmethod - @since("0.9.0") def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.0, regType=None, intercept=False, validateData=True, convergenceTol=0.001): @@ -244,42 +256,47 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, corresponding right hand side label y. See also the documentation for the precise formulation. - :param data: - The training data, an RDD of LabeledPoint. - :param iterations: - The number of iterations. - (default: 100) - :param step: - The step parameter used in SGD. - (default: 1.0) - :param miniBatchFraction: - Fraction of data to be used for each SGD iteration. - (default: 1.0) - :param initialWeights: - The initial weights. - (default: None) - :param regParam: - The regularizer parameter. - (default: 0.0) - :param regType: - The type of regularizer used for training our model. - Supported values: + .. versionadded:: 0.9.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + The training data, an RDD of LabeledPoint. + iterations : int, optional + The number of iterations. + (default: 100) + step : float, optional + The step parameter used in SGD. + (default: 1.0) + miniBatchFraction : float, optional + Fraction of data to be used for each SGD iteration. + (default: 1.0) + initialWeights : :py:class:`pyspark.mllib.linalg.Vector` or convertible, optional + The initial weights. + (default: None) + regParam : float, optional + The regularizer parameter. + (default: 0.0) + regType : str, optional + The type of regularizer used for training our model. + Supported values: - "l1" for using L1 regularization - "l2" for using L2 regularization - None for no regularization (default) - :param intercept: - Boolean parameter which indicates the use or not of the - augmented representation for training data (i.e., whether bias - features are activated or not). - (default: False) - :param validateData: - Boolean parameter which indicates if the algorithm should - validate data before training. - (default: True) - :param convergenceTol: - A condition which decides iteration termination. - (default: 0.001) + + intercept : bool, optional + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e., whether bias + features are activated or not). + (default: False) + validateData : bool, optional + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + convergenceTol : float, optional + A condition which decides iteration termination. + (default: 0.001) """ warnings.warn( "Deprecated in 2.0.0. Use ml.regression.LinearRegression.", DeprecationWarning) @@ -299,6 +316,10 @@ class LassoModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit with an l_1 penalty term. + .. versionadded:: 0.9.0 + + Examples + -------- >>> from pyspark.mllib.linalg import SparseVector >>> from pyspark.mllib.regression import LabeledPoint >>> data = [ @@ -351,8 +372,6 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True - - .. versionadded:: 0.9.0 """ @since("1.4.0") def save(self, sc, path): @@ -375,12 +394,14 @@ def load(cls, sc, path): class LassoWithSGD(object): """ + Train a regression model with L1-regularization using Stochastic Gradient Descent. + .. versionadded:: 0.9.0 - .. note:: Deprecated in 2.0.0. Use ml.regression.LinearRegression with elasticNetParam = 1.0. - Note the default regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression. + .. deprecated:: 2.0.0 + Use :py:class:`pyspark.ml.regression.LinearRegression` with elasticNetParam = 1.0. + Note the default regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression. """ @classmethod - @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): @@ -395,35 +416,39 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. - :param data: - The training data, an RDD of LabeledPoint. - :param iterations: - The number of iterations. - (default: 100) - :param step: - The step parameter used in SGD. - (default: 1.0) - :param regParam: - The regularizer parameter. - (default: 0.01) - :param miniBatchFraction: - Fraction of data to be used for each SGD iteration. - (default: 1.0) - :param initialWeights: - The initial weights. - (default: None) - :param intercept: - Boolean parameter which indicates the use or not of the - augmented representation for training data (i.e. whether bias - features are activated or not). - (default: False) - :param validateData: - Boolean parameter which indicates if the algorithm should - validate data before training. - (default: True) - :param convergenceTol: - A condition which decides iteration termination. - (default: 0.001) + .. versionadded:: 0.9.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + The training data, an RDD of LabeledPoint. + iterations : int, optional + The number of iterations. + (default: 100) + step : float, optional + The step parameter used in SGD. + (default: 1.0) + regParam : float, optional + The regularizer parameter. + (default: 0.01) + miniBatchFraction : float, optional + Fraction of data to be used for each SGD iteration. + (default: 1.0) + initialWeights : :py:class:`pyspark.mllib.linalg.Vector` or convertible, optional + The initial weights. + (default: None) + intercept : bool, optional + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + validateData : bool, optional + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + convergenceTol : float, optional + A condition which decides iteration termination. + (default: 0.001) """ warnings.warn( "Deprecated in 2.0.0. Use ml.regression.LinearRegression with elasticNetParam = 1.0. " @@ -444,6 +469,10 @@ class RidgeRegressionModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit with an l_2 penalty term. + .. versionadded:: 0.9.0 + + Examples + -------- >>> from pyspark.mllib.linalg import SparseVector >>> from pyspark.mllib.regression import LabeledPoint >>> data = [ @@ -496,8 +525,6 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True - - .. versionadded:: 0.9.0 """ @since("1.4.0") def save(self, sc, path): @@ -520,13 +547,15 @@ def load(cls, sc, path): class RidgeRegressionWithSGD(object): """ + Train a regression model with L2-regularization using Stochastic Gradient Descent. + .. versionadded:: 0.9.0 - .. note:: Deprecated in 2.0.0. Use ml.regression.LinearRegression with elasticNetParam = 0.0. - Note the default regParam is 0.01 for RidgeRegressionWithSGD, but is 0.0 for - LinearRegression. + .. deprecated:: 2.0.0 + Use :py:class:`pyspark.ml.regression.LinearRegression` with elasticNetParam = 0.0. + Note the default regParam is 0.01 for RidgeRegressionWithSGD, but is 0.0 for + LinearRegression. """ @classmethod - @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): @@ -541,35 +570,39 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. - :param data: - The training data, an RDD of LabeledPoint. - :param iterations: - The number of iterations. - (default: 100) - :param step: - The step parameter used in SGD. - (default: 1.0) - :param regParam: - The regularizer parameter. - (default: 0.01) - :param miniBatchFraction: - Fraction of data to be used for each SGD iteration. - (default: 1.0) - :param initialWeights: - The initial weights. - (default: None) - :param intercept: - Boolean parameter which indicates the use or not of the - augmented representation for training data (i.e. whether bias - features are activated or not). - (default: False) - :param validateData: - Boolean parameter which indicates if the algorithm should - validate data before training. - (default: True) - :param convergenceTol: - A condition which decides iteration termination. - (default: 0.001) + .. versionadded:: 0.9.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + The training data, an RDD of LabeledPoint. + iterations : int, optional + The number of iterations. + (default: 100) + step : float, optional + The step parameter used in SGD. + (default: 1.0) + regParam : float, optional + The regularizer parameter. + (default: 0.01) + miniBatchFraction : float, optional + Fraction of data to be used for each SGD iteration. + (default: 1.0) + initialWeights : :py:class:`pyspark.mllib.linalg.Vector` or convertible, optional + The initial weights. + (default: None) + intercept : bool, optional + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + validateData : bool, optional + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + convergenceTol : float, optional + A condition which decides iteration termination. + (default: 0.001) """ warnings.warn( "Deprecated in 2.0.0. Use ml.regression.LinearRegression with elasticNetParam = 0.0. " @@ -589,15 +622,21 @@ class IsotonicRegressionModel(Saveable, Loader): """ Regression model for isotonic regression. - :param boundaries: - Array of boundaries for which predictions are known. Boundaries - must be sorted in increasing order. - :param predictions: - Array of predictions associated to the boundaries at the same - index. Results of isotonic regression and therefore monotone. - :param isotonic: - Indicates whether this is isotonic or antitonic. + .. versionadded:: 1.4.0 + Parameters + ---------- + boundaries : ndarray + Array of boundaries for which predictions are known. Boundaries + must be sorted in increasing order. + predictions : ndarray + Array of predictions associated to the boundaries at the same + index. Results of isotonic regression and therefore monotone. + isotonic : true + Indicates whether this is isotonic or antitonic. + + Examples + -------- >>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)] >>> irm = IsotonicRegression.train(sc.parallelize(data)) >>> irm.predict(3) @@ -619,8 +658,6 @@ class IsotonicRegressionModel(Saveable, Loader): ... rmtree(path) ... except OSError: ... pass - - .. versionadded:: 1.4.0 """ def __init__(self, boundaries, predictions, isotonic): @@ -628,7 +665,6 @@ def __init__(self, boundaries, predictions, isotonic): self.predictions = predictions self.isotonic = isotonic - @since("1.4.0") def predict(self, x): """ Predict labels for provided features. @@ -647,8 +683,13 @@ def predict(self, x): values with the same boundary then the same rules as in 2) are used. - :param x: - Feature or RDD of Features to be labeled. + + .. versionadded:: 1.4.0 + + Parameters + ---------- + x : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + Feature or RDD of Features to be labeled. """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) @@ -680,35 +721,42 @@ class IsotonicRegression(object): Currently implemented using parallelized pool adjacent violators algorithm. Only univariate (single feature) algorithm supported. - Sequential PAV implementation based on: + .. versionadded:: 1.4.0 + + Notes + ----- + Sequential PAV implementation based on + Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani (2011) [1]_ - Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. - "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61. - Available from http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf + Sequential PAV parallelization based on + Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset (1996) [2]_ - Sequential PAV parallelization based on: + See also + `Isotonic regression (Wikipedia) `_. - Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. + .. [1] Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. + "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61. + Available from http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf + .. [2] Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset "An approach to parallelizing isotonic regression." Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. Available from http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf - - See `Isotonic regression (Wikipedia) `_. - - .. versionadded:: 1.4.0 """ @classmethod - @since("1.4.0") def train(cls, data, isotonic=True): """ Train an isotonic regression model on the given data. - :param data: - RDD of (label, feature, weight) tuples. - :param isotonic: - Whether this is isotonic (which is default) or antitonic. - (default: True) + .. versionadded:: 1.4.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + RDD of (label, feature, weight) tuples. + isotonic : bool, optional + Whether this is isotonic (which is default) or antitonic. + (default: True) """ boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel", data.map(_convert_to_vector), bool(isotonic)) @@ -741,26 +789,32 @@ def _validate(self, dstream): raise ValueError( "Model must be intialized using setInitialWeights") - @since("1.5.0") def predictOn(self, dstream): """ Use the model to make predictions on batches of data from a DStream. - :return: - DStream containing predictions. + .. versionadded:: 1.5.0 + + Returns + ------- + :py:class:`pyspark.streaming.DStream` + DStream containing predictions. """ self._validate(dstream) return dstream.map(lambda x: self._model.predict(x)) - @since("1.5.0") def predictOnValues(self, dstream): """ Use the model to make predictions on the values of a DStream and carry over its keys. - :return: - DStream containing the input keys and the predictions as values. + .. versionadded:: 1.5.0 + + Returns + ------- + :py:class:`pyspark.streaming.DStream` + DStream containing predictions. """ self._validate(dstream) return dstream.mapValues(lambda x: self._model.predict(x)) @@ -779,20 +833,22 @@ class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): of features must be constant. An initial weight vector must be provided. - :param stepSize: - Step size for each iteration of gradient descent. - (default: 0.1) - :param numIterations: - Number of iterations run for each batch of data. - (default: 50) - :param miniBatchFraction: - Fraction of each batch of data to use for updates. - (default: 1.0) - :param convergenceTol: - Value used to determine when to terminate iterations. - (default: 0.001) - .. versionadded:: 1.5.0 + + Parameters + ---------- + stepSize : float, optional + Step size for each iteration of gradient descent. + (default: 0.1) + numIterations : int, optional + Number of iterations run for each batch of data. + (default: 50) + miniBatchFraction : float, optional + Fraction of each batch of data to use for updates. + (default: 1.0) + convergenceTol : float, optional + Value used to determine when to terminate iterations. + (default: 0.001) """ def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, convergenceTol=0.001): self.stepSize = stepSize diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py index 56444c152f..1d4d43e535 100644 --- a/python/pyspark/mllib/stat/KernelDensity.py +++ b/python/pyspark/mllib/stat/KernelDensity.py @@ -26,6 +26,8 @@ class KernelDensity(object): Estimate probability density at required points given an RDD of samples from the population. + Examples + -------- >>> kd = KernelDensity() >>> sample = sc.parallelize([0.0, 1.0]) >>> kd.setSample(sample) diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py index 0fb3306183..d3b4ddf7e4 100644 --- a/python/pyspark/mllib/stat/__init__.py +++ b/python/pyspark/mllib/stat/__init__.py @@ -21,8 +21,9 @@ from pyspark.mllib.stat._statistics import Statistics, MultivariateStatisticalSummary from pyspark.mllib.stat.distribution import MultivariateGaussian -from pyspark.mllib.stat.test import ChiSqTestResult +from pyspark.mllib.stat.test import ChiSqTestResult, KolmogorovSmirnovTestResult from pyspark.mllib.stat.KernelDensity import KernelDensity -__all__ = ["Statistics", "MultivariateStatisticalSummary", "ChiSqTestResult", +__all__ = ["Statistics", "MultivariateStatisticalSummary", + "ChiSqTestResult", "KolmogorovSmirnovTestResult", "MultivariateGaussian", "KernelDensity"] diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 43454ba518..a4b45cf55f 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -65,11 +65,19 @@ def colStats(rdd): """ Computes column-wise summary statistics for the input RDD[Vector]. - :param rdd: an RDD[Vector] for which column-wise summary statistics - are to be computed. - :return: :class:`MultivariateStatisticalSummary` object containing - column-wise summary statistics. - + Parameters + ---------- + rdd : :py:class:`pyspark.RDD` + an RDD[Vector] for which column-wise summary statistics + are to be computed. + + Returns + ------- + :class:`MultivariateStatisticalSummary` + object containing column-wise summary statistics. + + Examples + -------- >>> from pyspark.mllib.linalg import Vectors >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]), ... Vectors.dense([4, 5, 0, 3]), @@ -103,13 +111,24 @@ def corr(x, y=None, method=None): to specify the method to be used for single RDD inout. If two RDDs of floats are passed in, a single float is returned. - :param x: an RDD of vector for which the correlation matrix is to be computed, - or an RDD of float of the same cardinality as y when y is specified. - :param y: an RDD of float of the same cardinality as x. - :param method: String specifying the method to use for computing correlation. - Supported: `pearson` (default), `spearman` - :return: Correlation matrix comparing columns in x. - + Parameters + ---------- + x : :py:class:`pyspark.RDD` + an RDD of vector for which the correlation matrix is to be computed, + or an RDD of float of the same cardinality as y when y is specified. + y : :py:class:`pyspark.RDD`, optional + an RDD of float of the same cardinality as x. + method : str, optional + String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman` + + Returns + ------- + :py:class:`pyspark.mllib.linalg.Matrix` + Correlation matrix comparing columns in x. + + Examples + -------- >>> x = sc.parallelize([1.0, 0.0, -2.0], 2) >>> y = sc.parallelize([4.0, 5.0, 3.0], 2) >>> zeros = sc.parallelize([0.0, 0.0, 0.0], 2) @@ -172,20 +191,33 @@ def chiSqTest(observed, expected=None): contingency matrix for which the chi-squared statistic is computed. All label and feature values must be categorical. - .. note:: `observed` cannot contain negative values - - :param observed: it could be a vector containing the observed categorical - counts/relative frequencies, or the contingency matrix - (containing either counts or relative frequencies), - or an RDD of LabeledPoint containing the labeled dataset - with categorical features. Real-valued features will be - treated as categorical for each distinct value. - :param expected: Vector containing the expected categorical counts/relative - frequencies. `expected` is rescaled if the `expected` sum - differs from the `observed` sum. - :return: ChiSquaredTest object containing the test statistic, degrees - of freedom, p-value, the method used, and the null hypothesis. - + Parameters + ---------- + observed : :py:class:`pyspark.mllib.linalg.Vector` or \ + :py:class:`pyspark.mllib.linalg.Matrix` + it could be a vector containing the observed categorical + counts/relative frequencies, or the contingency matrix + (containing either counts or relative frequencies), + or an RDD of LabeledPoint containing the labeled dataset + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + expected : :py:class:`pyspark.mllib.linalg.Vector` + Vector containing the expected categorical counts/relative + frequencies. `expected` is rescaled if the `expected` sum + differs from the `observed` sum. + + Returns + ------- + :py:class:`pyspark.mllib.stat.ChiSqTestResult` + object containing the test statistic, degrees + of freedom, p-value, the method used, and the null hypothesis. + + Notes + ----- + `observed` cannot contain negative values + + Examples + -------- >>> from pyspark.mllib.linalg import Vectors, Matrices >>> observed = Vectors.dense([4, 6, 5]) >>> pearson = Statistics.chiSqTest(observed) @@ -259,17 +291,28 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): For specific details of the implementation, please have a look at the Scala documentation. - :param data: RDD, samples from the data - :param distName: string, currently only "norm" is supported. - (Normal distribution) to calculate the - theoretical distribution of the data. - :param params: additional values which need to be provided for - a certain distribution. - If not provided, the default values are used. - :return: KolmogorovSmirnovTestResult object containing the test - statistic, degrees of freedom, p-value, - the method used, and the null hypothesis. + Parameters + ---------- + data : :py:class:`pyspark.RDD` + RDD, samples from the data + distName : str, optional + string, currently only "norm" is supported. + (Normal distribution) to calculate the + theoretical distribution of the data. + params + additional values which need to be provided for + a certain distribution. + If not provided, the default values are used. + + Returns + ------- + :py:class:`pyspark.mllib.stat.KolmogorovSmirnovTestResult` + object containing the test statistic, degrees of freedom, p-value, + the method used, and the null hypothesis. + + Examples + -------- >>> kstest = Statistics.kolmogorovSmirnovTest >>> data = sc.parallelize([-1.0, 0.0, 1.0]) >>> ksmodel = kstest(data, "norm") diff --git a/python/pyspark/mllib/stat/_statistics.pyi b/python/pyspark/mllib/stat/_statistics.pyi index 4d2701d486..3834d51639 100644 --- a/python/pyspark/mllib/stat/_statistics.pyi +++ b/python/pyspark/mllib/stat/_statistics.pyi @@ -65,5 +65,5 @@ class Statistics: def chiSqTest(observed: RDD[LabeledPoint]) -> List[ChiSqTestResult]: ... @staticmethod def kolmogorovSmirnovTest( - data, distName: Literal["norm"] = ..., *params: float + data: RDD[float], distName: Literal["norm"] = ..., *params: float ) -> KolmogorovSmirnovTestResult: ... diff --git a/python/pyspark/mllib/stat/distribution.py b/python/pyspark/mllib/stat/distribution.py index 46f7a1d2f2..aa35ac6dfd 100644 --- a/python/pyspark/mllib/stat/distribution.py +++ b/python/pyspark/mllib/stat/distribution.py @@ -24,6 +24,8 @@ class MultivariateGaussian(namedtuple('MultivariateGaussian', ['mu', 'sigma'])): """Represents a (mu, sigma) tuple + Examples + -------- >>> m = MultivariateGaussian(Vectors.dense([11,12]),DenseMatrix(2, 2, (1.0, 3.0, 5.0, 2.0))) >>> (m.mu, m.sigma.toArray()) (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]])) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index e05dfdb953..493dcf8db6 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -33,15 +33,18 @@ class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): .. versionadded:: 1.3.0 """ - @since("1.3.0") def predict(self, x): """ Predict values for a single data point or an RDD of points using the model trained. - .. note:: In Python, predict cannot currently be used within an RDD - transformation or action. - Call predict directly on the RDD instead. + .. versionadded:: 1.3.0 + + Notes + ----- + In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. """ if isinstance(x, RDD): return self.call("predict", x.map(_convert_to_vector)) @@ -79,18 +82,23 @@ class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): .. versionadded:: 1.1.0 """ - @since("1.1.0") def predict(self, x): """ Predict the label of one or more examples. - .. note:: In Python, predict cannot currently be used within an RDD - transformation or action. - Call predict directly on the RDD instead. + .. versionadded:: 1.1.0 + + Parameters + ---------- + x : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD` + Data point (feature vector), or an RDD of data points (feature + vectors). - :param x: - Data point (feature vector), or an RDD of data points (feature - vectors). + Notes + ----- + In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. """ if isinstance(x, RDD): return self.call("predict", x.map(_convert_to_vector)) @@ -143,45 +151,50 @@ def _train(cls, data, type, numClasses, features, impurity="gini", maxDepth=5, m return DecisionTreeModel(model) @classmethod - @since("1.1.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): """ Train a decision tree model for classification. - :param data: - Training data: RDD of LabeledPoint. Labels should take values - {0, 1, ..., numClasses-1}. - :param numClasses: - Number of classes for classification. - :param categoricalFeaturesInfo: - Map storing arity of categorical features. An entry (n -> k) - indicates that feature n is categorical with k categories - indexed from 0: {0, 1, ..., k-1}. - :param impurity: - Criterion used for information gain calculation. - Supported values: "gini" or "entropy". - (default: "gini") - :param maxDepth: - Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 - means 1 internal node + 2 leaf nodes). - (default: 5) - :param maxBins: - Number of bins used for finding splits at each node. - (default: 32) - :param minInstancesPerNode: - Minimum number of instances required at child nodes to create - the parent split. - (default: 1) - :param minInfoGain: - Minimum info gain required to create a split. - (default: 0.0) - :return: - DecisionTreeModel. - - Example usage: - + .. versionadded:: 1.1.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + Training data: RDD of LabeledPoint. Labels should take values + {0, 1, ..., numClasses-1}. + numClasses : int + Number of classes for classification. + categoricalFeaturesInfo : dict + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + impurity : str, optional + Criterion used for information gain calculation. + Supported values: "gini" or "entropy". + (default: "gini") + maxDepth : int, optional + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 5) + maxBins : int, optional + Number of bins used for finding splits at each node. + (default: 32) + minInstancesPerNode : int, optional + Minimum number of instances required at child nodes to create + the parent split. + (default: 1) + minInfoGain : float, optional + Minimum info gain required to create a split. + (default: 0.0) + + Returns + ------- + :py:class:`DecisionTreeModel` + + Examples + -------- >>> from numpy import array >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree @@ -222,35 +235,39 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, """ Train a decision tree model for regression. - :param data: - Training data: RDD of LabeledPoint. Labels are real numbers. - :param categoricalFeaturesInfo: - Map storing arity of categorical features. An entry (n -> k) - indicates that feature n is categorical with k categories - indexed from 0: {0, 1, ..., k-1}. - :param impurity: - Criterion used for information gain calculation. - The only supported value for regression is "variance". - (default: "variance") - :param maxDepth: - Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 - means 1 internal node + 2 leaf nodes). - (default: 5) - :param maxBins: - Number of bins used for finding splits at each node. - (default: 32) - :param minInstancesPerNode: - Minimum number of instances required at child nodes to create - the parent split. - (default: 1) - :param minInfoGain: - Minimum info gain required to create a split. - (default: 0.0) - :return: - DecisionTreeModel. - - Example usage: - + Parameters + ---------- + data : :py:class:`pyspark.RDD` + Training data: RDD of LabeledPoint. Labels are real numbers. + categoricalFeaturesInfo : dict + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + impurity : str, optional + Criterion used for information gain calculation. + The only supported value for regression is "variance". + (default: "variance") + maxDepth : int, optional + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 5) + maxBins : int, optional + Number of bins used for finding splits at each node. + (default: 32) + minInstancesPerNode : int, optional + Minimum number of instances required at child nodes to create + the parent split. + (default: 1) + minInfoGain : float, optional + Minimum info gain required to create a split. + (default: 0.0) + + Returns + ------- + :py:class:`DecisionTreeModel` + + Examples + -------- >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree >>> from pyspark.mllib.linalg import SparseVector @@ -313,7 +330,6 @@ def _train(cls, data, algo, numClasses, categoricalFeaturesInfo, numTrees, return RandomForestModel(model) @classmethod - @since("1.2.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32, seed=None): @@ -321,44 +337,51 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, Train a random forest model for binary or multiclass classification. - :param data: - Training dataset: RDD of LabeledPoint. Labels should take values - {0, 1, ..., numClasses-1}. - :param numClasses: - Number of classes for classification. - :param categoricalFeaturesInfo: - Map storing arity of categorical features. An entry (n -> k) - indicates that feature n is categorical with k categories - indexed from 0: {0, 1, ..., k-1}. - :param numTrees: - Number of trees in the random forest. - :param featureSubsetStrategy: - Number of features to consider for splits at each node. - Supported values: "auto", "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "sqrt". - (default: "auto") - :param impurity: - Criterion used for information gain calculation. - Supported values: "gini" or "entropy". - (default: "gini") - :param maxDepth: - Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 - means 1 internal node + 2 leaf nodes). - (default: 4) - :param maxBins: - Maximum number of bins used for splitting features. - (default: 32) - :param seed: - Random seed for bootstrapping and choosing feature subsets. - Set as None to generate seed based on system time. - (default: None) - :return: - RandomForestModel that can be used for prediction. - - Example usage: - + .. versionadded:: 1.2.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + Training dataset: RDD of LabeledPoint. Labels should take values + {0, 1, ..., numClasses-1}. + numClasses : int + Number of classes for classification. + categoricalFeaturesInfo : dict + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + numTrees : int + Number of trees in the random forest. + featureSubsetStrategy : str, optional + Number of features to consider for splits at each node. + Supported values: "auto", "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". + (default: "auto") + impurity : str, optional + Criterion used for information gain calculation. + Supported values: "gini" or "entropy". + (default: "gini") + maxDepth : int, optional + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 4) + maxBins : int, optional + Maximum number of bins used for splitting features. + (default: 32) + seed : int, Optional + Random seed for bootstrapping and choosing feature subsets. + Set as None to generate seed based on system time. + (default: None) + + Returns + ------- + :py:class:`RandomForestModel` + that can be used for prediction. + + Examples + -------- >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import RandomForest >>> @@ -405,47 +428,55 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, maxDepth, maxBins, seed) @classmethod - @since("1.2.0") def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="variance", maxDepth=4, maxBins=32, seed=None): """ Train a random forest model for regression. - :param data: - Training dataset: RDD of LabeledPoint. Labels are real numbers. - :param categoricalFeaturesInfo: - Map storing arity of categorical features. An entry (n -> k) - indicates that feature n is categorical with k categories - indexed from 0: {0, 1, ..., k-1}. - :param numTrees: - Number of trees in the random forest. - :param featureSubsetStrategy: - Number of features to consider for splits at each node. - Supported values: "auto", "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "onethird" for regression. - (default: "auto") - :param impurity: - Criterion used for information gain calculation. - The only supported value for regression is "variance". - (default: "variance") - :param maxDepth: - Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 - means 1 internal node + 2 leaf nodes). - (default: 4) - :param maxBins: - Maximum number of bins used for splitting features. - (default: 32) - :param seed: - Random seed for bootstrapping and choosing feature subsets. - Set as None to generate seed based on system time. - (default: None) - :return: - RandomForestModel that can be used for prediction. - - Example usage: - + .. versionadded:: 1.2.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + Training dataset: RDD of LabeledPoint. Labels are real numbers. + categoricalFeaturesInfo : dict + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + numTrees : int + Number of trees in the random forest. + featureSubsetStrategy : str, optional + Number of features to consider for splits at each node. + Supported values: "auto", "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + + - if numTrees == 1, set to "all"; + - if numTrees > 1 (forest) set to "onethird" for regression. + + (default: "auto") + impurity : str, optional + Criterion used for information gain calculation. + The only supported value for regression is "variance". + (default: "variance") + maxDepth : int, optional + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 4) + maxBins : int, optional + Maximum number of bins used for splitting features. + (default: 32) + seed : int, optional + Random seed for bootstrapping and choosing feature subsets. + Set as None to generate seed based on system time. + (default: None) + + Returns + ------- + :py:class:`RandomForestModel` + that can be used for prediction. + + Examples + -------- >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import RandomForest >>> from pyspark.mllib.linalg import SparseVector @@ -505,45 +536,51 @@ def _train(cls, data, algo, categoricalFeaturesInfo, return GradientBoostedTreesModel(model) @classmethod - @since("1.3.0") def trainClassifier(cls, data, categoricalFeaturesInfo, loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): """ Train a gradient-boosted trees model for classification. - :param data: - Training dataset: RDD of LabeledPoint. Labels should take values - {0, 1}. - :param categoricalFeaturesInfo: - Map storing arity of categorical features. An entry (n -> k) - indicates that feature n is categorical with k categories - indexed from 0: {0, 1, ..., k-1}. - :param loss: - Loss function used for minimization during gradient boosting. - Supported values: "logLoss", "leastSquaresError", - "leastAbsoluteError". - (default: "logLoss") - :param numIterations: - Number of iterations of boosting. - (default: 100) - :param learningRate: - Learning rate for shrinking the contribution of each estimator. - The learning rate should be between in the interval (0, 1]. - (default: 0.1) - :param maxDepth: - Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 - means 1 internal node + 2 leaf nodes). - (default: 3) - :param maxBins: - Maximum number of bins used for splitting features. DecisionTree - requires maxBins >= max categories. - (default: 32) - :return: - GradientBoostedTreesModel that can be used for prediction. - - Example usage: - + .. versionadded:: 1.3.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + Training dataset: RDD of LabeledPoint. Labels should take values + {0, 1}. + categoricalFeaturesInfo : dict + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + loss : str, optional + Loss function used for minimization during gradient boosting. + Supported values: "logLoss", "leastSquaresError", + "leastAbsoluteError". + (default: "logLoss") + numIterations : int, optional + Number of iterations of boosting. + (default: 100) + learningRate : float, optional + Learning rate for shrinking the contribution of each estimator. + The learning rate should be between in the interval (0, 1]. + (default: 0.1) + maxDepth : int, optional + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 3) + maxBins : int, optional + Maximum number of bins used for splitting features. DecisionTree + requires maxBins >= max categories. + (default: 32) + + Returns + ------- + :py:class:`GradientBoostedTreesModel` + that can be used for prediction. + + Examples + -------- >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import GradientBoostedTrees >>> @@ -574,44 +611,50 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, loss, numIterations, learningRate, maxDepth, maxBins) @classmethod - @since("1.3.0") def trainRegressor(cls, data, categoricalFeaturesInfo, loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): """ Train a gradient-boosted trees model for regression. - :param data: - Training dataset: RDD of LabeledPoint. Labels are real numbers. - :param categoricalFeaturesInfo: - Map storing arity of categorical features. An entry (n -> k) - indicates that feature n is categorical with k categories - indexed from 0: {0, 1, ..., k-1}. - :param loss: - Loss function used for minimization during gradient boosting. - Supported values: "logLoss", "leastSquaresError", - "leastAbsoluteError". - (default: "leastSquaresError") - :param numIterations: - Number of iterations of boosting. - (default: 100) - :param learningRate: - Learning rate for shrinking the contribution of each estimator. - The learning rate should be between in the interval (0, 1]. - (default: 0.1) - :param maxDepth: - Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 - means 1 internal node + 2 leaf nodes). - (default: 3) - :param maxBins: - Maximum number of bins used for splitting features. DecisionTree - requires maxBins >= max categories. - (default: 32) - :return: - GradientBoostedTreesModel that can be used for prediction. - - Example usage: - + .. versionadded:: 1.3.0 + + Parameters + ---------- + data : + Training dataset: RDD of LabeledPoint. Labels are real numbers. + categoricalFeaturesInfo : dict + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + loss : str, optional + Loss function used for minimization during gradient boosting. + Supported values: "logLoss", "leastSquaresError", + "leastAbsoluteError". + (default: "leastSquaresError") + numIterations : int, optional + Number of iterations of boosting. + (default: 100) + learningRate : float, optional + Learning rate for shrinking the contribution of each estimator. + The learning rate should be between in the interval (0, 1]. + (default: 0.1) + maxDepth : int, optional + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 3) + maxBins : int, optional + Maximum number of bins used for splitting features. DecisionTree + requires maxBins >= max categories. + (default: 32) + + Returns + ------- + :py:class:`GradientBoostedTreesModel` + that can be used for prediction. + + Examples + -------- >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import GradientBoostedTrees >>> from pyspark.mllib.linalg import SparseVector diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index a0be29a82e..68feb95638 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -65,7 +65,6 @@ def _convert_labeled_point_to_libsvm(p): return " ".join(items) @staticmethod - @since("1.0.0") def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): """ Loads labeled data in the LIBSVM format into an RDD of @@ -79,20 +78,33 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): method parses each line into a LabeledPoint, where the feature indices are converted to zero-based. - :param sc: Spark context - :param path: file or directory path in any Hadoop-supported file - system URI - :param numFeatures: number of features, which will be determined - from the input data if a nonpositive value - is given. This is useful when the dataset is - already split into multiple files and you - want to load them separately, because some - features may not present in certain files, - which leads to inconsistent feature - dimensions. - :param minPartitions: min number of partitions - :return: labeled data stored as an RDD of LabeledPoint - + .. versionadded:: 1.0.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + Spark context + path : str + file or directory path in any Hadoop-supported file system URI + numFeatures : int, optional + number of features, which will be determined + from the input data if a nonpositive value + is given. This is useful when the dataset is + already split into multiple files and you + want to load them separately, because some + features may not present in certain files, + which leads to inconsistent feature + dimensions. + minPartitions : int, optional + min number of partitions + + Returns + ------- + :py:class:`pyspark.RDD` + labeled data stored as an RDD of LabeledPoint + + Examples + -------- >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils >>> from pyspark.mllib.regression import LabeledPoint @@ -118,14 +130,21 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) @staticmethod - @since("1.0.0") def saveAsLibSVMFile(data, dir): """ Save labeled data in LIBSVM format. - :param data: an RDD of LabeledPoint to be saved - :param dir: directory to save the data + .. versionadded:: 1.0.0 + + Parameters + ---------- + data : :py:class:`pyspark.RDD` + an RDD of LabeledPoint to be saved + dir : str + directory to save the data + Examples + -------- >>> from tempfile import NamedTemporaryFile >>> from fileinput import input >>> from pyspark.mllib.regression import LabeledPoint @@ -143,17 +162,28 @@ def saveAsLibSVMFile(data, dir): lines.saveAsTextFile(dir) @staticmethod - @since("1.1.0") def loadLabeledPoints(sc, path, minPartitions=None): """ Load labeled points saved using RDD.saveAsTextFile. - :param sc: Spark context - :param path: file or directory path in any Hadoop-supported file - system URI - :param minPartitions: min number of partitions - :return: labeled data stored as an RDD of LabeledPoint + .. versionadded:: 1.0.0 + + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + Spark context + path : str + file or directory path in any Hadoop-supported file system URI + minPartitions : int, optional + min number of partitions + Returns + ------- + :py:class:`pyspark.RDD` + labeled data stored as an RDD of LabeledPoint + + Examples + -------- >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils >>> from pyspark.mllib.regression import LabeledPoint @@ -193,7 +223,6 @@ def loadVectors(sc, path): return callMLlibFunc("loadVectors", sc, path) @staticmethod - @since("2.0.0") def convertVectorColumnsToML(dataset, *cols): """ Converts vector columns in an input DataFrame from the @@ -201,16 +230,26 @@ def convertVectorColumnsToML(dataset, *cols): :py:class:`pyspark.ml.linalg.Vector` type under the `spark.ml` package. - :param dataset: - input dataset - :param cols: - a list of vector columns to be converted. - New vector columns will be ignored. If unspecified, all old - vector columns will be converted excepted nested ones. - :return: - the input dataset with old vector columns converted to the - new vector type + .. versionadded:: 2.0.0 + + Parameters + ---------- + dataset : :py:class:`pyspark.sql.DataFrame` + input dataset + \\*cols : str + Vector columns to be converted. + New vector columns will be ignored. If unspecified, all old + vector columns will be converted excepted nested ones. + + Returns + ------- + :py:class:`pyspark.sql.DataFrame` + the input dataset with old vector columns converted to the + new vector type + + Examples + -------- >>> import pyspark >>> from pyspark.mllib.linalg import Vectors >>> from pyspark.mllib.util import MLUtils @@ -233,7 +272,6 @@ def convertVectorColumnsToML(dataset, *cols): return callMLlibFunc("convertVectorColumnsToML", dataset, list(cols)) @staticmethod - @since("2.0.0") def convertVectorColumnsFromML(dataset, *cols): """ Converts vector columns in an input DataFrame to the @@ -241,16 +279,26 @@ def convertVectorColumnsFromML(dataset, *cols): :py:class:`pyspark.ml.linalg.Vector` type under the `spark.ml` package. - :param dataset: - input dataset - :param cols: - a list of vector columns to be converted. - Old vector columns will be ignored. If unspecified, all new - vector columns will be converted except nested ones. - :return: - the input dataset with new vector columns converted to the - old vector type + .. versionadded:: 2.0.0 + + Parameters + ---------- + dataset : :py:class:`pyspark.sql.DataFrame` + input dataset + \\*cols : str + Vector columns to be converted. + + Old vector columns will be ignored. If unspecified, all new + vector columns will be converted except nested ones. + + Returns + ------- + :py:class:`pyspark.sql.DataFrame` + the input dataset with new vector columns converted to the + old vector type + Examples + -------- >>> import pyspark >>> from pyspark.ml.linalg import Vectors >>> from pyspark.mllib.util import MLUtils @@ -273,7 +321,6 @@ def convertVectorColumnsFromML(dataset, *cols): return callMLlibFunc("convertVectorColumnsFromML", dataset, list(cols)) @staticmethod - @since("2.0.0") def convertMatrixColumnsToML(dataset, *cols): """ Converts matrix columns in an input DataFrame from the @@ -281,16 +328,26 @@ def convertMatrixColumnsToML(dataset, *cols): :py:class:`pyspark.ml.linalg.Matrix` type under the `spark.ml` package. - :param dataset: - input dataset - :param cols: - a list of matrix columns to be converted. - New matrix columns will be ignored. If unspecified, all old - matrix columns will be converted excepted nested ones. - :return: - the input dataset with old matrix columns converted to the - new matrix type + .. versionadded:: 2.0.0 + Parameters + ---------- + dataset : :py:class:`pyspark.sql.DataFrame` + input dataset + \\*cols : str + Matrix columns to be converted. + + New matrix columns will be ignored. If unspecified, all old + matrix columns will be converted excepted nested ones. + + Returns + ------- + :py:class:`pyspark.sql.DataFrame` + the input dataset with old matrix columns converted to the + new matrix type + + Examples + -------- >>> import pyspark >>> from pyspark.mllib.linalg import Matrices >>> from pyspark.mllib.util import MLUtils @@ -313,7 +370,6 @@ def convertMatrixColumnsToML(dataset, *cols): return callMLlibFunc("convertMatrixColumnsToML", dataset, list(cols)) @staticmethod - @since("2.0.0") def convertMatrixColumnsFromML(dataset, *cols): """ Converts matrix columns in an input DataFrame to the @@ -321,16 +377,26 @@ def convertMatrixColumnsFromML(dataset, *cols): :py:class:`pyspark.ml.linalg.Matrix` type under the `spark.ml` package. - :param dataset: - input dataset - :param cols: - a list of matrix columns to be converted. - Old matrix columns will be ignored. If unspecified, all new - matrix columns will be converted except nested ones. - :return: - the input dataset with new matrix columns converted to the - old matrix type + .. versionadded:: 2.0.0 + + Parameters + ---------- + dataset : :py:class:`pyspark.sql.DataFrame` + input dataset + \\*cols : str + Matrix columns to be converted. + + Old matrix columns will be ignored. If unspecified, all new + matrix columns will be converted except nested ones. + Returns + ------- + :py:class:`pyspark.sql.DataFrame` + the input dataset with new matrix columns converted to the + old matrix type + + Examples + -------- >>> import pyspark >>> from pyspark.ml.linalg import Matrices >>> from pyspark.mllib.util import MLUtils @@ -370,10 +436,14 @@ def save(self, sc, path): The model may be loaded using :py:meth:`Loader.load`. - :param sc: Spark context used to save model data. - :param path: Path specifying the directory in which to save - this model. If the directory already exists, - this method throws an exception. + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + Spark context used to save model data. + path : str + Path specifying the directory in which to save + this model. If the directory already exists, + this method throws an exception. """ raise NotImplementedError @@ -410,10 +480,17 @@ def load(cls, sc, path): Load a model from the given path. The model should have been saved using :py:meth:`Saveable.save`. - :param sc: Spark context used for loading model files. - :param path: Path specifying the directory to which the model - was saved. - :return: model instance + Parameters + ---------- + sc : :py:class:`pyspark.SparkContext` + Spark context used for loading model files. + path : str + Path specifying the directory to which the model was saved. + + Returns + ------- + object + model instance """ raise NotImplementedError @@ -463,20 +540,33 @@ class LinearDataGenerator(object): """ @staticmethod - @since("1.5.0") def generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps): """ - :param: intercept bias factor, the term c in X'w + c - :param: weights feature vector, the term w in X'w + c - :param: xMean Point around which the data X is centered. - :param: xVariance Variance of the given data - :param: nPoints Number of points to be generated - :param: seed Random Seed - :param: eps Used to scale the noise. If eps is set high, - the amount of gaussian noise added is more. - - Returns a list of LabeledPoints of length nPoints + .. versionadded:: 1.5.0 + + Parameters + ---------- + intercept : float + bias factor, the term c in X'w + c + weights : :py:class:`pyspark.mllib.linalg.Vector` or convertible + feature vector, the term w in X'w + c + xMean : :py:class:`pyspark.mllib.linalg.Vector` or convertible + Point around which the data X is centered. + xVariance : :py:class:`pyspark.mllib.linalg.Vector` or convertible + Variance of the given data + nPoints : int + Number of points to be generated + seed : int + Random Seed + eps : float + Used to scale the noise. If eps is set high, + the amount of gaussian noise added is more. + + Returns + ------- + list + of :py:class:`pyspark.mllib.regression.LabeledPoints` of length nPoints """ weights = [float(weight) for weight in weights] xMean = [float(mean) for mean in xMean] diff --git a/python/pyspark/rdd.pyi b/python/pyspark/rdd.pyi index 35c49e952b..a277cd9f7e 100644 --- a/python/pyspark/rdd.pyi +++ b/python/pyspark/rdd.pyi @@ -85,12 +85,16 @@ class PythonEvalType: SQL_COGROUPED_MAP_PANDAS_UDF: PandasCogroupedMapUDFType class BoundedFloat(float): - def __new__(cls, mean: float, confidence: float, low: float, high: float): ... + def __new__( + cls, mean: float, confidence: float, low: float, high: float + ) -> BoundedFloat: ... class Partitioner: numPartitions: int partitionFunc: Callable[[Any], int] - def __init__(self, numPartitions, partitionFunc) -> None: ... + def __init__( + self, numPartitions: int, partitionFunc: Callable[[Any], int] + ) -> None: ... def __eq__(self, other: Any) -> bool: ... def __call__(self, k: Any) -> int: ... diff --git a/python/pyspark/resource/profile.pyi b/python/pyspark/resource/profile.pyi index 6763baf659..0483869243 100644 --- a/python/pyspark/resource/profile.pyi +++ b/python/pyspark/resource/profile.pyi @@ -49,7 +49,7 @@ class ResourceProfileBuilder: def __init__(self) -> None: ... def require( self, resourceRequest: Union[ExecutorResourceRequest, TaskResourceRequests] - ): ... + ) -> ResourceProfileBuilder: ... def clearExecutorResourceRequests(self) -> None: ... def clearTaskResourceRequests(self) -> None: ... @property diff --git a/python/pyspark/sql/column.pyi b/python/pyspark/sql/column.pyi index 0fbb10053f..1f63e65b3d 100644 --- a/python/pyspark/sql/column.pyi +++ b/python/pyspark/sql/column.pyi @@ -32,7 +32,7 @@ from pyspark.sql.window import WindowSpec from py4j.java_gateway import JavaObject # type: ignore[import] class Column: - def __init__(self, JavaObject) -> None: ... + def __init__(self, jc: JavaObject) -> None: ... def __neg__(self) -> Column: ... def __add__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... def __sub__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... @@ -105,7 +105,11 @@ class Column: def name(self, *alias: str) -> Column: ... def cast(self, dataType: Union[DataType, str]) -> Column: ... def astype(self, dataType: Union[DataType, str]) -> Column: ... - def between(self, lowerBound, upperBound) -> Column: ... + def between( + self, + lowerBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral], + upperBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral], + ) -> Column: ... def when(self, condition: Column, value: Any) -> Column: ... def otherwise(self, value: Any) -> Column: ... def over(self, window: WindowSpec) -> Column: ... diff --git a/python/pyspark/sql/context.pyi b/python/pyspark/sql/context.pyi index 64927b37ac..915a0fe1f6 100644 --- a/python/pyspark/sql/context.pyi +++ b/python/pyspark/sql/context.pyi @@ -43,14 +43,14 @@ class SQLContext: sparkSession: SparkSession def __init__( self, - sparkContext, + sparkContext: SparkContext, sparkSession: Optional[SparkSession] = ..., jsqlContext: Optional[JavaObject] = ..., ) -> None: ... @classmethod def getOrCreate(cls: type, sc: SparkContext) -> SQLContext: ... def newSession(self) -> SQLContext: ... - def setConf(self, key: str, value) -> None: ... + def setConf(self, key: str, value: Union[bool, int, str]) -> None: ... def getConf(self, key: str, defaultValue: Optional[str] = ...) -> str: ... @property def udf(self) -> UDFRegistration: ... @@ -116,7 +116,7 @@ class SQLContext: path: Optional[str] = ..., source: Optional[str] = ..., schema: Optional[StructType] = ..., - **options + **options: str ) -> DataFrame: ... def sql(self, sqlQuery: str) -> DataFrame: ... def table(self, tableName: str) -> DataFrame: ... diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi index 281c1d7543..252f883b5f 100644 --- a/python/pyspark/sql/functions.pyi +++ b/python/pyspark/sql/functions.pyi @@ -65,13 +65,13 @@ def round(col: ColumnOrName, scale: int = ...) -> Column: ... def bround(col: ColumnOrName, scale: int = ...) -> Column: ... def shiftLeft(col: ColumnOrName, numBits: int) -> Column: ... def shiftRight(col: ColumnOrName, numBits: int) -> Column: ... -def shiftRightUnsigned(col, numBits) -> Column: ... +def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column: ... def spark_partition_id() -> Column: ... def expr(str: str) -> Column: ... def struct(*cols: ColumnOrName) -> Column: ... def greatest(*cols: ColumnOrName) -> Column: ... def least(*cols: Column) -> Column: ... -def when(condition: Column, value) -> Column: ... +def when(condition: Column, value: Any) -> Column: ... @overload def log(arg1: ColumnOrName) -> Column: ... @overload @@ -174,7 +174,9 @@ def create_map(*cols: ColumnOrName) -> Column: ... def array(*cols: ColumnOrName) -> Column: ... def array_contains(col: ColumnOrName, value: Any) -> Column: ... def arrays_overlap(a1: ColumnOrName, a2: ColumnOrName) -> Column: ... -def slice(x: ColumnOrName, start: Union[Column, int], length: Union[Column, int]) -> Column: ... +def slice( + x: ColumnOrName, start: Union[Column, int], length: Union[Column, int] +) -> Column: ... def array_join( col: ColumnOrName, delimiter: str, null_replacement: Optional[str] = ... ) -> Column: ... diff --git a/python/pyspark/sql/session.pyi b/python/pyspark/sql/session.pyi index 17ba8894c1..6cd2d3bed2 100644 --- a/python/pyspark/sql/session.pyi +++ b/python/pyspark/sql/session.pyi @@ -17,7 +17,8 @@ # under the License. from typing import overload -from typing import Any, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar, Union +from types import TracebackType from py4j.java_gateway import JavaObject # type: ignore[import] @@ -122,4 +123,9 @@ class SparkSession(SparkConversionMixin): def streams(self) -> StreamingQueryManager: ... def stop(self) -> None: ... def __enter__(self) -> SparkSession: ... - def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: ... diff --git a/python/pyspark/sql/types.pyi b/python/pyspark/sql/types.pyi index 31765e9488..3adf823d99 100644 --- a/python/pyspark/sql/types.pyi +++ b/python/pyspark/sql/types.pyi @@ -17,7 +17,8 @@ # under the License. from typing import overload -from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, TypeVar +from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, Type, TypeVar +from py4j.java_gateway import JavaGateway, JavaObject import datetime T = TypeVar("T") @@ -37,7 +38,7 @@ class DataType: def fromInternal(self, obj: Any) -> Any: ... class DataTypeSingleton(type): - def __call__(cls): ... + def __call__(cls: Type[T]) -> T: ... # type: ignore class NullType(DataType, metaclass=DataTypeSingleton): ... class AtomicType(DataType): ... @@ -85,8 +86,8 @@ class ShortType(IntegralType): class ArrayType(DataType): elementType: DataType containsNull: bool - def __init__(self, elementType=DataType, containsNull: bool = ...) -> None: ... - def simpleString(self): ... + def __init__(self, elementType: DataType, containsNull: bool = ...) -> None: ... + def simpleString(self) -> str: ... def jsonValue(self) -> Dict[str, Any]: ... @classmethod def fromJson(cls, json: Dict[str, Any]) -> ArrayType: ... @@ -197,8 +198,8 @@ class Row(tuple): class DateConverter: def can_convert(self, obj: Any) -> bool: ... - def convert(self, obj, gateway_client) -> Any: ... + def convert(self, obj: datetime.date, gateway_client: JavaGateway) -> JavaObject: ... class DatetimeConverter: - def can_convert(self, obj) -> bool: ... - def convert(self, obj, gateway_client) -> Any: ... + def can_convert(self, obj: Any) -> bool: ... + def convert(self, obj: datetime.datetime, gateway_client: JavaGateway) -> JavaObject: ... diff --git a/python/pyspark/sql/udf.pyi b/python/pyspark/sql/udf.pyi index 87c3672780..ea61397a67 100644 --- a/python/pyspark/sql/udf.pyi +++ b/python/pyspark/sql/udf.pyi @@ -18,8 +18,9 @@ from typing import Any, Callable, Optional -from pyspark.sql._typing import ColumnOrName, DataTypeOrString +from pyspark.sql._typing import ColumnOrName, DataTypeOrString, UserDefinedFunctionLike from pyspark.sql.column import Column +from pyspark.sql.types import DataType import pyspark.sql.session class UserDefinedFunction: @@ -35,7 +36,7 @@ class UserDefinedFunction: deterministic: bool = ..., ) -> None: ... @property - def returnType(self): ... + def returnType(self) -> DataType: ... def __call__(self, *cols: ColumnOrName) -> Column: ... def asNondeterministic(self) -> UserDefinedFunction: ... @@ -47,7 +48,7 @@ class UDFRegistration: name: str, f: Callable[..., Any], returnType: Optional[DataTypeOrString] = ..., - ): ... + ) -> UserDefinedFunctionLike: ... def registerJavaFunction( self, name: str, diff --git a/python/pyspark/streaming/context.pyi b/python/pyspark/streaming/context.pyi index 026163fc9a..117a6742e6 100644 --- a/python/pyspark/streaming/context.pyi +++ b/python/pyspark/streaming/context.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, List, Optional, TypeVar from py4j.java_gateway import JavaObject # type: ignore[import] diff --git a/python/pyspark/streaming/dstream.pyi b/python/pyspark/streaming/dstream.pyi index 7b76ce4c65..1521d838fc 100644 --- a/python/pyspark/streaming/dstream.pyi +++ b/python/pyspark/streaming/dstream.pyi @@ -30,9 +30,12 @@ from typing import ( ) import datetime from pyspark.rdd import RDD +import pyspark.serializers from pyspark.storagelevel import StorageLevel import pyspark.streaming.context +from py4j.java_gateway import JavaObject + S = TypeVar("S") T = TypeVar("T") U = TypeVar("U") @@ -42,7 +45,12 @@ V = TypeVar("V") class DStream(Generic[T]): is_cached: bool is_checkpointed: bool - def __init__(self, jdstream, ssc, jrdd_deserializer) -> None: ... + def __init__( + self, + jdstream: JavaObject, + ssc: pyspark.streaming.context.StreamingContext, + jrdd_deserializer: pyspark.serializers.Serializer, + ) -> None: ... def context(self) -> pyspark.streaming.context.StreamingContext: ... def count(self) -> DStream[int]: ... def filter(self, f: Callable[[T], bool]) -> DStream[T]: ... diff --git a/python/pyspark/streaming/kinesis.pyi b/python/pyspark/streaming/kinesis.pyi index af7cd6f6ec..399c37f869 100644 --- a/python/pyspark/streaming/kinesis.pyi +++ b/python/pyspark/streaming/kinesis.pyi @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Optional, TypeVar +from typing import Callable, Optional, TypeVar from pyspark.storagelevel import StorageLevel from pyspark.streaming.context import StreamingContext from pyspark.streaming.dstream import DStream diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index b5a3601676..4620bdb005 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -313,7 +313,6 @@ trait MesosSchedulerUtils extends Logging { // offer has the required attribute and subsumes the required values for that attribute case (name, requiredValues) => offerAttributes.get(name) match { - case None => false case Some(_) if requiredValues.isEmpty => true // empty value matches presence case Some(scalarValue: Value.Scalar) => // check if provided values is less than equal to the offered values @@ -332,6 +331,7 @@ trait MesosSchedulerUtils extends Logging { // check if the specified value is equal, if multiple values are specified // we succeed if any of them match. requiredValues.contains(textValue.getValue) + case _ => false } } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 67ecf3242f..6a6514569c 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -178,7 +178,7 @@ class MesosFineGrainedSchedulerBackendSuite val (execInfo, _) = backend.createExecutorInfo( Arrays.asList(backend.createResource("cpus", 4)), "mockExecutor") assert(execInfo.getContainer.getDocker.getImage.equals("spark/mock")) - assert(execInfo.getContainer.getDocker.getForcePullImage.equals(true)) + assert(execInfo.getContainer.getDocker.getForcePullImage) val portmaps = execInfo.getContainer.getDocker.getPortMappingsList assert(portmaps.get(0).getHostPort.equals(80)) assert(portmaps.get(0).getContainerPort.equals(8080)) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6b6b751cc3..5d17028c32 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -119,20 +119,9 @@ statement (RESTRICT | CASCADE)? #dropNamespace | SHOW (DATABASES | NAMESPACES) ((FROM | IN) multipartIdentifier)? (LIKE? pattern=STRING)? #showNamespaces - | createTableHeader ('(' colTypeList ')')? tableProvider + | createTableHeader ('(' colTypeList ')')? tableProvider? createTableClauses (AS? query)? #createTable - | createTableHeader ('(' columns=colTypeList ')')? - (commentSpec | - (PARTITIONED BY '(' partitionColumns=colTypeList ')' | - PARTITIONED BY partitionColumnNames=identifierList) | - bucketSpec | - skewSpec | - rowFormat | - createFileFormat | - locationSpec | - (TBLPROPERTIES tableProps=tablePropertyList))* - (AS? query)? #createHiveTable | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier (tableProvider | @@ -140,7 +129,7 @@ statement createFileFormat | locationSpec | (TBLPROPERTIES tableProps=tablePropertyList))* #createTableLike - | replaceTableHeader ('(' colTypeList ')')? tableProvider + | replaceTableHeader ('(' colTypeList ')')? tableProvider? createTableClauses (AS? query)? #replaceTable | ANALYZE TABLE multipartIdentifier partitionSpec? COMPUTE STATISTICS @@ -393,8 +382,11 @@ tableProvider createTableClauses :((OPTIONS options=tablePropertyList) | - (PARTITIONED BY partitioning=transformList) | + (PARTITIONED BY partitioning=partitionFieldList) | + skewSpec | bucketSpec | + rowFormat | + createFileFormat | locationSpec | commentSpec | (TBLPROPERTIES tableProps=tablePropertyList))* @@ -741,8 +733,13 @@ namedExpressionSeq : namedExpression (',' namedExpression)* ; -transformList - : '(' transforms+=transform (',' transforms+=transform)* ')' +partitionFieldList + : '(' fields+=partitionField (',' fields+=partitionField)* ')' + ; + +partitionField + : transform #partitionTransform + | colType #partitionColumn ; transform diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java index 446ea14633..380717d2e0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java @@ -106,10 +106,19 @@ Map loadPartitionMetadata(InternalRow ident) throws UnsupportedOperationException; /** - * List the identifiers of all partitions that contains the ident in a table. + * List the identifiers of all partitions that have the ident prefix in a table. * * @param ident a prefix of partition identifier * @return an array of Identifiers for the partitions */ InternalRow[] listPartitionIdentifiers(InternalRow ident); + + /** + * List the identifiers of all partitions that match to the ident by names. + * + * @param names the names of partition values in the identifier. + * @param ident a partition identifier values. + * @return an array of Identifiers for the partitions + */ + InternalRow[] listPartitionByNames(String[] names, InternalRow ident); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index 92079d127b..52a74ab9dd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -46,6 +46,11 @@ public interface TableCatalog extends CatalogPlugin { */ String PROP_LOCATION = "location"; + /** + * A reserved property to specify a table was created with EXTERNAL. + */ + String PROP_EXTERNAL = "external"; + /** * A reserved property to specify the description of the table. */ @@ -61,6 +66,11 @@ public interface TableCatalog extends CatalogPlugin { */ String PROP_OWNER = "owner"; + /** + * A prefix used to pass OPTIONS in table properties + */ + String OPTION_PREFIX = "option."; + /** * List the tables in a namespace from the catalog. *

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index deeb8215d2..7354d2478b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -143,7 +143,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) RenameTable(catalog.asTableCatalog, oldName.asIdentifier, newNameParts.asIdentifier) case c @ CreateTableStatement( - NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => + NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) => assertNoNullTypeInSchema(c.tableSchema) assertNoCharTypeInSchema(c.tableSchema) CreateV2Table( @@ -152,11 +152,11 @@ class ResolveCatalogs(val catalogManager: CatalogManager) c.tableSchema, // convert the bucket spec and add it as a transform c.partitioning ++ c.bucketSpec.map(_.asTransform), - convertTableProperties(c.properties, c.options, c.location, c.comment, c.provider), + convertTableProperties(c), ignoreIfExists = c.ifNotExists) case c @ CreateTableAsSelectStatement( - NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => + NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _, _) => if (c.asSelect.resolved) { assertNoNullTypeInSchema(c.asSelect.schema) } @@ -166,12 +166,12 @@ class ResolveCatalogs(val catalogManager: CatalogManager) // convert the bucket spec and add it as a transform c.partitioning ++ c.bucketSpec.map(_.asTransform), c.asSelect, - convertTableProperties(c.properties, c.options, c.location, c.comment, c.provider), + convertTableProperties(c), writeOptions = c.writeOptions, ignoreIfExists = c.ifNotExists) case c @ ReplaceTableStatement( - NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => + NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => assertNoNullTypeInSchema(c.tableSchema) assertNoCharTypeInSchema(c.tableSchema) ReplaceTable( @@ -180,11 +180,11 @@ class ResolveCatalogs(val catalogManager: CatalogManager) c.tableSchema, // convert the bucket spec and add it as a transform c.partitioning ++ c.bucketSpec.map(_.asTransform), - convertTableProperties(c.properties, c.options, c.location, c.comment, c.provider), + convertTableProperties(c), orCreate = c.orCreate) case c @ ReplaceTableAsSelectStatement( - NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => + NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) => if (c.asSelect.resolved) { assertNoNullTypeInSchema(c.asSelect.schema) } @@ -194,7 +194,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) // convert the bucket spec and add it as a transform c.partitioning ++ c.bucketSpec.map(_.asTransform), c.asSelect, - convertTableProperties(c.properties, c.options, c.location, c.comment, c.provider), + convertTableProperties(c), writeOptions = c.writeOptions, orCreate = c.orCreate) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala index 531d40f431..6d061fce06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAddPartition, AlterTableDropPartition, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement @@ -65,31 +65,8 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] { conf.resolver) val partValues = partSchema.map { part => - val partValue = normalizedSpec.get(part.name).orNull - if (partValue == null) { - null - } else { - // TODO: Support other datatypes, such as DateType - part.dataType match { - case _: ByteType => - partValue.toByte - case _: ShortType => - partValue.toShort - case _: IntegerType => - partValue.toInt - case _: LongType => - partValue.toLong - case _: FloatType => - partValue.toFloat - case _: DoubleType => - partValue.toDouble - case _: StringType => - partValue - case _ => - throw new AnalysisException( - s"Type ${part.dataType.typeName} is not supported for partition.") - } - } + val raw = normalizedSpec.get(part.name).orNull + Cast(Literal.create(raw, StringType), part.dataType, Some(conf.sessionLocalTimeZone)).eval() } InternalRow.fromSeq(partValues) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5afc308e52..e6f585cacc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -59,8 +59,7 @@ object Cast { case (StringType, TimestampType) => true case (BooleanType, TimestampType) => true case (DateType, TimestampType) => true - case (_: NumericType, TimestampType) => - SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP) + case (_: NumericType, TimestampType) => true case (StringType, DateType) => true case (TimestampType, DateType) => true @@ -263,6 +262,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit */ def canCast(from: DataType, to: DataType): Boolean + /** + * Returns the error message if casting from one type to another one is invalid. + */ + def typeCheckFailureMessage: String + override def toString: String = { val ansi = if (ansiEnabled) "ansi_" else "" s"${ansi}cast($child as ${dataType.simpleString})" @@ -272,16 +276,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit if (canCast(child.dataType, dataType)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure( - if (child.dataType.isInstanceOf[NumericType] && dataType.isInstanceOf[TimestampType]) { - s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}," + - "you can enable the casting by setting " + - s"${SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key} to true," + - "but we strongly recommend using function " + - "TIMESTAMP_SECONDS/TIMESTAMP_MILLIS/TIMESTAMP_MICROS instead." - } else { - s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}" - }) + TypeCheckResult.TypeCheckFailure(typeCheckFailureMessage) } } @@ -1764,6 +1759,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } else { Cast.canCast(from, to) } + + override def typeCheckFailureMessage: String = if (ansiEnabled) { + AnsiCast.typeCheckFailureMessage(child.dataType, dataType, SQLConf.ANSI_ENABLED.key, "false") + } else { + s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}" + } } /** @@ -1783,6 +1784,14 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St override protected val ansiEnabled: Boolean = true override def canCast(from: DataType, to: DataType): Boolean = AnsiCast.canCast(from, to) + + // For now, this expression is only used in table insertion. + // If there are more scenarios for this expression, we should update the error message on type + // check failure. + override def typeCheckFailureMessage: String = + AnsiCast.typeCheckFailureMessage(child.dataType, dataType, + SQLConf.STORE_ASSIGNMENT_POLICY.key, SQLConf.StoreAssignmentPolicy.LEGACY.toString) + } object AnsiCast { @@ -1885,6 +1894,35 @@ object AnsiCast { case _ => false } + + def typeCheckFailureMessage( + from: DataType, + to: DataType, + fallbackConfKey: String, + fallbackConfValue: String): String = + (from, to) match { + case (_: NumericType, TimestampType) => + // scalastyle:off line.size.limit + s""" + | cannot cast ${from.catalogString} to ${to.catalogString}. + | To convert values from ${from.catalogString} to ${to.catalogString}, you can use functions TIMESTAMP_SECONDS/TIMESTAMP_MILLIS/TIMESTAMP_MICROS instead. + |""".stripMargin + + case (_: ArrayType, StringType) => + s""" + | cannot cast ${from.catalogString} to ${to.catalogString} with ANSI mode on. + | If you have to cast ${from.catalogString} to ${to.catalogString}, you can use the function ARRAY_JOIN or set $fallbackConfKey as $fallbackConfValue. + |""".stripMargin + + case _ if Cast.canCast(from, to) => + s""" + | cannot cast ${from.catalogString} to ${to.catalogString} with ANSI mode on. + | If you have to cast ${from.catalogString} to ${to.catalogString}, you can set $fallbackConfKey as $fallbackConfValue. + |""".stripMargin + + case _ => s"cannot cast ${from.catalogString} to ${to.catalogString}" + // scalastyle:on line.size.limit + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 39d9eb5a36..a363615d3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -94,7 +94,7 @@ private[this] object JsonPathParser extends RegexParsers { case Success(result, _) => Some(result) - case NoSuccess(msg, next) => + case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 1e69814673..810cecff37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -322,7 +322,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { case (a: Array[Byte], b: Array[Byte]) => util.Arrays.equals(a, b) case (a: ArrayBasedMapData, b: ArrayBasedMapData) => a.keyArray == b.keyArray && a.valueArray == b.valueArray - case (a, b) => a != null && a.equals(b) + case (a: Double, b: Double) if a.isNaN && b.isNaN => true + case (a: Float, b: Float) if a.isNaN && b.isNaN => true + case (a, b) => a != null && a == b } case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9701420e65..9303df75af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -981,7 +981,7 @@ case class MapObjects private( (genValue: String) => s"$builder.add($genValue);", s"$builder;" ) - case None => + case _ => // array ( s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index de396a4c63..a39f06628b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -190,6 +190,9 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { } case VALUE_TRUE | VALUE_FALSE => BooleanType + + case _ => + throw new SparkException("Malformed JSON") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index b65fc7f7e2..bf3fced0ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -197,9 +197,9 @@ object StarSchemaDetection extends PredicateHelper with SQLConfHelper { } else { false } - case None => false + case _ => false } - case None => false + case _ => false } case _ => false } @@ -239,7 +239,7 @@ object StarSchemaDetection extends PredicateHelper with SQLConfHelper { case Some(col) if t.outputSet.contains(col) => val stats = t.stats stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) - case None => false + case _ => false } case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 55a45f4410..d1eb3b07d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -685,6 +685,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { case LeftOuter => newJoin.right.output case RightOuter => newJoin.left.output case FullOuter => newJoin.left.output ++ newJoin.right.output + case _ => Nil }) val newFoldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { case (attr, _) => missDerivedAttrsSet.contains(attr) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ea4baafbac..25423e5101 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -967,6 +967,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None) case Some(c) if c.booleanExpression != null => (baseJoinType, Option(expression(c.booleanExpression))) + case Some(c) => + throw new ParseException(s"Unimplemented joinCriteria: $c", ctx) case None if join.NATURAL != null => if (baseJoinType == Cross) { throw new ParseException("NATURAL CROSS JOIN is not supported", ctx) @@ -2457,10 +2459,22 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg /** * Type to keep track of table clauses: - * (partitioning, bucketSpec, properties, options, location, comment). + * - partition transforms + * - partition columns + * - bucketSpec + * - properties + * - options + * - location + * - comment + * - serde + * + * Note: Partition transforms are based on existing table schema definition. It can be simple + * column names, or functions like `year(date_col)`. Partition columns are column names with data + * types like `i INT`, which should be appended to the existing table schema. */ - type TableClauses = (Seq[Transform], Option[BucketSpec], Map[String, String], - Map[String, String], Option[String], Option[String]) + type TableClauses = ( + Seq[Transform], Seq[StructField], Option[BucketSpec], Map[String, String], + Map[String, String], Option[String], Option[String], Option[SerdeInfo]) /** * Validate a create table statement and return the [[TableIdentifier]]. @@ -2493,9 +2507,22 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Parse a list of transforms. + * Parse a list of transforms or columns. */ - override def visitTransformList(ctx: TransformListContext): Seq[Transform] = withOrigin(ctx) { + override def visitPartitionFieldList( + ctx: PartitionFieldListContext): (Seq[Transform], Seq[StructField]) = withOrigin(ctx) { + val (transforms, columns) = ctx.fields.asScala.map { + case transform: PartitionTransformContext => + (Some(visitPartitionTransform(transform)), None) + case field: PartitionColumnContext => + (None, Some(visitColType(field.colType))) + }.unzip + + (transforms.flatten.toSeq, columns.flatten.toSeq) + } + + override def visitPartitionTransform( + ctx: PartitionTransformContext): Transform = withOrigin(ctx) { def getFieldReference( ctx: ApplyTransformContext, arg: V2Expression): FieldReference = { @@ -2522,7 +2549,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } } - ctx.transforms.asScala.map { + ctx.transform match { case identityCtx: IdentityTransformContext => IdentityTransform(FieldReference(typedVisit[Seq[String]](identityCtx.qualifiedName))) @@ -2561,7 +2588,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg case name => ApplyTransform(name, arguments) } - }.toSeq + } } /** @@ -2761,16 +2788,157 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg (filtered, path) } + /** + * Create a [[SerdeInfo]] for creating tables. + * + * Format: STORED AS (name | INPUTFORMAT input_format OUTPUTFORMAT output_format) + */ + override def visitCreateFileFormat(ctx: CreateFileFormatContext): SerdeInfo = withOrigin(ctx) { + (ctx.fileFormat, ctx.storageHandler) match { + // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format + case (c: TableFileFormatContext, null) => + SerdeInfo(formatClasses = Some(FormatClasses(string(c.inFmt), string(c.outFmt)))) + // Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO + case (c: GenericFileFormatContext, null) => + SerdeInfo(storedAs = Some(c.identifier.getText)) + case (null, storageHandler) => + operationNotAllowed("STORED BY", ctx) + case _ => + throw new ParseException("Expected either STORED AS or STORED BY, not both", ctx) + } + } + + /** + * Create a [[SerdeInfo]] used for creating tables. + * + * Example format: + * {{{ + * SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)] + * }}} + * + * OR + * + * {{{ + * DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]] + * [COLLECTION ITEMS TERMINATED BY char] + * [MAP KEYS TERMINATED BY char] + * [LINES TERMINATED BY char] + * [NULL DEFINED AS char] + * }}} + */ + def visitRowFormat(ctx: RowFormatContext): SerdeInfo = withOrigin(ctx) { + ctx match { + case serde: RowFormatSerdeContext => visitRowFormatSerde(serde) + case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited) + } + } + + /** + * Create SERDE row format name and properties pair. + */ + override def visitRowFormatSerde(ctx: RowFormatSerdeContext): SerdeInfo = withOrigin(ctx) { + import ctx._ + SerdeInfo( + serde = Some(string(name)), + serdeProperties = Option(tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)) + } + + /** + * Create a delimited row format properties object. + */ + override def visitRowFormatDelimited( + ctx: RowFormatDelimitedContext): SerdeInfo = withOrigin(ctx) { + // Collect the entries if any. + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).toSeq.map(x => key -> string(x)) + } + // TODO we need proper support for the NULL format. + val entries = + entry("field.delim", ctx.fieldsTerminatedBy) ++ + entry("serialization.format", ctx.fieldsTerminatedBy) ++ + entry("escape.delim", ctx.escapedBy) ++ + // The following typo is inherited from Hive... + entry("colelction.delim", ctx.collectionItemsTerminatedBy) ++ + entry("mapkey.delim", ctx.keysTerminatedBy) ++ + Option(ctx.linesSeparatedBy).toSeq.map { token => + val value = string(token) + validate( + value == "\n", + s"LINES TERMINATED BY only supports newline '\\n' right now: $value", + ctx) + "line.delim" -> value + } + SerdeInfo(serdeProperties = entries.toMap) + } + + /** + * Throw a [[ParseException]] if the user specified incompatible SerDes through ROW FORMAT + * and STORED AS. + * + * The following are allowed. Anything else is not: + * ROW FORMAT SERDE ... STORED AS [SEQUENCEFILE | RCFILE | TEXTFILE] + * ROW FORMAT DELIMITED ... STORED AS TEXTFILE + * ROW FORMAT ... STORED AS INPUTFORMAT ... OUTPUTFORMAT ... + */ + protected def validateRowFormatFileFormat( + rowFormatCtx: RowFormatContext, + createFileFormatCtx: CreateFileFormatContext, + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx == null || createFileFormatCtx == null) { + return + } + (rowFormatCtx, createFileFormatCtx.fileFormat) match { + case (_, ffTable: TableFileFormatContext) => // OK + case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) => + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { + case ("sequencefile" | "textfile" | "rcfile") => // OK + case fmt => + operationNotAllowed( + s"ROW FORMAT SERDE is incompatible with format '$fmt', which also specifies a serde", + parentCtx) + } + case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) => + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { + case "textfile" => // OK + case fmt => operationNotAllowed( + s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx) + } + case _ => + // should never happen + def str(ctx: ParserRuleContext): String = { + (0 until ctx.getChildCount).map { i => ctx.getChild(i).getText }.mkString(" ") + } + operationNotAllowed( + s"Unexpected combination of ${str(rowFormatCtx)} and ${str(createFileFormatCtx)}", + parentCtx) + } + } + + protected def validateRowFormatFileFormat( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { + validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) + } + } + override def visitCreateTableClauses(ctx: CreateTableClausesContext): TableClauses = { checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) + checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx) checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) - val partitioning: Seq[Transform] = - Option(ctx.partitioning).map(visitTransformList).getOrElse(Nil) + if (ctx.skewSpec.size > 0) { + operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) + } + + val (partTransforms, partCols) = + Option(ctx.partitioning).map(visitPartitionFieldList).getOrElse((Nil, Nil)) val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) val cleanedProperties = cleanTableProperties(ctx, properties) @@ -2778,7 +2946,45 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg val location = visitLocationSpecList(ctx.locationSpec()) val (cleanedOptions, newLocation) = cleanTableOptions(ctx, options, location) val comment = visitCommentSpecList(ctx.commentSpec()) - (partitioning, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment) + val serdeInfo = getSerdeInfo(ctx.rowFormat.asScala, ctx.createFileFormat.asScala, ctx) + (partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment, + serdeInfo) + } + + protected def getSerdeInfo( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + ctx: ParserRuleContext, + skipCheck: Boolean = false): Option[SerdeInfo] = { + if (!skipCheck) validateRowFormatFileFormat(rowFormatCtx, createFileFormatCtx, ctx) + val rowFormatSerdeInfo = rowFormatCtx.map(visitRowFormat) + val fileFormatSerdeInfo = createFileFormatCtx.map(visitCreateFileFormat) + (fileFormatSerdeInfo ++ rowFormatSerdeInfo).reduceLeftOption((l, r) => l.merge(r)) + } + + private def partitionExpressions( + partTransforms: Seq[Transform], + partCols: Seq[StructField], + ctx: ParserRuleContext): Seq[Transform] = { + if (partTransforms.nonEmpty) { + if (partCols.nonEmpty) { + val references = partTransforms.map(_.describe()).mkString(", ") + val columns = partCols + .map(field => s"${field.name} ${field.dataType.simpleString}") + .mkString(", ") + operationNotAllowed( + s"""PARTITION BY: Cannot mix partition expressions and partition columns: + |Expressions: $references + |Columns: $columns""".stripMargin, ctx) + + } + partTransforms + } else { + // columns were added to create the schema. convert to column references + partCols.map { column => + IdentityTransform(FieldReference(Seq(column.name))) + } + } } /** @@ -2787,13 +2993,15 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * Expected format: * {{{ * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name - * USING table_provider + * [USING table_provider] * create_table_clauses * [[AS] select_statement]; * * create_table_clauses (order insensitive): + * [PARTITIONED BY (partition_fields)] * [OPTIONS table_property_list] - * [PARTITIONED BY (col_name, transform(col_name), transform(constant, col_name), ...)] + * [ROW FORMAT row_format] + * [STORED AS file_format] * [CLUSTERED BY (col_name, col_name, ...) * [SORTED BY (col_name [ASC|DESC], ...)] * INTO num_buckets BUCKETS @@ -2801,40 +3009,55 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * [LOCATION path] * [COMMENT table_comment] * [TBLPROPERTIES (property_name=property_value, ...)] + * + * partition_fields: + * col_name, transform(col_name), transform(constant, col_name), ... | + * col_name data_type [NOT NULL] [COMMENT col_comment], ... * }}} */ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - if (external) { - operationNotAllowed("CREATE EXTERNAL TABLE ...", ctx) - } - val schema = Option(ctx.colTypeList()).map(createSchema) + + val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) - val (partitioning, bucketSpec, properties, options, location, comment) = + val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) = visitCreateTableClauses(ctx.createTableClauses()) - Option(ctx.query).map(plan) match { - case Some(_) if temp => - operationNotAllowed("CREATE TEMPORARY TABLE ... USING ... AS query", ctx) + if (provider.isDefined && serdeInfo.isDefined) { + operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx) + } - case Some(_) if schema.isDefined => + if (temp) { + val asSelect = if (ctx.query == null) "" else " AS ..." + operationNotAllowed( + s"CREATE TEMPORARY TABLE ...$asSelect, use CREATE TEMPORARY VIEW instead", ctx) + } + + val partitioning = partitionExpressions(partTransforms, partCols, ctx) + + Option(ctx.query).map(plan) match { + case Some(_) if columns.nonEmpty => operationNotAllowed( "Schema may not be specified in a Create Table As Select (CTAS) statement", ctx) + case Some(_) if partCols.nonEmpty => + // non-reference partition columns are not allowed because schema can't be specified + operationNotAllowed( + "Partition column types may not be specified in Create Table As Select (CTAS)", + ctx) + case Some(query) => CreateTableAsSelectStatement( table, query, partitioning, bucketSpec, properties, provider, options, location, comment, - writeOptions = Map.empty, ifNotExists = ifNotExists) - - case None if temp => - // CREATE TEMPORARY TABLE ... USING ... is not supported by the catalyst parser. - // Use CREATE TEMPORARY VIEW ... USING ... instead. - operationNotAllowed("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx) + writeOptions = Map.empty, serdeInfo, external = external, ifNotExists = ifNotExists) case _ => - CreateTableStatement(table, schema.getOrElse(new StructType), partitioning, bucketSpec, - properties, provider, options, location, comment, ifNotExists = ifNotExists) + // Note: table schema includes both the table columns list and the partition columns + // with data type. + val schema = StructType(columns ++ partCols) + CreateTableStatement(table, schema, partitioning, bucketSpec, properties, provider, + options, location, comment, serdeInfo, external = external, ifNotExists = ifNotExists) } } @@ -2844,13 +3067,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * Expected format: * {{{ * [CREATE OR] REPLACE TABLE [db_name.]table_name - * USING table_provider + * [USING table_provider] * replace_table_clauses * [[AS] select_statement]; * * replace_table_clauses (order insensitive): * [OPTIONS table_property_list] - * [PARTITIONED BY (col_name, transform(col_name), transform(constant, col_name), ...)] + * [PARTITIONED BY (partition_fields)] * [CLUSTERED BY (col_name, col_name, ...) * [SORTED BY (col_name [ASC|DESC], ...)] * INTO num_buckets BUCKETS @@ -2858,33 +3081,63 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * [LOCATION path] * [COMMENT table_comment] * [TBLPROPERTIES (property_name=property_value, ...)] + * + * partition_fields: + * col_name, transform(col_name), transform(constant, col_name), ... | + * col_name data_type [NOT NULL] [COMMENT col_comment], ... * }}} */ override def visitReplaceTable(ctx: ReplaceTableContext): LogicalPlan = withOrigin(ctx) { - val (table, _, ifNotExists, external) = visitReplaceTableHeader(ctx.replaceTableHeader) + val (table, temp, ifNotExists, external) = visitReplaceTableHeader(ctx.replaceTableHeader) + val orCreate = ctx.replaceTableHeader().CREATE() != null + + if (temp) { + val action = if (orCreate) "CREATE OR REPLACE" else "REPLACE" + operationNotAllowed(s"$action TEMPORARY TABLE ..., use $action TEMPORARY VIEW instead.", ctx) + } + if (external) { - operationNotAllowed("REPLACE EXTERNAL TABLE ... USING", ctx) + operationNotAllowed("REPLACE EXTERNAL TABLE ...", ctx) + } + + if (ifNotExists) { + operationNotAllowed("REPLACE ... IF NOT EXISTS, use CREATE IF NOT EXISTS instead", ctx) } - val (partitioning, bucketSpec, properties, options, location, comment) = + val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) = visitCreateTableClauses(ctx.createTableClauses()) - val schema = Option(ctx.colTypeList()).map(createSchema) + val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) - val orCreate = ctx.replaceTableHeader().CREATE() != null + + if (provider.isDefined && serdeInfo.isDefined) { + operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx) + } + + val partitioning = partitionExpressions(partTransforms, partCols, ctx) Option(ctx.query).map(plan) match { - case Some(_) if schema.isDefined => + case Some(_) if columns.nonEmpty => operationNotAllowed( "Schema may not be specified in a Replace Table As Select (RTAS) statement", ctx) + case Some(_) if partCols.nonEmpty => + // non-reference partition columns are not allowed because schema can't be specified + operationNotAllowed( + "Partition column types may not be specified in Replace Table As Select (RTAS)", + ctx) + case Some(query) => ReplaceTableAsSelectStatement(table, query, partitioning, bucketSpec, properties, - provider, options, location, comment, writeOptions = Map.empty, orCreate = orCreate) + provider, options, location, comment, writeOptions = Map.empty, serdeInfo, + orCreate = orCreate) case _ => - ReplaceTableStatement(table, schema.getOrElse(new StructType), partitioning, - bucketSpec, properties, provider, options, location, comment, orCreate = orCreate) + // Note: table schema includes both the table columns list and the partition columns + // with data type. + val schema = StructType(columns ++ partCols) + ReplaceTableStatement(table, schema, partitioning, bucketSpec, properties, provider, + options, location, comment, serdeInfo, orCreate = orCreate) } } @@ -3354,7 +3607,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create a [[TruncateTableStatement]] command. + * Create a [[TruncateTable]] command. * * For example: * {{{ @@ -3362,8 +3615,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * }}} */ override def visitTruncateTable(ctx: TruncateTableContext): LogicalPlan = withOrigin(ctx) { - TruncateTableStatement( - visitMultipartIdentifier(ctx.multipartIdentifier), + TruncateTable( + UnresolvedTable(visitMultipartIdentifier(ctx.multipartIdentifier), "TRUNCATE TABLE"), Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) } @@ -3398,7 +3651,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg /** * A command for users to list the column names for a table. - * This function creates a [[ShowColumnsStatement]] logical plan. + * This function creates a [[ShowColumns]] logical plan. * * The syntax of using this command in SQL is: * {{{ @@ -3407,9 +3660,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * }}} */ override def visitShowColumns(ctx: ShowColumnsContext): LogicalPlan = withOrigin(ctx) { - val table = visitMultipartIdentifier(ctx.table) + val nameParts = visitMultipartIdentifier(ctx.table) val namespace = Option(ctx.ns).map(visitMultipartIdentifier) - ShowColumnsStatement(table, namespace) + // Use namespace only if table name doesn't specify it. If namespace is already specified + // in the table name, it's checked against the given namespace after table/view is resolved. + val tableName = if (namespace.isDefined && nameParts.length == 1) { + namespace.get ++ nameParts + } else { + nameParts + } + ShowColumns(UnresolvedTableOrView(tableName), namespace) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f96e07863f..c7108ea8ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -362,7 +362,7 @@ case class Join( left.constraints case RightOuter => right.constraints - case FullOuter => + case _ => ExpressionSet() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index 39bc5a5604..281d57b364 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -53,6 +53,81 @@ abstract class ParsedStatement extends LogicalPlan { final override lazy val resolved = false } +/** + * Type to keep track of Hive serde info + */ +case class SerdeInfo( + storedAs: Option[String] = None, + formatClasses: Option[FormatClasses] = None, + serde: Option[String] = None, + serdeProperties: Map[String, String] = Map.empty) { + // this uses assertions because validation is done in validateRowFormatFileFormat etc. + assert(storedAs.isEmpty || formatClasses.isEmpty, + "Cannot specify both STORED AS and INPUTFORMAT/OUTPUTFORMAT") + + def describe: String = { + val serdeString = if (serde.isDefined || serdeProperties.nonEmpty) { + "ROW FORMAT " + serde.map(sd => s"SERDE $sd").getOrElse("DELIMITED") + } else { + "" + } + + this match { + case SerdeInfo(Some(storedAs), _, _, _) => + s"STORED AS $storedAs $serdeString" + case SerdeInfo(_, Some(formatClasses), _, _) => + s"STORED AS $formatClasses $serdeString" + case _ => + serdeString + } + } + + def merge(other: SerdeInfo): SerdeInfo = { + def getOnly[T](desc: String, left: Option[T], right: Option[T]): Option[T] = { + (left, right) match { + case (Some(l), Some(r)) => + assert(l == r, s"Conflicting $desc values: $l != $r") + left + case (Some(_), _) => + left + case (_, Some(_)) => + right + case _ => + None + } + } + + SerdeInfo.checkSerdePropMerging(serdeProperties, other.serdeProperties) + SerdeInfo( + getOnly("STORED AS", storedAs, other.storedAs), + getOnly("INPUTFORMAT/OUTPUTFORMAT", formatClasses, other.formatClasses), + getOnly("SERDE", serde, other.serde), + serdeProperties ++ other.serdeProperties) + } +} + +case class FormatClasses(input: String, output: String) { + override def toString: String = s"INPUTFORMAT $input OUTPUTFORMAT $output" +} + +object SerdeInfo { + val empty: SerdeInfo = SerdeInfo(None, None, None, Map.empty) + + def checkSerdePropMerging( + props1: Map[String, String], props2: Map[String, String]): Unit = { + val conflictKeys = props1.keySet.intersect(props2.keySet) + if (conflictKeys.nonEmpty) { + throw new UnsupportedOperationException( + s""" + |Cannot safely merge SERDEPROPERTIES: + |${props1.map { case (k, v) => s"$k=$v" }.mkString("{", ",", "}")} + |${props2.map { case (k, v) => s"$k=$v" }.mkString("{", ",", "}")} + |The conflict keys: ${conflictKeys.mkString(", ")} + |""".stripMargin) + } + } +} + /** * A CREATE TABLE command, as parsed from SQL. * @@ -68,6 +143,8 @@ case class CreateTableStatement( options: Map[String, String], location: Option[String], comment: Option[String], + serde: Option[SerdeInfo], + external: Boolean, ifNotExists: Boolean) extends ParsedStatement /** @@ -84,6 +161,8 @@ case class CreateTableAsSelectStatement( location: Option[String], comment: Option[String], writeOptions: Map[String, String], + serde: Option[SerdeInfo], + external: Boolean, ifNotExists: Boolean) extends ParsedStatement { override def children: Seq[LogicalPlan] = Seq(asSelect) @@ -119,6 +198,7 @@ case class ReplaceTableStatement( options: Map[String, String], location: Option[String], comment: Option[String], + serde: Option[SerdeInfo], orCreate: Boolean) extends ParsedStatement /** @@ -135,6 +215,7 @@ case class ReplaceTableAsSelectStatement( location: Option[String], comment: Option[String], writeOptions: Map[String, String], + serde: Option[SerdeInfo], orCreate: Boolean) extends ParsedStatement { override def children: Seq[LogicalPlan] = Seq(asSelect) @@ -359,13 +440,6 @@ case class ShowPartitionsStatement( tableName: Seq[String], partitionSpec: Option[TablePartitionSpec]) extends ParsedStatement -/** - * A SHOW COLUMNS statement, as parsed from SQL - */ -case class ShowColumnsStatement( - table: Seq[String], - namespace: Option[Seq[String]]) extends ParsedStatement - /** * A SHOW CURRENT NAMESPACE statement, as parsed from SQL */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 5bda2b5b8d..ebf41f6a6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -670,3 +670,22 @@ case class LoadData( case class ShowCreateTable(child: LogicalPlan, asSerde: Boolean = false) extends Command { override def children: Seq[LogicalPlan] = child :: Nil } + +/** + * The logical plan of the SHOW COLUMN command. + */ +case class ShowColumns( + child: LogicalPlan, + namespace: Option[Seq[String]]) extends Command { + override def children: Seq[LogicalPlan] = child :: Nil +} + +/** + * The logical plan of the TRUNCATE TABLE command. + */ +case class TruncateTable( + child: LogicalPlan, + partitionSpec: Option[TablePartitionSpec]) extends Command { + override def children: Seq[LogicalPlan] = child :: Nil +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 81f412c143..e46d730afb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -120,7 +120,7 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { return false } - case _ => if (!o1.equals(o2)) { + case _ => if (o1.getClass != o2.getClass || o1 != o2) { return false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 1a3a7207c6..b6dc4f61c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, UnresolvedV2Relation} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.AlterTable +import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, CreateTableAsSelectStatement, CreateTableStatement, ReplaceTableAsSelectStatement, ReplaceTableStatement, SerdeInfo} import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{ArrayType, DataType, HIVE_TYPE_STRING, HiveStringType, MapType, NullType, StructField, StructType} @@ -295,18 +295,65 @@ private[sql] object CatalogV2Util { catalog.name().equalsIgnoreCase(CatalogManager.SESSION_CATALOG_NAME) } - def convertTableProperties( + def convertTableProperties(c: CreateTableStatement): Map[String, String] = { + convertTableProperties( + c.properties, c.options, c.serde, c.location, c.comment, c.provider, c.external) + } + + def convertTableProperties(c: CreateTableAsSelectStatement): Map[String, String] = { + convertTableProperties( + c.properties, c.options, c.serde, c.location, c.comment, c.provider, c.external) + } + + def convertTableProperties(r: ReplaceTableStatement): Map[String, String] = { + convertTableProperties(r.properties, r.options, r.serde, r.location, r.comment, r.provider) + } + + def convertTableProperties(r: ReplaceTableAsSelectStatement): Map[String, String] = { + convertTableProperties(r.properties, r.options, r.serde, r.location, r.comment, r.provider) + } + + private def convertTableProperties( properties: Map[String, String], options: Map[String, String], + serdeInfo: Option[SerdeInfo], location: Option[String], comment: Option[String], - provider: Option[String]): Map[String, String] = { - properties ++ options ++ + provider: Option[String], + external: Boolean = false): Map[String, String] = { + properties ++ + options ++ // to make the transition to the "option." prefix easier, add both + options.map { case (key, value) => TableCatalog.OPTION_PREFIX + key -> value } ++ + convertToProperties(serdeInfo) ++ + (if (external) Some(TableCatalog.PROP_EXTERNAL -> "true") else None) ++ provider.map(TableCatalog.PROP_PROVIDER -> _) ++ comment.map(TableCatalog.PROP_COMMENT -> _) ++ location.map(TableCatalog.PROP_LOCATION -> _) } + /** + * Converts Hive Serde info to table properties. The mapped property keys are: + * - INPUTFORMAT/OUTPUTFORMAT: hive.input/output-format + * - STORED AS: hive.stored-as + * - ROW FORMAT SERDE: hive.serde + * - SERDEPROPERTIES: add "option." prefix + */ + private def convertToProperties(serdeInfo: Option[SerdeInfo]): Map[String, String] = { + serdeInfo match { + case Some(s) => + s.formatClasses.map { f => + Map("hive.input-format" -> f.input, "hive.output-format" -> f.output) + }.getOrElse(Map.empty) ++ + s.storedAs.map("hive.stored-as" -> _) ++ + s.serde.map("hive.serde" -> _) ++ + s.serdeProperties.map { + case (key, value) => TableCatalog.OPTION_PREFIX + key -> value + } + case None => + Map.empty + } + } + def withDefaultOwnership(properties: Map[String, String]): Map[String, String] = { properties ++ Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ef974dc176..0738478888 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2824,15 +2824,6 @@ object SQLConf { .checkValue(_ > 0, "The timeout value must be positive") .createWithDefault(10L) - val LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP = - buildConf("spark.sql.legacy.allowCastNumericToTimestamp") - .internal() - .doc("When true, allow casting numeric to timestamp," + - "when false, forbid the cast, more details in SPARK-31710") - .version("3.1.0") - .booleanConf - .createWithDefault(true) - val COALESCE_BUCKETS_IN_JOIN_ENABLED = buildConf("spark.sql.bucketing.coalesceBucketsInJoin.enabled") .doc("When true, if two bucketed tables with the different number of buckets are joined, " + @@ -3550,9 +3541,6 @@ class SQLConf extends Serializable with Logging { def integerGroupingIdEnabled: Boolean = getConf(SQLConf.LEGACY_INTEGER_GROUPING_ID) - def legacyAllowCastNumericToTimestamp: Boolean = - getConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP) - def metadataCacheTTL: Long = getConf(StaticSQLConf.METADATA_CACHE_TTL_SECONDS) def coalesceBucketsInJoinEnabled: Boolean = getConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index ca1074fcf6..02cb6f2962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -249,4 +249,16 @@ object StaticSQLConf { .version("3.1.0") .timeConf(TimeUnit.SECONDS) .createWithDefault(-1) + + val ENABLED_STREAMING_UI_CUSTOM_METRIC_LIST = + buildStaticConf("spark.sql.streaming.ui.enabledCustomMetricList") + .internal() + .doc("Configures a list of custom metrics on Structured Streaming UI, which are enabled. " + + "The list contains the name of the custom metrics separated by comma. In aggregation" + + "only sum used. The list of supported custom metrics is state store provider specific " + + "and it can be found out for example from query progress log entry.") + .version("3.1.0") + .stringConf + .toSequence + .createWithDefault(Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index afb76d8a5a..f1fc921e40 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -25,6 +25,7 @@ import scala.collection.parallel.immutable.ParVector import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence import org.apache.spark.sql.catalyst.analysis.TypeCoercionSuite import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet} @@ -841,12 +842,28 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase { cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented") } + protected def setConfigurationHint: String + + private def verifyCastFailure(c: CastBase, optionalExpectedMsg: Option[String] = None): Unit = { + val typeCheckResult = c.checkInputDataTypes() + assert(typeCheckResult.isFailure) + assert(typeCheckResult.isInstanceOf[TypeCheckFailure]) + val message = typeCheckResult.asInstanceOf[TypeCheckFailure].message + + if (optionalExpectedMsg.isDefined) { + assert(message.contains(optionalExpectedMsg.get)) + } else { + assert(message.contains("with ANSI mode on")) + assert(message.contains(setConfigurationHint)) + } + } + test("ANSI mode: disallow type conversions between Numeric types and Timestamp type") { import DataTypeTestUtils.numericTypes checkInvalidCastFromNumericType(TimestampType) val timestampLiteral = Literal(1L, TimestampType) numericTypes.foreach { numericType => - assert(cast(timestampLiteral, numericType).checkInputDataTypes().isFailure) + verifyCastFailure(cast(timestampLiteral, numericType)) } } @@ -855,7 +872,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase { checkInvalidCastFromNumericType(DateType) val dateLiteral = Literal(1, DateType) numericTypes.foreach { numericType => - assert(cast(dateLiteral, numericType).checkInputDataTypes().isFailure) + verifyCastFailure(cast(dateLiteral, numericType)) } } @@ -880,9 +897,9 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase { } test("ANSI mode: disallow casting complex types as String type") { - assert(cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType).checkInputDataTypes().isFailure) - assert(cast(Literal.create(Map(1 -> "a")), StringType).checkInputDataTypes().isFailure) - assert(cast(Literal.create((1, "a", 0.1)), StringType).checkInputDataTypes().isFailure) + verifyCastFailure(cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType)) + verifyCastFailure(cast(Literal.create(Map(1 -> "a")), StringType)) + verifyCastFailure(cast(Literal.create((1, "a", 0.1)), StringType)) } test("cast from invalid string to numeric should throw NumberFormatException") { @@ -1311,20 +1328,6 @@ class CastSuite extends CastSuiteBase { } } - test("SPARK-31710: fail casting from numeric to timestamp if it is forbidden") { - Seq(true, false).foreach { enable => - withSQLConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP.key -> enable.toString) { - assert(cast(2.toByte, TimestampType).resolved == enable) - assert(cast(10.toShort, TimestampType).resolved == enable) - assert(cast(3, TimestampType).resolved == enable) - assert(cast(10L, TimestampType).resolved == enable) - assert(cast(Decimal(1.2), TimestampType).resolved == enable) - assert(cast(1.7f, TimestampType).resolved == enable) - assert(cast(2.3d, TimestampType).resolved == enable) - } - } - } - test("SPARK-32828: cast from a derived user-defined type to a base type") { val v = Literal.create(Row(1), new ExampleSubTypeUDT()) checkEvaluation(cast(v, new ExampleBaseTypeUDT), Row(1)) @@ -1503,6 +1506,9 @@ class CastSuiteWithAnsiModeOn extends AnsiCastSuiteBase { case _ => Cast(Literal(v), targetType, timeZoneId) } } + + override def setConfigurationHint: String = + s"set ${SQLConf.ANSI_ENABLED.key} as false" } /** @@ -1525,6 +1531,10 @@ class AnsiCastSuiteWithAnsiModeOn extends AnsiCastSuiteBase { case _ => AnsiCast(Literal(v), targetType, timeZoneId) } } + + override def setConfigurationHint: String = + s"set ${SQLConf.STORE_ASSIGNMENT_POLICY.key} as" + + s" ${SQLConf.StoreAssignmentPolicy.LEGACY.toString}" } /** @@ -1547,4 +1557,8 @@ class AnsiCastSuiteWithAnsiModeOff extends AnsiCastSuiteBase { case _ => AnsiCast(Literal(v), targetType, timeZoneId) } } + + override def setConfigurationHint: String = + s"set ${SQLConf.STORE_ASSIGNMENT_POLICY.key} as" + + s" ${SQLConf.StoreAssignmentPolicy.LEGACY.toString}" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index bd28484b23..f650922e75 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -63,6 +63,7 @@ class DDLParserSuite extends AnalysisTest { Some("parquet"), Map.empty[String, String], None, + None, None) Seq(createSql, replaceSql).foreach { sql => @@ -70,7 +71,7 @@ class DDLParserSuite extends AnalysisTest { } intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet", - "no viable alternative at input") + "extraneous input ':'") } test("create/replace table - with IF NOT EXISTS") { @@ -86,6 +87,7 @@ class DDLParserSuite extends AnalysisTest { Some("parquet"), Map.empty[String, String], None, + None, None), expectedIfNotExists = true) } @@ -106,6 +108,7 @@ class DDLParserSuite extends AnalysisTest { Some("parquet"), Map.empty[String, String], None, + None, None) Seq(createSql, replaceSql).foreach { sql => testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) @@ -160,6 +163,7 @@ class DDLParserSuite extends AnalysisTest { Some("parquet"), Map.empty[String, String], None, + None, None) Seq(createSql, replaceSql).foreach { sql => testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) @@ -182,6 +186,7 @@ class DDLParserSuite extends AnalysisTest { Some("parquet"), Map.empty[String, String], None, + None, None) Seq(createSql, replaceSql).foreach { sql => testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) @@ -200,7 +205,8 @@ class DDLParserSuite extends AnalysisTest { Some("parquet"), Map.empty[String, String], None, - Some("abc")) + Some("abc"), + None) Seq(createSql, replaceSql).foreach{ sql => testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) } @@ -220,6 +226,7 @@ class DDLParserSuite extends AnalysisTest { Some("parquet"), Map.empty[String, String], None, + None, None) Seq(createSql, replaceSql).foreach { sql => testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) @@ -238,6 +245,7 @@ class DDLParserSuite extends AnalysisTest { Some("parquet"), Map.empty[String, String], Some("/tmp/file"), + None, None) Seq(createSql, replaceSql).foreach { sql => testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) @@ -256,19 +264,309 @@ class DDLParserSuite extends AnalysisTest { Some("parquet"), Map.empty[String, String], None, + None, None) Seq(createSql, replaceSql).foreach { sql => testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) } } + test("create/replace table - partition column definitions") { + val createSql = "CREATE TABLE my_tab (id bigint) PARTITIONED BY (part string)" + val replaceSql = "REPLACE TABLE my_tab (id bigint) PARTITIONED BY (part string)" + val expectedTableSpec = TableSpec( + Seq("my_tab"), + Some(new StructType().add("id", LongType).add("part", StringType)), + Seq(IdentityTransform(FieldReference("part"))), + None, + Map.empty[String, String], + None, + Map.empty[String, String], + None, + None, + None) + Seq(createSql, replaceSql).foreach { sql => + testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) + } + } + + test("create/replace table - empty columns list") { + val createSql = "CREATE TABLE my_tab PARTITIONED BY (part string)" + val replaceSql = "REPLACE TABLE my_tab PARTITIONED BY (part string)" + val expectedTableSpec = TableSpec( + Seq("my_tab"), + Some(new StructType().add("part", StringType)), + Seq(IdentityTransform(FieldReference("part"))), + None, + Map.empty[String, String], + None, + Map.empty[String, String], + None, + None, + None) + Seq(createSql, replaceSql).foreach { sql => + testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) + } + } + + test("create/replace table - using with partition column definitions") { + val createSql = "CREATE TABLE my_tab (id bigint) USING parquet PARTITIONED BY (part string)" + val replaceSql = "REPLACE TABLE my_tab (id bigint) USING parquet PARTITIONED BY (part string)" + val expectedTableSpec = TableSpec( + Seq("my_tab"), + Some(new StructType().add("id", LongType).add("part", StringType)), + Seq(IdentityTransform(FieldReference("part"))), + None, + Map.empty[String, String], + Some("parquet"), + Map.empty[String, String], + None, + None, + None) + Seq(createSql, replaceSql).foreach { sql => + testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) + } + } + + test("create/replace table - mixed partition references and column definitions") { + val createSql = "CREATE TABLE my_tab (id bigint, p1 string) PARTITIONED BY (p1, p2 string)" + val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") + Seq(createSql, replaceSql).foreach { sql => + assertUnsupported(sql, Seq( + "PARTITION BY: Cannot mix partition expressions and partition columns", + "Expressions: p1", + "Columns: p2 string")) + } + + val createSqlWithExpr = + "CREATE TABLE my_tab (id bigint, p1 string) PARTITIONED BY (p2 string, truncate(p1, 16))" + val replaceSqlWithExpr = createSqlWithExpr.replaceFirst("CREATE", "REPLACE") + Seq(createSqlWithExpr, replaceSqlWithExpr).foreach { sql => + assertUnsupported(sql, Seq( + "PARTITION BY: Cannot mix partition expressions and partition columns", + "Expressions: truncate(p1, 16)", + "Columns: p2 string")) + } + } + + test("create/replace table - stored as") { + val createSql = + """CREATE TABLE my_tab (id bigint) + |PARTITIONED BY (part string) + |STORED AS parquet + """.stripMargin + val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") + val expectedTableSpec = TableSpec( + Seq("my_tab"), + Some(new StructType().add("id", LongType).add("part", StringType)), + Seq(IdentityTransform(FieldReference("part"))), + None, + Map.empty[String, String], + None, + Map.empty[String, String], + None, + None, + Some(SerdeInfo(storedAs = Some("parquet")))) + Seq(createSql, replaceSql).foreach { sql => + testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) + } + } + + test("create/replace table - stored as format with serde") { + Seq("sequencefile", "textfile", "rcfile").foreach { format => + val createSql = + s"""CREATE TABLE my_tab (id bigint) + |PARTITIONED BY (part string) + |STORED AS $format + |ROW FORMAT SERDE 'customSerde' + |WITH SERDEPROPERTIES ('prop'='value') + """.stripMargin + val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") + val expectedTableSpec = TableSpec( + Seq("my_tab"), + Some(new StructType().add("id", LongType).add("part", StringType)), + Seq(IdentityTransform(FieldReference("part"))), + None, + Map.empty[String, String], + None, + Map.empty[String, String], + None, + None, + Some(SerdeInfo(storedAs = Some(format), serde = Some("customSerde"), serdeProperties = Map( + "prop" -> "value" + )))) + Seq(createSql, replaceSql).foreach { sql => + testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) + } + } + + val createSql = + s"""CREATE TABLE my_tab (id bigint) + |PARTITIONED BY (part string) + |STORED AS otherFormat + |ROW FORMAT SERDE 'customSerde' + |WITH SERDEPROPERTIES ('prop'='value') + """.stripMargin + val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") + Seq(createSql, replaceSql).foreach { sql => + assertUnsupported(sql, Seq("ROW FORMAT SERDE is incompatible with format 'otherFormat'")) + } + } + + test("create/replace table - stored as format with delimited clauses") { + val createSql = + s"""CREATE TABLE my_tab (id bigint) + |PARTITIONED BY (part string) + |STORED AS textfile + |ROW FORMAT DELIMITED + |FIELDS TERMINATED BY ',' ESCAPED BY '\\\\' -- double escape for Scala and for SQL + |COLLECTION ITEMS TERMINATED BY '#' + |MAP KEYS TERMINATED BY '=' + |LINES TERMINATED BY '\\n' + """.stripMargin + val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") + val expectedTableSpec = TableSpec( + Seq("my_tab"), + Some(new StructType().add("id", LongType).add("part", StringType)), + Seq(IdentityTransform(FieldReference("part"))), + None, + Map.empty[String, String], + None, + Map.empty[String, String], + None, + None, + Some(SerdeInfo(storedAs = Some("textfile"), serdeProperties = Map( + "field.delim" -> ",", "serialization.format" -> ",", "escape.delim" -> "\\", + "colelction.delim" -> "#", "mapkey.delim" -> "=", "line.delim" -> "\n" + )))) + Seq(createSql, replaceSql).foreach { sql => + testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) + } + + val createFailSql = + s"""CREATE TABLE my_tab (id bigint) + |PARTITIONED BY (part string) + |STORED AS otherFormat + |ROW FORMAT DELIMITED + |FIELDS TERMINATED BY ',' + """.stripMargin + val replaceFailSql = createFailSql.replaceFirst("CREATE", "REPLACE") + Seq(createFailSql, replaceFailSql).foreach { sql => + assertUnsupported(sql, Seq( + "ROW FORMAT DELIMITED is only compatible with 'textfile', not 'otherFormat'")) + } + } + + test("create/replace table - stored as inputformat/outputformat") { + val createSql = + """CREATE TABLE my_tab (id bigint) + |PARTITIONED BY (part string) + |STORED AS INPUTFORMAT 'inFormat' OUTPUTFORMAT 'outFormat' + """.stripMargin + val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") + val expectedTableSpec = TableSpec( + Seq("my_tab"), + Some(new StructType().add("id", LongType).add("part", StringType)), + Seq(IdentityTransform(FieldReference("part"))), + None, + Map.empty[String, String], + None, + Map.empty[String, String], + None, + None, + Some(SerdeInfo(formatClasses = Some(FormatClasses("inFormat", "outFormat"))))) + Seq(createSql, replaceSql).foreach { sql => + testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) + } + } + + test("create/replace table - stored as inputformat/outputformat with serde") { + val createSql = + """CREATE TABLE my_tab (id bigint) + |PARTITIONED BY (part string) + |STORED AS INPUTFORMAT 'inFormat' OUTPUTFORMAT 'outFormat' + |ROW FORMAT SERDE 'customSerde' + """.stripMargin + val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") + val expectedTableSpec = TableSpec( + Seq("my_tab"), + Some(new StructType().add("id", LongType).add("part", StringType)), + Seq(IdentityTransform(FieldReference("part"))), + None, + Map.empty[String, String], + None, + Map.empty[String, String], + None, + None, + Some(SerdeInfo( + formatClasses = Some(FormatClasses("inFormat", "outFormat")), + serde = Some("customSerde")))) + Seq(createSql, replaceSql).foreach { sql => + testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) + } + } + + test("create/replace table - using with stored as") { + val createSql = + """CREATE TABLE my_tab (id bigint, part string) + |USING parquet + |STORED AS parquet + """.stripMargin + val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") + Seq(createSql, replaceSql).foreach { sql => + assertUnsupported(sql, Seq("CREATE TABLE ... USING ... STORED AS")) + } + } + + test("create/replace table - using with row format serde") { + val createSql = + """CREATE TABLE my_tab (id bigint, part string) + |USING parquet + |ROW FORMAT SERDE 'customSerde' + """.stripMargin + val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") + Seq(createSql, replaceSql).foreach { sql => + assertUnsupported(sql, Seq("CREATE TABLE ... USING ... ROW FORMAT SERDE")) + } + } + + test("create/replace table - using with row format delimited") { + val createSql = + """CREATE TABLE my_tab (id bigint, part string) + |USING parquet + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + """.stripMargin + val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") + Seq(createSql, replaceSql).foreach { sql => + assertUnsupported(sql, Seq("CREATE TABLE ... USING ... ROW FORMAT DELIMITED")) + } + } + + test("create/replace table - stored by") { + val createSql = + """CREATE TABLE my_tab (id bigint, p1 string) + |STORED BY 'handler' + """.stripMargin + val replaceSql = createSql.replaceFirst("CREATE", "REPLACE") + Seq(createSql, replaceSql).foreach { sql => + assertUnsupported(sql, Seq("stored by")) + } + } + + test("Unsupported skew clause - create/replace table") { + intercept("CREATE TABLE my_tab (id bigint) SKEWED BY (id) ON (1,2,3)", + "CREATE TABLE ... SKEWED BY") + intercept("REPLACE TABLE my_tab (id bigint) SKEWED BY (id) ON (1,2,3)", + "CREATE TABLE ... SKEWED BY") + } + test("Duplicate clauses - create/replace table") { def createTableHeader(duplicateClause: String): String = { - s"CREATE TABLE my_tab(a INT, b STRING) USING parquet $duplicateClause $duplicateClause" + s"CREATE TABLE my_tab(a INT, b STRING) $duplicateClause $duplicateClause" } def replaceTableHeader(duplicateClause: String): String = { - s"CREATE TABLE my_tab(a INT, b STRING) USING parquet $duplicateClause $duplicateClause" + s"CREATE TABLE my_tab(a INT, b STRING) $duplicateClause $duplicateClause" } intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')"), @@ -281,6 +579,14 @@ class DDLParserSuite extends AnalysisTest { "Found duplicate clauses: CLUSTERED BY") intercept(createTableHeader("PARTITIONED BY (b)"), "Found duplicate clauses: PARTITIONED BY") + intercept(createTableHeader("PARTITIONED BY (c int)"), + "Found duplicate clauses: PARTITIONED BY") + intercept(createTableHeader("STORED AS parquet"), + "Found duplicate clauses: STORED AS") + intercept(createTableHeader("STORED AS INPUTFORMAT 'in' OUTPUTFORMAT 'out'"), + "Found duplicate clauses: STORED AS") + intercept(createTableHeader("ROW FORMAT SERDE 'serde'"), + "Found duplicate clauses: ROW FORMAT") intercept(replaceTableHeader("TBLPROPERTIES('test' = 'test2')"), "Found duplicate clauses: TBLPROPERTIES") @@ -292,6 +598,14 @@ class DDLParserSuite extends AnalysisTest { "Found duplicate clauses: CLUSTERED BY") intercept(replaceTableHeader("PARTITIONED BY (b)"), "Found duplicate clauses: PARTITIONED BY") + intercept(replaceTableHeader("PARTITIONED BY (c int)"), + "Found duplicate clauses: PARTITIONED BY") + intercept(replaceTableHeader("STORED AS parquet"), + "Found duplicate clauses: STORED AS") + intercept(replaceTableHeader("STORED AS INPUTFORMAT 'in' OUTPUTFORMAT 'out'"), + "Found duplicate clauses: STORED AS") + intercept(replaceTableHeader("ROW FORMAT SERDE 'serde'"), + "Found duplicate clauses: ROW FORMAT") } test("support for other types in OPTIONS") { @@ -317,6 +631,7 @@ class DDLParserSuite extends AnalysisTest { Some("json"), Map("a" -> "1", "b" -> "0.1", "c" -> "true"), None, + None, None), expectedIfNotExists = false) } @@ -372,7 +687,8 @@ class DDLParserSuite extends AnalysisTest { Some("parquet"), Map.empty[String, String], Some("/user/external/page_view"), - Some("This is the staging page view table")) + Some("This is the staging page view table"), + None) Seq(s1, s2, s3, s4).foreach { sql => testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = true) } @@ -1621,11 +1937,13 @@ class DDLParserSuite extends AnalysisTest { test("TRUNCATE table") { comparePlans( parsePlan("TRUNCATE TABLE a.b.c"), - TruncateTableStatement(Seq("a", "b", "c"), None)) + TruncateTable(UnresolvedTable(Seq("a", "b", "c"), "TRUNCATE TABLE"), None)) comparePlans( parsePlan("TRUNCATE TABLE a.b.c PARTITION(ds='2017-06-10')"), - TruncateTableStatement(Seq("a", "b", "c"), Some(Map("ds" -> "2017-06-10")))) + TruncateTable( + UnresolvedTable(Seq("a", "b", "c"), "TRUNCATE TABLE"), + Some(Map("ds" -> "2017-06-10")))) } test("REFRESH TABLE") { @@ -1641,13 +1959,13 @@ class DDLParserSuite extends AnalysisTest { val sql4 = "SHOW COLUMNS FROM db1.t1 IN db1" val parsed1 = parsePlan(sql1) - val expected1 = ShowColumnsStatement(Seq("t1"), None) + val expected1 = ShowColumns(UnresolvedTableOrView(Seq("t1")), None) val parsed2 = parsePlan(sql2) - val expected2 = ShowColumnsStatement(Seq("db1", "t1"), None) + val expected2 = ShowColumns(UnresolvedTableOrView(Seq("db1", "t1")), None) val parsed3 = parsePlan(sql3) - val expected3 = ShowColumnsStatement(Seq("t1"), Some(Seq("db1"))) + val expected3 = ShowColumns(UnresolvedTableOrView(Seq("db1", "t1")), Some(Seq("db1"))) val parsed4 = parsePlan(sql4) - val expected4 = ShowColumnsStatement(Seq("db1", "t1"), Some(Seq("db1"))) + val expected4 = ShowColumns(UnresolvedTableOrView(Seq("db1", "t1")), Some(Seq("db1"))) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) @@ -2103,7 +2421,9 @@ class DDLParserSuite extends AnalysisTest { provider: Option[String], options: Map[String, String], location: Option[String], - comment: Option[String]) + comment: Option[String], + serdeInfo: Option[SerdeInfo], + external: Boolean = false) private object TableSpec { def apply(plan: LogicalPlan): TableSpec = { @@ -2118,7 +2438,9 @@ class DDLParserSuite extends AnalysisTest { create.provider, create.options, create.location, - create.comment) + create.comment, + create.serde, + create.external) case replace: ReplaceTableStatement => TableSpec( replace.tableName, @@ -2129,7 +2451,8 @@ class DDLParserSuite extends AnalysisTest { replace.provider, replace.options, replace.location, - replace.comment) + replace.comment, + replace.serde) case ctas: CreateTableAsSelectStatement => TableSpec( ctas.tableName, @@ -2140,7 +2463,9 @@ class DDLParserSuite extends AnalysisTest { ctas.provider, ctas.options, ctas.location, - ctas.comment) + ctas.comment, + ctas.serde, + ctas.external) case rtas: ReplaceTableAsSelectStatement => TableSpec( rtas.tableName, @@ -2151,7 +2476,8 @@ class DDLParserSuite extends AnalysisTest { rtas.provider, rtas.options, rtas.location, - rtas.comment) + rtas.comment, + rtas.serde) case other => fail(s"Expected to parse Create, CTAS, Replace, or RTAS plan" + s" from query, got ${other.getClass.getName}.") @@ -2177,8 +2503,7 @@ class DDLParserSuite extends AnalysisTest { CommentOnTable(UnresolvedTable(Seq("a", "b", "c"), "COMMENT ON TABLE"), "xYz")) } - // TODO: ignored by SPARK-31707, restore the test after create table syntax unification - ignore("create table - without using") { + test("create table - without using") { val sql = "CREATE TABLE 1m.2g(a INT)" val expectedTableSpec = TableSpec( Seq("1m", "2g"), @@ -2189,6 +2514,7 @@ class DDLParserSuite extends AnalysisTest { None, Map.empty[String, String], None, + None, None) testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala index 7790f467a8..1290f77034 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala @@ -39,6 +39,7 @@ class ScanOperationSuite extends SparkFunSuite { assert(projects(0) === colB) assert(projects(1) === aliasR) assert(filters.size === 1) + case _ => assert(false) } } @@ -50,6 +51,7 @@ class ScanOperationSuite extends SparkFunSuite { assert(projects(0) === colA) assert(projects(1) === colB) assert(filters.size === 1) + case _ => assert(false) } } @@ -65,6 +67,7 @@ class ScanOperationSuite extends SparkFunSuite { assert(projects.size === 2) assert(projects(0) === colA) assert(projects(1) === aliasId) + case _ => assert(false) } } @@ -81,6 +84,7 @@ class ScanOperationSuite extends SparkFunSuite { assert(projects(0) === colA) assert(projects(1) === aliasR) assert(filters.size === 1) + case _ => assert(false) } } @@ -93,6 +97,7 @@ class ScanOperationSuite extends SparkFunSuite { assert(projects(0) === colA) assert(projects(1) === aliasR) assert(filters.size === 1) + case _ => assert(false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala index 1e430351b5..9c3aaea0f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -45,7 +45,7 @@ class ArrayDataIndexedSeqSuite extends SparkFunSuite { if (e != null) { elementDt match { // For Nan, etc. - case FloatType | DoubleType => assert(seq(i).equals(e)) + case FloatType | DoubleType => assert(seq(i) == e) case _ => assert(seq(i) === e) } } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala index 23987e909a..ba762a58b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType @@ -96,4 +97,25 @@ class InMemoryPartitionTable( override protected def addPartitionKey(key: Seq[Any]): Unit = { memoryTablePartitions.put(InternalRow.fromSeq(key), Map.empty[String, String].asJava) } + + override def listPartitionByNames( + names: Array[String], + ident: InternalRow): Array[InternalRow] = { + assert(names.length == ident.numFields, + s"Number of partition names (${names.length}) must be equal to " + + s"the number of partition values (${ident.numFields}).") + val schema = partitionSchema + assert(names.forall(fieldName => schema.fieldNames.contains(fieldName)), + s"Some partition names ${names.mkString("[", ", ", "]")} don't belong to " + + s"the partition schema '${schema.sql}'.") + val indexes = names.map(schema.fieldIndex) + val dataTypes = names.map(schema(_).dataType) + val currentRow = new GenericInternalRow(new Array[Any](names.length)) + memoryTablePartitions.keySet().asScala.filter { key => + for (i <- 0 until names.length) { + currentRow.values(i) = key.get(indexes(i), dataTypes(i)) + } + currentRow == ident + }.toArray + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index c93053abc5..ffff00b54f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -156,7 +156,9 @@ class InMemoryTable( throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case BucketTransform(numBuckets, ref) => - (extractor(ref.fieldNames, schema, row).hashCode() & Integer.MAX_VALUE) % numBuckets + val (value, dataType) = extractor(ref.fieldNames, schema, row) + val valueHashCode = if (value == null) 0 else value.hashCode + ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala index e8e28e3422..caf7e91612 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.{InMemoryPartitionTable, InMemoryTableCatalog} +import org.apache.spark.sql.connector.{InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog} import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference} import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -140,4 +140,45 @@ class SupportsPartitionManagementSuite extends SparkFunSuite { partTable.dropPartition(partIdent1) assert(partTable.listPartitionIdentifiers(InternalRow.empty).isEmpty) } + + test("listPartitionByNames") { + val partCatalog = new InMemoryPartitionTableCatalog + partCatalog.initialize("test", CaseInsensitiveStringMap.empty()) + val table = partCatalog.createTable( + ident, + new StructType() + .add("col0", IntegerType) + .add("part0", IntegerType) + .add("part1", StringType), + Array(LogicalExpressions.identity(ref("part0")), LogicalExpressions.identity(ref("part1"))), + util.Collections.emptyMap[String, String]) + val partTable = table.asInstanceOf[InMemoryPartitionTable] + + Seq( + InternalRow(0, "abc"), + InternalRow(0, "def"), + InternalRow(1, "abc")).foreach { partIdent => + partTable.createPartition(partIdent, new util.HashMap[String, String]()) + } + + Seq( + (Array("part0", "part1"), InternalRow(0, "abc")) -> Set(InternalRow(0, "abc")), + (Array("part0"), InternalRow(0)) -> Set(InternalRow(0, "abc"), InternalRow(0, "def")), + (Array("part1"), InternalRow("abc")) -> Set(InternalRow(0, "abc"), InternalRow(1, "abc")), + (Array.empty[String], InternalRow.empty) -> + Set(InternalRow(0, "abc"), InternalRow(0, "def"), InternalRow(1, "abc")), + (Array("part0", "part1"), InternalRow(3, "xyz")) -> Set(), + (Array("part1"), InternalRow(3.14f)) -> Set() + ).foreach { case ((names, idents), expected) => + assert(partTable.listPartitionByNames(names, idents).toSet === expected) + } + // Check invalid parameters + Seq( + (Array("part0", "part1"), InternalRow(0)), + (Array("col0", "part1"), InternalRow(0, 1)), + (Array("wrong"), InternalRow("invalid")) + ).foreach { case (names, idents) => + intercept[AssertionError](partTable.listPartitionByNames(names, idents)) + } + } } diff --git a/sql/core/benchmarks/SubExprEliminationBenchmark-jdk11-results.txt b/sql/core/benchmarks/SubExprEliminationBenchmark-jdk11-results.txt index 3d2b2e5c8e..1eb7b534d2 100644 --- a/sql/core/benchmarks/SubExprEliminationBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/SubExprEliminationBenchmark-jdk11-results.txt @@ -5,11 +5,21 @@ Benchmark for performance of subexpression elimination Preparing data for benchmarking ... OpenJDK 64-Bit Server VM 11.0.9+11 on Mac OS X 10.15.6 Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz -from_json as subExpr: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -subexpressionElimination off, codegen on 25932 26908 916 0.0 259320042.3 1.0X -subexpressionElimination off, codegen off 26085 26159 65 0.0 260848905.0 1.0X -subexpressionElimination on, codegen on 2860 2939 72 0.0 28603312.9 9.1X -subexpressionElimination on, codegen off 2517 2617 93 0.0 25165157.7 10.3X +from_json as subExpr in Project: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +subExprElimination false, codegen: true 26447 27127 605 0.0 264467933.4 1.0X +subExprElimination false, codegen: false 25673 26035 546 0.0 256732419.1 1.0X +subExprElimination true, codegen: true 1384 1448 102 0.0 13842910.3 19.1X +subExprElimination true, codegen: false 1244 1347 123 0.0 12442389.3 21.3X + +Preparing data for benchmarking ... +OpenJDK 64-Bit Server VM 11.0.9+11 on Mac OS X 10.15.6 +Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz +from_json as subExpr in Filter: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +subexpressionElimination off, codegen on 34631 35449 833 0.0 346309884.0 1.0X +subexpressionElimination off, codegen on 34480 34851 353 0.0 344798490.4 1.0X +subexpressionElimination off, codegen on 16618 16811 291 0.0 166176642.6 2.1X +subexpressionElimination off, codegen on 34316 34667 310 0.0 343157094.7 1.0X diff --git a/sql/core/benchmarks/SubExprEliminationBenchmark-results.txt b/sql/core/benchmarks/SubExprEliminationBenchmark-results.txt index ca2a9c6497..801f519ca7 100644 --- a/sql/core/benchmarks/SubExprEliminationBenchmark-results.txt +++ b/sql/core/benchmarks/SubExprEliminationBenchmark-results.txt @@ -5,11 +5,21 @@ Benchmark for performance of subexpression elimination Preparing data for benchmarking ... OpenJDK 64-Bit Server VM 1.8.0_265-b01 on Mac OS X 10.15.6 Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz -from_json as subExpr: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -subexpressionElimination off, codegen on 26503 27622 1937 0.0 265033362.4 1.0X -subexpressionElimination off, codegen off 24920 25376 430 0.0 249196978.2 1.1X -subexpressionElimination on, codegen on 2421 2466 39 0.0 24213606.1 10.9X -subexpressionElimination on, codegen off 2360 2435 87 0.0 23604320.7 11.2X +from_json as subExpr in Project: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +subExprElimination false, codegen: true 22767 23240 424 0.0 227665316.7 1.0X +subExprElimination false, codegen: false 22869 23351 465 0.0 228693464.1 1.0X +subExprElimination true, codegen: true 1328 1340 10 0.0 13280056.2 17.1X +subExprElimination true, codegen: false 1248 1276 31 0.0 12476135.1 18.2X + +Preparing data for benchmarking ... +OpenJDK 64-Bit Server VM 1.8.0_265-b01 on Mac OS X 10.15.6 +Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz +from_json as subExpr in Filter: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +subexpressionElimination off, codegen on 37691 38846 1004 0.0 376913767.9 1.0X +subexpressionElimination off, codegen on 37852 39124 1103 0.0 378517745.5 1.0X +subexpressionElimination off, codegen on 22900 23085 202 0.0 229000242.5 1.6X +subexpressionElimination off, codegen on 38298 38598 374 0.0 382978731.3 1.0X diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 31b4c158aa..a8688bdf15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -658,6 +658,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { extraOptions.get("path"), extraOptions.get(TableCatalog.PROP_COMMENT), extraOptions.toMap, + None, orCreate = true) // Create the table if it doesn't exist case (other, _) => @@ -675,7 +676,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { extraOptions.get("path"), extraOptions.get(TableCatalog.PROP_COMMENT), extraOptions.toMap, - ifNotExists = other == SaveMode.Ignore) + None, + ifNotExists = other == SaveMode.Ignore, + external = false) } runCommand(df.sparkSession, "saveAsTable") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index d55b5c3103..9a49fc3d74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -119,7 +119,9 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) None, None, options.toMap, - ifNotExists = false) + None, + ifNotExists = false, + external = false) } } @@ -207,6 +209,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) None, None, options.toMap, + None, orCreate = orCreate) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 303ae47f06..f49caf7f04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType} /** @@ -265,16 +266,17 @@ class ResolveSessionCatalog( // For CREATE TABLE [AS SELECT], we should use the v1 command if the catalog is resolved to the // session catalog and the table provider is not v2. case c @ CreateTableStatement( - SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => + SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) => assertNoNullTypeInSchema(c.tableSchema) - val provider = c.provider.getOrElse(conf.defaultDataSourceName) + val (storageFormat, provider) = getStorageFormatAndProvider( + c.provider, c.options, c.location, c.serde, ctas = false) if (!isV2Provider(provider)) { if (!DDLUtils.isHiveTable(Some(provider))) { assertNoCharTypeInSchema(c.tableSchema) } val tableDesc = buildCatalogTable(tbl.asTableIdentifier, c.tableSchema, - c.partitioning, c.bucketSpec, c.properties, provider, c.options, c.location, - c.comment, c.ifNotExists) + c.partitioning, c.bucketSpec, c.properties, provider, c.location, + c.comment, storageFormat, c.external) val mode = if (c.ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTable(tableDesc, mode, None) } else { @@ -285,30 +287,32 @@ class ResolveSessionCatalog( c.tableSchema, // convert the bucket spec and add it as a transform c.partitioning ++ c.bucketSpec.map(_.asTransform), - convertTableProperties(c.properties, c.options, c.location, c.comment, Some(provider)), + convertTableProperties(c), ignoreIfExists = c.ifNotExists) } case c @ CreateTableAsSelectStatement( - SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => + SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _, _) => if (c.asSelect.resolved) { assertNoNullTypeInSchema(c.asSelect.schema) } - val provider = c.provider.getOrElse(conf.defaultDataSourceName) + val (storageFormat, provider) = getStorageFormatAndProvider( + c.provider, c.options, c.location, c.serde, ctas = true) if (!isV2Provider(provider)) { val tableDesc = buildCatalogTable(tbl.asTableIdentifier, new StructType, - c.partitioning, c.bucketSpec, c.properties, provider, c.options, c.location, - c.comment, c.ifNotExists) + c.partitioning, c.bucketSpec, c.properties, provider, c.location, + c.comment, storageFormat, c.external) val mode = if (c.ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTable(tableDesc, mode, Some(c.asSelect)) } else { + assertNoCharTypeInSchema(c.schema) CreateTableAsSelect( catalog.asTableCatalog, tbl.asIdentifier, // convert the bucket spec and add it as a transform c.partitioning ++ c.bucketSpec.map(_.asTransform), c.asSelect, - convertTableProperties(c.properties, c.options, c.location, c.comment, Some(provider)), + convertTableProperties(c), writeOptions = c.writeOptions, ignoreIfExists = c.ifNotExists) } @@ -322,7 +326,7 @@ class ResolveSessionCatalog( // For REPLACE TABLE [AS SELECT], we should fail if the catalog is resolved to the // session catalog and the table provider is not v2. case c @ ReplaceTableStatement( - SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => + SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => assertNoNullTypeInSchema(c.tableSchema) val provider = c.provider.getOrElse(conf.defaultDataSourceName) if (!isV2Provider(provider)) { @@ -335,12 +339,12 @@ class ResolveSessionCatalog( c.tableSchema, // convert the bucket spec and add it as a transform c.partitioning ++ c.bucketSpec.map(_.asTransform), - convertTableProperties(c.properties, c.options, c.location, c.comment, Some(provider)), + convertTableProperties(c), orCreate = c.orCreate) } case c @ ReplaceTableAsSelectStatement( - SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _) => + SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) => if (c.asSelect.resolved) { assertNoNullTypeInSchema(c.asSelect.schema) } @@ -354,7 +358,7 @@ class ResolveSessionCatalog( // convert the bucket spec and add it as a transform c.partitioning ++ c.bucketSpec.map(_.asTransform), c.asSelect, - convertTableProperties(c.properties, c.options, c.location, c.comment, Some(provider)), + convertTableProperties(c), writeOptions = c.writeOptions, orCreate = c.orCreate) } @@ -456,10 +460,9 @@ class ResolveSessionCatalog( val name = parseTempViewOrV1Table(tbl, "UNCACHE TABLE") UncacheTableCommand(name.asTableIdentifier, ifExists) - case TruncateTableStatement(tbl, partitionSpec) => - val v1TableName = parseV1Table(tbl, "TRUNCATE TABLE") + case TruncateTable(ResolvedV1TableIdentifier(ident), partitionSpec) => TruncateTableCommand( - v1TableName.asTableIdentifier, + ident.asTableIdentifier, partitionSpec) case ShowPartitionsStatement(tbl, partitionSpec) => @@ -468,25 +471,13 @@ class ResolveSessionCatalog( v1TableName.asTableIdentifier, partitionSpec) - case ShowColumnsStatement(tbl, ns) => - if (ns.isDefined && ns.get.length > 1) { - throw new AnalysisException( - s"Namespace name should have only one part if specified: ${ns.get.quoted}") - } - // Use namespace only if table name doesn't specify it. If namespace is already specified - // in the table name, it's checked against the given namespace below. - val nameParts = if (ns.isDefined && tbl.length == 1) { - ns.get ++ tbl - } else { - tbl - } - val sql = "SHOW COLUMNS" - val v1TableName = parseTempViewOrV1Table(nameParts, sql).asTableIdentifier + case ShowColumns(ResolvedV1TableOrViewIdentifier(ident), ns) => + val v1TableName = ident.asTableIdentifier val resolver = conf.resolver val db = ns match { case Some(db) if v1TableName.database.exists(!resolver(_, db.head)) => throw new AnalysisException( - s"SHOW COLUMNS with conflicting databases: " + + "SHOW COLUMNS with conflicting databases: " + s"'${db.head}' != '${v1TableName.database.get}'") case _ => ns.map(_.head) } @@ -634,6 +625,64 @@ class ResolveSessionCatalog( case _ => throw new AnalysisException(s"$sql is only supported with temp views or v1 tables.") } + private def getStorageFormatAndProvider( + provider: Option[String], + options: Map[String, String], + location: Option[String], + maybeSerdeInfo: Option[SerdeInfo], + ctas: Boolean): (CatalogStorageFormat, String) = { + val nonHiveStorageFormat = CatalogStorageFormat.empty.copy( + locationUri = location.map(CatalogUtils.stringToURI), + properties = options) + val defaultHiveStorage = HiveSerDe.getDefaultStorage(conf).copy( + locationUri = location.map(CatalogUtils.stringToURI), + properties = options) + + if (provider.isDefined) { + // The parser guarantees that USING and STORED AS/ROW FORMAT won't co-exist. + if (maybeSerdeInfo.isDefined) { + throw new AnalysisException( + s"Cannot create table with both USING $provider and ${maybeSerdeInfo.get.describe}") + } + (nonHiveStorageFormat, provider.get) + } else if (maybeSerdeInfo.isDefined) { + val serdeInfo = maybeSerdeInfo.get + SerdeInfo.checkSerdePropMerging(serdeInfo.serdeProperties, defaultHiveStorage.properties) + val storageFormat = if (serdeInfo.storedAs.isDefined) { + // If `STORED AS fileFormat` is used, infer inputFormat, outputFormat and serde from it. + HiveSerDe.sourceToSerDe(serdeInfo.storedAs.get) match { + case Some(hiveSerde) => + defaultHiveStorage.copy( + inputFormat = hiveSerde.inputFormat.orElse(defaultHiveStorage.inputFormat), + outputFormat = hiveSerde.outputFormat.orElse(defaultHiveStorage.outputFormat), + // User specified serde takes precedence over the one inferred from file format. + serde = serdeInfo.serde.orElse(hiveSerde.serde).orElse(defaultHiveStorage.serde), + properties = serdeInfo.serdeProperties ++ defaultHiveStorage.properties) + case _ => throw new AnalysisException( + s"STORED AS with file format '${serdeInfo.storedAs.get}' is invalid.") + } + } else { + defaultHiveStorage.copy( + inputFormat = + serdeInfo.formatClasses.map(_.input).orElse(defaultHiveStorage.inputFormat), + outputFormat = + serdeInfo.formatClasses.map(_.output).orElse(defaultHiveStorage.outputFormat), + serde = serdeInfo.serde.orElse(defaultHiveStorage.serde), + properties = serdeInfo.serdeProperties ++ defaultHiveStorage.properties) + } + (storageFormat, DDLUtils.HIVE_PROVIDER) + } else { + // If neither USING nor STORED AS/ROW FORMAT is specified, we create native data source + // tables if it's a CTAS and `conf.convertCTAS` is true. + // TODO: create native data source table by default for non-CTAS. + if (ctas && conf.convertCTAS) { + (nonHiveStorageFormat, conf.defaultDataSourceName) + } else { + (defaultHiveStorage, DDLUtils.HIVE_PROVIDER) + } + } + } + private def buildCatalogTable( table: TableIdentifier, schema: StructType, @@ -641,13 +690,19 @@ class ResolveSessionCatalog( bucketSpec: Option[BucketSpec], properties: Map[String, String], provider: String, - options: Map[String, String], location: Option[String], comment: Option[String], - ifNotExists: Boolean): CatalogTable = { - val storage = CatalogStorageFormat.empty.copy( - locationUri = location.map(CatalogUtils.stringToURI), - properties = options) + storageFormat: CatalogStorageFormat, + external: Boolean): CatalogTable = { + if (external) { + if (DDLUtils.isHiveTable(Some(provider))) { + if (location.isEmpty) { + throw new AnalysisException(s"CREATE EXTERNAL TABLE must be accompanied by LOCATION") + } + } else { + throw new AnalysisException(s"Operation not allowed: CREATE EXTERNAL TABLE ... USING") + } + } val tableType = if (location.isDefined) { CatalogTableType.EXTERNAL @@ -658,7 +713,7 @@ class ResolveSessionCatalog( CatalogTable( identifier = table, tableType = tableType, - storage = storage, + storage = storageFormat, schema = schema, provider = Some(provider), partitionColumnNames = partitioning.asPartitionColumns, @@ -730,6 +785,9 @@ class ResolveSessionCatalog( } private def isV2Provider(provider: String): Boolean = { + // Return earlier since `lookupDataSourceV2` may fail to resolve provider "hive" to + // `HiveFileFormat`, when running tests in sql/core. + if (DDLUtils.isHiveTable(Some(provider))) return false DataSource.lookupDataSourceV2(provider, conf) match { // TODO(SPARK-28396): Currently file source v2 can't work with tables. case Some(_: FileDataSourceV2) => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 85476bcd21..a92f0775f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -26,7 +26,6 @@ import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.tree.TerminalNode -import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Expression @@ -37,7 +36,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} -import org.apache.spark.sql.types.StructType /** * Concrete parser for Spark SQL statements. @@ -279,7 +277,7 @@ class SparkSqlAstBuilder extends AstBuilder { operationNotAllowed("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx) } - val (_, _, _, options, location, _) = visitCreateTableClauses(ctx.createTableClauses()) + val (_, _, _, _, options, location, _, _) = visitCreateTableClauses(ctx.createTableClauses()) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText).getOrElse( throw new ParseException("CREATE TEMPORARY TABLE without a provider is not allowed.", ctx)) val schema = Option(ctx.colTypeList()).map(createSchema) @@ -382,153 +380,34 @@ class SparkSqlAstBuilder extends AstBuilder { } } - /** - * Create a Hive serde table, returning a [[CreateTable]] logical plan. - * - * This is a legacy syntax for Hive compatibility, we recommend users to use the Spark SQL - * CREATE TABLE syntax to create Hive serde table, e.g. "CREATE TABLE ... USING hive ..." - * - * Note: several features are currently not supported - temporary tables, bucketing, - * skewed columns and storage handlers (STORED BY). - * - * Expected format: - * {{{ - * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name - * [(col1[:] data_type [COMMENT col_comment], ...)] - * create_table_clauses - * [AS select_statement]; - * - * create_table_clauses (order insensitive): - * [COMMENT table_comment] - * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] - * [ROW FORMAT row_format] - * [STORED AS file_format] - * [LOCATION path] - * [TBLPROPERTIES (property_name=property_value, ...)] - * }}} - */ - override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) { - val (ident, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - // TODO: implement temporary tables - if (temp) { - throw new ParseException( - "CREATE TEMPORARY TABLE is not supported yet. " + - "Please use CREATE TEMPORARY VIEW as an alternative.", ctx) - } - if (ctx.skewSpec.size > 0) { - operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) - } - - checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) - checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) - checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx) - checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) - checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) - checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) - checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) - - val dataCols = Option(ctx.columns).map(visitColTypeList).getOrElse(Nil) - val partitionCols = Option(ctx.partitionColumns).map(visitColTypeList).getOrElse(Nil) - val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) - val selectQuery = Option(ctx.query).map(plan) - val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) - - // Note: Hive requires partition columns to be distinct from the schema, so we need - // to include the partition columns here explicitly - val schema = StructType(dataCols ++ partitionCols) - - // Storage format - val defaultStorage = HiveSerDe.getDefaultStorage(conf) - validateRowFormatFileFormat( - ctx.rowFormat.asScala.toSeq, ctx.createFileFormat.asScala.toSeq, ctx) - val fileStorage = ctx.createFileFormat.asScala.headOption.map(visitCreateFileFormat) - .getOrElse(CatalogStorageFormat.empty) - val rowStorage = ctx.rowFormat.asScala.headOption.map(visitRowFormat) - .getOrElse(CatalogStorageFormat.empty) - val location = visitLocationSpecList(ctx.locationSpec()) - // If we are creating an EXTERNAL table, then the LOCATION field is required - if (external && location.isEmpty) { - operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) - } - - val locUri = location.map(CatalogUtils.stringToURI(_)) - val storage = CatalogStorageFormat( - locationUri = locUri, - inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat), - outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat), - serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde), - compressed = false, - properties = rowStorage.properties ++ fileStorage.properties) - // If location is defined, we'll assume this is an external table. - // Otherwise, we may accidentally delete existing data. - val tableType = if (external || location.isDefined) { - CatalogTableType.EXTERNAL + private def toStorageFormat( + location: Option[String], + maybeSerdeInfo: Option[SerdeInfo], + ctx: ParserRuleContext): CatalogStorageFormat = { + if (maybeSerdeInfo.isEmpty) { + CatalogStorageFormat.empty.copy(locationUri = location.map(CatalogUtils.stringToURI)) } else { - CatalogTableType.MANAGED - } - - val name = tableIdentifier(ident, "CREATE TABLE ... STORED AS ...", ctx) - - // TODO support the sql text - have a proper location for this! - val tableDesc = CatalogTable( - identifier = name, - tableType = tableType, - storage = storage, - schema = schema, - bucketSpec = bucketSpec, - provider = Some(DDLUtils.HIVE_PROVIDER), - partitionColumnNames = partitionCols.map(_.name), - properties = properties, - comment = visitCommentSpecList(ctx.commentSpec())) - - val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists - - selectQuery match { - case Some(q) => - // Don't allow explicit specification of schema for CTAS. - if (dataCols.nonEmpty) { - operationNotAllowed( - "Schema may not be specified in a Create Table As Select (CTAS) statement", - ctx) - } - - // When creating partitioned table with CTAS statement, we can't specify data type for the - // partition columns. - if (partitionCols.nonEmpty) { - val errorMessage = "Create Partitioned Table As Select cannot specify data type for " + - "the partition columns of the target table." - operationNotAllowed(errorMessage, ctx) - } - - // Hive CTAS supports dynamic partition by specifying partition column names. - val partitionColumnNames = - Option(ctx.partitionColumnNames) - .map(visitIdentifierList(_).toArray) - .getOrElse(Array.empty[String]) - - val tableDescWithPartitionColNames = - tableDesc.copy(partitionColumnNames = partitionColumnNames) - - val hasStorageProperties = (ctx.createFileFormat.size != 0) || (ctx.rowFormat.size != 0) - if (conf.convertCTAS && !hasStorageProperties) { - // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties - // are empty Maps. - val newTableDesc = tableDescWithPartitionColNames.copy( - storage = CatalogStorageFormat.empty.copy(locationUri = locUri), - provider = Some(conf.defaultDataSourceName)) - CreateTable(newTableDesc, mode, Some(q)) - } else { - CreateTable(tableDescWithPartitionColNames, mode, Some(q)) - } - case None => - // When creating partitioned table, we must specify data type for the partition columns. - if (Option(ctx.partitionColumnNames).isDefined) { - val errorMessage = "Must specify a data type for each partition column while creating " + - "Hive partitioned table." - operationNotAllowed(errorMessage, ctx) + val serdeInfo = maybeSerdeInfo.get + if (serdeInfo.storedAs.isEmpty) { + CatalogStorageFormat.empty.copy( + locationUri = location.map(CatalogUtils.stringToURI), + inputFormat = serdeInfo.formatClasses.map(_.input), + outputFormat = serdeInfo.formatClasses.map(_.output), + serde = serdeInfo.serde, + properties = serdeInfo.serdeProperties) + } else { + HiveSerDe.sourceToSerDe(serdeInfo.storedAs.get) match { + case Some(hiveSerde) => + CatalogStorageFormat.empty.copy( + locationUri = location.map(CatalogUtils.stringToURI), + inputFormat = hiveSerde.inputFormat, + outputFormat = hiveSerde.outputFormat, + serde = serdeInfo.serde.orElse(hiveSerde.serde), + properties = serdeInfo.serdeProperties) + case _ => + operationNotAllowed(s"STORED AS with file format '${serdeInfo.storedAs.get}'", ctx) } - - CreateTable(tableDesc, mode, None) + } } } @@ -559,189 +438,27 @@ class SparkSqlAstBuilder extends AstBuilder { checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) val provider = ctx.tableProvider.asScala.headOption.map(_.multipartIdentifier.getText) val location = visitLocationSpecList(ctx.locationSpec()) - // rowStorage used to determine CatalogStorageFormat.serde and - // CatalogStorageFormat.properties in STORED AS clause. - val rowStorage = ctx.rowFormat.asScala.headOption.map(visitRowFormat) - .getOrElse(CatalogStorageFormat.empty) - val fileFormat = ctx.createFileFormat.asScala.headOption.map(visitCreateFileFormat) match { - case Some(f) => - if (provider.isDefined) { - throw new ParseException("'STORED AS hiveFormats' and 'USING provider' " + - "should not be specified both", ctx) - } - f.copy( - locationUri = location.map(CatalogUtils.stringToURI), - serde = rowStorage.serde.orElse(f.serde), - properties = rowStorage.properties ++ f.properties) - case None => - if (rowStorage.serde.isDefined) { - throw new ParseException("'ROW FORMAT' must be used with 'STORED AS'", ctx) - } - CatalogStorageFormat.empty.copy(locationUri = location.map(CatalogUtils.stringToURI)) - } - val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) - CreateTableLikeCommand( - targetTable, sourceTable, fileFormat, provider, properties, ctx.EXISTS != null) - } - - /** - * Create a [[CatalogStorageFormat]] for creating tables. - * - * Format: STORED AS ... - */ - override def visitCreateFileFormat( - ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { - (ctx.fileFormat, ctx.storageHandler) match { - // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format - case (c: TableFileFormatContext, null) => - visitTableFileFormat(c) - // Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO - case (c: GenericFileFormatContext, null) => - visitGenericFileFormat(c) - case (null, storageHandler) => - operationNotAllowed("STORED BY", ctx) - case _ => - throw new ParseException("Expected either STORED AS or STORED BY, not both", ctx) + // TODO: Do not skip serde check for CREATE TABLE LIKE. + val serdeInfo = getSerdeInfo( + ctx.rowFormat.asScala, ctx.createFileFormat.asScala, ctx, skipCheck = true) + if (provider.isDefined && serdeInfo.isDefined) { + operationNotAllowed(s"CREATE TABLE LIKE ... USING ... ${serdeInfo.get.describe}", ctx) } - } - - /** - * Create a [[CatalogStorageFormat]]. - */ - override def visitTableFileFormat( - ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { - CatalogStorageFormat.empty.copy( - inputFormat = Option(string(ctx.inFmt)), - outputFormat = Option(string(ctx.outFmt))) - } - - /** - * Resolve a [[HiveSerDe]] based on the name given and return it as a [[CatalogStorageFormat]]. - */ - override def visitGenericFileFormat( - ctx: GenericFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { - val source = ctx.identifier.getText - HiveSerDe.sourceToSerDe(source) match { - case Some(s) => - CatalogStorageFormat.empty.copy( - inputFormat = s.inputFormat, - outputFormat = s.outputFormat, - serde = s.serde) - case None => - operationNotAllowed(s"STORED AS with file format '$source'", ctx) - } - } - - /** - * Create a [[CatalogStorageFormat]] used for creating tables. - * - * Example format: - * {{{ - * SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)] - * }}} - * - * OR - * - * {{{ - * DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]] - * [COLLECTION ITEMS TERMINATED BY char] - * [MAP KEYS TERMINATED BY char] - * [LINES TERMINATED BY char] - * [NULL DEFINED AS char] - * }}} - */ - private def visitRowFormat(ctx: RowFormatContext): CatalogStorageFormat = withOrigin(ctx) { - ctx match { - case serde: RowFormatSerdeContext => visitRowFormatSerde(serde) - case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited) - } - } - - /** - * Create SERDE row format name and properties pair. - */ - override def visitRowFormatSerde( - ctx: RowFormatSerdeContext): CatalogStorageFormat = withOrigin(ctx) { - import ctx._ - CatalogStorageFormat.empty.copy( - serde = Option(string(name)), - properties = Option(tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)) - } - - /** - * Create a delimited row format properties object. - */ - override def visitRowFormatDelimited( - ctx: RowFormatDelimitedContext): CatalogStorageFormat = withOrigin(ctx) { - // TODO we need proper support for the NULL format. - val entries = - entry("field.delim", ctx.fieldsTerminatedBy) ++ - entry("serialization.format", ctx.fieldsTerminatedBy) ++ - entry("escape.delim", ctx.escapedBy) ++ - // The following typo is inherited from Hive... - entry("colelction.delim", ctx.collectionItemsTerminatedBy) ++ - entry("mapkey.delim", ctx.keysTerminatedBy) ++ - Option(ctx.linesSeparatedBy).toSeq.map { token => - val value = string(token) - validate( - value == "\n", - s"LINES TERMINATED BY only supports newline '\\n' right now: $value", - ctx) - "line.delim" -> value - } - CatalogStorageFormat.empty.copy(properties = entries.toMap) - } - /** - * Throw a [[ParseException]] if the user specified incompatible SerDes through ROW FORMAT - * and STORED AS. - * - * The following are allowed. Anything else is not: - * ROW FORMAT SERDE ... STORED AS [SEQUENCEFILE | RCFILE | TEXTFILE] - * ROW FORMAT DELIMITED ... STORED AS TEXTFILE - * ROW FORMAT ... STORED AS INPUTFORMAT ... OUTPUTFORMAT ... - */ - private def validateRowFormatFileFormat( - rowFormatCtx: RowFormatContext, - createFileFormatCtx: CreateFileFormatContext, - parentCtx: ParserRuleContext): Unit = { - if (rowFormatCtx == null || createFileFormatCtx == null) { - return - } - (rowFormatCtx, createFileFormatCtx.fileFormat) match { - case (_, ffTable: TableFileFormatContext) => // OK - case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) => - ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { - case ("sequencefile" | "textfile" | "rcfile") => // OK - case fmt => - operationNotAllowed( - s"ROW FORMAT SERDE is incompatible with format '$fmt', which also specifies a serde", - parentCtx) - } - case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) => - ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { - case "textfile" => // OK - case fmt => operationNotAllowed( - s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx) + // TODO: remove this restriction as it seems unnecessary. + serdeInfo match { + case Some(SerdeInfo(storedAs, formatClasses, serde, _)) => + if (storedAs.isEmpty && formatClasses.isEmpty && serde.isDefined) { + throw new ParseException("'ROW FORMAT' must be used with 'STORED AS'", ctx) } case _ => - // should never happen - def str(ctx: ParserRuleContext): String = { - (0 until ctx.getChildCount).map { i => ctx.getChild(i).getText }.mkString(" ") - } - operationNotAllowed( - s"Unexpected combination of ${str(rowFormatCtx)} and ${str(createFileFormatCtx)}", - parentCtx) } - } - private def validateRowFormatFileFormat( - rowFormatCtx: Seq[RowFormatContext], - createFileFormatCtx: Seq[CreateFileFormatContext], - parentCtx: ParserRuleContext): Unit = { - if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { - validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) - } + // TODO: also look at `HiveSerDe.getDefaultStorage`. + val storage = toStorageFormat(location, serdeInfo, ctx) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) + CreateTableLikeCommand( + targetTable, sourceTable, storage, provider, properties, ctx.EXISTS != null) } /** @@ -788,7 +505,7 @@ class SparkSqlAstBuilder extends AstBuilder { case c: RowFormatSerdeContext => // Use a serde format. - val CatalogStorageFormat(None, None, None, Some(name), _, props) = visitRowFormatSerde(c) + val SerdeInfo(None, None, Some(name), props) = visitRowFormatSerde(c) // SPARK-10310: Special cases LazySimpleSerDe val recordHandler = if (name == "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") { @@ -868,12 +585,12 @@ class SparkSqlAstBuilder extends AstBuilder { // assert if directory is local when LOCAL keyword is mentioned val scheme = Option(storage.locationUri.get.getScheme) scheme match { - case None => + case Some(pathScheme) if (!pathScheme.equals("file")) => + throw new ParseException("LOCAL is supported only with file: scheme", ctx) + case _ => // force scheme to be file rather than fs.default.name val loc = Some(UriBuilder.fromUri(CatalogUtils.stringToURI(path)).scheme("file").build()) storage = storage.copy(locationUri = loc) - case Some(pathScheme) if (!pathScheme.equals("file")) => - throw new ParseException("LOCAL is supported only with file: scheme", ctx) } } @@ -896,28 +613,21 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitInsertOverwriteHiveDir( ctx: InsertOverwriteHiveDirContext): InsertDirParams = withOrigin(ctx) { - validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx) - val rowStorage = Option(ctx.rowFormat).map(visitRowFormat) - .getOrElse(CatalogStorageFormat.empty) - val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) - .getOrElse(CatalogStorageFormat.empty) - + val serdeInfo = getSerdeInfo( + Option(ctx.rowFormat).toSeq, Option(ctx.createFileFormat).toSeq, ctx) val path = string(ctx.path) // The path field is required if (path.isEmpty) { operationNotAllowed("INSERT OVERWRITE DIRECTORY must be accompanied by path", ctx) } - val defaultStorage = HiveSerDe.getDefaultStorage(conf) - - val storage = CatalogStorageFormat( - locationUri = Some(CatalogUtils.stringToURI(path)), - inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat), - outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat), - serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde), - compressed = false, - properties = rowStorage.properties ++ fileStorage.properties) + val default = HiveSerDe.getDefaultStorage(conf) + val storage = toStorageFormat(Some(path), serdeInfo, ctx) + val finalStorage = storage.copy( + inputFormat = storage.inputFormat.orElse(default.inputFormat), + outputFormat = storage.outputFormat.orElse(default.outputFormat), + serde = storage.serde.orElse(default.serde)) - (ctx.LOCAL != null, storage, Some(DDLUtils.HIVE_PROVIDER)) + (ctx.LOCAL != null, finalStorage, Some(DDLUtils.HIVE_PROVIDER)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e9b1aa8189..f5f77b03c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRe import org.apache.spark.sql.execution.aggregate.AggUtils import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{REPARTITION, REPARTITION_WITH_NUM, ShuffleExchangeExec} import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlan @@ -670,7 +670,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { ShuffleExchangeExec(RoundRobinPartitioning(numPartitions), - planLater(child), noUserSpecifiedNumPartition = false) :: Nil + planLater(child), REPARTITION_WITH_NUM) :: Nil } else { execution.CoalesceExec(numPartitions, planLater(child)) :: Nil } @@ -703,10 +703,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: logical.Range => execution.RangeExec(r) :: Nil case r: logical.RepartitionByExpression => - exchange.ShuffleExchangeExec( - r.partitioning, - planLater(r.child), - noUserSpecifiedNumPartition = r.optNumPartitions.isEmpty) :: Nil + val shuffleOrigin = if (r.optNumPartitions.isEmpty) { + REPARTITION + } else { + REPARTITION_WITH_NUM + } + exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child), shuffleOrigin) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala index 89ff528d7a..0cf3ab0cca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike} import org.apache.spark.sql.internal.SQLConf /** @@ -47,7 +49,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl val shuffleStages = collectShuffleStages(plan) // ShuffleExchanges introduced by repartition do not support changing the number of partitions. // We change the number of partitions in the stage only if all the ShuffleExchanges support it. - if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) { + if (!shuffleStages.forall(s => supportCoalesce(s.shuffle))) { plan } else { // `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions, @@ -82,4 +84,9 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl } } } + + private def supportCoalesce(s: ShuffleExchangeLike): Boolean = { + s.outputPartitioning != SinglePartition && + (s.shuffleOrigin == ENSURE_REQUIREMENTS || s.shuffleOrigin == REPARTITION) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index 8db2827bea..8f57947cb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.internal.SQLConf @@ -136,9 +137,13 @@ object OptimizeLocalShuffleReader extends Rule[SparkPlan] { def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match { case s: ShuffleQueryStageExec => - s.shuffle.canChangeNumPartitions && s.mapStats.isDefined + s.mapStats.isDefined && supportLocalReader(s.shuffle) case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs) => - s.shuffle.canChangeNumPartitions && s.mapStats.isDefined && partitionSpecs.nonEmpty + s.mapStats.isDefined && partitionSpecs.nonEmpty && supportLocalReader(s.shuffle) case _ => false } + + private def supportLocalReader(s: ShuffleExchangeLike): Boolean = { + s.outputPartitioning != SinglePartition && s.shuffleOrigin == ENSURE_REQUIREMENTS + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index efba51706c..c676609bc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -91,7 +91,7 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case Some(exprs) => ClusteredDistribution(exprs) :: Nil case None => UnspecifiedDistribution :: Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index fe733f4238..db7264d0c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -106,9 +106,10 @@ case class InsertIntoHadoopFsRelationCommand( fs, catalogTable.get, qualifiedOutputPath, matchingPartitions) } + val jobId = java.util.UUID.randomUUID().toString val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, - jobId = java.util.UUID.randomUUID().toString, + jobId = jobId, outputPath = outputPath.toString, dynamicPartitionOverwrite = dynamicPartitionOverwrite) @@ -163,6 +164,15 @@ case class InsertIntoHadoopFsRelationCommand( } } + // For dynamic partition overwrite, FileOutputCommitter's output path is staging path, files + // will be renamed from staging path to final output path during commit job + val committerOutputPath = if (dynamicPartitionOverwrite) { + FileCommitProtocol.getStagingDir(outputPath.toString, jobId) + .makeQualified(fs.getUri, fs.getWorkingDirectory) + } else { + qualifiedOutputPath + } + val updatedPartitionPaths = FileFormatWriter.write( sparkSession = sparkSession, @@ -170,7 +180,7 @@ case class InsertIntoHadoopFsRelationCommand( fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec( - qualifiedOutputPath.toString, customPartitionLocations, outputColumns), + committerOutputPath.toString, customPartitionLocations, outputColumns), hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = bucketSpec, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala index 39c594a9bc..144be2316f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala @@ -55,7 +55,8 @@ class SQLHadoopMapReduceCommitProtocol( // The specified output committer is a FileOutputCommitter. // So, we will use the FileOutputCommitter-specified constructor. val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - committer = ctor.newInstance(new Path(path), context) + val committerOutputPath = if (dynamicPartitionOverwrite) stagingDir else new Path(path) + committer = ctor.newInstance(committerOutputPath, context) } else { // The specified output committer is just an OutputCommitter. // So, we will use the no-argument constructor. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala index 1c0513f982..890205f2f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/BasicConnectionProvider.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.datasources.jdbc.connection import java.sql.{Connection, Driver} import java.util.Properties +import scala.collection.JavaConverters._ + import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.jdbc.JdbcConnectionProvider @@ -40,7 +42,7 @@ private[jdbc] class BasicConnectionProvider extends JdbcConnectionProvider with override def getConnection(driver: Driver, options: Map[String, String]): Connection = { val jdbcOptions = new JDBCOptions(options) val properties = getAdditionalProperties(jdbcOptions) - options.foreach { case(k, v) => + jdbcOptions.asProperties.asScala.foreach { case(k, v) => properties.put(k, v) } logDebug(s"JDBC connection initiated with URL: ${jdbcOptions.url} and properties: $properties") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index e5c29312b8..eb0d701004 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -302,6 +302,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case ShowCreateTable(_: ResolvedTable, _) => throw new AnalysisException("SHOW CREATE TABLE is not supported for v2 tables.") + case TruncateTable(_: ResolvedTable, _) => + throw new AnalysisException("TRUNCATE TABLE is not supported for v2 tables.") + + case ShowColumns(_: ResolvedTable, _) => + throw new AnalysisException("SHOW COLUMNS is not supported for v2 tables.") + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index 9ee145580c..f330d6a8c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -85,7 +85,7 @@ class V2SessionCatalog(catalog: SessionCatalog) val provider = properties.getOrDefault(TableCatalog.PROP_PROVIDER, conf.defaultDataSourceName) val tableProperties = properties.asScala val location = Option(properties.get(TableCatalog.PROP_LOCATION)) - val storage = DataSource.buildStorageFormatFromOptions(tableProperties.toMap) + val storage = DataSource.buildStorageFormatFromOptions(toOptions(tableProperties.toMap)) .copy(locationUri = location.map(CatalogUtils.stringToURI)) val tableType = if (location.isDefined) CatalogTableType.EXTERNAL else CatalogTableType.MANAGED @@ -111,6 +111,12 @@ class V2SessionCatalog(catalog: SessionCatalog) loadTable(ident) } + private def toOptions(properties: Map[String, String]): Map[String, String] = { + properties.filterKeys(_.startsWith(TableCatalog.OPTION_PREFIX)).map { + case (key, value) => key.drop(TableCatalog.OPTION_PREFIX.length) -> value + } + } + override def alterTable( ident: Identifier, changes: TableChange*): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 6af4b098be..affa92de69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -57,9 +57,9 @@ trait ShuffleExchangeLike extends Exchange { def numPartitions: Int /** - * Returns whether the shuffle partition number can be changed. + * The origin of this shuffle operator. */ - def canChangeNumPartitions: Boolean + def shuffleOrigin: ShuffleOrigin /** * The asynchronous job that materializes the shuffle. @@ -77,18 +77,30 @@ trait ShuffleExchangeLike extends Exchange { def runtimeStatistics: Statistics } +// Describes where the shuffle operator comes from. +sealed trait ShuffleOrigin + +// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule. It +// means that the shuffle operator is used to ensure internal data partitioning requirements and +// Spark is free to optimize it as long as the requirements are still ensured. +case object ENSURE_REQUIREMENTS extends ShuffleOrigin + +// Indicates that the shuffle operator was added by the user-specified repartition operator. Spark +// can still optimize it via changing shuffle partition number, as data partitioning won't change. +case object REPARTITION extends ShuffleOrigin + +// Indicates that the shuffle operator was added by the user-specified repartition operator with +// a certain partition number. Spark can't optimize it. +case object REPARTITION_WITH_NUM extends ShuffleOrigin + /** * Performs a shuffle that will result in the desired partitioning. */ case class ShuffleExchangeExec( override val outputPartitioning: Partitioning, child: SparkPlan, - noUserSpecifiedNumPartition: Boolean = true) extends ShuffleExchangeLike { - - // If users specify the num partitions via APIs like `repartition`, we shouldn't change it. - // For `SinglePartition`, it requires exactly one partition and we can't change it either. - override def canChangeNumPartitions: Boolean = - noUserSpecifiedNumPartition && outputPartitioning != SinglePartition + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) + extends ShuffleExchangeLike { private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala index c6b98d48d7..9832e5cd74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala @@ -71,6 +71,9 @@ trait WindowExecBase extends UnaryExecNode { case (RowFrame, IntegerLiteral(offset)) => RowBoundOrdering(offset) + case (RowFrame, _) => + sys.error(s"Unhandled bound in windows expressions: $bound") + case (RangeFrame, CurrentRow) => val ordering = RowOrdering.create(orderSpec, child.output) RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) @@ -249,6 +252,9 @@ trait WindowExecBase extends UnaryExecNode { createBoundOrdering(frameType, lower, timeZone), createBoundOrdering(frameType, upper, timeZone)) } + + case _ => + sys.error(s"Unsupported factory: $key") } // Keep track of the number of expressions. This is a side-effect in a map... diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala index f48672afb4..24709ba470 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala @@ -19,18 +19,32 @@ package org.apache.spark.sql.streaming.ui import java.{util => ju} import java.lang.{Long => JLong} -import java.util.UUID +import java.util.{Locale, UUID} import javax.servlet.http.HttpServletRequest +import scala.collection.JavaConverters._ import scala.xml.{Node, NodeBuffer, Unparsed} import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.state.StateStoreProvider +import org.apache.spark.sql.internal.SQLConf.STATE_STORE_PROVIDER_CLASS +import org.apache.spark.sql.internal.StaticSQLConf.ENABLED_STREAMING_UI_CUSTOM_METRIC_LIST import org.apache.spark.sql.streaming.ui.UIUtils._ import org.apache.spark.ui.{GraphUIData, JsCollector, UIUtils => SparkUIUtils, WebUIPage} private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) extends WebUIPage("statistics") with Logging { + // State store provider implementation mustn't do any heavyweight initialiation in constructor + // but in its init method. + private val supportedCustomMetrics = StateStoreProvider.create( + parent.parent.conf.get(STATE_STORE_PROVIDER_CLASS)).supportedCustomMetrics + logDebug(s"Supported custom metrics: $supportedCustomMetrics") + + private val enabledCustomMetrics = + parent.parent.conf.get(ENABLED_STREAMING_UI_CUSTOM_METRIC_LIST).map(_.toLowerCase(Locale.ROOT)) + logDebug(s"Enabled custom metrics: $enabledCustomMetrics") + def generateLoadResources(request: HttpServletRequest): Seq[Node] = { // scalastyle:off @@ -126,6 +140,58 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab)
} + def generateWatermark( + query: StreamingQueryUIData, + minBatchTime: Long, + maxBatchTime: Long, + jsCollector: JsCollector): Seq[Node] = { + // This is made sure on caller side but put it here to be defensive + require(query.lastProgress != null) + if (query.lastProgress.eventTime.containsKey("watermark")) { + val watermarkData = query.recentProgress.flatMap { p => + val batchTimestamp = parseProgressTimestamp(p.timestamp) + val watermarkValue = parseProgressTimestamp(p.eventTime.get("watermark")) + if (watermarkValue > 0L) { + // seconds + Some((batchTimestamp, ((batchTimestamp - watermarkValue) / 1000.0))) + } else { + None + } + } + + if (watermarkData.nonEmpty) { + val maxWatermark = watermarkData.maxBy(_._2)._2 + val graphUIDataForWatermark = + new GraphUIData( + "watermark-gap-timeline", + "watermark-gap-histogram", + watermarkData, + minBatchTime, + maxBatchTime, + 0, + maxWatermark, + "seconds") + graphUIDataForWatermark.generateDataJs(jsCollector) + + // scalastyle:off + + +

+
Global Watermark Gap {SparkUIUtils.tooltip("The gap between batch timestamp and global watermark for the batch.", "right")}
+
+ + {graphUIDataForWatermark.generateTimelineHtml(jsCollector)} + {graphUIDataForWatermark.generateHistogramHtml(jsCollector)} + + // scalastyle:on + } else { + Seq.empty[Node] + } + } else { + Seq.empty[Node] + } + } + def generateAggregatedStateOperators( query: StreamingQueryUIData, minBatchTime: Long, @@ -199,49 +265,100 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) "records") graphUIDataForNumRowsDroppedByWatermark.generateDataJs(jsCollector) - // scalastyle:off - - -
-
Aggregated Number Of Total State Rows {SparkUIUtils.tooltip("Aggregated number of total state rows.", "right")}
-
- - {graphUIDataForNumberTotalRows.generateTimelineHtml(jsCollector)} - {graphUIDataForNumberTotalRows.generateHistogramHtml(jsCollector)} - - - -
-
Aggregated Number Of Updated State Rows {SparkUIUtils.tooltip("Aggregated number of updated state rows.", "right")}
-
- - {graphUIDataForNumberUpdatedRows.generateTimelineHtml(jsCollector)} - {graphUIDataForNumberUpdatedRows.generateHistogramHtml(jsCollector)} - - - -
-
Aggregated State Memory Used In Bytes {SparkUIUtils.tooltip("Aggregated state memory used in bytes.", "right")}
-
- - {graphUIDataForMemoryUsedBytes.generateTimelineHtml(jsCollector)} - {graphUIDataForMemoryUsedBytes.generateHistogramHtml(jsCollector)} - - - -
-
Aggregated Number Of Rows Dropped By Watermark {SparkUIUtils.tooltip("Accumulates all input rows being dropped in stateful operators by watermark. 'Inputs' are relative to operators.", "right")}
-
- - {graphUIDataForNumRowsDroppedByWatermark.generateTimelineHtml(jsCollector)} - {graphUIDataForNumRowsDroppedByWatermark.generateHistogramHtml(jsCollector)} - - // scalastyle:on + val result = + // scalastyle:off + + +
+
Aggregated Number Of Total State Rows {SparkUIUtils.tooltip("Aggregated number of total state rows.", "right")}
+
+ + {graphUIDataForNumberTotalRows.generateTimelineHtml(jsCollector)} + {graphUIDataForNumberTotalRows.generateHistogramHtml(jsCollector)} + + + +
+
Aggregated Number Of Updated State Rows {SparkUIUtils.tooltip("Aggregated number of updated state rows.", "right")}
+
+ + {graphUIDataForNumberUpdatedRows.generateTimelineHtml(jsCollector)} + {graphUIDataForNumberUpdatedRows.generateHistogramHtml(jsCollector)} + + + +
+
Aggregated State Memory Used In Bytes {SparkUIUtils.tooltip("Aggregated state memory used in bytes.", "right")}
+
+ + {graphUIDataForMemoryUsedBytes.generateTimelineHtml(jsCollector)} + {graphUIDataForMemoryUsedBytes.generateHistogramHtml(jsCollector)} + + + +
+
Aggregated Number Of Rows Dropped By Watermark {SparkUIUtils.tooltip("Accumulates all input rows being dropped in stateful operators by watermark. 'Inputs' are relative to operators.", "right")}
+
+ + {graphUIDataForNumRowsDroppedByWatermark.generateTimelineHtml(jsCollector)} + {graphUIDataForNumRowsDroppedByWatermark.generateHistogramHtml(jsCollector)} + + // scalastyle:on + + if (enabledCustomMetrics.nonEmpty) { + result ++= generateAggregatedCustomMetrics(query, minBatchTime, maxBatchTime, jsCollector) + } + result } else { new NodeBuffer() } } + def generateAggregatedCustomMetrics( + query: StreamingQueryUIData, + minBatchTime: Long, + maxBatchTime: Long, + jsCollector: JsCollector): NodeBuffer = { + val result: NodeBuffer = new NodeBuffer + + // This is made sure on caller side but put it here to be defensive + require(query.lastProgress.stateOperators.nonEmpty) + query.lastProgress.stateOperators.head.customMetrics.keySet().asScala + .filter(m => enabledCustomMetrics.contains(m.toLowerCase(Locale.ROOT))).map { metricName => + val data = query.recentProgress.map(p => (parseProgressTimestamp(p.timestamp), + p.stateOperators.map(_.customMetrics.get(metricName).toDouble).sum)) + val max = data.maxBy(_._2)._2 + val metric = supportedCustomMetrics.find(_.name.equalsIgnoreCase(metricName)).get + + val graphUIData = + new GraphUIData( + s"aggregated-$metricName-timeline", + s"aggregated-$metricName-histogram", + data, + minBatchTime, + maxBatchTime, + 0, + max, + "") + graphUIData.generateDataJs(jsCollector) + + result ++= + // scalastyle:off + + +
+
Aggregated Custom Metric {s"$metricName"} {SparkUIUtils.tooltip(metric.desc, "right")}
+
+ + {graphUIData.generateTimelineHtml(jsCollector)} + {graphUIData.generateHistogramHtml(jsCollector)} + + // scalastyle:on + } + + result + } + def generateStatTable(query: StreamingQueryUIData): Seq[Node] = { val batchToTimestamps = withNoProgress(query, query.recentProgress.map(p => (p.batchId, parseProgressTimestamp(p.timestamp))), @@ -400,6 +517,7 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) {graphUIDataForDuration.generateAreaStackHtmlWithData(jsCollector, operationDurationData)} + {generateWatermark(query, minBatchTime, maxBatchTime, jsCollector)} {generateAggregatedStateOperators(query, minBatchTime, maxBatchTime, jsCollector)} diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index 567e0eabe1..578b0a807f 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 23 +-- Number of queries: 24 -- !query @@ -67,10 +67,10 @@ Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL] == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL]) - +- Exchange SinglePartition, true, [id=#x] + +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x] +- HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL]) +- HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL]) - +- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), true, [id=#x] + +- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x] +- HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL]) +- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct @@ -116,7 +116,7 @@ Results [2]: [key#x, max#x] (4) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (5) HashAggregate Input [2]: [key#x, max#x] @@ -127,7 +127,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x] (6) Exchange Input [2]: [key#x, max(val)#x] -Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x] +Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x] (7) Sort Input [2]: [key#x, max(val)#x] @@ -179,7 +179,7 @@ Results [2]: [key#x, max#x] (4) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (5) HashAggregate Input [2]: [key#x, max#x] @@ -254,7 +254,7 @@ Results [2]: [key#x, val#x] (7) Exchange Input [2]: [key#x, val#x] -Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x] (8) HashAggregate Input [2]: [key#x, val#x] @@ -576,7 +576,7 @@ Results [2]: [key#x, max#x] (4) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (5) HashAggregate Input [2]: [key#x, max#x] @@ -605,7 +605,7 @@ Results [2]: [key#x, max#x] (9) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (10) HashAggregate Input [2]: [key#x, max#x] @@ -687,7 +687,7 @@ Results [3]: [count#xL, sum#xL, count#xL] (3) Exchange Input [3]: [count#xL, sum#xL, count#xL] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (4) HashAggregate Input [3]: [count#xL, sum#xL, count#xL] @@ -732,7 +732,7 @@ Results [2]: [key#x, buf#x] (3) Exchange Input [2]: [key#x, buf#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (4) ObjectHashAggregate Input [2]: [key#x, buf#x] @@ -783,7 +783,7 @@ Results [2]: [key#x, min#x] (4) Exchange Input [2]: [key#x, min#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (5) Sort Input [2]: [key#x, min#x] diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index fcd69549f2..886b98e538 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 23 +-- Number of queries: 24 -- !query @@ -66,10 +66,10 @@ Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL] == Physical Plan == *HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL]) -+- Exchange SinglePartition, true, [id=#x] ++- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x] +- *HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL]) +- *HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL]) - +- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), true, [id=#x] + +- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x] +- *HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL]) +- *ColumnarToRow +- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct @@ -119,7 +119,7 @@ Results [2]: [key#x, max#x] (5) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (6) HashAggregate [codegen id : 2] Input [2]: [key#x, max#x] @@ -130,7 +130,7 @@ Results [2]: [key#x, max(val#x)#x AS max(val)#x] (7) Exchange Input [2]: [key#x, max(val)#x] -Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), true, [id=#x] +Arguments: rangepartitioning(key#x ASC NULLS FIRST, 4), ENSURE_REQUIREMENTS, [id=#x] (8) Sort [codegen id : 3] Input [2]: [key#x, max(val)#x] @@ -181,7 +181,7 @@ Results [2]: [key#x, max#x] (5) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (6) HashAggregate [codegen id : 2] Input [2]: [key#x, max#x] @@ -259,7 +259,7 @@ Results [2]: [key#x, val#x] (9) Exchange Input [2]: [key#x, val#x] -Arguments: hashpartitioning(key#x, val#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, val#x, 4), ENSURE_REQUIREMENTS, [id=#x] (10) HashAggregate [codegen id : 4] Input [2]: [key#x, val#x] @@ -452,7 +452,7 @@ Results [1]: [max#x] (9) Exchange Input [1]: [max#x] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (10) HashAggregate [codegen id : 2] Input [1]: [max#x] @@ -498,7 +498,7 @@ Results [1]: [max#x] (16) Exchange Input [1]: [max#x] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (17) HashAggregate [codegen id : 2] Input [1]: [max#x] @@ -580,7 +580,7 @@ Results [1]: [max#x] (9) Exchange Input [1]: [max#x] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (10) HashAggregate [codegen id : 2] Input [1]: [max#x] @@ -626,7 +626,7 @@ Results [2]: [sum#x, count#xL] (16) Exchange Input [2]: [sum#x, count#xL] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (17) HashAggregate [codegen id : 2] Input [2]: [sum#x, count#xL] @@ -690,7 +690,7 @@ Results [2]: [sum#x, count#xL] (7) Exchange Input [2]: [sum#x, count#xL] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (8) HashAggregate [codegen id : 2] Input [2]: [sum#x, count#xL] @@ -810,7 +810,7 @@ Results [2]: [key#x, max#x] (5) Exchange Input [2]: [key#x, max#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (6) HashAggregate [codegen id : 4] Input [2]: [key#x, max#x] @@ -901,7 +901,7 @@ Results [3]: [count#xL, sum#xL, count#xL] (4) Exchange Input [3]: [count#xL, sum#xL, count#xL] -Arguments: SinglePartition, true, [id=#x] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x] (5) HashAggregate [codegen id : 2] Input [3]: [count#xL, sum#xL, count#xL] @@ -945,7 +945,7 @@ Results [2]: [key#x, buf#x] (4) Exchange Input [2]: [key#x, buf#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (5) ObjectHashAggregate Input [2]: [key#x, buf#x] @@ -995,7 +995,7 @@ Results [2]: [key#x, min#x] (5) Exchange Input [2]: [key#x, min#x] -Arguments: hashpartitioning(key#x, 4), true, [id=#x] +Arguments: hashpartitioning(key#x, 4), ENSURE_REQUIREMENTS, [id=#x] (6) Sort [codegen id : 2] Input [2]: [key#x, min#x] diff --git a/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out b/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out index 4f5db7f6c6..6ddffb8998 100644 --- a/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out @@ -93,8 +93,8 @@ SHOW COLUMNS IN badtable FROM showdb -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.analysis.NoSuchTableException -Table or view 'badtable' not found in database 'showdb'; +org.apache.spark.sql.AnalysisException +Table or view not found: showdb.badtable; line 1 pos 0 -- !query @@ -129,8 +129,8 @@ SHOW COLUMNS IN showdb.showcolumn3 -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.analysis.NoSuchTableException -Table or view 'showcolumn3' not found in database 'showdb'; +org.apache.spark.sql.AnalysisException +Table or view not found: showdb.showcolumn3; line 1 pos 0 -- !query @@ -138,8 +138,8 @@ SHOW COLUMNS IN showcolumn3 FROM showdb -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.analysis.NoSuchTableException -Table or view 'showcolumn3' not found in database 'showdb'; +org.apache.spark.sql.AnalysisException +Table or view not found: showdb.showcolumn3; line 1 pos 0 -- !query @@ -147,8 +147,8 @@ SHOW COLUMNS IN showcolumn4 -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.analysis.NoSuchTableException -Table or view 'showcolumn4' not found in database 'showdb'; +org.apache.spark.sql.AnalysisException +Table or view not found: showcolumn4; line 1 pos 0 -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 951b72a863..12abd31b99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE @@ -766,7 +766,9 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] { case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleExchangeLike { override def numMappers: Int = delegate.numMappers override def numPartitions: Int = delegate.numPartitions - override def canChangeNumPartitions: Boolean = delegate.canChangeNumPartitions + override def shuffleOrigin: ShuffleOrigin = { + delegate.shuffleOrigin + } override def mapOutputStatisticsFuture: Future[MapOutputStatistics] = delegate.mapOutputStatisticsFuture override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala index e05c2c09ac..4cacd5ec2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTablePartitionV2SQLSuite.scala @@ -17,12 +17,16 @@ package org.apache.spark.sql.connector +import java.time.{LocalDate, LocalDateTime} + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionsException, PartitionsAlreadyExistException} +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.unsafe.types.UTF8String class AlterTablePartitionV2SQLSuite extends DatasourceV2SQLBase { @@ -185,4 +189,58 @@ class AlterTablePartitionV2SQLSuite extends DatasourceV2SQLBase { } } } + + test("SPARK-33521: universal type conversions of partition values") { + val t = "testpart.ns1.ns2.tbl" + withTable(t) { + sql(s""" + |CREATE TABLE $t ( + | part0 tinyint, + | part1 smallint, + | part2 int, + | part3 bigint, + | part4 float, + | part5 double, + | part6 string, + | part7 boolean, + | part8 date, + | part9 timestamp + |) USING foo + |PARTITIONED BY (part0, part1, part2, part3, part4, part5, part6, part7, part8, part9) + |""".stripMargin) + val partTable = catalog("testpart").asTableCatalog + .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")) + .asPartitionable + val expectedPartition = InternalRow.fromSeq(Seq[Any]( + -1, // tinyint + 0, // smallint + 1, // int + 2, // bigint + 3.14F, // float + 3.14D, // double + UTF8String.fromString("abc"), // string + true, // boolean + LocalDate.parse("2020-11-23").toEpochDay, + DateTimeUtils.instantToMicros( + LocalDateTime.parse("2020-11-23T22:13:10.123456").atZone(DateTimeTestUtils.LA).toInstant) + )) + assert(!partTable.partitionExists(expectedPartition)) + val partSpec = """ + | part0 = -1, + | part1 = 0, + | part2 = 1, + | part3 = 2, + | part4 = 3.14, + | part5 = 3.14, + | part6 = 'abc', + | part7 = true, + | part8 = '2020-11-23', + | part9 = '2020-11-23T22:13:10.123456' + |""".stripMargin + sql(s"ALTER TABLE $t ADD PARTITION ($partSpec) LOCATION 'loc1'") + assert(partTable.partitionExists(expectedPartition)) + sql(s" ALTER TABLE $t DROP PARTITION ($partSpec)") + assert(!partTable.partitionExists(expectedPartition)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index da53936239..f2b57f9442 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -432,7 +432,7 @@ class DataSourceV2SQLSuite intercept[Exception] { spark.sql("REPLACE TABLE testcat.table_name" + - s" USING foo OPTIONS (`${InMemoryTable.SIMULATE_FAILED_WRITE_OPTION}`=true)" + + s" USING foo TBLPROPERTIES (`${InMemoryTable.SIMULATE_FAILED_WRITE_OPTION}`=true)" + s" AS SELECT id FROM source") } @@ -465,7 +465,7 @@ class DataSourceV2SQLSuite intercept[Exception] { spark.sql("REPLACE TABLE testcat_atomic.table_name" + - s" USING foo OPTIONS (`${InMemoryTable.SIMULATE_FAILED_WRITE_OPTION}=true)" + + s" USING foo TBLPROPERTIES (`${InMemoryTable.SIMULATE_FAILED_WRITE_OPTION}=true)" + s" AS SELECT id FROM source") } @@ -1986,8 +1986,8 @@ class DataSourceV2SQLSuite |PARTITIONED BY (id) """.stripMargin) - testV1Command("TRUNCATE TABLE", t) - testV1Command("TRUNCATE TABLE", s"$t PARTITION(id='1')") + testNotSupportedV2Command("TRUNCATE TABLE", t) + testNotSupportedV2Command("TRUNCATE TABLE", s"$t PARTITION(id='1')") } } @@ -2047,14 +2047,9 @@ class DataSourceV2SQLSuite withTable(t) { spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo") - testV1CommandSupportingTempView("SHOW COLUMNS", s"FROM $t") - testV1CommandSupportingTempView("SHOW COLUMNS", s"IN $t") - - val e3 = intercept[AnalysisException] { - sql(s"SHOW COLUMNS FROM tbl IN testcat.ns1.ns2") - } - assert(e3.message.contains("Namespace name should have " + - "only one part if specified: testcat.ns1.ns2")) + testNotSupportedV2Command("SHOW COLUMNS", s"FROM $t") + testNotSupportedV2Command("SHOW COLUMNS", s"IN $t") + testNotSupportedV2Command("SHOW COLUMNS", "FROM tbl IN testcat.ns1.ns2") } } @@ -2511,7 +2506,7 @@ class DataSourceV2SQLSuite checkAnswer( spark.sql(s"SELECT id, data, _partition FROM $t1"), - Seq(Row(1, "a", "3/1"), Row(2, "b", "2/2"), Row(3, "c", "2/3"))) + Seq(Row(1, "a", "3/1"), Row(2, "b", "0/2"), Row(3, "c", "1/3"))) } } @@ -2524,7 +2519,7 @@ class DataSourceV2SQLSuite checkAnswer( spark.sql(s"SELECT index, data, _partition FROM $t1"), - Seq(Row(3, "c", "2/3"), Row(2, "b", "2/2"), Row(1, "a", "3/1"))) + Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1"))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 504cc57dc1..edeebde7db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -176,15 +176,18 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { sql(s"""LOAD DATA LOCAL INPATH "$dataFilePath" INTO TABLE $viewName""") }.getMessage assert(e2.contains(s"$viewName is a temp view. 'LOAD DATA' expects a table")) - assertNoSuchTable(s"TRUNCATE TABLE $viewName") val e3 = intercept[AnalysisException] { - sql(s"SHOW CREATE TABLE $viewName") + sql(s"TRUNCATE TABLE $viewName") }.getMessage - assert(e3.contains(s"$viewName is a temp view not table or permanent view")) + assert(e3.contains(s"$viewName is a temp view. 'TRUNCATE TABLE' expects a table")) val e4 = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + sql(s"SHOW CREATE TABLE $viewName") }.getMessage assert(e4.contains(s"$viewName is a temp view not table or permanent view")) + val e5 = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + }.getMessage + assert(e5.contains(s"$viewName is a temp view not table or permanent view")) assertNoSuchTable(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") } } @@ -219,7 +222,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { e = intercept[AnalysisException] { sql(s"TRUNCATE TABLE $viewName") }.getMessage - assert(e.contains(s"Operation not allowed: TRUNCATE TABLE on views: `default`.`testview`")) + assert(e.contains("default.testView is a view. 'TRUNCATE TABLE' expects a table")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index f55fbc9809..61c16baedb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -20,16 +20,14 @@ package org.apache.spark.sql.execution import scala.collection.JavaConverters._ import org.apache.spark.internal.config.ConfigEntry -import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing, RefreshResource} -import org.apache.spark.sql.internal.{HiveSerDe, StaticSQLConf} -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, RefreshResource} +import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.spark.sql.types.StringType /** * Parser test cases for rules defined in [[SparkSqlParser]]. @@ -42,23 +40,8 @@ class SparkSqlParserSuite extends AnalysisTest { private lazy val parser = new SparkSqlParser() - /** - * Normalizes plans: - * - CreateTable the createTime in tableDesc will replaced by -1L. - */ - override def normalizePlan(plan: LogicalPlan): LogicalPlan = { - plan match { - case CreateTable(tableDesc, mode, query) => - val newTableDesc = tableDesc.copy(createTime = -1L) - CreateTable(newTableDesc, mode, query) - case _ => plan // Don't transform - } - } - private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { - val normalized1 = normalizePlan(parser.parsePlan(sqlCommand)) - val normalized2 = normalizePlan(plan) - comparePlans(normalized1, normalized2) + comparePlans(parser.parsePlan(sqlCommand), plan) } private def intercept(sqlCommand: String, messages: String*): Unit = @@ -210,110 +193,6 @@ class SparkSqlParserSuite extends AnalysisTest { Map("path" -> "/data/tmp/testspark1"))) } - private def createTableUsing( - table: String, - database: Option[String] = None, - tableType: CatalogTableType = CatalogTableType.MANAGED, - storage: CatalogStorageFormat = CatalogStorageFormat.empty, - schema: StructType = new StructType, - provider: Option[String] = Some("parquet"), - partitionColumnNames: Seq[String] = Seq.empty, - bucketSpec: Option[BucketSpec] = None, - mode: SaveMode = SaveMode.ErrorIfExists, - query: Option[LogicalPlan] = None): CreateTable = { - CreateTable( - CatalogTable( - identifier = TableIdentifier(table, database), - tableType = tableType, - storage = storage, - schema = schema, - provider = provider, - partitionColumnNames = partitionColumnNames, - bucketSpec = bucketSpec - ), mode, query - ) - } - - private def createTable( - table: String, - database: Option[String] = None, - tableType: CatalogTableType = CatalogTableType.MANAGED, - storage: CatalogStorageFormat = CatalogStorageFormat.empty.copy( - inputFormat = HiveSerDe.sourceToSerDe("textfile").get.inputFormat, - outputFormat = HiveSerDe.sourceToSerDe("textfile").get.outputFormat, - serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")), - schema: StructType = new StructType, - provider: Option[String] = Some("hive"), - partitionColumnNames: Seq[String] = Seq.empty, - comment: Option[String] = None, - mode: SaveMode = SaveMode.ErrorIfExists, - query: Option[LogicalPlan] = None): CreateTable = { - CreateTable( - CatalogTable( - identifier = TableIdentifier(table, database), - tableType = tableType, - storage = storage, - schema = schema, - provider = provider, - partitionColumnNames = partitionColumnNames, - comment = comment - ), mode, query - ) - } - - test("create table - schema") { - assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) STORED AS textfile", - createTable( - table = "my_tab", - schema = (new StructType) - .add("a", IntegerType, nullable = true, "test") - .add("b", StringType) - ) - ) - assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) " + - "PARTITIONED BY (c INT, d STRING COMMENT 'test2')", - createTable( - table = "my_tab", - schema = (new StructType) - .add("a", IntegerType, nullable = true, "test") - .add("b", StringType) - .add("c", IntegerType) - .add("d", StringType, nullable = true, "test2"), - partitionColumnNames = Seq("c", "d") - ) - ) - assertEqual("CREATE TABLE my_tab(id BIGINT, nested STRUCT) " + - "STORED AS textfile", - createTable( - table = "my_tab", - schema = (new StructType) - .add("id", LongType) - .add("nested", (new StructType) - .add("col1", StringType) - .add("col2", IntegerType) - ) - ) - ) - // Partitioned by a StructType should be accepted by `SparkSqlParser` but will fail an analyze - // rule in `AnalyzeCreateTable`. - assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) " + - "PARTITIONED BY (nested STRUCT)", - createTable( - table = "my_tab", - schema = (new StructType) - .add("a", IntegerType, nullable = true, "test") - .add("b", StringType) - .add("nested", (new StructType) - .add("col1", StringType) - .add("col2", IntegerType) - ), - partitionColumnNames = Seq("nested") - ) - ) - intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING)", - "no viable alternative at input") - } - test("describe query") { val query = "SELECT * FROM t" assertEqual("DESCRIBE QUERY " + query, DescribeQueryCommand(query, parser.parsePlan(query))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala index 34b4a70d05..e26acbcb3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, Or} import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -39,7 +41,7 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { import spark.implicits._ def withFromJson(rowsNum: Int, numIters: Int): Unit = { - val benchmark = new Benchmark("from_json as subExpr", rowsNum, output = output) + val benchmark = new Benchmark("from_json as subExpr in Project", rowsNum, output = output) withTempPath { path => prepareDataInfo(benchmark) @@ -50,57 +52,65 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { from_json('value, schema).getField(s"col$idx") } - // We only benchmark subexpression performance under codegen/non-codegen, so disabling - // json optimization. - benchmark.addCase("subexpressionElimination off, codegen on", numIters) { _ => - withSQLConf( - SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> "false", - SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", - SQLConf.CODEGEN_FACTORY_MODE.key -> "CODEGEN_ONLY", - SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "false") { - val df = spark.read - .text(path.getAbsolutePath) - .select(cols: _*) - df.collect() + Seq( + ("false", "true", "CODEGEN_ONLY"), + ("false", "false", "NO_CODEGEN"), + ("true", "true", "CODEGEN_ONLY"), + ("true", "false", "NO_CODEGEN") + ).foreach { case (subExprEliminationEnabled, codegenEnabled, codegenFactory) => + // We only benchmark subexpression performance under codegen/non-codegen, so disabling + // json optimization. + val caseName = s"subExprElimination $subExprEliminationEnabled, codegen: $codegenEnabled" + benchmark.addCase(caseName, numIters) { _ => + withSQLConf( + SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> subExprEliminationEnabled, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled, + SQLConf.CODEGEN_FACTORY_MODE.key -> codegenFactory, + SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "false") { + val df = spark.read + .text(path.getAbsolutePath) + .select(cols: _*) + df.write.mode("overwrite").format("noop").save() + } } } - benchmark.addCase("subexpressionElimination off, codegen off", numIters) { _ => - withSQLConf( - SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> "false", - SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", - SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN", - SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "false") { - val df = spark.read - .text(path.getAbsolutePath) - .select(cols: _*) - df.collect() - } - } + benchmark.run() + } + } - benchmark.addCase("subexpressionElimination on, codegen on", numIters) { _ => - withSQLConf( - SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> "true", - SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", - SQLConf.CODEGEN_FACTORY_MODE.key -> "CODEGEN_ONLY", - SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "false") { - val df = spark.read - .text(path.getAbsolutePath) - .select(cols: _*) - df.collect() - } - } + def withFilter(rowsNum: Int, numIters: Int): Unit = { + val benchmark = new Benchmark("from_json as subExpr in Filter", rowsNum, output = output) - benchmark.addCase("subexpressionElimination on, codegen off", numIters) { _ => - withSQLConf( - SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> "true", - SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", - SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN", - SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "false") { - val df = spark.read - .text(path.getAbsolutePath) - .select(cols: _*) - df.collect() + withTempPath { path => + prepareDataInfo(benchmark) + val numCols = 1000 + val schema = writeWideRow(path.getAbsolutePath, rowsNum, numCols) + + val predicate = (0 until numCols).map { idx => + (from_json('value, schema).getField(s"col$idx") >= Literal(100000)).expr + }.asInstanceOf[Seq[Expression]].reduce(Or) + + Seq( + ("false", "true", "CODEGEN_ONLY"), + ("false", "false", "NO_CODEGEN"), + ("true", "true", "CODEGEN_ONLY"), + ("true", "false", "NO_CODEGEN") + ).foreach { case (subExprEliminationEnabled, codegenEnabled, codegenFactory) => + // We only benchmark subexpression performance under codegen/non-codegen, so disabling + // json optimization. + val caseName = s"subExprElimination $subExprEliminationEnabled, codegen: $codegenEnabled" + benchmark.addCase("subexpressionElimination off, codegen on", numIters) { _ => + withSQLConf( + SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> subExprEliminationEnabled, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled, + SQLConf.CODEGEN_FACTORY_MODE.key -> codegenFactory, + SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "false") { + val df = spark.read + .text(path.getAbsolutePath) + .where(Column(predicate)) + df.write.mode("overwrite").format("noop").save() + } } } @@ -108,11 +118,11 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { } } - override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { val numIters = 3 runBenchmark("Benchmark for performance of subexpression elimination") { withFromJson(100, numIters) + withFilter(100, numIters) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 38a323b1c0..758965954b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1307,4 +1307,14 @@ class AdaptiveQueryExecSuite spark.listenerManager.unregister(listener) } } + + test("SPARK-33494: Do not use local shuffle reader for repartition") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val df = spark.table("testData").repartition('key) + df.collect() + // local shuffle reader breaks partitioning and shouldn't be used for repartition operation + // which is specified by users. + checkNumLocalShuffleReaders(df.queryExecution.executedPlan, numShufflesWithoutLocalReader = 1) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 8ce4bcbadc..96f9421e1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -17,14 +17,10 @@ package org.apache.spark.sql.execution.command -import java.net.URI import java.util.Locale -import scala.reflect.{classTag, ClassTag} - -import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan @@ -32,10 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.execution.datasources.CreateTable -import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.StructType class DDLParserSuite extends AnalysisTest with SharedSparkSession { private lazy val parser = new SparkSqlParser() @@ -50,159 +43,17 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { } } - private def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(parser.parsePlan)(sqlCommand, messages: _*) - - private def parseAs[T: ClassTag](query: String): T = { - parser.parsePlan(query) match { - case t: T => t - case other => - fail(s"Expected to parse ${classTag[T].runtimeClass} from query," + - s"got ${other.getClass.getName}: $query") - } - } - private def compareTransformQuery(sql: String, expected: LogicalPlan): Unit = { val plan = parser.parsePlan(sql).asInstanceOf[ScriptTransformation].copy(ioschema = null) comparePlans(plan, expected, checkAnalysis = false) } - private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { - parser.parsePlan(sql).collect { - case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) - }.head - } - test("alter database - property values must be set") { assertUnsupported( sql = "ALTER DATABASE my_db SET DBPROPERTIES('key_without_value', 'key_with_value'='x')", containsThesePhrases = Seq("key_without_value")) } - test("create hive table - table file format") { - val allSources = Seq("parquet", "parquetfile", "orc", "orcfile", "avro", "avrofile", - "sequencefile", "rcfile", "textfile") - - allSources.foreach { s => - val query = s"CREATE TABLE my_tab STORED AS $s" - val ct = parseAs[CreateTable](query) - val hiveSerde = HiveSerDe.sourceToSerDe(s) - assert(hiveSerde.isDefined) - assert(ct.tableDesc.storage.serde == - hiveSerde.get.serde.orElse(Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))) - assert(ct.tableDesc.storage.inputFormat == hiveSerde.get.inputFormat) - assert(ct.tableDesc.storage.outputFormat == hiveSerde.get.outputFormat) - } - } - - test("create hive table - row format and table file format") { - val createTableStart = "CREATE TABLE my_tab ROW FORMAT" - val fileFormat = s"STORED AS INPUTFORMAT 'inputfmt' OUTPUTFORMAT 'outputfmt'" - val query1 = s"$createTableStart SERDE 'anything' $fileFormat" - val query2 = s"$createTableStart DELIMITED FIELDS TERMINATED BY ' ' $fileFormat" - - // No conflicting serdes here, OK - val parsed1 = parseAs[CreateTable](query1) - assert(parsed1.tableDesc.storage.serde == Some("anything")) - assert(parsed1.tableDesc.storage.inputFormat == Some("inputfmt")) - assert(parsed1.tableDesc.storage.outputFormat == Some("outputfmt")) - - val parsed2 = parseAs[CreateTable](query2) - assert(parsed2.tableDesc.storage.serde == - Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - assert(parsed2.tableDesc.storage.inputFormat == Some("inputfmt")) - assert(parsed2.tableDesc.storage.outputFormat == Some("outputfmt")) - } - - test("create hive table - row format serde and generic file format") { - val allSources = Seq("parquet", "orc", "avro", "sequencefile", "rcfile", "textfile") - val supportedSources = Set("sequencefile", "rcfile", "textfile") - - allSources.foreach { s => - val query = s"CREATE TABLE my_tab ROW FORMAT SERDE 'anything' STORED AS $s" - if (supportedSources.contains(s)) { - val ct = parseAs[CreateTable](query) - val hiveSerde = HiveSerDe.sourceToSerDe(s) - assert(hiveSerde.isDefined) - assert(ct.tableDesc.storage.serde == Some("anything")) - assert(ct.tableDesc.storage.inputFormat == hiveSerde.get.inputFormat) - assert(ct.tableDesc.storage.outputFormat == hiveSerde.get.outputFormat) - } else { - assertUnsupported(query, Seq("row format serde", "incompatible", s)) - } - } - } - - test("create hive table - row format delimited and generic file format") { - val allSources = Seq("parquet", "orc", "avro", "sequencefile", "rcfile", "textfile") - val supportedSources = Set("textfile") - - allSources.foreach { s => - val query = s"CREATE TABLE my_tab ROW FORMAT DELIMITED FIELDS TERMINATED BY ' ' STORED AS $s" - if (supportedSources.contains(s)) { - val ct = parseAs[CreateTable](query) - val hiveSerde = HiveSerDe.sourceToSerDe(s) - assert(hiveSerde.isDefined) - assert(ct.tableDesc.storage.serde == - hiveSerde.get.serde.orElse(Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))) - assert(ct.tableDesc.storage.inputFormat == hiveSerde.get.inputFormat) - assert(ct.tableDesc.storage.outputFormat == hiveSerde.get.outputFormat) - } else { - assertUnsupported(query, Seq("row format delimited", "only compatible with 'textfile'", s)) - } - } - } - - test("create hive external table - location must be specified") { - assertUnsupported( - sql = "CREATE EXTERNAL TABLE my_tab STORED AS parquet", - containsThesePhrases = Seq("create external table", "location")) - val query = "CREATE EXTERNAL TABLE my_tab STORED AS parquet LOCATION '/something/anything'" - val ct = parseAs[CreateTable](query) - assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) - assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) - } - - test("create hive table - property values must be set") { - assertUnsupported( - sql = "CREATE TABLE my_tab STORED AS parquet " + - "TBLPROPERTIES('key_without_value', 'key_with_value'='x')", - containsThesePhrases = Seq("key_without_value")) - assertUnsupported( - sql = "CREATE TABLE my_tab ROW FORMAT SERDE 'serde' " + - "WITH SERDEPROPERTIES('key_without_value', 'key_with_value'='x')", - containsThesePhrases = Seq("key_without_value")) - } - - test("create hive table - location implies external") { - val query = "CREATE TABLE my_tab STORED AS parquet LOCATION '/something/anything'" - val ct = parseAs[CreateTable](query) - assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) - assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) - } - - test("Duplicate clauses - create hive table") { - def createTableHeader(duplicateClause: String): String = { - s"CREATE TABLE my_tab(a INT, b STRING) STORED AS parquet $duplicateClause $duplicateClause" - } - - intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')"), - "Found duplicate clauses: TBLPROPERTIES") - intercept(createTableHeader("LOCATION '/tmp/file'"), - "Found duplicate clauses: LOCATION") - intercept(createTableHeader("COMMENT 'a table'"), - "Found duplicate clauses: COMMENT") - intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS"), - "Found duplicate clauses: CLUSTERED BY") - intercept(createTableHeader("PARTITIONED BY (k int)"), - "Found duplicate clauses: PARTITIONED BY") - intercept(createTableHeader("STORED AS parquet"), - "Found duplicate clauses: STORED AS/BY") - intercept( - createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'"), - "Found duplicate clauses: ROW FORMAT") - } - test("insert overwrite directory") { val v1 = "INSERT OVERWRITE DIRECTORY '/tmp/file' USING parquet SELECT 1 as a" parser.parsePlan(v1) match { @@ -359,180 +210,6 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { assert(e.contains("Found duplicate keys 'a'")) } - test("Test CTAS #1") { - val s1 = - """ - |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |COMMENT 'This is the staging page view table' - |STORED AS RCFILE - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src - """.stripMargin - - val s2 = - """ - |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |STORED AS RCFILE - |COMMENT 'This is the staging page view table' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |LOCATION '/user/external/page_view' - |AS SELECT * FROM src - """.stripMargin - - val s3 = - """ - |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |LOCATION '/user/external/page_view' - |STORED AS RCFILE - |COMMENT 'This is the staging page view table' - |AS SELECT * FROM src - """.stripMargin - - checkParsing(s1) - checkParsing(s2) - checkParsing(s3) - - def checkParsing(sql: String): Unit = { - val (desc, exists) = extractTableDesc(sql) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - assert(desc.comment == Some("This is the staging page view table")) - // TODO will be SQLText - assert(desc.viewText.isEmpty) - assert(desc.viewCatalogAndNamespace.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == - Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) - } - } - - test("Test CTAS #2") { - val s1 = - """ - |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |COMMENT 'This is the staging page view table' - |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' - | STORED AS - | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' - | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src - """.stripMargin - - val s2 = - """ - |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' - | STORED AS - | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' - | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' - |COMMENT 'This is the staging page view table' - |AS SELECT * FROM src - """.stripMargin - - checkParsing(s1) - checkParsing(s2) - - def checkParsing(sql: String): Unit = { - val (desc, exists) = extractTableDesc(sql) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - // TODO will be SQLText - assert(desc.comment == Some("This is the staging page view table")) - assert(desc.viewText.isEmpty) - assert(desc.viewCatalogAndNamespace.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) - } - } - - test("Test CTAS #3") { - val s3 = """CREATE TABLE page_view AS SELECT * FROM src""" - val (desc, exists) = extractTableDesc(s3) - assert(exists == false) - assert(desc.identifier.database == None) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.MANAGED) - assert(desc.storage.locationUri == None) - assert(desc.schema.isEmpty) - assert(desc.viewText == None) // TODO will be SQLText - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) - assert(desc.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) - assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - assert(desc.properties == Map()) - } - - test("Test CTAS #4") { - val s4 = - """CREATE TABLE page_view - |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin - intercept[AnalysisException] { - extractTableDesc(s4) - } - } - - test("Test CTAS #5") { - val s5 = """CREATE TABLE ctas2 - | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" - | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") - | STORED AS RCFile - | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") - | AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin - val (desc, exists) = extractTableDesc(s5) - assert(exists == false) - assert(desc.identifier.database == None) - assert(desc.identifier.table == "ctas2") - assert(desc.tableType == CatalogTableType.MANAGED) - assert(desc.storage.locationUri == None) - assert(desc.schema.isEmpty) - assert(desc.viewText == None) // TODO will be SQLText - assert(desc.viewCatalogAndNamespace.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.properties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) - assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) - } - - test("CTAS statement with a PARTITIONED BY clause is not allowed") { - assertUnsupported(s"CREATE TABLE ctas1 PARTITIONED BY (k int)" + - " AS SELECT key, value FROM (SELECT 1 as key, 2 as value) tmp") - } - - test("CTAS statement with schema") { - assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT * FROM src") - assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT 1, 'hello'") - } - test("unsupported operations") { intercept[ParseException] { parser.parsePlan( @@ -642,205 +319,6 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { """.stripMargin) } - test("create table - basic") { - val query = "CREATE TABLE my_table (id int, name string)" - val (desc, allowExisting) = extractTableDesc(query) - assert(!allowExisting) - assert(desc.identifier.database.isEmpty) - assert(desc.identifier.table == "my_table") - assert(desc.tableType == CatalogTableType.MANAGED) - assert(desc.schema == new StructType().add("id", "int").add("name", "string")) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.bucketSpec.isEmpty) - assert(desc.viewText.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.locationUri.isEmpty) - assert(desc.storage.inputFormat == - Some("org.apache.hadoop.mapred.TextInputFormat")) - assert(desc.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) - assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - assert(desc.storage.properties.isEmpty) - assert(desc.properties.isEmpty) - assert(desc.comment.isEmpty) - } - - test("create table - with database name") { - val query = "CREATE TABLE dbx.my_table (id int, name string)" - val (desc, _) = extractTableDesc(query) - assert(desc.identifier.database == Some("dbx")) - assert(desc.identifier.table == "my_table") - } - - test("create table - temporary") { - val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" - val e = intercept[ParseException] { parser.parsePlan(query) } - assert(e.message.contains("CREATE TEMPORARY TABLE is not supported yet")) - } - - test("create table - external") { - val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'" - val (desc, _) = extractTableDesc(query) - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/path/to/nowhere"))) - } - - test("create table - if not exists") { - val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)" - val (_, allowExisting) = extractTableDesc(query) - assert(allowExisting) - } - - test("create table - comment") { - val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'" - val (desc, _) = extractTableDesc(query) - assert(desc.comment == Some("its hot as hell below")) - } - - test("create table - partitioned columns") { - val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)" - val (desc, _) = extractTableDesc(query) - assert(desc.schema == new StructType() - .add("id", "int") - .add("name", "string") - .add("month", "int")) - assert(desc.partitionColumnNames == Seq("month")) - } - - test("create table - clustered by") { - val numBuckets = 10 - val bucketedColumn = "id" - val sortColumn = "id" - val baseQuery = - s""" - CREATE TABLE my_table ( - $bucketedColumn int, - name string) - CLUSTERED BY($bucketedColumn) - """ - - val query1 = s"$baseQuery INTO $numBuckets BUCKETS" - val (desc1, _) = extractTableDesc(query1) - assert(desc1.bucketSpec.isDefined) - val bucketSpec1 = desc1.bucketSpec.get - assert(bucketSpec1.numBuckets == numBuckets) - assert(bucketSpec1.bucketColumnNames.head.equals(bucketedColumn)) - assert(bucketSpec1.sortColumnNames.isEmpty) - - val query2 = s"$baseQuery SORTED BY($sortColumn) INTO $numBuckets BUCKETS" - val (desc2, _) = extractTableDesc(query2) - assert(desc2.bucketSpec.isDefined) - val bucketSpec2 = desc2.bucketSpec.get - assert(bucketSpec2.numBuckets == numBuckets) - assert(bucketSpec2.bucketColumnNames.head.equals(bucketedColumn)) - assert(bucketSpec2.sortColumnNames.head.equals(sortColumn)) - } - - test("create table(hive) - skewed by") { - val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY" - val query1 = s"$baseQuery(id) ON (1, 10, 100)" - val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))" - val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES" - val e1 = intercept[ParseException] { parser.parsePlan(query1) } - val e2 = intercept[ParseException] { parser.parsePlan(query2) } - val e3 = intercept[ParseException] { parser.parsePlan(query3) } - assert(e1.getMessage.contains("Operation not allowed")) - assert(e2.getMessage.contains("Operation not allowed")) - assert(e3.getMessage.contains("Operation not allowed")) - } - - test("create table(hive) - row format") { - val baseQuery = "CREATE TABLE my_table (id int, name string) ROW FORMAT" - val query1 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff'" - val query2 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')" - val query3 = - s""" - |$baseQuery DELIMITED FIELDS TERMINATED BY 'x' ESCAPED BY 'y' - |COLLECTION ITEMS TERMINATED BY 'a' - |MAP KEYS TERMINATED BY 'b' - |LINES TERMINATED BY '\n' - |NULL DEFINED AS 'c' - """.stripMargin - val (desc1, _) = extractTableDesc(query1) - val (desc2, _) = extractTableDesc(query2) - val (desc3, _) = extractTableDesc(query3) - assert(desc1.storage.serde == Some("org.apache.poof.serde.Baff")) - assert(desc1.storage.properties.isEmpty) - assert(desc2.storage.serde == Some("org.apache.poof.serde.Baff")) - assert(desc2.storage.properties == Map("k1" -> "v1")) - assert(desc3.storage.properties == Map( - "field.delim" -> "x", - "escape.delim" -> "y", - "serialization.format" -> "x", - "line.delim" -> "\n", - "colelction.delim" -> "a", // yes, it's a typo from Hive :) - "mapkey.delim" -> "b")) - } - - test("create table(hive) - file format") { - val baseQuery = "CREATE TABLE my_table (id int, name string) STORED AS" - val query1 = s"$baseQuery INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'" - val query2 = s"$baseQuery ORC" - val (desc1, _) = extractTableDesc(query1) - val (desc2, _) = extractTableDesc(query2) - assert(desc1.storage.inputFormat == Some("winput")) - assert(desc1.storage.outputFormat == Some("wowput")) - assert(desc1.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) - assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } - - test("create table(hive) - storage handler") { - val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY" - val query1 = s"$baseQuery 'org.papachi.StorageHandler'" - val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')" - val e1 = intercept[ParseException] { parser.parsePlan(query1) } - val e2 = intercept[ParseException] { parser.parsePlan(query2) } - assert(e1.getMessage.contains("Operation not allowed")) - assert(e2.getMessage.contains("Operation not allowed")) - } - - test("create table - properties") { - val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" - val (desc, _) = extractTableDesc(query) - assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) - } - - test("create table(hive) - everything!") { - val query = - """ - |CREATE EXTERNAL TABLE IF NOT EXISTS dbx.my_table (id int, name string) - |COMMENT 'no comment' - |PARTITIONED BY (month int) - |ROW FORMAT SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1') - |STORED AS INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput' - |LOCATION '/path/to/mercury' - |TBLPROPERTIES ('k1'='v1', 'k2'='v2') - """.stripMargin - val (desc, allowExisting) = extractTableDesc(query) - assert(allowExisting) - assert(desc.identifier.database == Some("dbx")) - assert(desc.identifier.table == "my_table") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.schema == new StructType() - .add("id", "int") - .add("name", "string") - .add("month", "int")) - assert(desc.partitionColumnNames == Seq("month")) - assert(desc.bucketSpec.isEmpty) - assert(desc.viewText.isEmpty) - assert(desc.viewCatalogAndNamespace.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.locationUri == Some(new URI("/path/to/mercury"))) - assert(desc.storage.inputFormat == Some("winput")) - assert(desc.storage.outputFormat == Some("wowput")) - assert(desc.storage.serde == Some("org.apache.poof.serde.Baff")) - assert(desc.storage.properties == Map("k1" -> "v1")) - assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) - assert(desc.comment == Some("no comment")) - } - test("create table like") { val v1 = "CREATE TABLE table1 LIKE table2" val (target, source, fileFormat, provider, properties, exists) = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 43a33860d2..4f79e71419 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2169,11 +2169,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { (1 to 10).map { i => (i, i) }.toDF("a", "b").createTempView("my_temp_tab") sql(s"CREATE TABLE my_ext_tab using parquet LOCATION '${tempDir.toURI}'") sql(s"CREATE VIEW my_view AS SELECT 1") - intercept[NoSuchTableException] { + val e1 = intercept[AnalysisException] { sql("TRUNCATE TABLE my_temp_tab") - } + }.getMessage + assert(e1.contains("my_temp_tab is a temp view. 'TRUNCATE TABLE' expects a table")) assertUnsupported("TRUNCATE TABLE my_ext_tab") - assertUnsupported("TRUNCATE TABLE my_view") + val e2 = intercept[AnalysisException] { + sql("TRUNCATE TABLE my_view") + }.getMessage + assert(e2.contains("default.my_view is a view. 'TRUNCATE TABLE' expects a table")) } } } @@ -2262,6 +2266,17 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("show columns - invalid db name") { + withTable("tbl") { + sql("CREATE TABLE tbl(col1 int, col2 string) USING parquet ") + val message = intercept[AnalysisException] { + sql("SHOW COLUMNS IN tbl FROM a.b.c") + }.getMessage + assert(message.contains( + "The namespace in session catalog must have exactly one name part: a.b.c.tbl")) + } + } + test("SPARK-18009 calling toLocalIterator on commands") { import scala.collection.JavaConverters._ val df = sql("show databases") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index fd1978c513..92c114e116 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -29,14 +29,14 @@ import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, CTESubstitution, EmptyFunctionRegistry, NoSuchTableException, ResolveCatalogs, ResolvedTable, ResolveInlineTables, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedV2Relation} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, StringLiteral} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, Assignment, CreateTableAsSelect, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, InsertIntoStatement, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, ShowTableProperties, SubqueryAlias, UpdateAction, UpdateTable} +import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} +import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, Assignment, CreateTableAsSelect, CreateTableStatement, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, InsertIntoStatement, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, ShowTableProperties, SubqueryAlias, UpdateAction, UpdateTable} import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table, TableCapability, TableCatalog, TableChange, V1Table} import org.apache.spark.sql.connector.catalog.TableChange.{UpdateColumnComment, UpdateColumnType} import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.sources.SimpleScanSource import org.apache.spark.sql.types.{CharType, DoubleType, HIVE_TYPE_STRING, IntegerType, LongType, MetadataBuilder, StringType, StructField, StructType} @@ -178,6 +178,16 @@ class PlanResolutionSuite extends AnalysisTest { }.head } + private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { + val e = intercept[ParseException] { + parsePlan(sql) + } + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) + containsThesePhrases.foreach { p => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(p.toLowerCase(Locale.ROOT))) + } + } + test("create table - with partitioned by") { val query = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " + "USING parquet PARTITIONED BY (a)" @@ -428,10 +438,11 @@ class PlanResolutionSuite extends AnalysisTest { val expectedProperties = Map( "p1" -> "v1", "p2" -> "v2", - "other" -> "20", + "option.other" -> "20", "provider" -> "parquet", "location" -> "s3://bucket/path/to/data", - "comment" -> "table comment") + "comment" -> "table comment", + "other" -> "20") parseAndResolve(sql) match { case create: CreateV2Table => @@ -467,10 +478,11 @@ class PlanResolutionSuite extends AnalysisTest { val expectedProperties = Map( "p1" -> "v1", "p2" -> "v2", - "other" -> "20", + "option.other" -> "20", "provider" -> "parquet", "location" -> "s3://bucket/path/to/data", - "comment" -> "table comment") + "comment" -> "table comment", + "other" -> "20") parseAndResolve(sql, withDefault = true) match { case create: CreateV2Table => @@ -542,10 +554,11 @@ class PlanResolutionSuite extends AnalysisTest { val expectedProperties = Map( "p1" -> "v1", "p2" -> "v2", - "other" -> "20", + "option.other" -> "20", "provider" -> "parquet", "location" -> "s3://bucket/path/to/data", - "comment" -> "table comment") + "comment" -> "table comment", + "other" -> "20") parseAndResolve(sql) match { case ctas: CreateTableAsSelect => @@ -576,10 +589,11 @@ class PlanResolutionSuite extends AnalysisTest { val expectedProperties = Map( "p1" -> "v1", "p2" -> "v2", - "other" -> "20", + "option.other" -> "20", "provider" -> "parquet", "location" -> "s3://bucket/path/to/data", - "comment" -> "table comment") + "comment" -> "table comment", + "other" -> "20") parseAndResolve(sql, withDefault = true) match { case ctas: CreateTableAsSelect => @@ -1557,6 +1571,630 @@ class PlanResolutionSuite extends AnalysisTest { checkFailure("testcat.tab", "foo") } + private def compareNormalized(plan1: LogicalPlan, plan2: LogicalPlan): Unit = { + /** + * Normalizes plans: + * - CreateTable the createTime in tableDesc will replaced by -1L. + */ + def normalizePlan(plan: LogicalPlan): LogicalPlan = { + plan match { + case CreateTable(tableDesc, mode, query) => + val newTableDesc = tableDesc.copy(createTime = -1L) + CreateTable(newTableDesc, mode, query) + case _ => plan // Don't transform + } + } + comparePlans(normalizePlan(plan1), normalizePlan(plan2)) + } + + test("create table - schema") { + def createTable( + table: String, + database: Option[String] = None, + tableType: CatalogTableType = CatalogTableType.MANAGED, + storage: CatalogStorageFormat = CatalogStorageFormat.empty.copy( + inputFormat = HiveSerDe.sourceToSerDe("textfile").get.inputFormat, + outputFormat = HiveSerDe.sourceToSerDe("textfile").get.outputFormat, + serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")), + schema: StructType = new StructType, + provider: Option[String] = Some("hive"), + partitionColumnNames: Seq[String] = Seq.empty, + comment: Option[String] = None, + mode: SaveMode = SaveMode.ErrorIfExists, + query: Option[LogicalPlan] = None): CreateTable = { + CreateTable( + CatalogTable( + identifier = TableIdentifier(table, database), + tableType = tableType, + storage = storage, + schema = schema, + provider = provider, + partitionColumnNames = partitionColumnNames, + comment = comment + ), mode, query + ) + } + + def compare(sql: String, plan: LogicalPlan): Unit = { + compareNormalized(parseAndResolve(sql), plan) + } + + compare("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) STORED AS textfile", + createTable( + table = "my_tab", + database = Some("default"), + schema = (new StructType) + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType) + ) + ) + compare("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) " + + "PARTITIONED BY (c INT, d STRING COMMENT 'test2')", + createTable( + table = "my_tab", + database = Some("default"), + schema = (new StructType) + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType) + .add("c", IntegerType) + .add("d", StringType, nullable = true, "test2"), + partitionColumnNames = Seq("c", "d") + ) + ) + compare("CREATE TABLE my_tab(id BIGINT, nested STRUCT) " + + "STORED AS textfile", + createTable( + table = "my_tab", + database = Some("default"), + schema = (new StructType) + .add("id", LongType) + .add("nested", (new StructType) + .add("col1", StringType) + .add("col2", IntegerType) + ) + ) + ) + // Partitioned by a StructType should be accepted by `SparkSqlParser` but will fail an analyze + // rule in `AnalyzeCreateTable`. + compare("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) " + + "PARTITIONED BY (nested STRUCT)", + createTable( + table = "my_tab", + database = Some("default"), + schema = (new StructType) + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType) + .add("nested", (new StructType) + .add("col1", StringType) + .add("col2", IntegerType) + ), + partitionColumnNames = Seq("nested") + ) + ) + + interceptParseException(parsePlan)( + "CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING)", + "extraneous input ':'") + } + + test("create hive table - table file format") { + val allSources = Seq("parquet", "parquetfile", "orc", "orcfile", "avro", "avrofile", + "sequencefile", "rcfile", "textfile") + + allSources.foreach { s => + val query = s"CREATE TABLE my_tab STORED AS $s" + parseAndResolve(query) match { + case ct: CreateTable => + val hiveSerde = HiveSerDe.sourceToSerDe(s) + assert(hiveSerde.isDefined) + assert(ct.tableDesc.storage.serde == + hiveSerde.get.serde.orElse(Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))) + assert(ct.tableDesc.storage.inputFormat == hiveSerde.get.inputFormat) + assert(ct.tableDesc.storage.outputFormat == hiveSerde.get.outputFormat) + } + } + } + + test("create hive table - row format and table file format") { + val createTableStart = "CREATE TABLE my_tab ROW FORMAT" + val fileFormat = s"STORED AS INPUTFORMAT 'inputfmt' OUTPUTFORMAT 'outputfmt'" + val query1 = s"$createTableStart SERDE 'anything' $fileFormat" + val query2 = s"$createTableStart DELIMITED FIELDS TERMINATED BY ' ' $fileFormat" + + // No conflicting serdes here, OK + parseAndResolve(query1) match { + case parsed1: CreateTable => + assert(parsed1.tableDesc.storage.serde == Some("anything")) + assert(parsed1.tableDesc.storage.inputFormat == Some("inputfmt")) + assert(parsed1.tableDesc.storage.outputFormat == Some("outputfmt")) + } + + parseAndResolve(query2) match { + case parsed2: CreateTable => + assert(parsed2.tableDesc.storage.serde == + Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(parsed2.tableDesc.storage.inputFormat == Some("inputfmt")) + assert(parsed2.tableDesc.storage.outputFormat == Some("outputfmt")) + } + } + + test("create hive table - row format serde and generic file format") { + val allSources = Seq("parquet", "orc", "avro", "sequencefile", "rcfile", "textfile") + val supportedSources = Set("sequencefile", "rcfile", "textfile") + + allSources.foreach { s => + val query = s"CREATE TABLE my_tab ROW FORMAT SERDE 'anything' STORED AS $s" + if (supportedSources.contains(s)) { + parseAndResolve(query) match { + case ct: CreateTable => + val hiveSerde = HiveSerDe.sourceToSerDe(s) + assert(hiveSerde.isDefined) + assert(ct.tableDesc.storage.serde == Some("anything")) + assert(ct.tableDesc.storage.inputFormat == hiveSerde.get.inputFormat) + assert(ct.tableDesc.storage.outputFormat == hiveSerde.get.outputFormat) + } + } else { + assertUnsupported(query, Seq("row format serde", "incompatible", s)) + } + } + } + + test("create hive table - row format delimited and generic file format") { + val allSources = Seq("parquet", "orc", "avro", "sequencefile", "rcfile", "textfile") + val supportedSources = Set("textfile") + + allSources.foreach { s => + val query = s"CREATE TABLE my_tab ROW FORMAT DELIMITED FIELDS TERMINATED BY ' ' STORED AS $s" + if (supportedSources.contains(s)) { + parseAndResolve(query) match { + case ct: CreateTable => + val hiveSerde = HiveSerDe.sourceToSerDe(s) + assert(hiveSerde.isDefined) + assert(ct.tableDesc.storage.serde == hiveSerde.get.serde + .orElse(Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))) + assert(ct.tableDesc.storage.inputFormat == hiveSerde.get.inputFormat) + assert(ct.tableDesc.storage.outputFormat == hiveSerde.get.outputFormat) + } + } else { + assertUnsupported(query, Seq("row format delimited", "only compatible with 'textfile'", s)) + } + } + } + + test("create hive external table - location must be specified") { + val exc = intercept[AnalysisException] { + parseAndResolve("CREATE EXTERNAL TABLE my_tab STORED AS parquet") + } + assert(exc.getMessage.contains("CREATE EXTERNAL TABLE must be accompanied by LOCATION")) + + val query = "CREATE EXTERNAL TABLE my_tab STORED AS parquet LOCATION '/something/anything'" + parseAndResolve(query) match { + case ct: CreateTable => + assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) + assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) + } + } + + test("create hive table - property values must be set") { + assertUnsupported( + sql = "CREATE TABLE my_tab STORED AS parquet " + + "TBLPROPERTIES('key_without_value', 'key_with_value'='x')", + containsThesePhrases = Seq("key_without_value")) + assertUnsupported( + sql = "CREATE TABLE my_tab ROW FORMAT SERDE 'serde' " + + "WITH SERDEPROPERTIES('key_without_value', 'key_with_value'='x')", + containsThesePhrases = Seq("key_without_value")) + } + + test("create hive table - location implies external") { + val query = "CREATE TABLE my_tab STORED AS parquet LOCATION '/something/anything'" + parseAndResolve(query) match { + case ct: CreateTable => + assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) + assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) + } + } + + test("Duplicate clauses - create hive table") { + def intercept(sqlCommand: String, messages: String*): Unit = + interceptParseException(parsePlan)(sqlCommand, messages: _*) + + def createTableHeader(duplicateClause: String): String = { + s"CREATE TABLE my_tab(a INT, b STRING) STORED AS parquet $duplicateClause $duplicateClause" + } + + intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')"), + "Found duplicate clauses: TBLPROPERTIES") + intercept(createTableHeader("LOCATION '/tmp/file'"), + "Found duplicate clauses: LOCATION") + intercept(createTableHeader("COMMENT 'a table'"), + "Found duplicate clauses: COMMENT") + intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS"), + "Found duplicate clauses: CLUSTERED BY") + intercept(createTableHeader("PARTITIONED BY (k int)"), + "Found duplicate clauses: PARTITIONED BY") + intercept(createTableHeader("STORED AS parquet"), + "Found duplicate clauses: STORED AS/BY") + intercept( + createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'"), + "Found duplicate clauses: ROW FORMAT") + } + + test("Test CTAS #1") { + val s1 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |COMMENT 'This is the staging page view table' + |STORED AS RCFILE + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + // TODO will be SQLText + assert(desc.viewText.isEmpty) + assert(desc.viewCatalogAndNamespace.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + } + + test("Test CTAS #2") { + val s1 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |COMMENT 'This is the staging page view table' + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + // TODO will be SQLText + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewCatalogAndNamespace.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + } + + test("Test CTAS #3") { + val s3 = """CREATE TABLE page_view AS SELECT * FROM src""" + val (desc, exists) = extractTableDesc(s3) + assert(exists == false) + assert(desc.identifier.database == Some("default")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.MANAGED) + assert(desc.storage.locationUri == None) + assert(desc.schema.isEmpty) + assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(desc.properties == Map()) + } + + test("Test CTAS #4") { + val s4 = + """CREATE TABLE page_view + |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin + intercept[AnalysisException] { + extractTableDesc(s4) + } + } + + test("Test CTAS #5") { + val s5 = """CREATE TABLE ctas2 + | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" + | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") + | STORED AS RCFile + | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") + | AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin + val (desc, exists) = extractTableDesc(s5) + assert(exists == false) + assert(desc.identifier.database == Some("default")) + assert(desc.identifier.table == "ctas2") + assert(desc.tableType == CatalogTableType.MANAGED) + assert(desc.storage.locationUri == None) + assert(desc.schema.isEmpty) + assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewCatalogAndNamespace.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.properties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) + assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) + } + + test("CTAS statement with a PARTITIONED BY clause is not allowed") { + assertUnsupported(s"CREATE TABLE ctas1 PARTITIONED BY (k int)" + + " AS SELECT key, value FROM (SELECT 1 as key, 2 as value) tmp") + } + + test("CTAS statement with schema") { + assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT * FROM src") + assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT 1, 'hello'") + } + + test("create table - basic") { + val query = "CREATE TABLE my_table (id int, name string)" + val (desc, allowExisting) = extractTableDesc(query) + assert(!allowExisting) + assert(desc.identifier.database == Some("default")) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.MANAGED) + assert(desc.schema == new StructType().add("id", "int").add("name", "string")) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.bucketSpec.isEmpty) + assert(desc.viewText.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.locationUri.isEmpty) + assert(desc.storage.inputFormat == + Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(desc.storage.properties.isEmpty) + assert(desc.properties.isEmpty) + assert(desc.comment.isEmpty) + } + + test("create table - with database name") { + val query = "CREATE TABLE dbx.my_table (id int, name string)" + val (desc, _) = extractTableDesc(query) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + } + + test("create table - temporary") { + val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" + val e = intercept[ParseException] { parsePlan(query) } + assert(e.message.contains("Operation not allowed: CREATE TEMPORARY TABLE")) + } + + test("create table - external") { + val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'" + val (desc, _) = extractTableDesc(query) + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/path/to/nowhere"))) + } + + test("create table - if not exists") { + val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)" + val (_, allowExisting) = extractTableDesc(query) + assert(allowExisting) + } + + test("create table - comment") { + val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'" + val (desc, _) = extractTableDesc(query) + assert(desc.comment == Some("its hot as hell below")) + } + + test("create table - partitioned columns") { + val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)" + val (desc, _) = extractTableDesc(query) + assert(desc.schema == new StructType() + .add("id", "int") + .add("name", "string") + .add("month", "int")) + assert(desc.partitionColumnNames == Seq("month")) + } + + test("create table - clustered by") { + val numBuckets = 10 + val bucketedColumn = "id" + val sortColumn = "id" + val baseQuery = + s""" + CREATE TABLE my_table ( + $bucketedColumn int, + name string) + CLUSTERED BY($bucketedColumn) + """ + + val query1 = s"$baseQuery INTO $numBuckets BUCKETS" + val (desc1, _) = extractTableDesc(query1) + assert(desc1.bucketSpec.isDefined) + val bucketSpec1 = desc1.bucketSpec.get + assert(bucketSpec1.numBuckets == numBuckets) + assert(bucketSpec1.bucketColumnNames.head.equals(bucketedColumn)) + assert(bucketSpec1.sortColumnNames.isEmpty) + + val query2 = s"$baseQuery SORTED BY($sortColumn) INTO $numBuckets BUCKETS" + val (desc2, _) = extractTableDesc(query2) + assert(desc2.bucketSpec.isDefined) + val bucketSpec2 = desc2.bucketSpec.get + assert(bucketSpec2.numBuckets == numBuckets) + assert(bucketSpec2.bucketColumnNames.head.equals(bucketedColumn)) + assert(bucketSpec2.sortColumnNames.head.equals(sortColumn)) + } + + test("create table(hive) - skewed by") { + val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY" + val query1 = s"$baseQuery(id) ON (1, 10, 100)" + val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))" + val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES" + val e1 = intercept[ParseException] { parsePlan(query1) } + val e2 = intercept[ParseException] { parsePlan(query2) } + val e3 = intercept[ParseException] { parsePlan(query3) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + assert(e3.getMessage.contains("Operation not allowed")) + } + + test("create table(hive) - row format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) ROW FORMAT" + val query1 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff'" + val query2 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')" + val query3 = + s""" + |$baseQuery DELIMITED FIELDS TERMINATED BY 'x' ESCAPED BY 'y' + |COLLECTION ITEMS TERMINATED BY 'a' + |MAP KEYS TERMINATED BY 'b' + |LINES TERMINATED BY '\n' + |NULL DEFINED AS 'c' + """.stripMargin + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + val (desc3, _) = extractTableDesc(query3) + assert(desc1.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc1.storage.properties.isEmpty) + assert(desc2.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc2.storage.properties == Map("k1" -> "v1")) + assert(desc3.storage.properties == Map( + "field.delim" -> "x", + "escape.delim" -> "y", + "serialization.format" -> "x", + "line.delim" -> "\n", + "colelction.delim" -> "a", // yes, it's a typo from Hive :) + "mapkey.delim" -> "b")) + } + + test("create table(hive) - file format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED AS" + val query1 = s"$baseQuery INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'" + val query2 = s"$baseQuery ORC" + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + assert(desc1.storage.inputFormat == Some("winput")) + assert(desc1.storage.outputFormat == Some("wowput")) + assert(desc1.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } + + test("create table(hive) - storage handler") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY" + val query1 = s"$baseQuery 'org.papachi.StorageHandler'" + val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')" + val e1 = intercept[ParseException] { parsePlan(query1) } + val e2 = intercept[ParseException] { parsePlan(query2) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + } + + test("create table - properties") { + val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" + parsePlan(query) match { + case state: CreateTableStatement => + assert(state.properties == Map("k1" -> "v1", "k2" -> "v2")) + } + } + + test("create table(hive) - everything!") { + val query = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS dbx.my_table (id int, name string) + |COMMENT 'no comment' + |PARTITIONED BY (month int) + |ROW FORMAT SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1') + |STORED AS INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput' + |LOCATION '/path/to/mercury' + |TBLPROPERTIES ('k1'='v1', 'k2'='v2') + """.stripMargin + val (desc, allowExisting) = extractTableDesc(query) + assert(allowExisting) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.schema == new StructType() + .add("id", "int") + .add("name", "string") + .add("month", "int")) + assert(desc.partitionColumnNames == Seq("month")) + assert(desc.bucketSpec.isEmpty) + assert(desc.viewText.isEmpty) + assert(desc.viewCatalogAndNamespace.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.storage.locationUri == Some(new URI("/path/to/mercury"))) + assert(desc.storage.inputFormat == Some("winput")) + assert(desc.storage.outputFormat == Some("wowput")) + assert(desc.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc.storage.properties == Map("k1" -> "v1")) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) + assert(desc.comment == Some("no comment")) + } + // TODO: add tests for more commands. } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 983209051c..00c599065c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -166,13 +166,13 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSparkSession { ) }.getMessage assert(error.contains("Operation not allowed") && - error.contains("CREATE TEMPORARY TABLE ... USING ... AS query")) + error.contains("CREATE TEMPORARY TABLE")) } } test("disallows CREATE EXTERNAL TABLE ... USING ... AS query") { withTable("t") { - val error = intercept[ParseException] { + val error = intercept[AnalysisException] { sql( s""" |CREATE EXTERNAL TABLE t USING PARQUET diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 6df1c5db14..52825a155e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.sources import java.io.File import java.sql.Timestamp -import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.spark.TestUtils import org.apache.spark.internal.Logging @@ -164,4 +165,48 @@ class PartitionedWriteSuite extends QueryTest with SharedSparkSession { assert(e.getMessage.contains("Found duplicate column(s) b, b: `b`;")) } } + + test("SPARK-27194 SPARK-29302: Fix commit collision in dynamic partition overwrite mode") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> + SQLConf.PartitionOverwriteMode.DYNAMIC.toString, + SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> + classOf[PartitionFileExistCommitProtocol].getName) { + withTempDir { d => + withTable("t") { + sql( + s""" + | create table t(c1 int, p1 int) using parquet partitioned by (p1) + | location '${d.getAbsolutePath}' + """.stripMargin) + + val df = Seq((1, 2)).toDF("c1", "p1") + df.write + .partitionBy("p1") + .mode("overwrite") + .saveAsTable("t") + checkAnswer(sql("select * from t"), df) + } + } + } + } +} + +/** + * A file commit protocol with pre-created partition file. when try to overwrite partition dir + * in dynamic partition mode, FileAlreadyExist exception would raise without SPARK-27194 + */ +private class PartitionFileExistCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean) + extends SQLHadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) { + override def setupJob(jobContext: JobContext): Unit = { + super.setupJob(jobContext) + val stagingDir = new File(new Path(path).toUri.getPath, s".spark-staging-$jobId") + stagingDir.mkdirs() + val stagingPartDir = new File(stagingDir, "p1=2") + stagingPartDir.mkdirs() + val conflictTaskFile = new File(stagingPartDir, s"part-00000-$jobId.c000.snappy.parquet") + conflictTaskFile.createNewFile() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPageSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPageSuite.scala index 640c21c52a..c2b6688faf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPageSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/StreamingQueryPageSuite.scala @@ -24,8 +24,10 @@ import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.scalatest.BeforeAndAfter import scala.xml.Node +import org.apache.spark.SparkConf import org.apache.spark.sql.streaming.StreamingQueryProgress import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.ui.SparkUI class StreamingQueryPageSuite extends SharedSparkSession with BeforeAndAfter { @@ -65,10 +67,13 @@ class StreamingQueryPageSuite extends SharedSparkSession with BeforeAndAfter { val request = mock(classOf[HttpServletRequest]) val tab = mock(classOf[StreamingQueryTab], RETURNS_SMART_NULLS) val statusListener = mock(classOf[StreamingQueryStatusListener], RETURNS_SMART_NULLS) + val ui = mock(classOf[SparkUI]) when(request.getParameter("id")).thenReturn(id.toString) when(tab.appName).thenReturn("testing") when(tab.headerTabs).thenReturn(Seq.empty) when(tab.statusListener).thenReturn(statusListener) + when(ui.conf).thenReturn(new SparkConf()) + when(tab.parent).thenReturn(ui) val streamQuery = createStreamQueryUIData(id) when(statusListener.allQueryStatus).thenReturn(Seq(streamQuery)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala index 307479db33..db3d6529c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala @@ -31,7 +31,10 @@ import org.apache.spark.internal.config.UI.{UI_ENABLED, UI_PORT} import org.apache.spark.sql.LocalSparkSession.withSparkSession import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.streaming.StreamingQueryException +import org.apache.spark.sql.functions.{window => windowFn, _} +import org.apache.spark.sql.internal.SQLConf.SHUFFLE_PARTITIONS +import org.apache.spark.sql.internal.StaticSQLConf.ENABLED_STREAMING_UI_CUSTOM_METRIC_LIST +import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger} import org.apache.spark.ui.SparkUICssErrorHandler class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll { @@ -51,8 +54,10 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val conf = new SparkConf() .setMaster(master) .setAppName("ui-test") + .set(SHUFFLE_PARTITIONS, 5) .set(UI_ENABLED, true) .set(UI_PORT, 0) + .set(ENABLED_STREAMING_UI_CUSTOM_METRIC_LIST, Seq("stateOnCurrentVersionSizeBytes")) additionalConfs.foreach { case (k, v) => conf.set(k, v) } val spark = SparkSession.builder().master(master).config(conf).getOrCreate() assert(spark.sparkContext.ui.isDefined) @@ -77,10 +82,15 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val input1 = spark.readStream.format("rate").load() val input2 = spark.readStream.format("rate").load() + val input3 = spark.readStream.format("rate").load() val activeQuery = - input1.join(input2, "value").writeStream.format("noop").start() + input1.selectExpr("timestamp", "mod(value, 100) as mod", "value") + .withWatermark("timestamp", "0 second") + .groupBy(windowFn($"timestamp", "10 seconds", "2 seconds"), $"mod") + .agg(avg("value").as("avg_value")) + .writeStream.format("noop").trigger(Trigger.ProcessingTime("5 seconds")).start() val completedQuery = - input1.join(input2, "value").writeStream.format("noop").start() + input2.join(input3, "value").writeStream.format("noop").start() completedQuery.stop() val failedQuery = spark.readStream.format("rate").load().select("value").as[Long] .map(_ / 0).writeStream.format("noop").start() @@ -136,10 +146,15 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B summaryText should contain ("Input Rows (?)") summaryText should contain ("Batch Duration (?)") summaryText should contain ("Operation Duration (?)") + summaryText should contain ("Global Watermark Gap (?)") summaryText should contain ("Aggregated Number Of Total State Rows (?)") summaryText should contain ("Aggregated Number Of Updated State Rows (?)") summaryText should contain ("Aggregated State Memory Used In Bytes (?)") summaryText should contain ("Aggregated Number Of Rows Dropped By Watermark (?)") + summaryText should contain ("Aggregated Custom Metric stateOnCurrentVersionSizeBytes" + + " (?)") + summaryText should not contain ("Aggregated Custom Metric loadedMapCacheHitCount (?)") + summaryText should not contain ("Aggregated Custom Metric loadedMapCacheMissCount (?)") } } finally { spark.streams.active.foreach(_.stop()) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index d9b6bb43c2..462206d8c5 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -40,8 +40,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone - private val originalLegacyAllowCastNumericToTimestamp = - TestHive.conf.legacyAllowCastNumericToTimestamp def testCases: Seq[(String, File)] = { hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) @@ -61,8 +59,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") - // Ensures that cast numeric to timestamp enabled so that we can test them - TestHive.setConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP, true) RuleExecutor.resetMetrics() } @@ -73,8 +69,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) - TestHive.setConf(SQLConf.LEGACY_ALLOW_CAST_NUMERIC_TO_TIMESTAMP, - originalLegacyAllowCastNumericToTimestamp) // For debugging dump some statistics about how much time was spent in various optimizer rules logWarning(RuleExecutor.dumpTimeSpent()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 907bb86ad0..54c237f78c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{PartitioningUtils, SourceOptions} import org.apache.spark.sql.hive.client.HiveClient @@ -1264,11 +1264,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat defaultTimeZoneId: String): Seq[CatalogTablePartition] = withClient { val rawTable = getRawTable(db, table) val catalogTable = restoreTableMetadata(rawTable) + val timeZoneId = CaseInsensitiveMap(catalogTable.storage.properties).getOrElse( + DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId) val partColNameMap = buildLowerCasePartColNameMap(catalogTable) val clientPrunedPartitions = - client.getPartitionsByFilter(rawTable, predicates).map { part => + client.getPartitionsByFilter(rawTable, predicates, timeZoneId).map { part => part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) } prunePartitionsByFilter(catalogTable, clientPrunedPartitions, predicates, defaultTimeZoneId) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 8ab6e28366..9213173bbc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -1039,6 +1039,7 @@ private[hive] trait HiveInspectors { private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) + case dt => throw new AnalysisException(s"${dt.catalogString} is not supported.") } def toTypeInfo: TypeInfo = dt match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 3ea80eaf6f..48f3837740 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -233,7 +233,8 @@ private[hive] trait HiveClient { /** Returns partitions filtered by predicates for the given table. */ def getPartitionsByFilter( catalogTable: CatalogTable, - predicates: Seq[Expression]): Seq[CatalogTablePartition] + predicates: Seq[Expression], + timeZoneId: String): Seq[CatalogTablePartition] /** Loads a static partition into an existing table. */ def loadPartition( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 9bc99b08c2..b2f0867114 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -733,9 +733,11 @@ private[hive] class HiveClientImpl( override def getPartitionsByFilter( table: CatalogTable, - predicates: Seq[Expression]): Seq[CatalogTablePartition] = withHiveState { + predicates: Seq[Expression], + timeZoneId: String): Seq[CatalogTablePartition] = withHiveState { val hiveTable = toHiveTable(table, Some(userName)) - val parts = shim.getPartitionsByFilter(client, hiveTable, predicates).map(fromHivePartition) + val parts = shim.getPartitionsByFilter(client, hiveTable, predicates, timeZoneId) + .map(fromHivePartition) HiveCatalogMetrics.incrementFetchedPartitions(parts.length) parts } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index d989f0154e..17a64a67df 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -45,9 +45,9 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TypeUtils} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{AtomicType, IntegralType, StringType} +import org.apache.spark.sql.types.{AtomicType, DateType, IntegralType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -79,7 +79,11 @@ private[client] sealed abstract class Shim { def getAllPartitions(hive: Hive, table: Table): Seq[Partition] - def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition] + def getPartitionsByFilter( + hive: Hive, + table: Table, + predicates: Seq[Expression], + timeZoneId: String): Seq[Partition] def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor @@ -349,7 +353,8 @@ private[client] class Shim_v0_12 extends Shim with Logging { override def getPartitionsByFilter( hive: Hive, table: Table, - predicates: Seq[Expression]): Seq[Partition] = { + predicates: Seq[Expression], + timeZoneId: String): Seq[Partition] = { // getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12. // See HIVE-4888. logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " + @@ -632,7 +637,9 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { * * Unsupported predicates are skipped. */ - def convertFilters(table: Table, filters: Seq[Expression]): String = { + def convertFilters(table: Table, filters: Seq[Expression], timeZoneId: String): String = { + lazy val dateFormatter = DateFormatter(DateTimeUtils.getZoneId(timeZoneId)) + /** * An extractor that matches all binary comparison operators except null-safe equality. * @@ -650,6 +657,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { case Literal(null, _) => None // `null`s can be cast as other types; we want to avoid NPEs. case Literal(value, _: IntegralType) => Some(value.toString) case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString)) + case Literal(value, _: DateType) => + Some(dateFormatter.format(value.asInstanceOf[Int])) case _ => None } } @@ -700,6 +709,21 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } } + object ExtractableDateValues { + private lazy val valueToLiteralString: PartialFunction[Any, String] = { + case value: Int => dateFormatter.format(value) + } + + def unapply(values: Set[Any]): Option[Seq[String]] = { + val extractables = values.toSeq.map(valueToLiteralString.lift) + if (extractables.nonEmpty && extractables.forall(_.isDefined)) { + Some(extractables.map(_.get)) + } else { + None + } + } + } + object SupportedAttribute { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. private val varcharKeys = table.getPartitionKeys.asScala @@ -711,7 +735,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { val resolver = SQLConf.get.resolver if (varcharKeys.exists(c => resolver(c, attr.name))) { None - } else if (attr.dataType.isInstanceOf[IntegralType] || attr.dataType == StringType) { + } else if (attr.dataType.isInstanceOf[IntegralType] || attr.dataType == StringType || + attr.dataType == DateType) { Some(attr.name) } else { None @@ -748,6 +773,10 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { convert(And(GreaterThanOrEqual(child, Literal(sortedValues.head, dataType)), LessThanOrEqual(child, Literal(sortedValues.last, dataType)))) + case InSet(child @ ExtractAttribute(SupportedAttribute(name)), ExtractableDateValues(values)) + if useAdvanced && child.dataType == DateType => + Some(convertInToOr(name, values)) + case InSet(ExtractAttribute(SupportedAttribute(name)), ExtractableValues(values)) if useAdvanced => Some(convertInToOr(name, values)) @@ -803,11 +832,12 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { override def getPartitionsByFilter( hive: Hive, table: Table, - predicates: Seq[Expression]): Seq[Partition] = { + predicates: Seq[Expression], + timeZoneId: String): Seq[Partition] = { // Hive getPartitionsByFilter() takes a string that represents partition // predicates like "str_key=\"value\" and int_key=1 ..." - val filter = convertFilters(table, predicates) + val filter = convertFilters(table, predicates, timeZoneId) val partitions = if (filter.isEmpty) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 12b409e487..6c0531182e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive.client +import java.sql.Date import java.util.Collections import org.apache.hadoop.hive.metastore.api.FieldSchema @@ -29,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * A set of tests for the filter conversion logic used when pushing partition pruning into the @@ -63,6 +65,28 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil, "1 = intcol and \"a\" = strcol") + filterTest("date filter", + (a("datecol", DateType) === Literal(Date.valueOf("2019-01-01"))) :: Nil, + "datecol = 2019-01-01") + + filterTest("date filter with IN predicate", + (a("datecol", DateType) in + (Literal(Date.valueOf("2019-01-01")), Literal(Date.valueOf("2019-01-07")))) :: Nil, + "(datecol = 2019-01-01 or datecol = 2019-01-07)") + + filterTest("date and string filter", + (Literal(Date.valueOf("2019-01-01")) === a("datecol", DateType)) :: + (Literal("a") === a("strcol", IntegerType)) :: Nil, + "2019-01-01 = datecol and \"a\" = strcol") + + filterTest("date filter with null", + (a("datecol", DateType) === Literal(null)) :: Nil, + "") + + filterTest("string filter with InSet predicate", + InSet(a("strcol", StringType), Set("1", "2").map(s => UTF8String.fromString(s))) :: Nil, + "(strcol = \"1\" or strcol = \"2\")") + filterTest("skip varchar", (Literal("") === a("varchar", StringType)) :: Nil, "") @@ -89,7 +113,7 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") { - val converted = shim.convertFilters(testTable, filters) + val converted = shim.convertFilters(testTable, filters, conf.sessionLocalTimeZone) if (converted != result) { fail(s"Expected ${filters.mkString(",")} to convert to '$result' but got '$converted'") } @@ -104,7 +128,7 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { val filters = (Literal(1) === a("intcol", IntegerType) || Literal(2) === a("intcol", IntegerType)) :: Nil - val converted = shim.convertFilters(testTable, filters) + val converted = shim.convertFilters(testTable, filters, conf.sessionLocalTimeZone) if (enabled) { assert(converted == "(1 = intcol or 2 = intcol)") } else { @@ -116,7 +140,7 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { test("SPARK-33416: Avoid Hive metastore stack overflow when InSet predicate have many values") { def checkConverted(inSet: InSet, result: String): Unit = { - assert(shim.convertFilters(testTable, inSet :: Nil) == result) + assert(shim.convertFilters(testTable, inSet :: Nil, conf.sessionLocalTimeZone) == result) } withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING_INSET_THRESHOLD.key -> "15") { @@ -139,6 +163,11 @@ class FiltersSuite extends SparkFunSuite with Logging with PlanTest { InSet(a("doublecol", DoubleType), Range(1, 20).map(s => Literal(s.toDouble).eval(EmptyRow)).toSet), "") + + checkConverted( + InSet(a("datecol", DateType), + Range(1, 20).map(d => Literal(d, DateType).eval(EmptyRow)).toSet), + "(datecol >= 1970-01-02 and datecol <= 1970-01-20)") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala index 81186909bb..ab83f751f1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.client +import java.sql.Date + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat @@ -28,7 +30,8 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{BooleanType, DateType, IntegerType, LongType, StringType, StructType} import org.apache.spark.util.Utils class HivePartitionFilteringSuite(version: String) @@ -38,15 +41,16 @@ class HivePartitionFilteringSuite(version: String) private val testPartitionCount = 3 * 5 * 4 - private def init(tryDirectSql: Boolean): HiveClient = { - val storageFormat = CatalogStorageFormat( - locationUri = None, - inputFormat = None, - outputFormat = None, - serde = None, - compressed = false, - properties = Map.empty) + private val storageFormat = CatalogStorageFormat( + locationUri = None, + inputFormat = Some(classOf[TextInputFormat].getName), + outputFormat = Some(classOf[HiveIgnoreKeyTextOutputFormat[_, _]].getName), + serde = Some(classOf[LazySimpleSerDe].getName()), + compressed = false, + properties = Map.empty + ) + private def init(tryDirectSql: Boolean): HiveClient = { val hadoopConf = new Configuration() hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) hadoopConf.set("hive.metastore.warehouse.dir", Utils.createTempDir().toURI().toString()) @@ -58,14 +62,7 @@ class HivePartitionFilteringSuite(version: String) tableType = CatalogTableType.MANAGED, schema = tableSchema, partitionColumnNames = Seq("ds", "h", "chunk"), - storage = CatalogStorageFormat( - locationUri = None, - inputFormat = Some(classOf[TextInputFormat].getName), - outputFormat = Some(classOf[HiveIgnoreKeyTextOutputFormat[_, _]].getName), - serde = Some(classOf[LazySimpleSerDe].getName()), - compressed = false, - properties = Map.empty - )) + storage = storageFormat) client.createTable(table, ignoreIfExists = false) val partitions = @@ -102,7 +99,7 @@ class HivePartitionFilteringSuite(version: String) test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { val client = init(false) val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), - Seq(attr("ds") === 20170101)) + Seq(attr("ds") === 20170101), SQLConf.get.sessionLocalTimeZone) assert(filteredPartitions.size == testPartitionCount) } @@ -297,6 +294,63 @@ class HivePartitionFilteringSuite(version: String) day :: Nil) } + test("getPartitionsByFilter: date type pruning by metastore") { + val table = CatalogTable( + identifier = TableIdentifier("test_date", Some("default")), + tableType = CatalogTableType.MANAGED, + schema = new StructType().add("value", "int").add("part", "date"), + partitionColumnNames = Seq("part"), + storage = storageFormat) + client.createTable(table, ignoreIfExists = false) + + val partitions = + for { + date <- Seq("2019-01-01", "2019-01-02", "2019-01-03", "2019-01-04") + } yield CatalogTablePartition(Map( + "part" -> date + ), storageFormat) + assert(partitions.size == 4) + + client.createPartitions("default", "test_date", partitions, ignoreIfExists = false) + + def testDataTypeFiltering( + filterExprs: Seq[Expression], + expectedPartitionCubes: Seq[Seq[Date]]): Unit = { + val filteredPartitions = client.getPartitionsByFilter( + client.getTable("default", "test_date"), + filterExprs, + SQLConf.get.sessionLocalTimeZone) + + val expectedPartitions = expectedPartitionCubes.map { + expectedDt => + for { + dt <- expectedDt + } yield Set( + "part" -> dt.toString + ) + }.reduce(_ ++ _) + + assert(filteredPartitions.map(_.spec.toSet).toSet == expectedPartitions.toSet) + } + + val dateAttr: Attribute = AttributeReference("part", DateType)() + + testDataTypeFiltering( + Seq(dateAttr === Date.valueOf("2019-01-01")), + Seq("2019-01-01").map(Date.valueOf) :: Nil) + testDataTypeFiltering( + Seq(dateAttr > Date.valueOf("2019-01-02")), + Seq("2019-01-03", "2019-01-04").map(Date.valueOf) :: Nil) + testDataTypeFiltering( + Seq(In(dateAttr, + Seq("2019-01-01", "2019-01-02").map(d => Literal(Date.valueOf(d))))), + Seq("2019-01-01", "2019-01-02").map(Date.valueOf) :: Nil) + testDataTypeFiltering( + Seq(InSet(dateAttr, + Set("2019-01-01", "2019-01-02").map(d => Literal(Date.valueOf(d)).eval(EmptyRow)))), + Seq("2019-01-01", "2019-01-02").map(Date.valueOf) :: Nil) + } + private def testMetastorePartitionFiltering( filterExpr: Expression, expectedDs: Seq[Int], @@ -333,7 +387,7 @@ class HivePartitionFilteringSuite(version: String) val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), Seq( transform(filterExpr) - )) + ), SQLConf.get.sessionLocalTimeZone) val expectedPartitionCount = expectedPartitionCubes.map { case (expectedDs, expectedH, expectedChunks) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index c5c92ddad9..d9ba6dd80e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -488,7 +488,8 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: getPartitionsByFilter") { // Only one partition [1, 1] for key2 == 1 val result = client.getPartitionsByFilter(client.getTable("default", "src_part"), - Seq(EqualTo(AttributeReference("key2", IntegerType)(), Literal(1)))) + Seq(EqualTo(AttributeReference("key2", IntegerType)(), Literal(1))), + versionSpark.conf.sessionLocalTimeZone) // Hive 0.12 doesn't support getPartitionsByFilter, it ignores the filter condition. if (version != "0.12") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 56b8716444..b8b1da4cb9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -598,8 +598,7 @@ class HiveDDLSuite val e = intercept[AnalysisException] { sql("CREATE TABLE tbl(a int) PARTITIONED BY (b) STORED AS parquet") } - assert(e.message.contains("Must specify a data type for each partition column while creating " + - "Hive partitioned table.")) + assert(e.message.contains("partition column b is not defined in table")) } test("add/drop partition with location - managed table") { @@ -2701,8 +2700,7 @@ class HiveDDLSuite |AS SELECT 1 as a, "a" as b """.stripMargin) }.getMessage - assert(err1.contains("Schema may not be specified in a Create Table As Select " + - "(CTAS) statement")) + assert(err1.contains("Schema may not be specified in a Create Table As Select")) val err2 = intercept[ParseException] { spark.sql( @@ -2713,8 +2711,7 @@ class HiveDDLSuite |AS SELECT 1 as a, "a" as b """.stripMargin) }.getMessage - assert(err2.contains("Create Partitioned Table As Select cannot specify data type for " + - "the partition columns of the target table")) + assert(err2.contains("Partition column types may not be specified in Create Table As Select")) } test("Hive CTAS with dynamic partition") { @@ -2783,7 +2780,7 @@ class HiveDDLSuite |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' """.stripMargin) }.getMessage - assert(e.contains("'ROW FORMAT' must be used with 'STORED AS'")) + assert(e.contains("Operation not allowed: CREATE TABLE LIKE ... USING ... ROW FORMAT SERDE")) // row format doesn't work with provider hive e = intercept[AnalysisException] { @@ -2794,7 +2791,7 @@ class HiveDDLSuite |WITH SERDEPROPERTIES ('test' = 'test') """.stripMargin) }.getMessage - assert(e.contains("'ROW FORMAT' must be used with 'STORED AS'")) + assert(e.contains("Operation not allowed: CREATE TABLE LIKE ... USING ... ROW FORMAT SERDE")) // row format doesn't work without 'STORED AS' e = intercept[AnalysisException] { @@ -2807,6 +2804,17 @@ class HiveDDLSuite }.getMessage assert(e.contains("'ROW FORMAT' must be used with 'STORED AS'")) + // 'INPUTFORMAT' and 'OUTPUTFORMAT' conflict with 'USING' + e = intercept[AnalysisException] { + spark.sql( + """ + |CREATE TABLE targetDsTable LIKE sourceDsTable USING format + |STORED AS INPUTFORMAT 'inFormat' OUTPUTFORMAT 'outFormat' + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + """.stripMargin) + }.getMessage + assert(e.contains("Operation not allowed: CREATE TABLE LIKE ... USING ... STORED AS")) + // row format works with STORED AS hive format (from hive table) spark.sql( """ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 24b1e34053..f723c9f80c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -21,11 +21,10 @@ import java.net.URI import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} -import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -71,8 +70,8 @@ class HiveSerDeSuite extends HiveComparisonTest with PlanTest with BeforeAndAfte } private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { - TestHive.sessionState.sqlParser.parsePlan(sql).collect { - case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) + TestHive.sessionState.analyzer.execute(TestHive.sessionState.sqlParser.parsePlan(sql)).collect { + case CreateTableCommand(tableDesc, ifNotExists) => (tableDesc, ifNotExists) }.head } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 712f81d987..79b3c3efe5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -712,8 +712,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi |AS SELECT key, value FROM mytable1 """.stripMargin) }.getMessage - assert(e.contains("Create Partitioned Table As Select cannot specify data type for " + - "the partition columns of the target table")) + assert(e.contains("Partition column types may not be specified in Create Table As Select")) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 2e5000159b..d1f9dfb791 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -293,7 +293,7 @@ private[streaming] object FileBasedWriteAheadLog { val startTime = startTimeStr.toLong val stopTime = stopTimeStr.toLong Some(LogInfo(startTime, stopTime, file.toString)) - case None => + case None | Some(_) => None } }.sortBy { _.startTime }