Skip to content

Commit

Permalink
support date type
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed Oct 10, 2014
1 parent 6f98902 commit 2dfbb5b
Show file tree
Hide file tree
Showing 125 changed files with 813 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
Expand Down Expand Up @@ -77,8 +77,9 @@ object ScalaReflection {
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,20 +220,36 @@ trait HiveTypeCoercion {
case a: BinaryArithmetic if a.right.dataType == StringType =>
a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))

case p: BinaryPredicate if p.left.dataType == StringType
&& p.right.dataType == DateType =>
p.makeCopy(Array(Cast(p.left, DateType), p.right))
case p: BinaryPredicate if p.left.dataType == DateType
&& p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, DateType)))
case p: BinaryPredicate if p.left.dataType == StringType
&& p.right.dataType == TimestampType =>
p.makeCopy(Array(Cast(p.left, TimestampType), p.right))
case p: BinaryPredicate if p.left.dataType == TimestampType
&& p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, TimestampType)))
case p: BinaryPredicate if p.left.dataType == TimestampType
&& p.right.dataType == DateType =>
p.makeCopy(Array(Cast(p.left, DateType), p.right))
case p: BinaryPredicate if p.left.dataType == DateType
&& p.right.dataType == TimestampType =>
p.makeCopy(Array(p.left, Cast(p.right, DateType)))

case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))

case i @ In(a,b) if a.dataType == DateType && b.forall(_.dataType == StringType) =>
i.makeCopy(Array(a,b.map(Cast(_,DateType))))
case i @ In(a,b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
i.makeCopy(Array(a,b.map(Cast(_,TimestampType))))
case i @ In(a,b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) =>
i.makeCopy(Array(a,b.map(Cast(_,DateType))))

case Sum(e) if e.dataType == StringType =>
Sum(Cast(e, DoubleType))
Expand Down Expand Up @@ -283,6 +299,8 @@ trait HiveTypeCoercion {
// Skip if the type is boolean type already. Note that this extra cast should be removed
// by optimizer.SimplifyCasts.
case Cast(e, BooleanType) if e.dataType == BooleanType => e
// DateType should be null if be cast to boolean.
case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType)
// If the data type is not boolean and is being cast boolean, turn it into a comparison
// with the numeric value, i.e. x != 0. This will coerce the type into numeric type.
case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import scala.language.implicitConversions

Expand Down Expand Up @@ -119,6 +119,7 @@ package object dsl {
implicit def floatToLiteral(f: Float) = Literal(f)
implicit def doubleToLiteral(d: Double) = Literal(d)
implicit def stringToLiteral(s: String) = Literal(s)
implicit def dateToLiteral(d: Date) = Literal(d)
implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
implicit def timestampToLiteral(t: Timestamp) = Literal(t)
implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)
Expand Down Expand Up @@ -174,6 +175,9 @@ package object dsl {
/** Creates a new AttributeReference of type string */
def string = AttributeReference(s, StringType, nullable = true)()

/** Creates a new AttributeReference of type date */
def date = AttributeReference(s, DateType, nullable = true)()

/** Creates a new AttributeReference of type decimal */
def decimal = AttributeReference(s, DecimalType, nullable = true)()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp
import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.types._

/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
override def foldable = child.foldable

override def nullable = (child.dataType, dataType) match {
case (StringType, _: NumericType) => true
case (StringType, TimestampType) => true
case (StringType, DateType) => true
case _ => child.nullable
}

Expand All @@ -42,6 +44,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
// UDFToString
private[this] def castToString: Any => Any = child.dataType match {
case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
case DateType => buildCast[Date](_, dateToString)
case TimestampType => buildCast[Timestamp](_, timestampToString)
case _ => buildCast[Any](_, _.toString)
}
Expand All @@ -56,7 +59,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case StringType =>
buildCast[String](_, _.length() != 0)
case TimestampType =>
buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0)
buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0)
case DateType =>
buildCast[Date](_, d => null)
case LongType =>
buildCast[Long](_, _ != 0)
case IntegerType =>
Expand Down Expand Up @@ -95,6 +100,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
buildCast[Short](_, s => new Timestamp(s))
case ByteType =>
buildCast[Byte](_, b => new Timestamp(b))
case DateType =>
buildCast[Date](_, d => Timestamp.valueOf(dateToString(d) + " 00:00:00"))
// TimestampWritable.decimalToTimestamp
case DecimalType =>
buildCast[BigDecimal](_, d => decimalToTimestamp(d))
Expand Down Expand Up @@ -130,7 +137,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
// Converts Timestamp to string according to Hive TimestampWritable convention
private[this] def timestampToString(ts: Timestamp): String = {
val timestampString = ts.toString
val formatted = Cast.threadLocalDateFormat.get.format(ts)
val formatted = Cast.threadLocalTimestampFormat.get.format(ts)

if (timestampString.length > 19 && timestampString.substring(19) != ".0") {
formatted + timestampString.substring(19)
Expand All @@ -139,13 +146,48 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
}
}

// Converts Timestamp to string according to Hive TimestampWritable convention
private[this] def timestampToDateString(ts: Timestamp): String = {
Cast.threadLocalDateFormat.get.format(ts)
}

// DateConverter
private[this] def castToDate: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => if (s.contains(" ")) {
try castToDate(castToTimestamp(s))
catch { case _: java.lang.IllegalArgumentException => null }
} else {
try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null }
})
case TimestampType =>
buildCast[Timestamp](_, t => Date.valueOf(timestampToDateString(t)))
// TimestampWritable.decimalToDate
case _ =>
_ => null
}

// Date cannot be cast to long, according to hive
private[this] def dateToLong(d: Date) = null

// Date cannot be cast to double, according to hive
private[this] def dateToDouble(d: Date) = null

// Converts Timestamp to string according to Hive TimestampWritable convention
private[this] def dateToString(d: Date): String = {
Cast.threadLocalDateFormat.get.format(d)
}

// LongConverter
private[this] def castToLong: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toLong catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1L else 0L)
case DateType =>
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t))
case DecimalType =>
Expand All @@ -154,13 +196,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
}

// IntConverter
private[this] def castToInt: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toInt catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1 else 0)
case DateType =>
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toInt)
case DecimalType =>
Expand All @@ -169,13 +214,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
}

// ShortConverter
private[this] def castToShort: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toShort catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
case DateType =>
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toShort)
case DecimalType =>
Expand All @@ -184,13 +232,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
}

// ByteConverter
private[this] def castToByte: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toByte catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
case DateType =>
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toByte)
case DecimalType =>
Expand All @@ -199,27 +250,33 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
}

// DecimalConverter
private[this] def castToDecimal: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try BigDecimal(s.toDouble) catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0))
case DateType =>
buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
// Note that we lose precision here.
buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
case x: NumericType =>
b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
}

// DoubleConverter
private[this] def castToDouble: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toDouble catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1d else 0d)
case DateType =>
buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToDouble(t))
case DecimalType =>
Expand All @@ -228,13 +285,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
}

// FloatConverter
private[this] def castToFloat: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toFloat catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1f else 0f)
case DateType =>
buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
case DecimalType =>
Expand All @@ -245,17 +305,18 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {

private[this] lazy val cast: Any => Any = dataType match {
case dt if dt == child.dataType => identity[Any]
case StringType => castToString
case BinaryType => castToBinary
case DecimalType => castToDecimal
case StringType => castToString
case BinaryType => castToBinary
case DecimalType => castToDecimal
case DateType => castToDate
case TimestampType => castToTimestamp
case BooleanType => castToBoolean
case ByteType => castToByte
case ShortType => castToShort
case IntegerType => castToInt
case FloatType => castToFloat
case LongType => castToLong
case DoubleType => castToDouble
case BooleanType => castToBoolean
case ByteType => castToByte
case ShortType => castToShort
case IntegerType => castToInt
case FloatType => castToFloat
case LongType => castToLong
case DoubleType => castToDouble
}

override def eval(input: Row): Any = {
Expand All @@ -267,6 +328,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
object Cast {
// `SimpleDateFormat` is not thread-safe.
private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] {
override def initialValue() = {
new SimpleDateFormat("yyyy-MM-dd")
}
}

// `SimpleDateFormat` is not thread-safe.
private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] {
override def initialValue() = {
new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.types._

Expand All @@ -33,6 +33,7 @@ object Literal {
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(d, DecimalType)
case t: Timestamp => Literal(t, TimestampType)
case d: Date => Literal(d, DateType)
case a: Array[Byte] => Literal(a, BinaryType)
case null => Literal(null, NullType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.types

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral}
import scala.reflect.ClassTag
Expand Down Expand Up @@ -250,6 +250,18 @@ case object TimestampType extends NativeType {
}
}

case object DateType extends NativeType {
private[sql] type JvmType = Date

@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }

private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Date, y: Date) = x.compareTo(y)
}

def simpleString: String = "date"
}

abstract class NumericType extends NativeType with PrimitiveType {
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
Expand Down
Loading

0 comments on commit 2dfbb5b

Please sign in to comment.