diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 70a320fc53b69..85057f37a1817 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -20,7 +20,7 @@ import tempfile from pyspark.errors import AnalysisException -from pyspark.sql.functions import col +from pyspark.sql.functions import col, lit from pyspark.sql.readwriter import DataFrameWriterV2 from pyspark.sql.types import StructType, StructField, StringType from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -181,6 +181,27 @@ def test_insert_into(self): df.write.mode("overwrite").insertInto("test_table", False) self.assertEqual(6, self.spark.sql("select * from test_table").count()) + def test_cached_table(self): + with self.table("test_cached_table_1"): + self.spark.range(10).withColumn( + "value_1", + lit(1), + ).write.saveAsTable("test_cached_table_1") + + with self.table("test_cached_table_2"): + self.spark.range(10).withColumnRenamed("id", "index").withColumn( + "value_2", lit(2) + ).write.saveAsTable("test_cached_table_2") + + df1 = self.spark.read.table("test_cached_table_1") + df2 = self.spark.read.table("test_cached_table_2") + df3 = self.spark.read.table("test_cached_table_1") + + join1 = df1.join(df2, on=df1.id == df2.index).select(df2.index, df2.value_2) + join2 = df3.join(join1, how="left", on=join1.index == df3.id) + + self.assertEqual(join2.columns, ["id", "value_1", "index", "value_2"]) + class ReadwriterV2TestsMixin: def test_api(self): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d8127fe03da4e..1fb5d00bdf39a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1275,16 +1275,29 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val key = ((catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq, finalTimeTravelSpec) - AnalysisContext.get.relationCache.get(key).map(_.transform { - case multi: MultiInstanceRelation => - val newRelation = multi.newInstance() - newRelation.copyTagsFrom(multi) - newRelation - }).orElse { + AnalysisContext.get.relationCache.get(key).map { cache => + val cachedRelation = cache.transform { + case multi: MultiInstanceRelation => + val newRelation = multi.newInstance() + newRelation.copyTagsFrom(multi) + newRelation + } + u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => + val cachedConnectRelation = cachedRelation.clone() + cachedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + cachedConnectRelation + }.getOrElse(cachedRelation) + }.orElse { val table = CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec) val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming) loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) - loaded + u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => + loaded.map { loadedRelation => + val loadedConnectRelation = loadedRelation.clone() + loadedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + loadedConnectRelation + } + }.getOrElse(loaded) } case _ => None }