Skip to content

Commit

Permalink
[SPARK-33071][SPARK-33536][SQL] Avoid changing dataset_id of LogicalP…
Browse files Browse the repository at this point in the history
…lan in join() to not break DetectAmbiguousSelfJoin

### What changes were proposed in this pull request?

Currently, `join()` uses `withPlan(logicalPlan)` for convenient to call some Dataset functions. But it leads to the `dataset_id` inconsistent between the `logicalPlan` and the original `Dataset`(because `withPlan(logicalPlan)` will create a new Dataset with the new id and reset the `dataset_id` with the new id of the `logicalPlan`). As a result, it breaks the rule `DetectAmbiguousSelfJoin`.

In this PR, we propose to drop the usage of `withPlan` but use the `logicalPlan` directly so its `dataset_id` doesn't change.

Besides, this PR also removes related metadata (`DATASET_ID_KEY`,  `COL_POS_KEY`) when an `Alias` tries to construct its own metadata. Because the `Alias` is no longer a reference column after converting to an `Attribute`.  To achieve that, we add a new field, `deniedMetadataKeys`, to indicate the metadata that needs to be removed.

### Why are the changes needed?

For the query below, it returns the wrong result while it should throws ambiguous self join exception instead:

```scala
val emp1 = Seq[TestData](
  TestData(1, "sales"),
  TestData(2, "personnel"),
  TestData(3, "develop"),
  TestData(4, "IT")).toDS()
val emp2 = Seq[TestData](
  TestData(1, "sales"),
  TestData(2, "personnel"),
  TestData(3, "develop")).toDS()
val emp3 = emp1.join(emp2, emp1("key") === emp2("key")).select(emp1("*"))
emp1.join(emp3, emp1.col("key") === emp3.col("key"), "left_outer")
  .select(emp1.col("*"), emp3.col("key").as("e2")).show()

// wrong result
+---+---------+---+
|key|    value| e2|
+---+---------+---+
|  1|    sales|  1|
|  2|personnel|  2|
|  3|  develop|  3|
|  4|       IT|  4|
+---+---------+---+
```
This PR fixes the wrong behaviour.

### Does this PR introduce _any_ user-facing change?

Yes, users hit the exception instead of the wrong result after this PR.

### How was this patch tested?

Added a new unit test.

Closes #30488 from Ngone51/fix-self-join.

Authored-by: yi.wu <yi.wu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Ngone51 authored and cloud-fan committed Dec 2, 2020
1 parent 91182d6 commit a082f46
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ trait AliasHelper {
a.copy(child = trimAliases(a.child))(
exprId = a.exprId,
qualifier = a.qualifier,
explicitMetadata = Some(a.metadata))
explicitMetadata = Some(a.metadata),
deniedMetadataKeys = a.deniedMetadataKeys)
case a: MultiAlias =>
a.copy(child = trimAliases(a.child))
case other => trimAliases(other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,14 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn
* fully qualified way. Consider the examples tableName.name, subQueryAlias.name.
* tableName and subQueryAlias are possible qualifiers.
* @param explicitMetadata Explicit metadata associated with this alias that overwrites child's.
* @param deniedMetadataKeys Keys of metadata entries that are supposed to be removed when
* inheriting the metadata from the child.
*/
case class Alias(child: Expression, name: String)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifier: Seq[String] = Seq.empty,
val explicitMetadata: Option[Metadata] = None)
val explicitMetadata: Option[Metadata] = None,
val deniedMetadataKeys: Seq[String] = Seq.empty)
extends UnaryExpression with NamedExpression {

// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
Expand All @@ -167,7 +170,11 @@ case class Alias(child: Expression, name: String)(
override def metadata: Metadata = {
explicitMetadata.getOrElse {
child match {
case named: NamedExpression => named.metadata
case named: NamedExpression =>
val builder = new MetadataBuilder().withMetadata(named.metadata)
deniedMetadataKeys.foreach(builder.remove)
builder.build()

case _ => Metadata.empty
}
}
Expand All @@ -194,7 +201,7 @@ case class Alias(child: Expression, name: String)(
override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix$delaySuffix"

override protected final def otherCopyArgs: Seq[AnyRef] = {
exprId :: qualifier :: explicitMetadata :: Nil
exprId :: qualifier :: explicitMetadata :: deniedMetadataKeys :: Nil
}

override def hashCode(): Int = {
Expand All @@ -205,7 +212,7 @@ case class Alias(child: Expression, name: String)(
override def equals(other: Any): Boolean = other match {
case a: Alias =>
name == a.name && exprId == a.exprId && child == a.child && qualifier == a.qualifier &&
explicitMetadata == a.explicitMetadata
explicitMetadata == a.explicitMetadata && deniedMetadataKeys == a.deniedMetadataKeys
case _ => false
}

Expand Down
5 changes: 4 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,10 @@ class Column(val expr: Expression) extends Logging {
* @since 2.0.0
*/
def name(alias: String): Column = withExpr {
Alias(normalizedExpr(), alias)()
// SPARK-33536: The Alias is no longer a column reference after converting to an attribute.
// These denied metadata keys are used to strip the column reference related metadata for
// the Alias. So it won't be caught as a column reference in DetectAmbiguousSelfJoin.
Alias(expr, alias)(deniedMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY))
}

/**
Expand Down
39 changes: 23 additions & 16 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ class Dataset[T] private[sql](
case _ =>
queryExecution.analyzed
}
if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) {
if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) &&
plan.getTagValue(Dataset.DATASET_ID_TAG).isEmpty) {
plan.setTagValue(Dataset.DATASET_ID_TAG, id)
}
plan
Expand Down Expand Up @@ -259,15 +260,16 @@ class Dataset[T] private[sql](
private[sql] def resolve(colName: String): NamedExpression = {
val resolver = sparkSession.sessionState.analyzer.resolver
queryExecution.analyzed.resolveQuoted(colName, resolver)
.getOrElse {
val fields = schema.fieldNames
val extraMsg = if (fields.exists(resolver(_, colName))) {
s"; did you mean to quote the `$colName` column?"
} else ""
val fieldsStr = fields.mkString(", ")
val errorMsg = s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}"""
throw new AnalysisException(errorMsg)
}
.getOrElse(throw resolveException(colName, schema.fieldNames))
}

private def resolveException(colName: String, fields: Array[String]): AnalysisException = {
val extraMsg = if (fields.exists(sparkSession.sessionState.analyzer.resolver(_, colName))) {
s"; did you mean to quote the `$colName` column?"
} else ""
val fieldsStr = fields.mkString(", ")
val errorMsg = s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}"""
new AnalysisException(errorMsg)
}

private[sql] def numericColumns: Seq[Expression] = {
Expand Down Expand Up @@ -1083,26 +1085,31 @@ class Dataset[T] private[sql](
}

// If left/right have no output set intersection, return the plan.
val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed
val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed
val lanalyzed = this.queryExecution.analyzed
val ranalyzed = right.queryExecution.analyzed
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
return withPlan(plan)
}

// Otherwise, find the trivially true predicates and automatically resolves them to both sides.
// By the time we get here, since we have already run analysis, all attributes should've been
// resolved and become AttributeReference.
val resolver = sparkSession.sessionState.analyzer.resolver
val cond = plan.condition.map { _.transform {
case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference)
if a.sameRef(b) =>
catalyst.expressions.EqualTo(
withPlan(plan.left).resolve(a.name),
withPlan(plan.right).resolve(b.name))
plan.left.resolveQuoted(a.name, resolver)
.getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)),
plan.right.resolveQuoted(b.name, resolver)
.getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames)))
case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference)
if a.sameRef(b) =>
catalyst.expressions.EqualNullSafe(
withPlan(plan.left).resolve(a.name),
withPlan(plan.right).resolve(b.name))
plan.left.resolveQuoted(a.name, resolver)
.getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)),
plan.right.resolveQuoted(b.name, resolver)
.getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames)))
}}

withPlan {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{count, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData.TestData

class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
import testImplicits._
Expand Down Expand Up @@ -219,4 +220,32 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
Seq((1, 2), (1, 2), (2, 4), (2, 4)).map(Row.fromTuple))
}
}

test("SPARK-33071/SPARK-33536: Avoid changing dataset_id of LogicalPlan in join() " +
"to not break DetectAmbiguousSelfJoin") {
val emp1 = Seq[TestData](
TestData(1, "sales"),
TestData(2, "personnel"),
TestData(3, "develop"),
TestData(4, "IT")).toDS()
val emp2 = Seq[TestData](
TestData(1, "sales"),
TestData(2, "personnel"),
TestData(3, "develop")).toDS()
val emp3 = emp1.join(emp2, emp1("key") === emp2("key")).select(emp1("*"))
assertAmbiguousSelfJoin(emp1.join(emp3, emp1.col("key") === emp3.col("key"),
"left_outer").select(emp1.col("*"), emp3.col("key").as("e2")))
}

test("df.show() should also not change dataset_id of LogicalPlan") {
val df = Seq[TestData](
TestData(1, "sales"),
TestData(2, "personnel"),
TestData(3, "develop"),
TestData(4, "IT")).toDF()
val ds_id1 = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG)
df.show(0)
val ds_id2 = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG)
assert(ds_id1 === ds_id2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,9 @@ class ColumnarBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean
class ColumnarAlias(child: ColumnarExpression, name: String)(
override val exprId: ExprId = NamedExpression.newExprId,
override val qualifier: Seq[String] = Seq.empty,
override val explicitMetadata: Option[Metadata] = None)
extends Alias(child, name)(exprId, qualifier, explicitMetadata)
override val explicitMetadata: Option[Metadata] = None,
override val deniedMetadataKeys: Seq[String] = Seq.empty)
extends Alias(child, name)(exprId, qualifier, explicitMetadata, deniedMetadataKeys)
with ColumnarExpression {

override def columnarEval(batch: ColumnarBatch): Any = child.columnarEval(batch)
Expand Down Expand Up @@ -711,7 +712,7 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
def replaceWithColumnarExpression(exp: Expression): ColumnarExpression = exp match {
case a: Alias =>
new ColumnarAlias(replaceWithColumnarExpression(a.child),
a.name)(a.exprId, a.qualifier, a.explicitMetadata)
a.name)(a.exprId, a.qualifier, a.explicitMetadata, a.deniedMetadataKeys)
case att: AttributeReference =>
new ColumnarAttributeReference(att.name, att.dataType, att.nullable,
att.metadata)(att.exprId, att.qualifier)
Expand Down

0 comments on commit a082f46

Please sign in to comment.