Skip to content

Commit

Permalink
Avoid showing GpuColumnarToRow transition in plan that does not actua…
Browse files Browse the repository at this point in the history
…lly execute
  • Loading branch information
jlowe committed Oct 16, 2023
1 parent 6307e25 commit 7425ae9
Show file tree
Hide file tree
Showing 11 changed files with 23 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@ class GpuTransitionOverrides extends Rule[SparkPlan] {
private def insertColumnarFromGpu(plan: SparkPlan): SparkPlan = {
if (plan.supportsColumnar && plan.isInstanceOf[GpuExec]) {
GpuBringBackToHost(insertColumnarToGpu(plan))
} else if (plan.isInstanceOf[ColumnarToRowTransition] && plan.isInstanceOf[GpuExec]) {
plan.withNewChildren(plan.children.map(insertColumnarToGpu))
} else {
plan.withNewChildren(plan.children.map(insertColumnarFromGpu))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -59,7 +59,7 @@ case class GpuAtomicCreateTableAsSelectExec(
query: SparkPlan,
properties: Map[String, String],
writeOptions: CaseInsensitiveStringMap,
ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec {
ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition {

override def supportsColumnar: Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, Table, TableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -62,7 +62,7 @@ case class GpuAtomicReplaceTableAsSelectExec(
writeOptions: CaseInsensitiveStringMap,
orCreate: Boolean,
invalidateCache: (TableCatalog, Table, Identifier) => Unit)
extends TableWriteExecHelper with GpuExec {
extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition {

override def supportsColumnar: Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -58,7 +58,7 @@ case class GpuAtomicCreateTableAsSelectExec(
query: SparkPlan,
tableSpec: TableSpec,
writeOptions: CaseInsensitiveStringMap,
ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec {
ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition {

val properties = CatalogV2Util.convertTableProperties(tableSpec)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -63,7 +63,7 @@ case class GpuAtomicReplaceTableAsSelectExec(
writeOptions: CaseInsensitiveStringMap,
orCreate: Boolean,
invalidateCache: (TableCatalog, Table, Identifier) => Unit)
extends TableWriteExecHelper with GpuExec {
extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition {

val properties = CatalogV2Util.convertTableProperties(tableSpec)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -58,7 +58,7 @@ case class GpuAtomicCreateTableAsSelectExec(
query: SparkPlan,
tableSpec: TableSpec,
writeOptions: CaseInsensitiveStringMap,
ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec {
ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition {

val properties = CatalogV2Util.convertTableProperties(tableSpec)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -63,7 +63,7 @@ case class GpuAtomicReplaceTableAsSelectExec(
writeOptions: CaseInsensitiveStringMap,
orCreate: Boolean,
invalidateCache: (TableCatalog, Table, Identifier) => Unit)
extends TableWriteExecHelper with GpuExec {
extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition {

val properties = CatalogV2Util.convertTableProperties(tableSpec)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -55,7 +55,7 @@ case class GpuAtomicCreateTableAsSelectExec(
query: SparkPlan,
tableSpec: TableSpec,
writeOptions: CaseInsensitiveStringMap,
ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec {
ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition {

val properties = CatalogV2Util.convertTableProperties(tableSpec)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -60,7 +60,7 @@ case class GpuAtomicReplaceTableAsSelectExec(
writeOptions: CaseInsensitiveStringMap,
orCreate: Boolean,
invalidateCache: (TableCatalog, Table, Identifier) => Unit)
extends TableWriteExecHelper with GpuExec {
extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition {

val properties = CatalogV2Util.convertTableProperties(tableSpec)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.ColumnarToRowTransition
import org.apache.spark.sql.execution.datasources.v2.V2CreateTableAsSelectBaseExec
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand All @@ -50,7 +51,8 @@ case class GpuAtomicCreateTableAsSelectExec(
query: LogicalPlan,
tableSpec: TableSpec,
writeOptions: Map[String, String],
ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec with GpuExec {
ifNotExists: Boolean)
extends V2CreateTableAsSelectBaseExec with GpuExec with ColumnarToRowTransition {

val properties = CatalogV2Util.convertTableProperties(tableSpec)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.ColumnarToRowTransition
import org.apache.spark.sql.execution.datasources.v2.V2CreateTableAsSelectBaseExec
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand All @@ -55,7 +56,7 @@ case class GpuAtomicReplaceTableAsSelectExec(
writeOptions: Map[String, String],
orCreate: Boolean,
invalidateCache: (TableCatalog, Table, Identifier) => Unit)
extends V2CreateTableAsSelectBaseExec with GpuExec {
extends V2CreateTableAsSelectBaseExec with GpuExec with ColumnarToRowTransition {

val properties = CatalogV2Util.convertTableProperties(tableSpec)

Expand Down

0 comments on commit 7425ae9

Please sign in to comment.