Skip to content

Commit

Permalink
TypedColumn#year and LocalDateTime generator (#228)
Browse files Browse the repository at this point in the history
* added year function
* changed Column#year type to Option[Int]
* cleanup
* Formatting NonAggregateFunctionsTests
* fix merge conflicts
* clean up after merge
* removed curly braces from `year` function
  • Loading branch information
Avasil authored and imarios committed Jan 31, 2018
1 parent 31b6dd5 commit 864fd3f
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package frameless
package functions

import org.apache.spark.sql.Column
import org.apache.spark.sql.{functions => untyped}
import org.apache.spark.sql.{Column, functions => untyped}

import scala.util.matching.Regex

Expand Down Expand Up @@ -318,4 +317,14 @@ trait NonAggregateFunctions {
*/
def upper[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] =
str.typed(untyped.upper(str.untyped))
}

/**
* Non-Aggregate function: Extracts the year as an integer from a given date/timestamp/string.
*
* Differs from `Column#year` by wrapping it's result into an `Option`.
*
* apache/spark
*/
def year[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] =
str.typed[Option[Int]](untyped.year(str.untyped))
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import java.io.File

import frameless.functions.nonAggregate._
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{Encoder, SaveMode, functions => untyped}
import org.scalacheck.Gen
import org.apache.spark.sql.{Encoder, Row, SaveMode, functions => untyped}
import org.scalacheck.Prop._
import org.scalacheck.{Gen, Prop}

class NonAggregateFunctionsTests extends TypedDatasetSuite {
val testTempFiles = "target/testoutput"
Expand All @@ -33,7 +33,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.map(_.getAs[B](0))
.collect().toList


val typedDS = TypedDataset.create(values)
val col = typedDS('a)
val res = typedDS
Expand Down Expand Up @@ -98,7 +97,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.map(DoubleBehaviourUtils.nanNullHandler)
.collect().toList


val typedDS = TypedDataset.create(values)
val res = typedDS
.select(acos(typedDS('a)))
Expand All @@ -119,7 +117,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[Double] _))
}


/*
* Currently not all Collection types play nice with the Encoders.
* This test needs to be readressed and Set readded to the Collection Typeclass once these issues are resolved.
Expand Down Expand Up @@ -174,7 +171,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.map(_.getAs[Boolean](0))
.collect().toList


val typedDS = TypedDataset.create(List(X1(values)))
val res = typedDS
.select(arrayContains(typedDS('a), contained))
Expand All @@ -185,7 +181,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
res ?= resCompare
}


check(
forAll(
Gen.listOfN(listLength, Gen.choose(0,100)),
Expand Down Expand Up @@ -286,7 +281,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
res ?= resCompare
}


check(forAll(prop[Int] _))
check(forAll(prop[Long] _))
check(forAll(prop[Short] _))
Expand Down Expand Up @@ -375,7 +369,6 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
(res ?= resCompare).&&(aggrTyped ?= aggrSpark)
}


check(forAll(prop[Int] _))
check(forAll(prop[Long] _))
check(forAll(prop[Short] _))
Expand Down Expand Up @@ -660,9 +653,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {

test("concat for TypedAggregate") {
val spark = session
import spark.implicits._

import frameless.functions.aggregate._
import spark.implicits._
val pairs = for {
y <- Gen.alphaStr
x <- Gen.nonEmptyListOf(X2(y, y))
Expand Down Expand Up @@ -708,9 +700,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {

test("concat_ws for TypedAggregate") {
val spark = session
import spark.implicits._

import frameless.functions.aggregate._
import spark.implicits._
val pairs = for {
y <- Gen.alphaStr
x <- Gen.listOfN(10, X2(y, y))
Expand Down Expand Up @@ -1037,4 +1028,35 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
// This fails due to issue #239
//check(forAll(prop[Option[Vector[Boolean]], Long] _))
}
}

test("year") {
val spark = session
import spark.implicits._

val nullHandler: Row => Option[Int] = _.get(0) match {
case i: Int => Some(i)
case _ => None
}

def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = {
val ds = TypedDataset.create(data)

val sparkResult = ds.toDF()
.select(untyped.year($"a"))
.map(nullHandler)
.collect()
.toList

val typed = ds
.select(year(ds[String]('a)))
.collect()
.run()
.toList

typed ?= sparkResult
}

check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply))))
check(forAll(prop _))
}
}
22 changes: 22 additions & 0 deletions dataset/src/test/scala/frameless/package.scala
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import java.time.format.DateTimeFormatter
import java.time.{LocalDateTime => JavaLocalDateTime}

import org.scalacheck.{Arbitrary, Gen}

package object frameless {
Expand Down Expand Up @@ -39,5 +42,24 @@ package object frameless {
} yield new UdtEncodedClass(int, doubles.toArray)
}

val dateTimeFormatter: DateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm")

implicit val localDateArb: Arbitrary[JavaLocalDateTime] = Arbitrary {
for {
year <- Gen.chooseNum(1900, 2027)
month <- Gen.chooseNum(1, 12)
dayOfMonth <- Gen.chooseNum(1, 28)
hour <- Gen.chooseNum(1, 23)
minute <- Gen.chooseNum(1, 59)
} yield JavaLocalDateTime.of(year, month, dayOfMonth, hour, minute)
}

/** LocalDateTime String Generator to test time related Spark functions */
val dateTimeStringGen: Gen[List[String]] =
for {
listOfDates <- Gen.listOf(localDateArb.arbitrary)
localDate <- listOfDates
} yield localDate.format(dateTimeFormatter)

val TEST_OUTPUT_DIR = "target/test-output"
}

0 comments on commit 864fd3f

Please sign in to comment.