Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypedColumn#year and LocalDateTime generator #228

Merged
merged 10 commits into from
Jan 31, 2018
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package frameless
package functions

import org.apache.spark.sql.Column
import org.apache.spark.sql.{functions => untyped}
import org.apache.spark.sql.{Column, functions => untyped}

import scala.util.matching.Regex

Expand Down Expand Up @@ -279,4 +278,13 @@ trait NonAggregateFunctions {
def upper[T](str: TypedColumn[T, String]): TypedColumn[T, String] = {
new TypedColumn[T, String](untyped.upper(str.untyped))
}
}

/**
* Non-Aggregate function: Extracts the year as an integer from a given date/timestamp/string.
*
* apache/spark
*/
def year[T](col: TypedColumn[T, String]): TypedColumn[T, Int] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably TypedColumn[T, Option[Int]]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree here. This should fail when there is years to be extracted. Is Spark returns null in this case, then this should be encoded as @frosforever suggests. Should be a trivial change.

new TypedColumn[T, Int](untyped.year(col.untyped))
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package frameless
package functions
import java.time.{LocalDateTime => JavaLocalDateTime}

import frameless.functions.nonAggregate._
import org.apache.spark.sql.{ Column, Encoder }
import org.scalacheck.Gen
import org.apache.spark.sql.{Column, Encoder, functions => untyped}
import org.scalacheck.Prop._
import org.apache.spark.sql.{ functions => untyped }
import org.scalacheck.{Gen, Prop}

class NonAggregateFunctionsTests extends TypedDatasetSuite {

Expand Down Expand Up @@ -611,23 +612,46 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(stringFuncProp(upper, untyped.upper))
}

def stringFuncProp[A : Encoder](strFunc: TypedColumn[X1[String], String] => TypedColumn[X1[String], A], sparkFunc: Column => Column) = {
test("year") {
val spark = session
import spark.implicits._

check(dateTimeStringFuncProp(year, untyped.year))
}

def stringFuncProp[A: Encoder](strFunc: TypedColumn[X1[String], String] => TypedColumn[X1[String], A],
sparkFunc: Column => Column): Prop = {
forAll { values: List[X1[String]] =>
val ds = TypedDataset.create(values)

val sparkResult: List[A] = ds.toDF()
.select(sparkFunc(untyped.col("a")))
.map(_.getAs[A](0))
.collect()
.toList
funcProp(ds)(strFunc, sparkFunc)
}
}

val typed: List[A] = ds
.select(strFunc(ds[String]('a)))
.collect()
.run()
.toList
def dateTimeStringFuncProp[A: Encoder](strFunc: TypedColumn[X1[String], String] => TypedColumn[X1[String], A],
sparkFunc: Column => Column): Prop =
forAll { values: List[JavaLocalDateTime] =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if it's not a nicely formatted date string?

val ds = TypedDataset.create(values.map(v => X1[String](v.format(dateTimeFormatter))))

typed ?= sparkResult
funcProp(ds)(strFunc, sparkFunc)
}

def funcProp[A: Encoder](ds: TypedDataset[X1[String]])
(strFunc: TypedColumn[X1[String], String] => TypedColumn[X1[String], A],
sparkFunc: Column => Column): Prop = {
val sparkResult = ds.toDF()
.select(sparkFunc(untyped.col("a")))
.map(_.getAs[A](0))
.collect()
.toList

val typed = ds
.select(strFunc(ds[String]('a)))
.collect()
.run()
.toList

typed ?= sparkResult
}
}

}
14 changes: 14 additions & 0 deletions dataset/src/test/scala/frameless/package.scala
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import java.time.format.DateTimeFormatter
import java.time.{LocalDateTime => JavaLocalDateTime}

import org.scalacheck.{Arbitrary, Gen}

package object frameless {
Expand Down Expand Up @@ -37,4 +40,15 @@ package object frameless {
} yield new UdtEncodedClass(int, doubles.toArray)
}

val dateTimeFormatter: DateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm")

implicit val localDateArb: Arbitrary[JavaLocalDateTime] = Arbitrary {
for {
year <- Gen.chooseNum(1900, 2027)
month <- Gen.chooseNum(1, 12)
dayOfMonth <- Gen.chooseNum(1, 28)
hour <- Gen.chooseNum(1, 23)
minute <- Gen.chooseNum(1, 59)
} yield JavaLocalDateTime.of(year, month, dayOfMonth, hour, minute)
}
}