Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Apr 28, 2020
1 parent 7216511 commit 663fdc7
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,12 @@ trait ExpressionWithRandomSeed {
""",
since = "1.5.0")
// scalastyle:on line.size.limit
case class Rand(child: Expression) extends RDG with ExpressionWithRandomSeed {
case class Rand(child: Expression, useRandSeed: Boolean = false)
extends RDG with ExpressionWithRandomSeed {

def this() = this(Literal(Utils.random.nextLong(), LongType))
def this() = this(Literal(Utils.random.nextLong(), LongType), true)

def this(child: Expression) = this(child, false)

override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType))

Expand All @@ -103,7 +106,10 @@ case class Rand(child: Expression) extends RDG with ExpressionWithRandomSeed {

override def freshCopy(): Rand = Rand(child)

override def sql: String = "rand()"
override def flatArguments: Iterator[Any] = Iterator(child)
override def sql: String = {
s"rand(${if (useRandSeed) "" else child.sql})"
}
}

object Rand {
Expand All @@ -128,9 +134,12 @@ object Rand {
""",
since = "1.5.0")
// scalastyle:on line.size.limit
case class Randn(child: Expression) extends RDG with ExpressionWithRandomSeed {
case class Randn(child: Expression, useRandSeed: Boolean = false)
extends RDG with ExpressionWithRandomSeed {

def this() = this(Literal(Utils.random.nextLong(), LongType))
def this() = this(Literal(Utils.random.nextLong(), LongType), true)

def this(child: Expression) = this(child, false)

override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType))

Expand All @@ -148,7 +157,10 @@ case class Randn(child: Expression) extends RDG with ExpressionWithRandomSeed {

override def freshCopy(): Randn = Randn(child)

override def sql: String = "randn()"
override def flatArguments: Iterator[Any] = Iterator(child)
override def sql: String = {
s"randn(${if (useRandSeed) "" else child.sql})"
}
}

object Randn {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,11 @@ class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Rand(5419823303878592871L), 0.7145363364564755)
checkEvaluation(Randn(5419823303878592871L), 0.7816815274533012)
}

test("Do not display the seed of rand/randn with no argument in output schema") {
assert(Rand(Literal(1L), true).sql === "rand()")
assert(Randn(Literal(1L), true).sql === "randn()")
assert(Rand(Literal(1L), false).sql === "rand(1L)")
assert(Randn(Literal(1L), false).sql === "randn(1L)")
}
}
22 changes: 22 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3425,6 +3425,28 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
assert(SQLConf.get.getConf(SQLConf.CODEGEN_FALLBACK) === true)
}
}

test("Do not display the seed of rand/randn with no argument in output schema") {
def checkIfSeedExistsInExplain(df: DataFrame): Unit = {
val output = new java.io.ByteArrayOutputStream()
Console.withOut(output) {
df.explain()
}
output.toString.matches("""randn?\(-?[0-9]+\)""")
}
val df1 = sql("SELECT rand()")
assert(df1.schema.head.name === "rand()")
checkIfSeedExistsInExplain(df1)
val df2 = sql("SELECT rand(1L)")
assert(df2.schema.head.name === "rand(1)")
checkIfSeedExistsInExplain(df2)
val df3 = sql("SELECT randn()")
assert(df3.schema.head.name === "randn()")
checkIfSeedExistsInExplain(df1)
val df4 = sql("SELECT randn(1L)")
assert(df4.schema.head.name === "randn(1)")
checkIfSeedExistsInExplain(df2)
}
}

case class Foo(bar: Option[String])

0 comments on commit 663fdc7

Please sign in to comment.