Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypedColumn#year and LocalDateTime generator #228

Merged
merged 10 commits into from
Jan 31, 2018
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package frameless
package functions

import org.apache.spark.sql.Column
import org.apache.spark.sql.{functions => untyped}
import org.apache.spark.sql.{Column, functions => untyped}

import scala.util.matching.Regex

Expand Down Expand Up @@ -318,4 +317,14 @@ trait NonAggregateFunctions {
*/
def upper[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] =
str.typed(untyped.upper(str.untyped))
}

/**
* Non-Aggregate function: Extracts the year as an integer from a given date/timestamp/string.
*
* Differs from `Column#year` by wrapping it's result into an `Option`.
*
* apache/spark
*/
def year[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] =
str.typed[Option[Int]](untyped.year(str.untyped))
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import java.io.File

import frameless.functions.nonAggregate._
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{Encoder, SaveMode, functions => untyped}
import org.scalacheck.Gen
import org.apache.spark.sql.{Encoder, Row, SaveMode, functions => untyped}
import org.scalacheck.Prop._
import org.scalacheck.{Gen, Prop}

class NonAggregateFunctionsTests extends TypedDatasetSuite {
val testTempFiles = "target/testoutput"
Expand All @@ -33,7 +33,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.map(_.getAs[B](0))
.collect().toList


val typedDS = TypedDataset.create(values)
val col = typedDS('a)
val res = typedDS
Expand Down Expand Up @@ -98,7 +97,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.map(DoubleBehaviourUtils.nanNullHandler)
.collect().toList


val typedDS = TypedDataset.create(values)
val res = typedDS
.select(acos(typedDS('a)))
Expand All @@ -119,7 +117,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[Double] _))
}


/*
* Currently not all Collection types play nice with the Encoders.
* This test needs to be readressed and Set readded to the Collection Typeclass once these issues are resolved.
Expand Down Expand Up @@ -174,7 +171,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.map(_.getAs[Boolean](0))
.collect().toList


val typedDS = TypedDataset.create(List(X1(values)))
val res = typedDS
.select(arrayContains(typedDS('a), contained))
Expand All @@ -185,7 +181,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
res ?= resCompare
}


check(
forAll(
Gen.listOfN(listLength, Gen.choose(0,100)),
Expand Down Expand Up @@ -286,7 +281,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
res ?= resCompare
}


check(forAll(prop[Int] _))
check(forAll(prop[Long] _))
check(forAll(prop[Short] _))
Expand Down Expand Up @@ -375,7 +369,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
(res ?= resCompare).&&(aggrTyped ?= aggrSpark)
}


check(forAll(prop[Int] _))
check(forAll(prop[Long] _))
check(forAll(prop[Short] _))
Expand Down Expand Up @@ -660,9 +653,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {

test("concat for TypedAggregate") {
val spark = session
import spark.implicits._

import frameless.functions.aggregate._
import spark.implicits._
val pairs = for {
y <- Gen.alphaStr
x <- Gen.nonEmptyListOf(X2(y, y))
Expand Down Expand Up @@ -708,9 +700,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {

test("concat_ws for TypedAggregate") {
val spark = session
import spark.implicits._

import frameless.functions.aggregate._
import spark.implicits._
val pairs = for {
y <- Gen.alphaStr
x <- Gen.listOfN(10, X2(y, y))
Expand Down Expand Up @@ -1037,4 +1028,35 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
// This fails due to issue #239
//check(forAll(prop[Option[Vector[Boolean]], Long] _))
}
}

test("year") {
val spark = session
import spark.implicits._

val nullHandler: Row => Option[Int] = _.get(0) match {
case i: Int => Some(i)
case _ => None
}

def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = {
val ds = TypedDataset.create(data)

val sparkResult = ds.toDF()
.select(untyped.year($"a"))
.map(nullHandler)
.collect()
.toList

val typed = ds
.select(year(ds[String]('a)))
.collect()
.run()
.toList

typed ?= sparkResult
}

check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply))))
check(forAll(prop _))
}
}
22 changes: 22 additions & 0 deletions dataset/src/test/scala/frameless/package.scala
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import java.time.format.DateTimeFormatter
import java.time.{LocalDateTime => JavaLocalDateTime}

import org.scalacheck.{Arbitrary, Gen}

package object frameless {
Expand Down Expand Up @@ -39,5 +42,24 @@ package object frameless {
} yield new UdtEncodedClass(int, doubles.toArray)
}

val dateTimeFormatter: DateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm")

implicit val localDateArb: Arbitrary[JavaLocalDateTime] = Arbitrary {
for {
year <- Gen.chooseNum(1900, 2027)
month <- Gen.chooseNum(1, 12)
dayOfMonth <- Gen.chooseNum(1, 28)
hour <- Gen.chooseNum(1, 23)
minute <- Gen.chooseNum(1, 59)
} yield JavaLocalDateTime.of(year, month, dayOfMonth, hour, minute)
}

/** LocalDateTime String Generator to test time related Spark functions */
val dateTimeStringGen: Gen[List[String]] =
for {
listOfDates <- Gen.listOf(localDateArb.arbitrary)
localDate <- listOfDates
} yield localDate.format(dateTimeFormatter)

val TEST_OUTPUT_DIR = "target/test-output"
}