diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala index ec47875754a6f..c61eb68db5bfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala @@ -89,7 +89,8 @@ trait AliasHelper { a.copy(child = trimAliases(a.child))( exprId = a.exprId, qualifier = a.qualifier, - explicitMetadata = Some(a.metadata)) + explicitMetadata = Some(a.metadata), + deniedMetadataKeys = a.deniedMetadataKeys) case a: MultiAlias => a.copy(child = trimAliases(a.child)) case other => trimAliases(other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 2abd9d7bb4423..22aabd3c6b30b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -143,11 +143,14 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * fully qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. + * @param deniedMetadataKeys Keys of metadata entries that are supposed to be removed when + * inheriting the metadata from the child. */ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, val qualifier: Seq[String] = Seq.empty, - val explicitMetadata: Option[Metadata] = None) + val explicitMetadata: Option[Metadata] = None, + val deniedMetadataKeys: Seq[String] = Seq.empty) extends UnaryExpression with NamedExpression { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) @@ -167,7 +170,11 @@ case class Alias(child: Expression, name: String)( override def metadata: Metadata = { explicitMetadata.getOrElse { child match { - case named: NamedExpression => named.metadata + case named: NamedExpression => + val builder = new MetadataBuilder().withMetadata(named.metadata) + deniedMetadataKeys.foreach(builder.remove) + builder.build() + case _ => Metadata.empty } } @@ -194,7 +201,7 @@ case class Alias(child: Expression, name: String)( override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix$delaySuffix" override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifier :: explicitMetadata :: Nil + exprId :: qualifier :: explicitMetadata :: deniedMetadataKeys :: Nil } override def hashCode(): Int = { @@ -205,7 +212,7 @@ case class Alias(child: Expression, name: String)( override def equals(other: Any): Boolean = other match { case a: Alias => name == a.name && exprId == a.exprId && child == a.child && qualifier == a.qualifier && - explicitMetadata == a.explicitMetadata + explicitMetadata == a.explicitMetadata && deniedMetadataKeys == a.deniedMetadataKeys case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 95134d9111593..86ba81340272b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1164,7 +1164,10 @@ class Column(val expr: Expression) extends Logging { * @since 2.0.0 */ def name(alias: String): Column = withExpr { - Alias(normalizedExpr(), alias)() + // SPARK-33536: The Alias is no longer a column reference after converting to an attribute. + // These denied metadata keys are used to strip the column reference related metadata for + // the Alias. So it won't be caught as a column reference in DetectAmbiguousSelfJoin. + Alias(expr, alias)(deniedMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2c38a65ac2106..0716043bcf660 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -231,7 +231,8 @@ class Dataset[T] private[sql]( case _ => queryExecution.analyzed } - if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) { + if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) && + plan.getTagValue(Dataset.DATASET_ID_TAG).isEmpty) { plan.setTagValue(Dataset.DATASET_ID_TAG, id) } plan @@ -259,15 +260,16 @@ class Dataset[T] private[sql]( private[sql] def resolve(colName: String): NamedExpression = { val resolver = sparkSession.sessionState.analyzer.resolver queryExecution.analyzed.resolveQuoted(colName, resolver) - .getOrElse { - val fields = schema.fieldNames - val extraMsg = if (fields.exists(resolver(_, colName))) { - s"; did you mean to quote the `$colName` column?" - } else "" - val fieldsStr = fields.mkString(", ") - val errorMsg = s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}""" - throw new AnalysisException(errorMsg) - } + .getOrElse(throw resolveException(colName, schema.fieldNames)) + } + + private def resolveException(colName: String, fields: Array[String]): AnalysisException = { + val extraMsg = if (fields.exists(sparkSession.sessionState.analyzer.resolver(_, colName))) { + s"; did you mean to quote the `$colName` column?" + } else "" + val fieldsStr = fields.mkString(", ") + val errorMsg = s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}""" + new AnalysisException(errorMsg) } private[sql] def numericColumns: Seq[Expression] = { @@ -1083,8 +1085,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed - val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed + val lanalyzed = this.queryExecution.analyzed + val ranalyzed = right.queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -1092,17 +1094,22 @@ class Dataset[T] private[sql]( // Otherwise, find the trivially true predicates and automatically resolves them to both sides. // By the time we get here, since we have already run analysis, all attributes should've been // resolved and become AttributeReference. + val resolver = sparkSession.sessionState.analyzer.resolver val cond = plan.condition.map { _.transform { case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => catalyst.expressions.EqualTo( - withPlan(plan.left).resolve(a.name), - withPlan(plan.right).resolve(b.name)) + plan.left.resolveQuoted(a.name, resolver) + .getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)), + plan.right.resolveQuoted(b.name, resolver) + .getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames))) case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => catalyst.expressions.EqualNullSafe( - withPlan(plan.left).resolve(a.name), - withPlan(plan.right).resolve(b.name)) + plan.left.resolveQuoted(a.name, resolver) + .getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)), + plan.right.resolveQuoted(b.name, resolver) + .getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames))) }} withPlan { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 3b3b54f75da57..50846d9d12b97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{count, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.test.SQLTestData.TestData class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -219,4 +220,32 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { Seq((1, 2), (1, 2), (2, 4), (2, 4)).map(Row.fromTuple)) } } + + test("SPARK-33071/SPARK-33536: Avoid changing dataset_id of LogicalPlan in join() " + + "to not break DetectAmbiguousSelfJoin") { + val emp1 = Seq[TestData]( + TestData(1, "sales"), + TestData(2, "personnel"), + TestData(3, "develop"), + TestData(4, "IT")).toDS() + val emp2 = Seq[TestData]( + TestData(1, "sales"), + TestData(2, "personnel"), + TestData(3, "develop")).toDS() + val emp3 = emp1.join(emp2, emp1("key") === emp2("key")).select(emp1("*")) + assertAmbiguousSelfJoin(emp1.join(emp3, emp1.col("key") === emp3.col("key"), + "left_outer").select(emp1.col("*"), emp3.col("key").as("e2"))) + } + + test("df.show() should also not change dataset_id of LogicalPlan") { + val df = Seq[TestData]( + TestData(1, "sales"), + TestData(2, "personnel"), + TestData(3, "develop"), + TestData(4, "IT")).toDF() + val ds_id1 = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG) + df.show(0) + val ds_id2 = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG) + assert(ds_id1 === ds_id2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 12abd31b99e93..f02d2041dd7f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -573,8 +573,9 @@ class ColumnarBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean class ColumnarAlias(child: ColumnarExpression, name: String)( override val exprId: ExprId = NamedExpression.newExprId, override val qualifier: Seq[String] = Seq.empty, - override val explicitMetadata: Option[Metadata] = None) - extends Alias(child, name)(exprId, qualifier, explicitMetadata) + override val explicitMetadata: Option[Metadata] = None, + override val deniedMetadataKeys: Seq[String] = Seq.empty) + extends Alias(child, name)(exprId, qualifier, explicitMetadata, deniedMetadataKeys) with ColumnarExpression { override def columnarEval(batch: ColumnarBatch): Any = child.columnarEval(batch) @@ -711,7 +712,7 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] { def replaceWithColumnarExpression(exp: Expression): ColumnarExpression = exp match { case a: Alias => new ColumnarAlias(replaceWithColumnarExpression(a.child), - a.name)(a.exprId, a.qualifier, a.explicitMetadata) + a.name)(a.exprId, a.qualifier, a.explicitMetadata, a.deniedMetadataKeys) case att: AttributeReference => new ColumnarAttributeReference(att.name, att.dataType, att.nullable, att.metadata)(att.exprId, att.qualifier)