Skip to content

Commit

Permalink
[SPARK-7462] By default retain group by columns in aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed May 8, 2015
1 parent 22ab70e commit 1e6e666
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 145 deletions.
10 changes: 9 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,15 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}
DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
val retainedExprs = groupingExprs.map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}
DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan))
} else {
DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
}
}

/**
Expand Down
6 changes: 6 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ private[spark] object SQLConf {
// See SPARK-6231.
val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = "spark.sql.selfJoinAutoResolveAmbiguity"

// Whether to retain group by columns or not in GroupedData.agg.
val DATAFRAME_RETAIN_GROUP_COLUMNS = "spark.sql.retainGroupColumns"

val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2"

val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI"
Expand Down Expand Up @@ -233,6 +236,9 @@ private[sql] class SQLConf extends Serializable {

private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY, "true").toBoolean

private[spark] def dataFrameRetainGroupColumns: Boolean =
getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean

/** ********************** SQLConf functionality methods ************ */

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types.DecimalType


class DataFrameAggregateSuite extends QueryTest {

test("groupBy") {
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
)
checkAnswer(
testData2.groupBy("a").agg(sum($"b").as("totB")).agg(sum('totB)),
Row(9)
)
checkAnswer(
testData2.groupBy("a").agg(count("*")),
Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
)
checkAnswer(
testData2.groupBy("a").agg(Map("*" -> "count")),
Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
)
checkAnswer(
testData2.groupBy("a").agg(Map("b" -> "sum")),
Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
)

val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d"))
.toDF("key", "value1", "value2", "rest")

checkAnswer(
df1.groupBy("key").min(),
df1.groupBy("key").min("value1", "value2").collect()
)
checkAnswer(
df1.groupBy("key").min("value2"),
Seq(Row("a", 0), Row("b", 4))
)
}

test("spark.sql.retainGroupColumns config") {
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
)

TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "false")
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
Seq(Row(3), Row(3), Row(3))
)
TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "true")
}

test("agg without groups") {
checkAnswer(
testData2.agg(sum('b)),
Row(9)
)
}

test("average") {
checkAnswer(
testData2.agg(avg('a)),
Row(2.0))

checkAnswer(
testData2.agg(avg('a), sumDistinct('a)), // non-partial
Row(2.0, 6.0) :: Nil)

checkAnswer(
decimalData.agg(avg('a)),
Row(new java.math.BigDecimal(2.0)))
checkAnswer(
decimalData.agg(avg('a), sumDistinct('a)), // non-partial
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)

checkAnswer(
decimalData.agg(avg('a cast DecimalType(10, 2))),
Row(new java.math.BigDecimal(2.0)))
// non-partial
checkAnswer(
decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))),
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
}

test("null average") {
checkAnswer(
testData3.agg(avg('b)),
Row(2.0))

checkAnswer(
testData3.agg(avg('b), countDistinct('b)),
Row(2.0, 1))

checkAnswer(
testData3.agg(avg('b), sumDistinct('b)), // non-partial
Row(2.0, 2.0))
}

test("zero average") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(avg('a)),
Row(null))

checkAnswer(
emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial
Row(null, null))
}

test("count") {
assert(testData2.count() === testData2.map(_ => 1).count())

checkAnswer(
testData2.agg(count('a), sumDistinct('a)), // non-partial
Row(6, 6.0))
}

test("null count") {
checkAnswer(
testData3.groupBy('a).agg(count('b)),
Seq(Row(1,0), Row(2, 1))
)

checkAnswer(
testData3.groupBy('a).agg(count('a + 'b)),
Seq(Row(1,0), Row(2, 1))
)

checkAnswer(
testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)),
Row(2, 1, 2, 2, 1)
)

checkAnswer(
testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial
Row(1, 1, 2)
)
}

test("zero count") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
assert(emptyTableData.count() === 0)

checkAnswer(
emptyTableData.agg(count('a), sumDistinct('a)), // non-partial
Row(0, null))
}

test("zero sum") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(sum('a)),
Row(null))
}

test("zero sum distinct") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(sumDistinct('a)),
Row(null))
}

}
142 changes: 0 additions & 142 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import scala.language.postfixOps
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
import org.apache.spark.sql.test.TestSQLContext.implicits._


Expand Down Expand Up @@ -165,48 +164,6 @@ class DataFrameSuite extends QueryTest {
testData.select('key).collect().toSeq)
}

test("groupBy") {
checkAnswer(
testData2.groupBy("a").agg($"a", sum($"b")),
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
)
checkAnswer(
testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)),
Row(9)
)
checkAnswer(
testData2.groupBy("a").agg(col("a"), count("*")),
Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
)
checkAnswer(
testData2.groupBy("a").agg(Map("*" -> "count")),
Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
)
checkAnswer(
testData2.groupBy("a").agg(Map("b" -> "sum")),
Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
)

val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d"))
.toDF("key", "value1", "value2", "rest")

checkAnswer(
df1.groupBy("key").min(),
df1.groupBy("key").min("value1", "value2").collect()
)
checkAnswer(
df1.groupBy("key").min("value2"),
Seq(Row("a", 0), Row("b", 4))
)
}

test("agg without groups") {
checkAnswer(
testData2.agg(sum('b)),
Row(9)
)
}

test("convert $\"attribute name\" into unresolved attribute") {
checkAnswer(
testData.where($"key" === lit(1)).select($"value"),
Expand Down Expand Up @@ -303,105 +260,6 @@ class DataFrameSuite extends QueryTest {
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
}

test("average") {
checkAnswer(
testData2.agg(avg('a)),
Row(2.0))

checkAnswer(
testData2.agg(avg('a), sumDistinct('a)), // non-partial
Row(2.0, 6.0) :: Nil)

checkAnswer(
decimalData.agg(avg('a)),
Row(new java.math.BigDecimal(2.0)))
checkAnswer(
decimalData.agg(avg('a), sumDistinct('a)), // non-partial
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)

checkAnswer(
decimalData.agg(avg('a cast DecimalType(10, 2))),
Row(new java.math.BigDecimal(2.0)))
// non-partial
checkAnswer(
decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))),
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
}

test("null average") {
checkAnswer(
testData3.agg(avg('b)),
Row(2.0))

checkAnswer(
testData3.agg(avg('b), countDistinct('b)),
Row(2.0, 1))

checkAnswer(
testData3.agg(avg('b), sumDistinct('b)), // non-partial
Row(2.0, 2.0))
}

test("zero average") {
checkAnswer(
emptyTableData.agg(avg('a)),
Row(null))

checkAnswer(
emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial
Row(null, null))
}

test("count") {
assert(testData2.count() === testData2.map(_ => 1).count())

checkAnswer(
testData2.agg(count('a), sumDistinct('a)), // non-partial
Row(6, 6.0))
}

test("null count") {
checkAnswer(
testData3.groupBy('a).agg('a, count('b)),
Seq(Row(1,0), Row(2, 1))
)

checkAnswer(
testData3.groupBy('a).agg('a, count('a + 'b)),
Seq(Row(1,0), Row(2, 1))
)

checkAnswer(
testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)),
Row(2, 1, 2, 2, 1)
)

checkAnswer(
testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial
Row(1, 1, 2)
)
}

test("zero count") {
assert(emptyTableData.count() === 0)

checkAnswer(
emptyTableData.agg(count('a), sumDistinct('a)), // non-partial
Row(0, null))
}

test("zero sum") {
checkAnswer(
emptyTableData.agg(sum('a)),
Row(null))
}

test("zero sum distinct") {
checkAnswer(
emptyTableData.agg(sumDistinct('a)),
Row(null))
}

test("except") {
checkAnswer(
lowerCaseData.except(upperCaseData),
Expand Down
2 changes: 0 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ object TestData {
TestData3(2, Some(2)) :: Nil).toDF()
testData3.registerTempTable("testData3")

val emptyTableData = logical.LocalRelation($"a".int, $"b".int)

case class UpperCaseData(N: Int, L: String)
val upperCaseData =
TestSQLContext.sparkContext.parallelize(
Expand Down

0 comments on commit 1e6e666

Please sign in to comment.