From 4292432f3de61c69008b5df4aa4f23aaeac1fd23 Mon Sep 17 00:00:00 2001 From: Liangcai Li Date: Fri, 11 Sep 2020 11:14:02 +0800 Subject: [PATCH 1/2] Support running Pandas UDFs on GPUs in Python processes. (#640) Add support to run Pandas UDFs on GPUs, mainly consisting of two things: Overriding all the 6 related plans to build GPU context of device and memory for Python processes. Introducing 2 new python modules rapids.worker and rapids.daemon to execute the GPU memory initialization by leveraging RMM Python APIs. Signed-off-by: Firestarman Co-authored-by: Liangcai Li Co-authored-by: Robert (Bobby) Evans Co-authored-by: shotai --- docs/configs.md | 12 + docs/get-started/getting-started-on-prem.md | 65 +++ integration_tests/README.md | 10 + integration_tests/pytest.ini | 1 + integration_tests/src/main/python/marks.py | 1 + .../src/main/python/udf_cudf_test.py | 285 ++++++++++ integration_tests/src/main/python/udf_test.py | 151 ++++++ jenkins/Dockerfile.integration.centos7 | 21 +- jenkins/Dockerfile.ubuntu16 | 2 +- jenkins/spark-premerge-build.sh | 6 +- jenkins/spark-tests.sh | 9 +- pom.xml | 2 +- python/rapids/__init__.py | 16 + python/rapids/daemon.py | 165 ++++++ python/rapids/worker.py | 81 +++ sql-plugin/pom.xml | 4 + .../nvidia/spark/rapids/GpuOverrides.scala | 65 ++- .../com/nvidia/spark/rapids/Plugin.scala | 5 +- .../com/nvidia/spark/rapids/RapidsConf.scala | 32 ++ .../com/nvidia/spark/rapids/RapidsMeta.scala | 11 +- .../rapids/python/PythonConfEntries.scala | 78 +++ .../rapids/python/PythonWorkerSemaphore.scala | 144 +++++ .../python/rapids/GpuPandasUtils.scala | 51 ++ .../python/GpuAggregateInPandasExec.scala | 198 +++++++ .../python/GpuArrowEvalPythonExec.scala | 495 ++++++++++++++++++ .../GpuFlatMapCoGroupsInPandasExec.scala | 139 +++++ .../python/GpuFlatMapGroupsInPandasExec.scala | 131 +++++ .../execution/python/GpuMapInPandasExec.scala | 136 +++++ .../execution/python/GpuPythonHelper.scala | 139 +++++ .../python/GpuWindowInPandasExec.scala | 407 ++++++++++++++ .../rapids/execution/python/RowUtils.scala | 295 +++++++++++ 31 files changed, 3143 insertions(+), 14 deletions(-) create mode 100644 integration_tests/src/main/python/udf_cudf_test.py create mode 100644 integration_tests/src/main/python/udf_test.py create mode 100644 python/rapids/__init__.py create mode 100644 python/rapids/daemon.py create mode 100644 python/rapids/worker.py create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/PythonConfEntries.scala create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/PythonWorkerSemaphore.scala create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/execution/python/rapids/GpuPandasUtils.scala create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapCoGroupsInPandasExec.scala create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInPandasExec.scala create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonHelper.scala create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExec.scala create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/RowUtils.scala diff --git a/docs/configs.md b/docs/configs.md index c8c5ec517b0..190533dc91a 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -36,6 +36,10 @@ Name | Description | Default Value spark.rapids.memory.gpu.reserve|The amount of GPU memory that should remain unallocated by RMM and left for system use such as memory needed for kernels, kernel launches or JIT compilation.|1073741824 spark.rapids.memory.host.spillStorageSize|Amount of off-heap host memory to use for buffering spilled GPU data before spilling to local disk|1073741824 spark.rapids.memory.pinnedPool.size|The size of the pinned memory pool in bytes unless otherwise specified. Use 0 to disable the pool.|0 +spark.rapids.python.concurrentPythonWorkers|Set the number of Python worker processes that can execute concurrently per GPU. Python worker processes may temporarily block when the number of concurrent Python worker processes started by the same executor exceeds this amount. Allowing too many concurrent tasks on the same GPU may lead to GPU out of memory errors. >0 means enabled, while <=0 means unlimited|0 +spark.rapids.python.memory.gpu.allocFraction|The fraction of total GPU memory that should be initially allocated for pooled memory for all the Python workers. It supposes to be less than (1 - $(spark.rapids.memory.gpu.allocFraction)), since the executor will share the GPU with its owning Python workers. Half of the rest will be used if not specified|None +spark.rapids.python.memory.gpu.maxAllocFraction|The fraction of total GPU memory that limits the maximum size of the RMM pool for all the Python workers. It supposes to be less than (1 - $(spark.rapids.memory.gpu.maxAllocFraction)), since the executor will share the GPU with its owning Python workers. when setting to 0 it means no limit.|0.0 +spark.rapids.python.memory.gpu.pooling.enabled|Should RMM in Python workers act as a pooling allocator for GPU memory, or should it just pass through to CUDA memory allocation directly. When not specified, It will honor the value of config 'spark.rapids.memory.gpu.pooling.enabled'|None spark.rapids.shuffle.transport.enabled|When set to true, enable the Rapids Shuffle Transport for accelerated shuffle.|false spark.rapids.shuffle.transport.maxReceiveInflightBytes|Maximum aggregate amount of bytes that be fetched at any given time from peers during shuffle|1073741824 spark.rapids.shuffle.ucx.managementServerHost|The host to be used to start the management server|null @@ -65,6 +69,7 @@ Name | Description | Default Value spark.rapids.sql.improvedFloatOps.enabled|For some floating point operations spark uses one way to compute the value and the underlying cudf implementation can use an improved algorithm. In some cases this can result in cudf producing an answer when spark overflows. Because this is not as compatible with spark, we have it disabled by default.|false spark.rapids.sql.improvedTimeOps.enabled|When set to true, some operators will avoid overflowing by converting epoch days directly to seconds without first converting to microseconds|false spark.rapids.sql.incompatibleOps.enabled|For operations that work, but are not 100% compatible with the Spark equivalent set if they should be enabled by default or disabled by default.|false +spark.rapids.sql.python.gpu.enabled|This is an experimental feature and is likely to change in the future. Enable (true) or disable (false) support for scheduling Python Pandas UDFs with GPU resources. When enabled, pandas UDFs are assumed to share the same GPU that the RAPIDs accelerator uses and will honor the python GPU configs|false spark.rapids.sql.reader.batchSizeBytes|Soft limit on the maximum number of bytes the reader reads per batch. The readers will read chunks of data until this limit is met or exceeded. Note that the reader may estimate the number of bytes that will be used on the GPU in some cases based on the schema and number of rows in each batch.|2147483647 spark.rapids.sql.reader.batchSizeRows|Soft limit on the maximum number of rows the reader will read per batch. The orc and parquet readers will read row groups until this limit is met or exceeded. The limit is respected by the csv reader.|2147483647 spark.rapids.sql.replaceSortMergeJoin.enabled|Allow replacing sortMergeJoin with HashJoin|true @@ -169,6 +174,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.Or|`or`|Logical OR|true|None| spark.rapids.sql.expression.Pmod|`pmod`|Pmod|true|None| spark.rapids.sql.expression.Pow|`pow`, `power`|lhs ^ rhs|true|None| +spark.rapids.sql.expression.PythonUDF| |UDF run in an external python process. Does not actually run on the GPU, but the transfer of data to/from it can be accelerated.|true|None| spark.rapids.sql.expression.Quarter|`quarter`|Returns the quarter of the year for date, in the range 1 to 4|true|None| spark.rapids.sql.expression.Rand|`random`, `rand`|Generate a random column with i.i.d. uniformly distributed values in [0, 1)|true|None| spark.rapids.sql.expression.RegExpReplace|`regexp_replace`|RegExpReplace support for string literal input patterns|true|None| @@ -253,6 +259,12 @@ Name | Description | Default Value | Notes spark.rapids.sql.exec.CartesianProductExec|Implementation of join using brute force|false|This is disabled by default because large joins can cause out of memory errors| spark.rapids.sql.exec.ShuffledHashJoinExec|Implementation of join using hashed shuffled data|true|None| spark.rapids.sql.exec.SortMergeJoinExec|Sort merge join, replacing with shuffled hash join|true|None| +spark.rapids.sql.exec.AggregateInPandasExec|The backend for Grouped Aggregation Pandas UDF, it runs on CPU itself now but supports running the Python UDFs code on GPU when calling cuDF APIs in the UDF|false|This is disabled by default because Performance is not ideal now| +spark.rapids.sql.exec.ArrowEvalPythonExec|The backend of the Scalar Pandas UDFs, it supports running the Python UDFs code on GPU when calling cuDF APIs in the UDF, also accelerates the data transfer between the Java process and Python process|false|This is disabled by default because Performance is not ideal for UDFs that take a long time| +spark.rapids.sql.exec.FlatMapCoGroupsInPandasExec|The backend for CoGrouped Aggregation Pandas UDF, it runs on CPU itself now but supports running the Python UDFs code on GPU when calling cuDF APIs in the UDF|false|This is disabled by default because Performance is not ideal now| +spark.rapids.sql.exec.FlatMapGroupsInPandasExec|The backend for Grouped Map Pandas UDF, it runs on CPU itself now but supports running the Python UDFs code on GPU when calling cuDF APIs in the UDF|false|This is disabled by default because Performance is not ideal now| +spark.rapids.sql.exec.MapInPandasExec|The backend for Map Pandas Iterator UDF, it runs on CPU itself now but supports running the Python UDFs code on GPU when calling cuDF APIs in the UDF|false|This is disabled by default because Performance is not ideal now| +spark.rapids.sql.exec.WindowInPandasExec|The backend for Pandas UDF with window functions, it runs on CPU itself now but supports running the Python UDFs code on GPU when calling cuDF APIs in the UDF|false|This is disabled by default because Performance is not ideal now| spark.rapids.sql.exec.WindowExec|Window-operator backend|true|None| ### Scans diff --git a/docs/get-started/getting-started-on-prem.md b/docs/get-started/getting-started-on-prem.md index bc5f165362b..489da716246 100644 --- a/docs/get-started/getting-started-on-prem.md +++ b/docs/get-started/getting-started-on-prem.md @@ -475,6 +475,71 @@ This setting controls the amount of host memory (RAM) that can be utilized to sp the GPU is out of memory, before going to disk. Please verify the [defaults](../configs.md). - `spark.rapids.memory.host.spillStorageSize` +## GPU Scheduling For Pandas UDF +--- +**NOTE** + +The _GPU Scheduling for Pandas UDF_ is an experimental feature, and may change at any point it time. + +--- + +_GPU Scheduling for Pandas UDF_ is built on Apache Spark's [Pandas UDF(user defined function)](https://spark.apache.org/docs/3.0.0/sql-pyspark-pandas-with-arrow.html#pandas-udfs-aka-vectorized-udfs), and has two components: + +- **Share GPU with JVM**: Let the Python process share JVM GPU. The Python process could run on the same GPU with JVM. + +- **Increase Speed**: Make the data transport faster between JVM process and Python process. + + + +To enable _GPU Scheduling for Pandas UDF_, you need to configure your spark job with extra settings. + +1. Make sure GPU exclusive mode is disabled. Note that this will not work if you are using exclusive mode to assign GPUs under spark. +2. Currently the python files are packed into the spark rapids plugin jar. + + On Yarn, you need to add + ```shell + ... + --py-files ${SPARK_RAPIDS_PLUGIN_JAR} + ``` + + + On Standalone, you need to add + ```shell + ... + --conf spark.executorEnv.PYTHONPATH=rapids-4-spark_2.12-0.2.0-SNAPSHOT.jar \ + --py-files ${SPARK_RAPIDS_PLUGIN_JAR} + ``` + +3. Enable GPU Scheduling for Pandas UDF. + + ```shell + ... + --conf spark.rapids.python.gpu.enabled=true \ + --conf spark.rapids.python.memory.gpu.pooling.enabled=false \ + --conf spark.rapids.sql.exec.ArrowEvalPythonExec=true \ + --conf spark.rapids.sql.exec.MapInPandasExec=true \ + --conf spark.rapids.sql.exec.FlatMapGroupsInPandasExec=true \ + --conf spark.rapids.sql.exec.AggregateInPandasExec=true \ + --conf spark.rapids.sql.exec.FlatMapCoGroupsInPandasExec=true \ + --conf spark.rapids.sql.exec.WindowInPandasExec=true + ``` + +Please note the data transfer acceleration only supports scalar UDF and Scalar iterator UDF currently. +You could choose the exec you need to enable. + +### Other Configuration + +Following configuration settings are also for _GPU Scheduling for Pandas UDF_ +``` +spark.rapids.python.concurrentPythonWorkers +spark.rapids.python.memory.gpu.allocFraction +spark.rapids.python.memory.gpu.maxAllocFraction +``` + +To find details on the above Python configuration settings, please see the [RAPIDS Accelerator for Apache Spark Configuration Guide](../configs.md). + + + ## Advanced Configuration See the [RAPIDS Accelerator for Apache Spark Configuration Guide](../configs.md) for details on all diff --git a/integration_tests/README.md b/integration_tests/README.md index 0e9bbda14f5..b2e01f2ce59 100644 --- a/integration_tests/README.md +++ b/integration_tests/README.md @@ -24,6 +24,16 @@ Should be enough to get the basics started. `sre_yield` provides a set of APIs to generate string data from a regular expression. +### pandas +`pip install pandas` + +`pandas` is a fast, powerful, flexible and easy to use open source data analysis and manipulation tool. + +### pyarrow +`pip install pyarrow` + +`pyarrow` provides a Python API for functionality provided by the Arrow C++ libraries, along with tools for Arrow integration and interoperability with pandas, NumPy, and other software in the Python ecosystem. + ## Running Running the tests follows the pytest conventions, the main difference is using diff --git a/integration_tests/pytest.ini b/integration_tests/pytest.ini index 19d8d87e2ea..4549c2e1281 100644 --- a/integration_tests/pytest.ini +++ b/integration_tests/pytest.ini @@ -20,4 +20,5 @@ markers = incompat: Enable incompat operators limit(num_rows): Limit the number of rows that will be check in a result qarun: Mark qa test + cudf_udf: Mark udf cudf test diff --git a/integration_tests/src/main/python/marks.py b/integration_tests/src/main/python/marks.py index 8f5b4fca242..e10d6be44bc 100644 --- a/integration_tests/src/main/python/marks.py +++ b/integration_tests/src/main/python/marks.py @@ -20,3 +20,4 @@ incompat = pytest.mark.incompat limit = pytest.mark.limit qarun = pytest.mark.qarun +cudf_udf = pytest.mark.cudf_udf diff --git a/integration_tests/src/main/python/udf_cudf_test.py b/integration_tests/src/main/python/udf_cudf_test.py new file mode 100644 index 00000000000..e06984a4993 --- /dev/null +++ b/integration_tests/src/main/python/udf_cudf_test.py @@ -0,0 +1,285 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pandas as pd +import time +from typing import Iterator +from pyspark.sql import Window +from pyspark.sql.functions import pandas_udf, PandasUDFType +from spark_session import with_cpu_session, with_gpu_session +from marks import allow_non_gpu, cudf_udf + + +_conf = { + 'spark.rapids.sql.exec.ArrowEvalPythonExec':'true', + 'spark.rapids.sql.exec.MapInPandasExec':'true', + 'spark.rapids.sql.exec.FlatMapGroupsInPandasExec': 'true', + 'spark.rapids.sql.exec.AggregateInPandasExec': 'true', + 'spark.rapids.sql.exec.FlatMapCoGroupsInPandasExec': 'true', + 'spark.rapids.sql.exec.WindowInPandasExec': 'true', + 'spark.rapids.sql.python.gpu.enabled': 'true' + } + +def _create_df(spark): + return spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ("id", "v") + ) + +# since this test requires to run different functions on CPU and GPU(need cudf), +# create its own assert function +def _assert_cpu_gpu(cpu_func, gpu_func, cpu_conf={}, gpu_conf={}, is_sort=False): + print('### CPU RUN ###') + cpu_start = time.time() + cpu_ret = with_cpu_session(cpu_func, conf=cpu_conf) + cpu_end = time.time() + print('### GPU RUN ###') + gpu_start = time.time() + gpu_ret = with_gpu_session(gpu_func, conf=gpu_conf) + gpu_end = time.time() + print('### WRITE: GPU TOOK {} CPU TOOK {} ###'.format( + gpu_end - gpu_start, cpu_end - cpu_start)) + if is_sort: + assert cpu_ret.sort() == gpu_ret.sort() + else: + assert cpu_ret == gpu_ret + + +@pandas_udf('int') +def _plus_one_cpu_func(v: pd.Series) -> pd.Series: + return v + 1 + +@pandas_udf('int') +def _plus_one_gpu_func(v: pd.Series) -> pd.Series: + import cudf + gpu_serises = cudf.Series(v) + gpu_serises = gpu_serises + 1 + return gpu_serises.to_pandas() + +@allow_non_gpu(any=True) +@pytest.mark.skip("exception in docker: OSError: Invalid IPC stream: negative continuation token, skip for now") +@cudf_udf +def test_with_column(): + def cpu_run(spark): + df = _create_df(spark) + return df.withColumn("v1", _plus_one_cpu_func(df.v)).collect() + + def gpu_run(spark): + df = _create_df(spark) + return df.withColumn("v1", _plus_one_gpu_func(df.v)).collect() + + _assert_cpu_gpu(cpu_run, gpu_run, gpu_conf=_conf) + +@allow_non_gpu(any=True) +@pytest.mark.skip("exception in docker: OSError: Invalid IPC stream: negative continuation token, skip for now") +@cudf_udf +def test_sql(): + def cpu_run(spark): + _ = spark.udf.register("add_one_cpu", _plus_one_cpu_func) + return spark.sql("SELECT add_one_cpu(id) FROM range(3)").collect() + def gpu_run(spark): + _ = spark.udf.register("add_one_gpu", _plus_one_gpu_func) + return spark.sql("SELECT add_one_gpu(id) FROM range(3)").collect() + + _assert_cpu_gpu(cpu_run, gpu_run, gpu_conf=_conf) + + +@pandas_udf("long") +def _plus_one_cpu_iter_func(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + for s in iterator: + yield s + 1 + +@pandas_udf("long") +def _plus_one_gpu_iter_func(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + import cudf + for s in iterator: + gpu_serises = cudf.Series(s) + gpu_serises = gpu_serises + 1 + yield gpu_serises.to_pandas() + +@allow_non_gpu(any=True) +@pytest.mark.skip("exception in docker: OSError: Invalid IPC stream: negative continuation token, skip for now") +@cudf_udf +def test_select(): + def cpu_run(spark): + df = _create_df(spark) + return df.select(_plus_one_cpu_iter_func(df.v)).collect() + + def gpu_run(spark): + df = _create_df(spark) + return df.select(_plus_one_gpu_iter_func(df.v)).collect() + + _assert_cpu_gpu(cpu_run, gpu_run, gpu_conf=_conf) + + +@allow_non_gpu('GpuMapInPandasExec','PythonUDF') +@cudf_udf +def test_map_in_pandas(): + def cpu_run(spark): + df = _create_df(spark) + def _filter_cpu_func(iterator): + for pdf in iterator: + yield pdf[pdf.id == 1] + return df.mapInPandas(_filter_cpu_func, df.schema).collect() + + def gpu_run(spark): + df = _create_df(spark) + def _filter_gpu_func(iterator): + import cudf + for pdf in iterator: + gdf = cudf.from_pandas(pdf) + yield gdf[gdf.id == 1].to_pandas() + return df.mapInPandas(_filter_gpu_func, df.schema).collect() + + _assert_cpu_gpu(cpu_run, gpu_run, gpu_conf=_conf) + + +# To solve: Invalid udf: the udf argument must be a pandas_udf of type GROUPED_MAP +# need to add udf type +@pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) +def _normalize_cpu_func(df): + v = df.v + return df.assign(v=(v - v.mean()) / v.std()) + +@pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) +def _normalize_gpu_func(df): + import cudf + gdf = cudf.from_pandas(df) + v = gdf.v + return gdf.assign(v=(v - v.mean()) / v.std()).to_pandas() + +@allow_non_gpu('GpuFlatMapGroupsInPandasExec','PythonUDF') +@cudf_udf +def test_group_apply(): + def cpu_run(spark): + df = _create_df(spark) + return df.groupby("id").apply(_normalize_cpu_func).collect() + + def gpu_run(spark): + df = _create_df(spark) + return df.groupby("id").apply(_normalize_gpu_func).collect() + + _assert_cpu_gpu(cpu_run, gpu_run, gpu_conf=_conf, is_sort=True) + + +@allow_non_gpu('GpuFlatMapGroupsInPandasExec','PythonUDF') +@cudf_udf +def test_group_apply_in_pandas(): + def cpu_run(spark): + df = _create_df(spark) + def _normalize_cpu_in_pandas_func(df): + v = df.v + return df.assign(v=(v - v.mean()) / v.std()) + return df.groupby("id").applyInPandas(_normalize_cpu_in_pandas_func, df.schema).collect() + + def gpu_run(spark): + df = _create_df(spark) + def _normalize_gpu_in_pandas_func(df): + import cudf + gdf = cudf.from_pandas(df) + v = gdf.v + return gdf.assign(v=(v - v.mean()) / v.std()).to_pandas() + return df.groupby("id").applyInPandas(_normalize_gpu_in_pandas_func, df.schema).collect() + + _assert_cpu_gpu(cpu_run, gpu_run, gpu_conf=_conf, is_sort=True) + + +@pandas_udf("int") +def _sum_cpu_func(v: pd.Series) -> int: + return v.sum() + +@pandas_udf("integer") +def _sum_gpu_func(v: pd.Series) -> int: + import cudf + gpu_serises = cudf.Series(v) + return gpu_serises.sum() + +@allow_non_gpu('GpuAggregateInPandasExec','PythonUDF','Alias') +@cudf_udf +def test_group_agg(): + def cpu_run(spark): + df = _create_df(spark) + return df.groupby("id").agg(_sum_cpu_func(df.v)).collect() + + def gpu_run(spark): + df = _create_df(spark) + return df.groupby("id").agg(_sum_gpu_func(df.v)).collect() + + _assert_cpu_gpu(cpu_run, gpu_run, gpu_conf=_conf, is_sort=True) + + +@allow_non_gpu('GpuAggregateInPandasExec','PythonUDF','Alias') +@cudf_udf +def test_sql_group(): + def cpu_run(spark): + _ = spark.udf.register("sum_cpu_udf", _sum_cpu_func) + q = "SELECT sum_cpu_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" + return spark.sql(q).collect() + + def gpu_run(spark): + _ = spark.udf.register("sum_gpu_udf", _sum_gpu_func) + q = "SELECT sum_gpu_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" + return spark.sql(q).collect() + + _assert_cpu_gpu(cpu_run, gpu_run, gpu_conf=_conf, is_sort=True) + + +@allow_non_gpu('GpuWindowInPandasExec','PythonUDF','Alias','WindowExpression','WindowSpecDefinition','SpecifiedWindowFrame','UnboundedPreceding$', 'UnboundedFollowing$') +@cudf_udf +def test_window(): + def cpu_run(spark): + df = _create_df(spark) + w = Window.partitionBy('id').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + return df.withColumn('sum_v', _sum_cpu_func('v').over(w)).collect() + + def gpu_run(spark): + df = _create_df(spark) + w = Window.partitionBy('id').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + return df.withColumn('sum_v', _sum_gpu_func('v').over(w)).collect() + + _assert_cpu_gpu(cpu_run, gpu_run, gpu_conf=_conf, is_sort=True) + + +@allow_non_gpu('GpuFlatMapCoGroupsInPandasExec','PythonUDF') +@cudf_udf +def test_cogroup(): + def cpu_run(spark): + df1 = spark.createDataFrame( + [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], + ("time", "id", "v1")) + df2 = spark.createDataFrame( + [(20000101, 1, "x"), (20000101, 2, "y")], + ("time", "id", "v2")) + def _cpu_join_func(l, r): + return pd.merge(l, r, on="time") + return df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(_cpu_join_func, schema="time int, id_x int, id_y int, v1 double, v2 string").collect() + + def gpu_run(spark): + df1 = spark.createDataFrame( + [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], + ("time", "id", "v1")) + df2 = spark.createDataFrame( + [(20000101, 1, "x"), (20000101, 2, "y")], + ("time", "id", "v2")) + def _gpu_join_func(l, r): + import cudf + gl = cudf.from_pandas(l) + gr = cudf.from_pandas(r) + return gl.merge(gr, on="time").to_pandas() + return df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(_gpu_join_func, schema="time int, id_x int, id_y int, v1 double, v2 string").collect() + + _assert_cpu_gpu(cpu_run, gpu_run, gpu_conf=_conf, is_sort=True) + + diff --git a/integration_tests/src/main/python/udf_test.py b/integration_tests/src/main/python/udf_test.py new file mode 100644 index 00000000000..276fe15d8d6 --- /dev/null +++ b/integration_tests/src/main/python/udf_test.py @@ -0,0 +1,151 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from asserts import assert_gpu_and_cpu_are_equal_collect +from data_gen import * +from marks import incompat, approximate_float, allow_non_gpu, ignore_order +from pyspark.sql import Window +from pyspark.sql.types import * +import pyspark.sql.functions as f +from pyspark.sql.pandas.utils import require_minimum_pyarrow_version +import pandas as pd +from typing import Iterator, Tuple + +try: + require_minimum_pyarrow_version() +except Exception as e: + pytestmark = pytest.mark.skip(reason=str(e)) + +arrow_udf_conf = {'spark.sql.execution.arrow.pyspark.enabled': 'true', + 'spark.rapids.sql.exec.ArrowEvalPythonExec': 'true'} + +#################################################################### +# NOTE: pytest does not play well with pyspark udfs, because pyspark +# tries to import the dependencies for top level functions and +# pytest messes around with imports. To make this work, all UDFs +# must either be lambdas or totally defined within the test method +# itself. +#################################################################### + +@pytest.mark.parametrize('data_gen', integral_gens, ids=idfn) +def test_pandas_math_udf(data_gen): + def add(a, b): + return a + b + my_udf = f.pandas_udf(add, returnType=LongType()) + assert_gpu_and_cpu_are_equal_collect( + lambda spark : binary_op_df(spark, data_gen).select( + my_udf(f.col('a') - 3, f.col('b'))), + conf=arrow_udf_conf) + +@pytest.mark.parametrize('data_gen', integral_gens, ids=idfn) +def test_iterator_math_udf(data_gen): + def iterator_add(to_process: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]: + for a, b in to_process: + yield a + b + + my_udf = f.pandas_udf(iterator_add, returnType=LongType()) + assert_gpu_and_cpu_are_equal_collect( + lambda spark : binary_op_df(spark, data_gen).select( + my_udf(f.col('a'), f.col('b'))), + conf=arrow_udf_conf) + +@allow_non_gpu('AggregateInPandasExec', 'PythonUDF', 'Alias') +@pytest.mark.parametrize('data_gen', integral_gens, ids=idfn) +def test_single_aggregate_udf(data_gen): + @f.pandas_udf('double') + def pandas_sum(to_process: pd.Series) -> float: + return to_process.sum() + + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).select( + pandas_sum(f.col('a'))), + conf=arrow_udf_conf) + +@ignore_order +@allow_non_gpu('AggregateInPandasExec', 'PythonUDF', 'Alias') +@pytest.mark.parametrize('data_gen', integral_gens, ids=idfn) +def test_group_aggregate_udf(data_gen): + @f.pandas_udf('long') + def pandas_sum(to_process: pd.Series) -> int: + return to_process.sum() + + assert_gpu_and_cpu_are_equal_collect( + lambda spark : binary_op_df(spark, data_gen)\ + .groupBy('a')\ + .agg(pandas_sum(f.col('b'))), + conf=arrow_udf_conf) + +@ignore_order +@allow_non_gpu('WindowInPandasExec', 'PythonUDF', 'WindowExpression', 'Alias', 'WindowSpecDefinition', 'SpecifiedWindowFrame', 'UnboundedPreceding$', 'UnboundedFollowing$') +@pytest.mark.parametrize('data_gen', integral_gens, ids=idfn) +def test_window_aggregate_udf(data_gen): + @f.pandas_udf('long') + def pandas_sum(to_process: pd.Series) -> int: + return to_process.sum() + + w = Window\ + .partitionBy('a') \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + assert_gpu_and_cpu_are_equal_collect( + lambda spark : binary_op_df(spark, data_gen).select( + pandas_sum(f.col('b')).over(w)), + conf=arrow_udf_conf) + +@ignore_order +@allow_non_gpu('FlatMapGroupsInPandasExec', 'PythonUDF', 'Alias') +@pytest.mark.parametrize('data_gen', [LongGen()], ids=idfn) +def test_group_apply_udf(data_gen): + def pandas_add(data): + data.sum = data.b + data.a + return data + + assert_gpu_and_cpu_are_equal_collect( + lambda spark : binary_op_df(spark, data_gen)\ + .groupBy('a')\ + .applyInPandas(pandas_add, schema="a long, b long"), + conf=arrow_udf_conf) + + +@allow_non_gpu('MapInPandasExec', 'PythonUDF', 'Alias') +@pytest.mark.parametrize('data_gen', [LongGen()], ids=idfn) +def test_map_apply_udf(data_gen): + def pandas_filter(iterator): + for data in iterator: + yield data[data.b <= data.a] + + assert_gpu_and_cpu_are_equal_collect( + lambda spark : binary_op_df(spark, data_gen)\ + .mapInPandas(pandas_filter, schema="a long, b long"), + conf=arrow_udf_conf) + +def create_df(spark, data_gen, left_length, right_length): + left = binary_op_df(spark, data_gen, length=left_length) + right = binary_op_df(spark, data_gen, length=right_length) + return left, right + +@ignore_order +@allow_non_gpu('FlatMapCoGroupsInPandasExec', 'PythonUDF', 'Alias') +@pytest.mark.parametrize('data_gen', [ShortGen(nullable=False)], ids=idfn) +def test_cogroup_apply_udf(data_gen): + def asof_join(l, r): + return pd.merge_asof(l, r, on='a', by='b') + + def do_it(spark): + left, right = create_df(spark, data_gen, 500, 500) + return left.groupby('a').cogroup( + right.groupby('a')).applyInPandas( + asof_join, schema="a int, b int") + assert_gpu_and_cpu_are_equal_collect(do_it, conf=arrow_udf_conf) diff --git a/jenkins/Dockerfile.integration.centos7 b/jenkins/Dockerfile.integration.centos7 index 4b6cce36a5e..5ec4b652314 100644 --- a/jenkins/Dockerfile.integration.centos7 +++ b/jenkins/Dockerfile.integration.centos7 @@ -39,7 +39,26 @@ RUN wget ${URM_URL}/org/apache/maven/apache-maven/3.6.3/apache-maven-3.6.3-bin.t rm -f $MAVEN_HOME-bin.tar.gz ENV PATH "$MAVEN_HOME/bin:/opt/rh/rh-python36/root/usr/bin/:$PATH" -RUN python -m pip install pytest sre_yield + +RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \ + /bin/bash ~/miniconda.sh -b -p /opt/conda + +RUN . /opt/conda/bin/activate && \ + conda init + +ENV PATH="/opt/conda/bin:${PATH}" + +RUN conda --version + +ARG CUDA_TOOLKIT_VER=10.1 +RUN conda install -c rapidsai-nightly -c nvidia -c conda-forge \ + -c defaults cudf=0.15 python=3.7 cudatoolkit=${CUDA_TOOLKIT_VER} + +RUN conda install spacy && \ + python -m spacy download en_core_web_sm + +RUN conda install -c anaconda pytest requests pandas pyarrow +RUN conda install -c conda-forge sre_yield # Set ENV for mvn ENV JAVA_HOME "/usr/lib/jvm/java-1.8.0-openjdk" diff --git a/jenkins/Dockerfile.ubuntu16 b/jenkins/Dockerfile.ubuntu16 index 2bfdc6067e3..8c093f94a51 100644 --- a/jenkins/Dockerfile.ubuntu16 +++ b/jenkins/Dockerfile.ubuntu16 @@ -35,7 +35,7 @@ RUN add-apt-repository ppa:deadsnakes/ppa && \ openjdk-8-jdk python3.6 python3-pip tzdata git RUN ln -s /usr/bin/python3.6 /usr/bin/python -RUN python -m pip install pytest sre_yield requests +RUN python -m pip install pytest sre_yield requests pandas pyarrow RUN adduser --uid 26576 --gid 30 --shell /bin/bash svcngcc USER svcngcc diff --git a/jenkins/spark-premerge-build.sh b/jenkins/spark-premerge-build.sh index f26b82cb9a6..34578e8df9f 100755 --- a/jenkins/spark-premerge-build.sh +++ b/jenkins/spark-premerge-build.sh @@ -37,10 +37,10 @@ export PATH="$SPARK_HOME/bin:$SPARK_HOME/sbin:$PATH" tar zxf $SPARK_HOME.tgz -C $ARTF_ROOT && \ rm -f $SPARK_HOME.tgz -mvn -U -B $MVN_URM_MIRROR '-Pinclude-databricks,!snapshot-shims' clean verify -Dpytest.TEST_TAGS='' +mvn -U -B $MVN_URM_MIRROR '-Pinclude-databricks,!snapshot-shims' clean verify -Dpytest.TEST_TAGS='not cudf_udf' # Run the unit tests for other Spark versions but dont run full python integration tests -env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark301tests,snapshot-shims test -Dpytest.TEST_TAGS='' -env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark310tests,snapshot-shims test -Dpytest.TEST_TAGS='' +env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark301tests,snapshot-shims test -Dpytest.TEST_TAGS='not cudf_udf' +env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark310tests,snapshot-shims test -Dpytest.TEST_TAGS='not cudf_udf' # The jacoco coverage should have been collected, but because of how the shade plugin # works and jacoco we need to clean some things up so jacoco will only report for the diff --git a/jenkins/spark-tests.sh b/jenkins/spark-tests.sh index 4b090432cd9..7c02d0cdcfd 100755 --- a/jenkins/spark-tests.sh +++ b/jenkins/spark-tests.sh @@ -70,6 +70,12 @@ MORTGAGE_SPARK_SUBMIT_ARGS=" --conf spark.plugins=com.nvidia.spark.SQLPlugin \ --class com.nvidia.spark.rapids.tests.mortgage.Main \ $RAPIDS_TEST_JAR" +# need to disable pooling for udf test to prevent cudaErrorMemoryAllocation +CUDF_UDF_TEST_ARGS="--conf spark.rapids.python.memory.gpu.pooling.enabled=false \ + --conf spark.rapids.memory.gpu.pooling.enabled=false \ + --conf spark.executorEnv.PYTHONPATH=rapids-4-spark_2.12-0.2.0-SNAPSHOT.jar \ + --py-files ${RAPIDS_PLUGIN_JAR}" + TEST_PARAMS="$SPARK_VER $PARQUET_PERF $PARQUET_ACQ $OUTPUT" export PATH="$SPARK_HOME/bin:$SPARK_HOME/sbin:$PATH" @@ -84,4 +90,5 @@ jps echo "----------------------------START TEST------------------------------------" rm -rf $OUTPUT spark-submit $BASE_SPARK_SUBMIT_ARGS $MORTGAGE_SPARK_SUBMIT_ARGS $TEST_PARAMS -cd $RAPIDS_INT_TESTS_HOME && spark-submit $BASE_SPARK_SUBMIT_ARGS --jars $RAPIDS_TEST_JAR ./runtests.py -v -rfExXs --std_input_path="$WORKSPACE/integration_tests/src/test/resources/" +cd $RAPIDS_INT_TESTS_HOME && spark-submit $BASE_SPARK_SUBMIT_ARGS --jars $RAPIDS_TEST_JAR ./runtests.py -m "not cudf_udf" -v -rfExXs --std_input_path="$WORKSPACE/integration_tests/src/test/resources/" +spark-submit $BASE_SPARK_SUBMIT_ARGS $CUDF_UDF_TEST_ARGS --jars $RAPIDS_TEST_JAR ./runtests.py -m "cudf_udf" -v -rfExXs diff --git a/pom.xml b/pom.xml index 3035229b29f..0737d5689c9 100644 --- a/pom.xml +++ b/pom.xml @@ -164,7 +164,7 @@ false UTF-8 UTF-8 - not qarun + not qarun and not cudf_udf false 1.7.30 3.0.0 diff --git a/python/rapids/__init__.py b/python/rapids/__init__.py new file mode 100644 index 00000000000..7e778cc73fa --- /dev/null +++ b/python/rapids/__init__.py @@ -0,0 +1,16 @@ +## +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +## diff --git a/python/rapids/daemon.py b/python/rapids/daemon.py new file mode 100644 index 00000000000..31353bd92ab --- /dev/null +++ b/python/rapids/daemon.py @@ -0,0 +1,165 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import signal +import select +import socket +import sys +import traceback +import time +import gc +from errno import EINTR, EAGAIN +from socket import AF_INET, SOCK_STREAM, SOMAXCONN +from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN + +from pyspark.serializers import read_int, write_int +from pyspark.daemon import worker + +from rapids.worker import initialize_gpu_mem + + +def manager(): + # Create a new process group to corral our children + os.setpgid(0, 0) + + # Create a listening socket on the AF_INET loopback interface + listen_sock = socket.socket(AF_INET, SOCK_STREAM) + listen_sock.bind(('127.0.0.1', 0)) + listen_sock.listen(max(1024, SOMAXCONN)) + listen_host, listen_port = listen_sock.getsockname() + + # re-open stdin/stdout in 'wb' mode + stdin_bin = os.fdopen(sys.stdin.fileno(), 'rb', 4) + stdout_bin = os.fdopen(sys.stdout.fileno(), 'wb', 4) + write_int(listen_port, stdout_bin) + stdout_bin.flush() + + def shutdown(code): + signal.signal(SIGTERM, SIG_DFL) + # Send SIGHUP to notify workers of shutdown + os.kill(0, SIGHUP) + sys.exit(code) + + def handle_sigterm(*args): + shutdown(1) + signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM + signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP + signal.signal(SIGCHLD, SIG_IGN) + + reuse = os.environ.get("SPARK_REUSE_WORKER") + + # Initialization complete + try: + while True: + try: + ready_fds = select.select([0, listen_sock], [], [], 1)[0] + except select.error as ex: + if ex[0] == EINTR: + continue + else: + raise + + if 0 in ready_fds: + try: + worker_pid = read_int(stdin_bin) + except EOFError: + # Spark told us to exit by closing stdin + shutdown(0) + try: + os.kill(worker_pid, signal.SIGKILL) + except OSError: + pass # process already died + + if listen_sock in ready_fds: + try: + sock, _ = listen_sock.accept() + except OSError as e: + if e.errno == EINTR: + continue + raise + + # Launch a worker process + try: + pid = os.fork() + except OSError as e: + if e.errno in (EAGAIN, EINTR): + time.sleep(1) + pid = os.fork() # error here will shutdown daemon + else: + outfile = sock.makefile(mode='wb') + write_int(e.errno, outfile) # Signal that the fork failed + outfile.flush() + outfile.close() + sock.close() + continue + + if pid == 0: + # in child process + listen_sock.close() + + # It should close the standard input in the child process so that + # Python native function executions stay intact. + # + # Note that if we just close the standard input (file descriptor 0), + # the lowest file descriptor (file descriptor 0) will be allocated, + # later when other file descriptors should happen to open. + # + # Therefore, here we redirects it to '/dev/null' by duplicating + # another file descriptor for '/dev/null' to the standard input (0). + # See SPARK-26175. + devnull = open(os.devnull, 'r') + os.dup2(devnull.fileno(), 0) + devnull.close() + + try: + # GPU context setup + initialize_gpu_mem() + + # Acknowledge that the fork was successful + outfile = sock.makefile(mode="wb") + write_int(os.getpid(), outfile) + outfile.flush() + outfile.close() + authenticated = False + while True: + code = worker(sock, authenticated) + if code == 0: + authenticated = True + if not reuse or code: + # wait for closing + try: + while sock.recv(1024): + pass + except Exception: + pass + break + gc.collect() + except: + traceback.print_exc() + os._exit(1) + else: + os._exit(0) + else: + sock.close() + + finally: + shutdown(1) + + +if __name__ == '__main__': + manager() diff --git a/python/rapids/worker.py b/python/rapids/worker.py new file mode 100644 index 00000000000..a6aaac669f8 --- /dev/null +++ b/python/rapids/worker.py @@ -0,0 +1,81 @@ +## +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +## + +import os +from pyspark.worker import local_connect_and_auth, main as worker_main + + +def initialize_gpu_mem(): + # CUDA device(s) info + cuda_devices_str = os.environ.get('CUDA_VISIBLE_DEVICES') + python_gpu_disabled = os.environ.get('RAPIDS_PYTHON_ENABLED', 'false').lower() == 'false' + if python_gpu_disabled or not cuda_devices_str: + # Skip gpu initialization due to no CUDA device or python on gpu is disabled. + # One case to come here is the test runs with cpu session in integration tests. + return + + print("INFO: Process {} found CUDA visible device(s): {}".format( + os.getpid(), cuda_devices_str)) + # Initialize RMM only when requiring to enable pooled or managed memory. + pool_enabled = os.environ.get('RAPIDS_POOLED_MEM_ENABLED', 'false').lower() == 'true' + uvm_enabled = os.environ.get('RAPIDS_UVM_ENABLED', 'false').lower() == 'true' + if pool_enabled: + from cudf import rmm + ''' + RMM will be initialized with default configures (pool disabled) when importing cudf + as above. So overwrite the initialization when asking for pooled memory, + along with a pool size and max pool size. + Meanwhile, the above `import` precedes the `import` in UDF, make our initialization + not be overwritten again by the `import` in UDF, since Python will ignore duplicated + `import`. + ''' + import sys + max_size = sys.maxint if sys.version_info.major == 2 else sys.maxsize + pool_size = int(os.environ.get('RAPIDS_POOLED_MEM_SIZE', 0)) + pool_max_size = int(os.environ.get('RAPIDS_POOLED_MEM_MAX_SIZE', 0)) + if 0 < pool_max_size < pool_size: + raise ValueError("Value of `RAPIDS_POOLED_MEM_MAX_SIZE` should not be less than " + "`RAPIDS_POOLED_MEM_SIZE`.") + if pool_max_size == 0: + pool_max_size = max_size + print("DEBUG: Pooled memory, pool size: {} MiB, max size: {} MiB".format( + pool_size / 1024.0 / 1024, + ('unlimited' if pool_max_size == max_size else pool_max_size / 1024.0 / 1024))) + base_t = rmm.mr.ManagedMemoryResource if uvm_enabled else rmm.mr.CudaMemoryResource + rmm.mr.set_current_device_resource(rmm.mr.PoolMemoryResource(base_t(), pool_size, pool_max_size)) + elif uvm_enabled: + from cudf import rmm + rmm.mr.set_current_device_resource(rmm.mr.ManagedMemoryResource()) + else: + # Do nothing, whether to use RMM (default mode) or not depends on UDF definition. + pass + + +if __name__ == '__main__': + # GPU context setup + initialize_gpu_mem() + + # Code below is all copied from Pyspark/worker.py + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) + # Use the `sock_file` as both input and output will cause EOFException in JVM side, + # So open a new file object on the same socket as output, similar behavior + # with that in `pyspark/daemon.py`. + buffer_size = int(os.environ.get("SPARK_BUFFER_SIZE", 65536)) + outfile = os.fdopen(os.dup(sock.fileno()), "wb", buffer_size) + worker_main(sock_file, outfile) diff --git a/sql-plugin/pom.xml b/sql-plugin/pom.xml index f450d5483e6..3a65798b2f7 100644 --- a/sql-plugin/pom.xml +++ b/sql-plugin/pom.xml @@ -102,6 +102,10 @@ LICENSE + + + ${project.basedir}/../python/ + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 9b0ae460766..7f6cd72968d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -41,11 +41,13 @@ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.execution.datasources.v2.{AlterNamespaceSetPropertiesExec, AlterTableExec, AtomicReplaceTableExec, BatchScanExec, CreateNamespaceExec, CreateTableExec, DeleteFromTableExec, DescribeNamespaceExec, DescribeTableExec, DropNamespaceExec, DropTableExec, RefreshTableExec, RenameTableExec, ReplaceTableExec, SetCatalogAndNamespaceExec, ShowCurrentNamespaceExec, ShowNamespacesExec, ShowTablePropertiesExec, ShowTablesExec} import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec} +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.rapids._ import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand import org.apache.spark.sql.rapids.execution.{GpuBroadcastMeta, GpuBroadcastNestedLoopJoinMeta, GpuCustomShuffleReaderExec, GpuShuffleMeta} +import org.apache.spark.sql.rapids.execution.python._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -1287,6 +1289,20 @@ object GpuOverrides { override def convertToGpu(child: Expression): GpuExpression = GpuAverage(child) }), + expr[PythonUDF]( + "UDF run in an external python process. Does not actually run on the GPU, but " + + "the transfer of data to/from it can be accelerated.", + (a, conf, p, r) => new ExprMeta[PythonUDF](a, conf, p, r) { + override def couldReplaceMessage: String = "does not block GPU acceleration" + override def noReplacementPossibleMessage(reasons: String): String = + s"blocks running on GPU because $reasons" + + override def convertToGpu(): GpuExpression = + GpuPythonUDF(a.name, a.func, a.dataType, + childExprs.map(_.convertToGpu()), + a.evalType, a.udfDeterministic, a.resultId) + } + ), expr[Rand]( "Generate a random column with i.i.d. uniformly distributed values in [0, 1)", (a, conf, p, r) => new UnaryExprMeta[Rand](a, conf, p, r) { @@ -1701,6 +1717,28 @@ object GpuOverrides { override def convertToGpu(): GpuExec = GpuLocalLimitExec(localLimitExec.limit, childPlans(0).convertIfNeeded()) }), + exec[ArrowEvalPythonExec]( + "The backend of the Scalar Pandas UDFs, it supports running the Python UDFs code on GPU" + + " when calling cuDF APIs in the UDF, also accelerates the data transfer between the" + + " Java process and Python process", + (e, conf, p, r) => + new SparkPlanMeta[ArrowEvalPythonExec](e, conf, p, r) { + val udfs: Seq[BaseExprMeta[PythonUDF]] = + e.udfs.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + val resultAttrs: Seq[BaseExprMeta[Attribute]] = + e.resultAttrs.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + override val childExprs: Seq[BaseExprMeta[_]] = udfs ++ resultAttrs + + override def couldReplaceMessage: String = "could partially run on GPU" + override def noReplacementPossibleMessage(reasons: String): String = + s"cannot run even partially on the GPU because $reasons" + + override def convertToGpu(): GpuExec = + GpuArrowEvalPythonExec(udfs.map(_.convertToGpu()).asInstanceOf[Seq[GpuPythonUDF]], + resultAttrs.map(_.convertToGpu()).asInstanceOf[Seq[Attribute]], + childPlans.head.convertIfNeeded(), + e.evalType) + }).disabledByDefault("Performance is not ideal for UDFs that take a long time"), exec[GlobalLimitExec]( "Limiting of results across partitions", (globalLimitExec, conf, p, r) => @@ -1780,6 +1818,31 @@ object GpuOverrides { exec.partitionSpecs) } }), + exec[MapInPandasExec]( + "The backend for Map Pandas Iterator UDF, it runs on CPU itself now but supports running" + + " the Python UDFs code on GPU when calling cuDF APIs in the UDF", + (mapPy, conf, p, r) => new GpuMapInPandasExecMeta(mapPy, conf, p, r)) + .disabledByDefault("Performance is not ideal now"), + exec[FlatMapGroupsInPandasExec]( + "The backend for Grouped Map Pandas UDF, it runs on CPU itself now but supports running" + + " the Python UDFs code on GPU when calling cuDF APIs in the UDF", + (flatPy, conf, p, r) => new GpuFlatMapGroupsInPandasExecMeta(flatPy, conf, p, r)) + .disabledByDefault("Performance is not ideal now"), + exec[AggregateInPandasExec]( + "The backend for Grouped Aggregation Pandas UDF, it runs on CPU itself now but supports" + + " running the Python UDFs code on GPU when calling cuDF APIs in the UDF", + (aggPy, conf, p, r) => new GpuAggregateInPandasExecMeta(aggPy, conf, p, r)) + .disabledByDefault("Performance is not ideal now"), + exec[FlatMapCoGroupsInPandasExec]( + "The backend for CoGrouped Aggregation Pandas UDF, it runs on CPU itself now but supports" + + " running the Python UDFs code on GPU when calling cuDF APIs in the UDF", + (flatCoPy, conf, p, r) => new GpuFlatMapCoGroupsInPandasExecMeta(flatCoPy, conf, p, r)) + .disabledByDefault("Performance is not ideal now"), + exec[WindowInPandasExec]( + "The backend for Pandas UDF with window functions, it runs on CPU itself now but supports" + + " running the Python UDFs code on GPU when calling cuDF APIs in the UDF", + (winPy, conf, p, r) => new GpuWindowInPandasExecMeta(winPy, conf, p, r)) + .disabledByDefault("Performance is not ideal now"), neverReplaceExec[AlterNamespaceSetPropertiesExec]("Namespace metadata operation"), neverReplaceExec[CreateNamespaceExec]("Namespace metadata operation"), neverReplaceExec[DescribeNamespaceExec]("Namespace metadata operation"), diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index 688e27aa08e..0d6efd08f12 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -21,6 +21,8 @@ import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import scala.collection.JavaConverters._ +import com.nvidia.spark.rapids.python.PythonWorkerSemaphore + import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext} import org.apache.spark.internal.Logging @@ -139,6 +141,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { override def shutdown(): Unit = { GpuSemaphore.shutdown() + PythonWorkerSemaphore.shutdown() } } @@ -231,4 +234,4 @@ class ExecutionPlanCaptureCallback extends QueryExecutionListener { override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = captureIfNeeded(qe) -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index f8de4941962..efe883d8bc3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -145,6 +145,29 @@ class ConfEntryWithDefault[T](key: String, converter: String => T, doc: String, } } +class OptionalConfEntry[T](key: String, val rawConverter: String => T, doc: String, + isInternal: Boolean) + extends ConfEntry[Option[T]](key, s => Some(rawConverter(s)), doc, isInternal) { + + override def get(conf: Map[String, String]): Option[T] = { + conf.get(key).map(rawConverter) + } + + override def help(asTable: Boolean = false): Unit = { + if (!isInternal) { + if (asTable) { + import ConfHelper.makeConfAnchor + println(s"${makeConfAnchor(key)}|$doc|None") + } else { + println(s"$key:") + println(s"\t$doc") + println("\tNone") + println() + } + } + } +} + class TypedConfBuilder[T]( val parent: ConfBuilder, val converter: String => T, @@ -181,6 +204,13 @@ class TypedConfBuilder[T]( new TypedConfBuilder(parent, ConfHelper.stringToSeq(_, converter), ConfHelper.seqToString(_, stringConverter)) } + + def createOptional: OptionalConfEntry[T] = { + val ret = new OptionalConfEntry[T](parent.key, converter, + parent.doc, parent.isInternal) + parent.register(ret) + ret + } } class ConfBuilder(val key: String, val register: ConfEntry[_] => Unit) { @@ -772,6 +802,8 @@ object RapidsConf { } } def main(args: Array[String]): Unit = { + // Include the configs in PythonConfEntries + com.nvidia.spark.rapids.python.PythonConfEntries.init() val out = new FileOutputStream(new File(args(0))) Console.withOut(out) { Console.withErr(out) { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index f86e6908e5e..ca7495c5c4c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -227,14 +227,15 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE]( private def indent(append: StringBuilder, depth: Int): Unit = append.append(" " * depth) + def couldReplaceMessage: String = "could run on GPU" + def noReplacementPossibleMessage(reasons: String): String = s"cannot run on GPU because $reasons" def suppressWillWorkOnGpuInfo: Boolean = false private def willWorkOnGpuInfo: String = cannotBeReplacedReasons match { case None => "NOT EVALUATED FOR GPU YET" - case Some(v) if v.isEmpty => "could run on GPU" + case Some(v) if v.isEmpty => couldReplaceMessage case Some(v) => - val reasons = v mkString "; " - s"cannot run on GPU because ${reasons}" + noReplacementPossibleMessage(v mkString "; ") } private def willBeRemovedInfo: String = shouldBeRemovedReasons match { @@ -242,7 +243,7 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE]( case Some(v) if v.isEmpty => "" case Some(v) => val reasons = v mkString "; " - s" but is going to be removed because ${reasons}" + s" but is going to be removed because $reasons" } /** @@ -404,7 +405,7 @@ final class RuleNotFoundDataWritingCommandMeta[INPUT <: DataWritingCommand]( extends DataWritingCommandMeta[INPUT](cmd, conf, parent, new NoRuleConfKeysAndIncompat) { override def tagSelfForGpu(): Unit = { - willNotWorkOnGpu(s"no GPU enabled version of command ${cmd.getClass} could be found") + willNotWorkOnGpu(s"no GPU accelerated version of command ${cmd.getClass} could be found") } override def convertToGpu(): GpuDataWritingCommand = diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/PythonConfEntries.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/PythonConfEntries.scala new file mode 100644 index 00000000000..82ed32f5bb0 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/PythonConfEntries.scala @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.python + +import com.nvidia.spark.rapids.RapidsConf.{POOLED_MEM, UVM_ENABLED} +import com.nvidia.spark.rapids.RapidsConf.conf + +object PythonConfEntries { + + val PYTHON_GPU_ENABLED = conf("spark.rapids.sql.python.gpu.enabled") + .doc("This is an experimental feature and is likely to change in the future." + + " Enable (true) or disable (false) support for scheduling Python Pandas UDFs with" + + " GPU resources. When enabled, pandas UDFs are assumed to share the same GPU that" + + " the RAPIDs accelerator uses and will honor the python GPU configs") + .booleanConf + .createWithDefault(false) + + val CONCURRENT_PYTHON_WORKERS = conf("spark.rapids.python.concurrentPythonWorkers") + .doc("Set the number of Python worker processes that can execute concurrently per GPU. " + + "Python worker processes may temporarily block when the number of concurrent Python " + + "worker processes started by the same executor exceeds this amount. Allowing too " + + "many concurrent tasks on the same GPU may lead to GPU out of memory errors. " + + ">0 means enabled, while <=0 means unlimited") + .integerConf + .createWithDefault(0) + + val PYTHON_RMM_ALLOC_FRACTION = conf("spark.rapids.python.memory.gpu.allocFraction") + .doc("The fraction of total GPU memory that should be initially allocated " + + "for pooled memory for all the Python workers. It supposes to be less than " + + "(1 - $(spark.rapids.memory.gpu.allocFraction)), since the executor will share the " + + "GPU with its owning Python workers. Half of the rest will be used if not specified") + .doubleConf + .checkValue(v => v >= 0 && v <= 1, "The fraction value for Python workers must be in [0, 1].") + .createOptional + + val PYTHON_RMM_MAX_ALLOC_FRACTION = conf("spark.rapids.python.memory.gpu.maxAllocFraction") + .doc("The fraction of total GPU memory that limits the maximum size of the RMM pool " + + "for all the Python workers. It supposes to be less than " + + "(1 - $(spark.rapids.memory.gpu.maxAllocFraction)), since the executor will share the " + + "GPU with its owning Python workers. when setting to 0 it means no limit.") + .doubleConf + .checkValue(v => v >= 0 && v <= 1, "The value of maxAllocFraction for Python workers must be" + + " in [0, 1].") + .createWithDefault(0.0) + + val PYTHON_POOLED_MEM = conf("spark.rapids.python.memory.gpu.pooling.enabled") + .doc("Should RMM in Python workers act as a pooling allocator for GPU memory, or" + + " should it just pass through to CUDA memory allocation directly. When not specified," + + s" It will honor the value of config '${POOLED_MEM.key}'") + .booleanConf + .createOptional + + val PYTHON_UVM_ENABLED = conf("spark.rapids.python.memory.uvm.enabled") + .doc(s"Similar with '${UVM_ENABLED.key}', but this conf is for" + + s" python workers. When not specified, it will honor the value of config" + + s" '${UVM_ENABLED.key}'. This is an experimental feature.") + .internal() + .booleanConf + .createOptional + + // An empty function called by RapidsConf to initialize the config definitions above for + // doc generation + def init(): Unit = {} +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/PythonWorkerSemaphore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/PythonWorkerSemaphore.scala new file mode 100644 index 00000000000..c8594aba03c --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/python/PythonWorkerSemaphore.scala @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.python + +import java.util.concurrent.{ConcurrentHashMap, Semaphore} + +import com.nvidia.spark.rapids.RapidsConf +import com.nvidia.spark.rapids.python.PythonConfEntries.CONCURRENT_PYTHON_WORKERS +import org.apache.commons.lang3.mutable.MutableInt + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.internal.Logging + +/* + * PythonWorkerSemaphore is used to limit the number of Python workers(processes) to be started + * by an executor. + * + * This PythonWorkerSemaphore will not initialize the GPU, different from GpuSemaphore. Since + * tasks calling the API `acquireIfNecessary` are supposed not to use the GPU directly, but + * delegate the permits to the Python workers respectively. + * + * Call `acquireIfNecessary` or `releaseIfNecessary` directly when needed, since the inner + * semaphore will be initialized implicitly, but need to call `shutdown` explicitly to release + * the inner semaphore when no longer needed. + * + */ +object PythonWorkerSemaphore extends Logging { + + private lazy val rapidsConf = new RapidsConf(SparkEnv.get.conf) + private lazy val workersPerGpu = rapidsConf.get(CONCURRENT_PYTHON_WORKERS) + private lazy val enabled = workersPerGpu > 0 + + // DO NOT ACCESS DIRECTLY! Use `getInstance` instead. + @volatile + private var instance: PythonWorkerSemaphore = _ + + private def getInstance(): PythonWorkerSemaphore = { + if (instance == null) { + synchronized { + if (instance == null) { + logDebug(s"Initialize the python workers semaphore with number: $workersPerGpu") + instance = new PythonWorkerSemaphore(workersPerGpu) + } + } + } + instance + } + + /* + * Tasks must call this when they begin to start a Python worker who will use GPU. + * If the task has not already acquired the GPU semaphore then it is acquired, + * blocking if necessary. + * NOTE: A task completion listener will automatically be installed to ensure + * the semaphore is always released by the time the task completes. + */ + def acquireIfNecessary(context: TaskContext): Unit = { + if (enabled && context != null) { + getInstance.acquireIfNecessary(context) + } + } + + /* + * Tasks must call this when they are finished using the GPU. + */ + def releaseIfNecessary(context: TaskContext): Unit = { + if (enabled && context != null) { + getInstance.releaseIfNecessary(context) + } + } + + /* + * Release the inner semaphore. + * NOTE: This does not wait for active tasks to release! + */ + def shutdown(): Unit = synchronized { + if (instance != null) { + instance.shutdown() + instance = null + } + } +} + +private final class PythonWorkerSemaphore(tasksPerGpu: Int) extends Logging { + private val semaphore = new Semaphore(tasksPerGpu) + // Map to track which tasks have acquired the semaphore. + private val activeTasks = new ConcurrentHashMap[Long, MutableInt] + + def acquireIfNecessary(context: TaskContext): Unit = { + val taskAttemptId = context.taskAttemptId() + val refs = activeTasks.get(taskAttemptId) + if (refs == null) { + // first time this task has been seen + activeTasks.put(taskAttemptId, new MutableInt(1)) + context.addTaskCompletionListener[Unit](completeTask) + } else { + refs.increment() + } + logDebug(s"Task $taskAttemptId acquiring GPU for python worker") + semaphore.acquire() + } + + def releaseIfNecessary(context: TaskContext): Unit = { + val taskAttemptId = context.taskAttemptId() + val refs = activeTasks.get(taskAttemptId) + if (refs != null && refs.getValue > 0) { + logDebug(s"Task $taskAttemptId releasing GPU for python worker") + semaphore.release(refs.getValue) + refs.setValue(0) + } + } + + def completeTask(context: TaskContext): Unit = { + val taskAttemptId = context.taskAttemptId() + val refs = activeTasks.remove(taskAttemptId) + if (refs == null) { + throw new IllegalStateException(s"Completion of unknown task $taskAttemptId") + } + if (refs.getValue > 0) { + logDebug(s"Task $taskAttemptId releasing all GPU resources for python worker") + semaphore.release(refs.getValue) + } + } + + def shutdown(): Unit = { + if (!activeTasks.isEmpty) { + logDebug(s"Shutting down Python worker semaphore with ${activeTasks.size} " + + s"tasks still registered") + } + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/execution/python/rapids/GpuPandasUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/python/rapids/GpuPandasUtils.scala new file mode 100644 index 00000000000..acf2c211bcb --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/execution/python/rapids/GpuPandasUtils.scala @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python.rapids + +import org.apache.spark.api.python.BasePythonRunner +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.python.PandasGroupUtils +import org.apache.spark.sql.vectorized.ColumnarBatch + +/* + * This is to expose the APIs of PandasGroupUtils to rapids Execs + */ +private[sql] object GpuPandasUtils { + + def executePython[T]( + data: Iterator[T], + output: Seq[Attribute], + runner: BasePythonRunner[T, ColumnarBatch]): Iterator[InternalRow] = { + PandasGroupUtils.executePython(data, output, runner) + } + + def groupAndProject( + input: Iterator[InternalRow], + groupingAttributes: Seq[Attribute], + inputSchema: Seq[Attribute], + dedupSchema: Seq[Attribute]): + Iterator[(InternalRow, Iterator[InternalRow])] = { + PandasGroupUtils.groupAndProject(input, groupingAttributes, inputSchema, dedupSchema) + } + + def resolveArgOffsets( + child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { + PandasGroupUtils.resolveArgOffsets(child, groupingAttributes) + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala new file mode 100644 index 00000000000..c3be87167aa --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution.python + +import java.io.File + +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.python.PythonWorkerSemaphore +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, + Distribution, Partitioning} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.python.{AggregateInPandasExec, ArrowPythonRunner} +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils + + +class GpuAggregateInPandasExecMeta( + aggPandas: AggregateInPandasExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: ConfKeysAndIncompat) + extends SparkPlanMeta[AggregateInPandasExec](aggPandas, conf, parent, rule) { + + override def couldReplaceMessage: String = "could partially run on GPU" + override def noReplacementPossibleMessage(reasons: String): String = + s"cannot run even partially on the GPU because $reasons" + + // Ignore the expressions since columnar way is not supported yet + override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty + + override def convertToGpu(): GpuExec = + GpuAggregateInPandasExec( + aggPandas.groupingExpressions, + aggPandas.udfExpressions, + aggPandas.resultExpressions, + childPlans.head.convertIfNeeded() + ) +} + +/* + * This GpuAggregateInPandasExec aims at supporting running Pandas UDF code + * on GPU at Python side. + * + * (Currently it will not run on GPU itself, since the columnar way is not implemented yet.) + * + */ +case class GpuAggregateInPandasExec( + groupingExpressions: Seq[NamedExpression], + udfExpressions: Seq[PythonUDF], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryExecNode with GpuExec { + + override def supportsColumnar = false + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + throw new IllegalStateException(s"Columnar execution is not supported by $this yet") + } + + // Most code is copied from AggregateInPandasExec, except two GPU related calls + override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingExpressions.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + lazy val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf) + val inputRDD = child.execute() + + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + + // Schema of input rows to the python runner + val aggInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }.toSeq) + + // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty + inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { + val prunedProj = UnsafeProjection.create(allInputs.toSeq, child.output) + + val grouped = if (groupingExpressions.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, groupingExpressions, child.output) + }.map { case (key, rows) => + (key, rows.map(prunedProj)) + } + + val context = TaskContext.get() + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) + context.addTaskCompletionListener[Unit] { _ => + queue.close() + } + + // Add rows to queue to join later with the result. + val projectedRowIter = grouped.map { case (groupingKey, rows) => + queue.add(groupingKey.asInstanceOf[UnsafeRow]) + rows + } + + // Start of GPU things + if (isPythonOnGpuEnabled) { + GpuPythonHelper.injectGpuInfo(pyFuncs, isPythonOnGpuEnabled) + PythonWorkerSemaphore.acquireIfNecessary(context) + } + // End of GPU things + + val columnarBatchIter = new ArrowPythonRunner( + pyFuncs, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + argOffsets, + aggInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) + + val joinedAttributes = + groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes) + + columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, aggOutputRow) + resultProj(joinedRow) + } + }} + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala new file mode 100644 index 00000000000..73acdef228e --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala @@ -0,0 +1,495 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution.python + +import java.io.{DataInputStream, DataOutputStream} +import java.net.Socket +import java.util.concurrent.atomic.AtomicBoolean + +import ai.rapids.cudf._ +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.GpuMetricNames._ +import com.nvidia.spark.rapids.python.PythonWorkerSemaphore +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.python.PythonUDFRunner +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils + + +class RebatchingIterator( + wrapped: Iterator[ColumnarBatch], + targetRowSize: Int, + inputRows: SQLMetric, + inputBatches: SQLMetric) + extends Iterator[ColumnarBatch] with Arm { + var pending: Table = _ + + override def hasNext: Boolean = pending != null || wrapped.hasNext + + private[this] def updatePending(): Unit = { + if (pending == null) { + withResource(wrapped.next()) { cb => + inputBatches += 1 + inputRows += cb.numRows() + pending = GpuColumnVector.from(cb) + } + } + } + + override def next(): ColumnarBatch = { + updatePending() + + while (pending.getRowCount < targetRowSize) { + if (wrapped.hasNext) { + val combined = withResource(wrapped.next()) { cb => + inputBatches += 1 + inputRows += cb.numRows() + withResource(GpuColumnVector.from(cb)) { nextTable => + Table.concatenate(pending, nextTable) + } + } + pending.close() + pending = combined + } else { + // No more to data so return what is left + val ret = withResource(pending) { p => + GpuColumnVector.from(p) + } + pending = null + return ret + } + } + + // We got lucky + if (pending.getRowCount == targetRowSize) { + val ret = withResource(pending) { p => + GpuColumnVector.from(p) + } + pending = null + return ret + } + + val split = pending.contiguousSplit(targetRowSize) + split.foreach(_.getBuffer.close()) + pending.close() + pending = split(1).getTable + withResource(split.head.getTable) { ret => + GpuColumnVector.from(ret) + } + } +} + +// TODO extend this with spilling and other wonderful things +class BatchQueue extends AutoCloseable { + // TODO for now we will use an built in queue + private val queue: mutable.Queue[ColumnarBatch] = mutable.Queue[ColumnarBatch]() + + def add(batch: ColumnarBatch): Unit = synchronized { + // If you cannot add something blow up + queue.enqueue(batch) + } + + def remove(): ColumnarBatch = synchronized { + if (queue.isEmpty) { + null + } else { + queue.dequeue() + } + } + + override def close(): Unit = synchronized { + while(queue.nonEmpty) { + queue.dequeue().close() + } + } +} + +/* + * Helper functions for [[GpuPythonUDF]] + */ +object GpuPythonUDF { + private[this] val SCALAR_TYPES = Set( + PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF + ) + + def isScalarPythonUDF(e: Expression): Boolean = { + e.isInstanceOf[GpuPythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[GpuPythonUDF].evalType) + } + + def isGroupedAggPandasUDF(e: Expression): Boolean = { + e.isInstanceOf[GpuPythonUDF] && + e.asInstanceOf[GpuPythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF + } + + // This is currently same as GroupedAggPandasUDF, but we might support new types in the future, + // e.g, N -> N transform. + def isWindowPandasUDF(e: Expression): Boolean = isGroupedAggPandasUDF(e) +} + +/* + * A serialized version of a Python lambda function. This is a special expression, which needs a + * dedicated physical operator to execute it, and thus can't be pushed down to data sources. + */ +case class GpuPythonUDF( + name: String, + func: PythonFunction, + dataType: DataType, + children: Seq[Expression], + evalType: Int, + udfDeterministic: Boolean, + resultId: ExprId = NamedExpression.newExprId) + extends Expression with GpuUnevaluable with NonSQLExpression with UserDefinedExpression { + + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) + + override def toString: String = s"$name(${children.mkString(", ")})" + + lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)( + exprId = resultId) + + override def nullable: Boolean = true + + override lazy val canonicalized: Expression = { + val canonicalizedChildren = children.map(_.canonicalized) + // `resultId` can be seen as cosmetic variation in PythonUDF, as it doesn't affect the result. + this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren) + } +} + +/* + * A trait that can be mixed-in with `BasePythonRunner`. It implements the logic from + * Python (Arrow) to GPU/JVM (ColumnarBatch). + */ +trait GpuPythonArrowOutput extends Arm { self: BasePythonRunner[_, ColumnarBatch] => + + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[ColumnarBatch] = { + + new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { + + private[this] var arrowReader: StreamedTableReader = _ + + context.addTaskCompletionListener[Unit] { _ => + if (arrowReader != null) { + arrowReader.close() + arrowReader = null + } + } + + private var batchLoaded = true + + protected override def read(): ColumnarBatch = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (arrowReader != null && batchLoaded) { + val table = + withResource(new NvtxRange("read python batch", NvtxColor.DARK_GREEN)) { _ => + arrowReader.getNextIfAvailable + } + if (table == null) { + batchLoaded = false + arrowReader.close() + arrowReader = null + read() + } else { + withResource(table) { table => + batchLoaded = true + GpuColumnVector.from(table) + } + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + arrowReader = Table.readArrowIPCChunked(new StreamToBufferProvider(stream)) + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + } + } +} + + +/* + * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. + */ +class GpuArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + timeZoneId: String, + conf: Map[String, String]) + extends BasePythonRunner[ColumnarBatch, ColumnarBatch](funcs, evalType, argOffsets) + with GpuPythonArrowOutput { + + override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize + require( + bufferSize >= 4, + "Pandas execution requires more than 4 bytes. Please set higher buffer. " + + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[ColumnarBatch], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + + // Write config for the worker as a number of key -> value pairs of strings + dataOut.writeInt(conf.size) + for ((k, v) <- conf) { + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val writer = { + val builder = ArrowIPCWriterOptions.builder() + schema.foreach { field => + if (field.nullable) { + builder.withColumnNames(field.name) + } else { + builder.withNotNullableColumnNames(field.name) + } + } + Table.writeArrowIPCChunked(builder.build(), new BufferToStreamWriter(dataOut)) + } + Utils.tryWithSafeFinally { + while(inputIterator.hasNext) { + withResource(inputIterator.next()) { nextBatch => + withResource(GpuColumnVector.from(nextBatch)) { table => + withResource(new NvtxRange("write python batch", NvtxColor.DARK_GREEN)) { _ => + writer.write(table) + } + } + } + } + } { + writer.close() + dataOut.flush() + } + } + } + } +} + +class BufferToStreamWriter(outputStream: DataOutputStream) extends HostBufferConsumer with Arm { + private[this] val tempBuffer = new Array[Byte](128 * 1024) + + override def handleBuffer(hostBuffer: HostMemoryBuffer, length: Long): Unit = { + withResource(hostBuffer) { buffer => + var len = length + var offset: Long = 0 + while(len > 0) { + val toCopy = math.min(tempBuffer.length, len).toInt + buffer.getBytes(tempBuffer, 0, offset, toCopy) + outputStream.write(tempBuffer, 0, toCopy) + len = len - toCopy + offset = offset + toCopy + } + } + } +} + +class StreamToBufferProvider(inputStream: DataInputStream) extends HostBufferProvider { + private[this] val tempBuffer = new Array[Byte](128 * 1024) + + override def readInto(hostBuffer: HostMemoryBuffer, length: Long): Long = { + var amountLeft = length + var totalRead : Long = 0 + while (amountLeft > 0) { + val amountToRead = Math.min(tempBuffer.length, amountLeft).toInt + val amountRead = inputStream.read(tempBuffer, 0, amountToRead) + if (amountRead <= 0) { + // Reached EOF + amountLeft = 0 + } else { + amountLeft -= amountRead + hostBuffer.setBytes(totalRead, tempBuffer, 0, amountRead) + totalRead += amountRead + } + } + totalRead + } +} + +/* + * A physical plan that evaluates a [[GpuPythonUDF]]. The transformation of the data to arrow + * happens on the GPU (practically a noop), But execution of the UDFs are on the CPU or GPU. + */ +case class GpuArrowEvalPythonExec( + udfs: Seq[GpuPythonUDF], + resultAttrs: Seq[Attribute], + child: SparkPlan, + evalType: Int) extends UnaryExecNode with GpuExec { + + // We split the input batch up into small pieces when sending to python for compatibility reasons + override def coalesceAfter: Boolean = true + + override def output: Seq[Attribute] = child.output ++ resultAttrs + + override def producedAttributes: AttributeSet = AttributeSet(resultAttrs) + + private def collectFunctions(udf: GpuPythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: GpuPythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[GpuPythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + private val batchSize = conf.arrowMaxRecordsPerBatch + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + + override protected def doExecute(): RDD[InternalRow] = + throw new IllegalStateException(s"Row-based execution should not occur for $this") + + override lazy val metrics: Map[String, SQLMetric] = Map( + NUM_OUTPUT_ROWS -> SQLMetrics.createMetric(sparkContext, DESCRIPTION_NUM_OUTPUT_ROWS), + NUM_OUTPUT_BATCHES -> SQLMetrics.createMetric(sparkContext, DESCRIPTION_NUM_OUTPUT_BATCHES), + NUM_INPUT_ROWS -> SQLMetrics.createMetric(sparkContext, DESCRIPTION_NUM_INPUT_ROWS), + NUM_INPUT_BATCHES -> SQLMetrics.createMetric(sparkContext, DESCRIPTION_NUM_INPUT_BATCHES) + ) + + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric(NUM_OUTPUT_ROWS) + val numOutputBatches = longMetric(NUM_OUTPUT_BATCHES) + val numInputRows = longMetric(NUM_INPUT_ROWS) + val numInputBatches = longMetric(NUM_INPUT_BATCHES) + + lazy val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf) + val inputRDD = child.executeColumnar() + inputRDD.mapPartitions { iter => + val queue: BatchQueue = new BatchQueue() + val context = TaskContext.get() + context.addTaskCompletionListener[Unit](_ => queue.close()) + + val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip + + // Not sure why we are doing this in every task. It is not going to change, but it might + // just be less that we have to ship. + + // flatten all the arguments + val allInputs = new ArrayBuffer[Expression] + // TODO eventually we should just do type checking on these, but that can get a little complex + // with how things are setup for replacement... + // perhaps it needs to be with the special, it is an gpu compatible expression, but not a + // gpu expression... + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + + val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + + val boundReferences = GpuBindReferences.bindReferences(allInputs, child.output) + val batchedIterator = new RebatchingIterator(iter, batchSize, numInputRows, numInputBatches) + val projectedIterator = batchedIterator.map { batch => + queue.add(batch) + GpuProjectExec.project(batch, boundReferences) + } + + if (isPythonOnGpuEnabled) { + GpuPythonHelper.injectGpuInfo(pyFuncs, isPythonOnGpuEnabled) + PythonWorkerSemaphore.acquireIfNecessary(context) + } + + val outputBatchIterator = new GpuArrowPythonRunner( + pyFuncs, + evalType, + argOffsets, + schema, + sessionLocalTimeZone, + pythonRunnerConf).compute(projectedIterator, + context.partitionId(), + context) + + outputBatchIterator.map { outputBatch => + withResource(outputBatch) { outBatch => + withResource(queue.remove()) { origBatch => + val rows = origBatch.numRows() + assert(outBatch.numRows() == rows) + val lColumns = GpuColumnVector.extractColumns(origBatch) + val rColumns = GpuColumnVector.extractColumns(outBatch) + numOutputBatches += 1 + numOutputRows += rows + new ColumnarBatch(lColumns.map(_.incRefCount()) ++ rColumns.map(_.incRefCount()), + rows) + } + } + } + } + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapCoGroupsInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapCoGroupsInPandasExec.scala new file mode 100644 index 00000000000..49bd548f8fb --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapCoGroupsInPandasExec.scala @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution.python + +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.python.PythonWorkerSemaphore + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, + Distribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan} +import org.apache.spark.sql.execution.python.{CoGroupedArrowPythonRunner, + FlatMapCoGroupsInPandasExec} +import org.apache.spark.sql.execution.python.rapids.GpuPandasUtils._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.ColumnarBatch + + +class GpuFlatMapCoGroupsInPandasExecMeta( + flatPandas: FlatMapCoGroupsInPandasExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: ConfKeysAndIncompat) + extends SparkPlanMeta[FlatMapCoGroupsInPandasExec](flatPandas, conf, parent, rule) { + + override def couldReplaceMessage: String = "could partially run on GPU" + override def noReplacementPossibleMessage(reasons: String): String = + s"cannot run even partially on the GPU because $reasons" + + // Ignore the expressions since columnar way is not supported yet + override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty + + override def convertToGpu(): GpuExec = + GpuFlatMapCoGroupsInPandasExec( + flatPandas.leftGroup, flatPandas.rightGroup, + flatPandas.func, + flatPandas.output, + childPlans.head.convertIfNeeded(), childPlans(1).convertIfNeeded() + ) +} + +/* + * + * This GpuFlatMapCoGroupsInPandasExec aims at supporting running Pandas functional code + * on GPU at Python side. + * + * (Currently it will not run on GPU itself, since the columnar way is not implemented yet.) + * + */ +case class GpuFlatMapCoGroupsInPandasExec( + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + func: Expression, + output: Seq[Attribute], + left: SparkPlan, + right: SparkPlan) + extends SparkPlan with BinaryExecNode with GpuExec { + + override def supportsColumnar = false + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + throw new IllegalStateException(s"Columnar execution is not supported by $this yet") + } + + // Most code is copied from FlatMapCoGroupsInPandasExec, except two GPU related calls + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + private val pandasFunction = func.asInstanceOf[PythonUDF].func + private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + val leftDist = if (leftGroup.isEmpty) AllTuples else ClusteredDistribution(leftGroup) + val rightDist = if (rightGroup.isEmpty) AllTuples else ClusteredDistribution(rightGroup) + leftDist :: rightDist :: Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + leftGroup + .map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil + } + + override protected def doExecute(): RDD[InternalRow] = { + lazy val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf) + + val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup) + val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup) + + // Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty + left.execute().zipPartitions(right.execute()) { (leftData, rightData) => + if (leftData.isEmpty && rightData.isEmpty) Iterator.empty else { + + val leftGrouped = groupAndProject(leftData, leftGroup, left.output, leftDedup) + val rightGrouped = groupAndProject(rightData, rightGroup, right.output, rightDedup) + val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup) + .map { case (_, l, r) => (l, r) } + + // Start of GPU things + if (isPythonOnGpuEnabled) { + GpuPythonHelper.injectGpuInfo(chainedFunc, isPythonOnGpuEnabled) + PythonWorkerSemaphore.acquireIfNecessary(TaskContext.get()) + } + // End of GPU things + + val runner = new CoGroupedArrowPythonRunner( + chainedFunc, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + Array(leftArgOffsets ++ rightArgOffsets), + StructType.fromAttributes(leftDedup), + StructType.fromAttributes(rightDedup), + sessionLocalTimeZone, + pythonRunnerConf) + + executePython(data, output, runner) + } + } + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala new file mode 100644 index 00000000000..485b747b67a --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapGroupsInPandasExec.scala @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution.python + +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.python.PythonWorkerSemaphore + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, + Distribution, Partitioning} +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, FlatMapGroupsInPandasExec} +import org.apache.spark.sql.execution.python.rapids.GpuPandasUtils._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.ColumnarBatch + + +class GpuFlatMapGroupsInPandasExecMeta( + flatPandas: FlatMapGroupsInPandasExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: ConfKeysAndIncompat) + extends SparkPlanMeta[FlatMapGroupsInPandasExec](flatPandas, conf, parent, rule) { + + override def couldReplaceMessage: String = "could partially run on GPU" + override def noReplacementPossibleMessage(reasons: String): String = + s"cannot run even partially on the GPU because $reasons" + + // Ignore the expressions since columnar way is not supported yet + override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty + + override def convertToGpu(): GpuExec = + GpuFlatMapGroupsInPandasExec( + flatPandas.groupingAttributes, + flatPandas.func, + flatPandas.output, + childPlans.head.convertIfNeeded() + ) +} + +/* + * + * This GpuFlatMapGroupsInPandasExec aims at supporting running Pandas functional code + * on GPU at Python side. + * + * (Currently it will not run on GPU itself, since the columnar way is not implemented yet.) + * + */ +case class GpuFlatMapGroupsInPandasExec( + groupingAttributes: Seq[Attribute], + func: Expression, + output: Seq[Attribute], + child: SparkPlan) + extends SparkPlan with UnaryExecNode with GpuExec { + + override def supportsColumnar = false + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + throw new IllegalStateException(s"Columnar execution is not supported by $this yet") + } + + // Most code is copied from FlatMapGroupsInPandasExec, except two GPU related calls + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + private val pandasFunction = func.asInstanceOf[PythonUDF].func + private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingAttributes.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingAttributes) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + lazy val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf) + val inputRDD = child.execute() + + val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) + + // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty + inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { + + val data = groupAndProject(iter, groupingAttributes, child.output, dedupAttributes) + .map { case (_, x) => x } + + // Start of GPU things + if (isPythonOnGpuEnabled) { + GpuPythonHelper.injectGpuInfo(chainedFunc, isPythonOnGpuEnabled) + PythonWorkerSemaphore.acquireIfNecessary(TaskContext.get()) + } + // End of GPU things + + val runner = new ArrowPythonRunner( + chainedFunc, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + Array(argOffsets), + StructType.fromAttributes(dedupAttributes), + sessionLocalTimeZone, + pythonRunnerConf) + + executePython(data, output, runner) + }} + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInPandasExec.scala new file mode 100644 index 00000000000..b6b2e5b0e90 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInPandasExec.scala @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution.python + +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.python.PythonWorkerSemaphore +import scala.collection.JavaConverters._ + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, + PythonUDF, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.python._ +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + +class GpuMapInPandasExecMeta( + mapPandas: MapInPandasExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: ConfKeysAndIncompat) + extends SparkPlanMeta[MapInPandasExec](mapPandas, conf, parent, rule) { + + override def couldReplaceMessage: String = "could partially run on GPU" + override def noReplacementPossibleMessage(reasons: String): String = + s"cannot run even partially on the GPU because $reasons" + + // Ignore the udf since columnar way is not supported yet + override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty + + override def convertToGpu(): GpuExec = + GpuMapInPandasExec( + mapPandas.func, + mapPandas.output, + childPlans.head.convertIfNeeded() + ) +} + +/* + * A relation produced by applying a function that takes an iterator of pandas DataFrames + * and outputs an iterator of pandas DataFrames. + * + * This GpuMapInPandasExec aims at supporting running Pandas functional code + * on GPU at Python side. + * + * (Currently it will not run on GPU itself, since the columnar way is not implemented yet.) + * + */ +case class GpuMapInPandasExec( + func: Expression, + output: Seq[Attribute], + child: SparkPlan) + extends UnaryExecNode with GpuExec { + + override def supportsColumnar = false + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + throw new IllegalStateException(s"Columnar execution is not supported by $this yet") + } + + // Most code is copied from MapInPandasExec, except two GPU related calls + private val pandasFunction = func.asInstanceOf[PythonUDF].func + + override def producedAttributes: AttributeSet = AttributeSet(output) + + private val batchSize = conf.arrowMaxRecordsPerBatch + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override protected def doExecute(): RDD[InternalRow] = { + lazy val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf) + child.execute().mapPartitionsInternal { inputIter => + // Single function with one struct. + val argOffsets = Array(Array(0)) + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + val outputTypes = child.schema + + // Here we wrap it via another row so that Python sides understand it + // as a DataFrame. + val wrappedIter = inputIter.map(InternalRow(_)) + + // DO NOT use iter.grouped(). See BatchIterator. + val batchIter = + if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter) + + val context = TaskContext.get() + + // Start of GPU things + if (isPythonOnGpuEnabled) { + GpuPythonHelper.injectGpuInfo(chainedFunc, isPythonOnGpuEnabled) + PythonWorkerSemaphore.acquireIfNecessary(context) + } + // End of GPU things + + val columnarBatchIter = new ArrowPythonRunner( + chainedFunc, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + argOffsets, + StructType(StructField("struct", outputTypes) :: Nil), + sessionLocalTimeZone, + pythonRunnerConf).compute(batchIter, context.partitionId(), context) + + val unsafeProj = UnsafeProjection.create(output, output) + + columnarBatchIter.flatMap { batch => + // Scalar Iterator UDF returns a StructType column in ColumnarBatch, select + // the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = output.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + flattenedBatch.rowIterator.asScala + }.map(unsafeProj) + } + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonHelper.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonHelper.scala new file mode 100644 index 00000000000..0743e33b5ad --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuPythonHelper.scala @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution.python + +import ai.rapids.cudf.Cuda +import com.nvidia.spark.rapids.{GpuDeviceManager, RapidsConf} +import com.nvidia.spark.rapids.python.PythonConfEntries._ + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.{CPUS_PER_TASK, EXECUTOR_CORES} +import org.apache.spark.internal.config.Python._ +import org.apache.spark.sql.internal.SQLConf + +object GpuPythonHelper extends Logging { + + private val sparkConf = SparkEnv.get.conf + private lazy val rapidsConf = new RapidsConf(sparkConf) + private lazy val gpuId = GpuDeviceManager.getDeviceId() + .getOrElse(throw new IllegalStateException("No gpu id!")) + .toString + private lazy val isPythonPooledMemEnabled = rapidsConf.get(PYTHON_POOLED_MEM) + .getOrElse(rapidsConf.isPooledMemEnabled) + .toString + private lazy val isPythonUvmEnabled = rapidsConf.get(PYTHON_UVM_ENABLED) + .getOrElse(rapidsConf.isUvmEnabled) + .toString + private lazy val (initAllocPerWorker, maxAllocPerWorker) = { + val info = Cuda.memGetInfo() + val maxFactionTotal = rapidsConf.get(PYTHON_RMM_MAX_ALLOC_FRACTION) + val maxAllocTotal = (maxFactionTotal * info.total).toLong + // Initialize pool size for all pythons workers. If the fraction is not set, + // use half of the free memory as default. + val initAllocTotal = rapidsConf.get(PYTHON_RMM_ALLOC_FRACTION) + .map { fraction => + if (0 < maxFactionTotal && maxFactionTotal < fraction) { + throw new IllegalArgumentException(s"The value of '$PYTHON_RMM_MAX_ALLOC_FRACTION' " + + s"should not be less than that of '$PYTHON_RMM_ALLOC_FRACTION', but found " + + s"$maxFactionTotal < $fraction") + } + (fraction * info.total).toLong + } + .getOrElse((0.5 * info.free).toLong) + if (initAllocTotal > info.free) { + logWarning(s"Initial RMM allocation(${initAllocTotal / 1024.0 / 1024} MB) for " + + s"all the Python workers is larger than free memory(${info.free / 1024.0 / 1024} MB)") + } else { + logDebug(s"Configure ${initAllocTotal / 1024.0 / 1024}MB GPU memory for " + + s"all the Python workers.") + } + + // Calculate the pool size for each Python worker. + val concurrentPythonWorkers = rapidsConf.get(CONCURRENT_PYTHON_WORKERS) + if (0 < concurrentPythonWorkers) { + (initAllocTotal / concurrentPythonWorkers, maxAllocTotal / concurrentPythonWorkers) + } else { + // When semaphore is disabled or invalid, use the number of cpu task slots instead. + // Spark does not throw exception even the value of CPUS_PER_TASK is negative, so + // return 1 if it is less than zero to continue the task. + val cpuTaskSlots = sparkConf.get(EXECUTOR_CORES) / Math.max(1, sparkConf.get(CPUS_PER_TASK)) + (initAllocTotal / cpuTaskSlots, maxAllocTotal / cpuTaskSlots) + } + } + + def isPythonOnGpuEnabled(sqlConf: SQLConf): Boolean = { + val pythonEnabled = new RapidsConf(sqlConf).get(PYTHON_GPU_ENABLED) + if (pythonEnabled) { + checkPythonConfigs(sparkConf) + } + pythonEnabled + } + + // Called in each task at the executor side + def injectGpuInfo(funcs: Seq[ChainedPythonFunctions], isPythonOnGpuEnabled: Boolean): Unit = { + // Insert GPU related env(s) into `envVars` for all the PythonFunction(s). + // Yes `PythonRunner` will only use the first one, but just make sure it will + // take effect no matter the order changes or not. + funcs.foreach(_.funcs.foreach { pyF => + pyF.envVars.put("CUDA_VISIBLE_DEVICES", gpuId) + pyF.envVars.put("RAPIDS_PYTHON_ENABLED", isPythonOnGpuEnabled.toString) + pyF.envVars.put("RAPIDS_UVM_ENABLED", isPythonUvmEnabled) + pyF.envVars.put("RAPIDS_POOLED_MEM_ENABLED", isPythonPooledMemEnabled) + pyF.envVars.put("RAPIDS_POOLED_MEM_SIZE", initAllocPerWorker.toString) + pyF.envVars.put("RAPIDS_POOLED_MEM_MAX_SIZE", maxAllocPerWorker.toString) + }) + } + + // Check the related conf(s) to launch our rapids daemon or worker for + // the GPU initialization when python on gpu enabled. + // - python worker module if useDaemon is false, otherwise + // - python daemon module. + private[sql] def checkPythonConfigs(conf: SparkConf): Unit = synchronized { + val useDaemon = { + val useDaemonEnabled = conf.get(PYTHON_USE_DAEMON) + // This flag is ignored on Windows as it's unable to fork. + !System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled + } + if (useDaemon) { + val oDaemon = conf.get(PYTHON_DAEMON_MODULE) + if (oDaemon.nonEmpty) { + val daemon = oDaemon.get + if (daemon != "rapids.daemon") { + throw new IllegalArgumentException("Python daemon module config conflicts." + + s" Expect 'rapids.daemon' but set to $daemon") + } + } else { + // Set daemon only when not specified + conf.set(PYTHON_DAEMON_MODULE, "rapids.daemon") + } + } else { + val oWorker = conf.get(PYTHON_WORKER_MODULE) + if (oWorker.nonEmpty) { + val worker = oWorker.get + if (worker != "rapids.worker") { + throw new IllegalArgumentException("Python worker module config conflicts." + + s" Expect 'rapids.worker' but set to $worker") + } + } else { + // Set worker only when not specified + conf.set(PYTHON_WORKER_MODULE, "rapids.worker") + } + } + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExec.scala new file mode 100644 index 00000000000..837814f0825 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExec.scala @@ -0,0 +1,407 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution.python + +import java.io.File + +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.python.PythonWorkerSemaphore +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, + Distribution, Partitioning} +import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} +import org.apache.spark.sql.execution.python._ +import org.apache.spark.sql.execution.window._ +import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils + +class GpuWindowInPandasExecMeta( + winPandas: WindowInPandasExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: ConfKeysAndIncompat) + extends SparkPlanMeta[WindowInPandasExec](winPandas, conf, parent, rule) { + + override def couldReplaceMessage: String = "could partially run on GPU" + override def noReplacementPossibleMessage(reasons: String): String = + s"cannot run even partially on the GPU because $reasons" + + // Ignore the expressions since columnar way is not supported yet + override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty + + override def convertToGpu(): GpuExec = + GpuWindowInPandasExec( + winPandas.windowExpression, + winPandas.partitionSpec, + winPandas.orderSpec, + childPlans.head.convertIfNeeded() + ) +} + +/* + * This GpuWindowInPandasExec aims at supporting running Pandas UDF code + * on GPU at Python side. + * + * (Currently it will not run on GPU itself, since the columnar way is not implemented yet.) + * + */ +case class GpuWindowInPandasExec( + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) + extends WindowExecBase(windowExpression, partitionSpec, orderSpec, child) with GpuExec { + + override def supportsColumnar = false + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + throw new IllegalStateException(s"Columnar execution is not supported by $this yet") + } + + // Most code is copied from WindowInPandasExec, except two GPU related calls + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = { + if (partitionSpec.isEmpty) { + // Only show warning when the number of bytes is larger than 100 MiB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") + AllTuples :: Nil + } else { + ClusteredDistribution(partitionSpec) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + + /* + * Helper functions and data structures for window bounds + * + * It contains: + * (1) Total number of window bound indices in the python input row + * (2) Function from frame index to its lower bound column index in the python input row + * (3) Function from frame index to its upper bound column index in the python input row + * (4) Seq from frame index to its window bound type + */ + private type WindowBoundHelpers = (Int, Int => Int, Int => Int, Seq[WindowBoundType]) + + /* + * Enum for window bound types. Used only inside this class. + */ + private sealed case class WindowBoundType(value: String) + private object UnboundedWindow extends WindowBoundType("unbounded") + private object BoundedWindow extends WindowBoundType("bounded") + + private val windowBoundTypeConf = "pandas_window_bound_types" + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + /* + * See [[WindowBoundHelpers]] for details. + */ + private def computeWindowBoundHelpers( + factories: Seq[InternalRow => WindowFunctionFrame] + ): WindowBoundHelpers = { + val functionFrames = factories.map(_(EmptyRow)) + + val windowBoundTypes = functionFrames.map { + case _: UnboundedWindowFunctionFrame => UnboundedWindow + case _: UnboundedFollowingWindowFunctionFrame | + _: SlidingWindowFunctionFrame | + _: UnboundedPrecedingWindowFunctionFrame => BoundedWindow + // It should be impossible to get other types of window function frame here + case frame => throw new RuntimeException(s"Unexpected window function frame $frame.") + } + + val requiredIndices = functionFrames.map { + case _: UnboundedWindowFunctionFrame => 0 + case _ => 2 + } + + val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail + + val boundIndices = requiredIndices.zip(upperBoundIndices).map { case (num, upperBoundIndex) => + if (num == 0) { + // Sentinel values for unbounded window + (-1, -1) + } else { + (upperBoundIndex - 2, upperBoundIndex - 1) + } + } + + def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1 + def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2 + + (requiredIndices.sum, lowerBoundIndex, upperBoundIndex, windowBoundTypes) + } + + protected override def doExecute(): RDD[InternalRow] = { + lazy val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf) + // Unwrap the expressions and factories from the map. + val expressionsWithFrameIndex = + windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap { + case (buffer, frameIndex) => buffer.map(expr => (expr, frameIndex)) + } + + val expressions = expressionsWithFrameIndex.map(_._1) + val expressionIndexToFrameIndex = + expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap + + val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + + // Helper functions + val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) = + computeWindowBoundHelpers(factories) + val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 } + val numFrames = factories.length + + val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold + val spillThreshold = conf.windowExecBufferSpillThreshold + val sessionLocalTimeZone = conf.sessionLocalTimeZone + + // Extract window expressions and window functions + val windowExpressions = expressions.flatMap(_.collect { case e: WindowExpression => e }) + val udfExpressions = windowExpressions.map(_.windowFunction.asInstanceOf[PythonUDF]) + + // We shouldn't be chaining anything here. + // All chained python functions should only contain one function. + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + require(pyFuncs.length == expressions.length) + + val udfWindowBoundTypes = pyFuncs.indices.map(i => + frameWindowBoundTypes(expressionIndexToFrameIndex(i))) + val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf) + + (windowBoundTypeConf -> udfWindowBoundTypes.map(_.value).mkString(","))) + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. This is similar to how other Python UDF node + // handles UDF inputs. + val dataInputs = new ArrayBuffer[Expression] + val dataInputTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (dataInputs.exists(_.semanticEquals(e))) { + dataInputs.indexWhere(_.semanticEquals(e)) + } else { + dataInputs += e + dataInputTypes += e.dataType + dataInputs.length - 1 + } + }.toArray + }.toArray + + // In addition to UDF inputs, we will prepend window bounds for each UDFs. + // For bounded windows, we prepend lower bound and upper bound. For unbounded windows, + // we no not add window bounds. (strictly speaking, we only need to lower or upper bound + // if the window is bounded only on one side, this can be improved in the future) + + // Setting window bounds for each window frames. Each window frame has different bounds so + // each has its own window bound columns. + val windowBoundsInput = factories.indices.flatMap { frameIndex => + if (isBounded(frameIndex)) { + Seq( + BoundReference(lowerBoundIndex(frameIndex), IntegerType, nullable = false), + BoundReference(upperBoundIndex(frameIndex), IntegerType, nullable = false) + ) + } else { + Seq.empty + } + } + + // Setting the window bounds argOffset for each UDF. For UDFs with bounded window, argOffset + // for the UDF is (lowerBoundOffset, upperBoundOffset, inputOffset1, inputOffset2, ...) + // For UDFs with unbounded window, argOffset is (inputOffset1, inputOffset2, ...) + pyFuncs.indices.foreach { exprIndex => + val frameIndex = expressionIndexToFrameIndex(exprIndex) + if (isBounded(frameIndex)) { + argOffsets(exprIndex) = + Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++ + argOffsets(exprIndex).map(_ + windowBoundsInput.length) + } else { + argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + windowBoundsInput.length) + } + } + + val allInputs = windowBoundsInput ++ dataInputs + val allInputTypes = allInputs.map(_.dataType) + + // Start processing. + child.execute().mapPartitions { iter => + val context = TaskContext.get() + + // Get all relevant projections. + val resultProj = createResultProjection(expressions) + val pythonInputProj = UnsafeProjection.create( + allInputs, + windowBoundsInput.map(ref => + AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ child.output + ) + val pythonInputSchema = StructType( + allInputTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + } + ) + val grouping = UnsafeProjection.create(partitionSpec, child.output) + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + context.addTaskCompletionListener[Unit] { _ => + queue.close() + } + + val stream = iter.map { row => + queue.add(row.asInstanceOf[UnsafeRow]) + row + } + + val pythonInput = new Iterator[Iterator[UnsafeRow]] { + + // Manage the stream and the grouping. + var nextRow: UnsafeRow = null + var nextGroup: UnsafeRow = null + var nextRowAvailable: Boolean = false + private[this] def fetchNextRow(): Unit = { + nextRowAvailable = stream.hasNext + if (nextRowAvailable) { + nextRow = stream.next().asInstanceOf[UnsafeRow] + nextGroup = grouping(nextRow) + } else { + nextRow = null + nextGroup = null + } + } + fetchNextRow() + + // Manage the current partition. + val buffer: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + var bufferIterator: Iterator[UnsafeRow] = _ + + val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType)) + + val frames = factories.map(_(indexRow)) + + private[this] def fetchNextPartition(): Unit = { + // Collect all the rows in the current partition. + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() + + // clear last partition + buffer.clear() + + while (nextRowAvailable && nextGroup == currentGroup) { + buffer.add(nextRow) + fetchNextRow() + } + + // Setup the frames. + var i = 0 + while (i < numFrames) { + frames(i).prepare(buffer) + i += 1 + } + + // Setup iteration + rowIndex = 0 + bufferIterator = buffer.generateIterator() + } + + // Iteration + var rowIndex = 0 + + override final def hasNext: Boolean = + (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable + + override final def next(): Iterator[UnsafeRow] = { + // Load the next partition if we need to. + if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { + fetchNextPartition() + } + + val join = new JoinedRow + + bufferIterator.zipWithIndex.map { + case (current, index) => + var frameIndex = 0 + while (frameIndex < numFrames) { + frames(frameIndex).write(index, current) + // If the window is unbounded we don't need to write out window bounds. + if (isBounded(frameIndex)) { + indexRow.setInt( + lowerBoundIndex(frameIndex), frames(frameIndex).currentLowerBound()) + indexRow.setInt( + upperBoundIndex(frameIndex), frames(frameIndex).currentUpperBound()) + } + frameIndex += 1 + } + + pythonInputProj(join(indexRow, current)) + } + } + } + + // Start of GPU things + if (isPythonOnGpuEnabled) { + GpuPythonHelper.injectGpuInfo(pyFuncs, isPythonOnGpuEnabled) + PythonWorkerSemaphore.acquireIfNecessary(TaskContext.get()) + } + // End of GPU things + + val windowFunctionResult = new ArrowPythonRunner( + pyFuncs, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + argOffsets, + pythonInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(pythonInput, context.partitionId(), context) + + val joined = new JoinedRow + + windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, windowOutput) + resultProj(joinedRow) + } + } + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/RowUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/RowUtils.scala new file mode 100644 index 00000000000..42844cd216d --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/RowUtils.scala @@ -0,0 +1,295 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution.python + +import java.io._ + +import com.google.common.io.Closeables + +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.io.NioBufferedFileInputStream +import org.apache.spark.memory.{MemoryConsumer, SparkOutOfMemoryError, TaskMemoryManager} +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.memory.MemoryBlock + +// The whole file is copied from Spark `RowQueue` to expose row queue utils to rapids execs + +/** + * A RowQueue is an FIFO queue for UnsafeRow. + * + * This RowQueue is ONLY designed and used for Python UDF, which has only one writer and only one + * reader, the reader ALWAYS ran behind the writer. See the doc of class BatchEvalPythonExec + * on how it works. + */ +private[sql] trait RowQueue { + + /** + * Add a row to the end of it, returns true iff the row has been added to the queue. + */ + def add(row: UnsafeRow): Boolean + + /** + * Retrieve and remove the first row, returns null if it's empty. + * + * It can only be called after add is called, otherwise it will fail (NPE). + */ + def remove(): UnsafeRow + + /** + * Cleanup all the resources. + */ + def close(): Unit +} + +/** + * A RowQueue that is based on in-memory page. UnsafeRows are appended into it until it's full. + * Another thread could read from it at the same time (behind the writer). + * + * The format of UnsafeRow in page: + * [4 bytes to hold length of record (N)] [N bytes to hold record] [...] + * + * -1 length means end of page. + */ +private[sql] abstract class InMemoryRowQueue(val page: MemoryBlock, numFields: Int) + extends RowQueue { + private val base: AnyRef = page.getBaseObject + private val endOfPage: Long = page.getBaseOffset + page.size + // the first location where a new row would be written + private var writeOffset = page.getBaseOffset + // points to the start of the next row to read + private var readOffset = page.getBaseOffset + private val resultRow = new UnsafeRow(numFields) + + def add(row: UnsafeRow): Boolean = synchronized { + val size = row.getSizeInBytes + if (writeOffset + 4 + size > endOfPage) { + // if there is not enough space in this page to hold the new record + if (writeOffset + 4 <= endOfPage) { + // if there's extra space at the end of the page, store a special "end-of-page" length (-1) + Platform.putInt(base, writeOffset, -1) + } + false + } else { + Platform.putInt(base, writeOffset, size) + Platform.copyMemory(row.getBaseObject, row.getBaseOffset, base, writeOffset + 4, size) + writeOffset += 4 + size + true + } + } + + def remove(): UnsafeRow = synchronized { + assert(readOffset <= writeOffset, "reader should not go beyond writer") + if (readOffset + 4 > endOfPage || Platform.getInt(base, readOffset) < 0) { + null + } else { + val size = Platform.getInt(base, readOffset) + resultRow.pointTo(base, readOffset + 4, size) + readOffset += 4 + size + resultRow + } + } +} + +/** + * A RowQueue that is backed by a file on disk. This queue will stop accepting new rows once any + * reader has begun reading from the queue. + */ +private[sql] case class DiskRowQueue( + file: File, + fields: Int, + serMgr: SerializerManager) extends RowQueue { + + private var out = new DataOutputStream(serMgr.wrapForEncryption( + new BufferedOutputStream(new FileOutputStream(file.toString)))) + private var unreadBytes = 0L + + private var in: DataInputStream = _ + private val resultRow = new UnsafeRow(fields) + + def add(row: UnsafeRow): Boolean = synchronized { + if (out == null) { + // Another thread is reading, stop writing this one + return false + } + out.writeInt(row.getSizeInBytes) + out.write(row.getBytes) + unreadBytes += 4 + row.getSizeInBytes + true + } + + def remove(): UnsafeRow = synchronized { + if (out != null) { + out.close() + out = null + in = new DataInputStream(serMgr.wrapForEncryption( + new NioBufferedFileInputStream(file))) + } + + if (unreadBytes > 0) { + val size = in.readInt() + val bytes = new Array[Byte](size) + in.readFully(bytes) + unreadBytes -= 4 + size + resultRow.pointTo(bytes, size) + resultRow + } else { + null + } + } + + def close(): Unit = synchronized { + Closeables.close(out, true) + out = null + Closeables.close(in, true) + in = null + if (file.exists()) { + file.delete() + } + } +} + +/** + * A RowQueue that has a list of RowQueues, which could be in memory or disk. + * + * HybridRowQueue could be safely appended in one thread, and pulled in another thread in the same + * time. + */ +private[sql] case class HybridRowQueue( + memManager: TaskMemoryManager, + tempDir: File, + numFields: Int, + serMgr: SerializerManager) + extends MemoryConsumer(memManager) with RowQueue { + + // Each buffer should have at least one row + private var queues = new java.util.LinkedList[RowQueue]() + + private var writing: RowQueue = _ + private var reading: RowQueue = _ + + // exposed for testing + private[python] def numQueues(): Int = queues.size() + + def spill(size: Long, trigger: MemoryConsumer): Long = { + if (trigger == this) { + // When it's triggered by itself, it should write upcoming rows into disk instead of copying + // the rows already in the queue. + return 0L + } + var released = 0L + synchronized { + // poll out all the buffers and add them back in the same order to make sure that the rows + // are in correct order. + val newQueues = new java.util.LinkedList[RowQueue]() + while (!queues.isEmpty) { + val queue = queues.remove() + val newQueue = if (!queues.isEmpty && queue.isInstanceOf[InMemoryRowQueue]) { + val diskQueue = createDiskQueue() + var row = queue.remove() + while (row != null) { + diskQueue.add(row) + row = queue.remove() + } + released += queue.asInstanceOf[InMemoryRowQueue].page.size() + queue.close() + diskQueue + } else { + queue + } + newQueues.add(newQueue) + } + queues = newQueues + } + released + } + + private def createDiskQueue(): RowQueue = { + DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields, serMgr) + } + + private def createNewQueue(required: Long): RowQueue = { + val page = try { + allocatePage(required) + } catch { + case _: SparkOutOfMemoryError => + null + } + val buffer = if (page != null) { + new InMemoryRowQueue(page, numFields) { + override def close(): Unit = { + freePage(page) + } + } + } else { + createDiskQueue() + } + + synchronized { + queues.add(buffer) + } + buffer + } + + def add(row: UnsafeRow): Boolean = { + if (writing == null || !writing.add(row)) { + writing = createNewQueue(4 + row.getSizeInBytes) + if (!writing.add(row)) { + throw new SparkException(s"failed to push a row into $writing") + } + } + true + } + + def remove(): UnsafeRow = { + var row: UnsafeRow = null + if (reading != null) { + row = reading.remove() + } + if (row == null) { + if (reading != null) { + reading.close() + } + synchronized { + reading = queues.remove() + } + assert(reading != null, s"queue should not be empty") + row = reading.remove() + assert(row != null, s"$reading should have at least one row") + } + row + } + + def close(): Unit = { + if (reading != null) { + reading.close() + reading = null + } + synchronized { + while (!queues.isEmpty) { + queues.remove().close() + } + } + } +} + +private[sql] object HybridRowQueue { + def apply(taskMemoryMgr: TaskMemoryManager, file: File, fields: Int): HybridRowQueue = { + HybridRowQueue(taskMemoryMgr, file, fields, SparkEnv.get.serializerManager) + } +} From 105961fa9b80670c86dea06c41b31e48605c4475 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 11 Sep 2020 08:28:37 -0500 Subject: [PATCH 2/2] Fix collect time metric in CoalesceBatches (#729) * Fix collect time in CoalesceBatches * Also move totalTime * switch to use if !isdefined Signed-off-by: Thomas Graves * remove extra newline Signed-off-by: Thomas Graves --- .../spark/rapids/GpuCoalesceBatches.scala | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala index 336d9b4f5d3..ecdcf9b9450 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala @@ -160,6 +160,8 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], private val iter = new RemoveEmptyBatchIterator(origIter, numInputBatches) private var onDeck: Option[ColumnarBatch] = None private var batchInitialized: Boolean = false + private var collectMetric: Option[MetricRange] = None + private var totalMetric: Option[MetricRange] = None /** We need to track the sizes of string columns to make sure we don't exceed 2GB */ private val stringFieldIndices: Array[Int] = schema.fields.zipWithIndex @@ -174,7 +176,22 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], Option(TaskContext.get()) .foreach(_.addTaskCompletionListener[Unit](_ => onDeck.foreach(_.close()))) - override def hasNext: Boolean = onDeck.isDefined || iter.hasNext + override def hasNext: Boolean = { + if (!collectMetric.isDefined) { + // use one being not set as indicator that neither are intialized to avoid + // 2 checks or extra initialized variable + collectMetric = Some(new MetricRange(collectTime)) + totalMetric = Some(new MetricRange(totalTime)) + } + val res = onDeck.isDefined || iter.hasNext + if (!res) { + collectMetric.foreach(_.close()) + collectMetric = None + totalMetric.foreach(_.close()) + totalMetric = None + } + res + } /** * Called first to initialize any state needed for a new batch to be created. @@ -236,9 +253,6 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], * @return The coalesced batch */ override def next(): ColumnarBatch = { - - val total = new MetricRange(totalTime) - // reset batch state batchInitialized = false batchRowLimit = 0 @@ -261,7 +275,6 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], numBytes += columnSizes.sum } - val collect = new MetricRange(collectTime) try { // there is a hard limit of 2^31 rows @@ -339,7 +352,8 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], s"and $numBytes bytes") } finally { - collect.close() + collectMetric.foreach(_.close()) + collectMetric = None } val concatRange = new NvtxWithMetrics(s"$opName concat", NvtxColor.CYAN, concatTime) @@ -351,7 +365,8 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], ret } finally { cleanupConcatIsDone() - total.close() + totalMetric.foreach(_.close()) + totalMetric = None } }